use separate StorageFactory objects for different database / database types
This commit is contained in:
		
							parent
							
								
									fbe8d99d74
								
							
						
					
					
						commit
						e3bb5b2880
					
				
					 6 changed files with 70 additions and 88 deletions
				
			
		|  | @ -9,58 +9,46 @@ from sqlalchemy.dialects.sqlite import JSON | ||||||
| import threading | import threading | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # predefined db-specific definitions, usable for SQLite; |  | ||||||
| # may be overriden by import of ``scopes.storage.db.<dbname>`` |  | ||||||
| 
 |  | ||||||
| def sessionFactory(engine): |  | ||||||
|     return engine.connect |  | ||||||
| 
 |  | ||||||
| def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): |  | ||||||
|     return create_engine('%s:///%s' % (dbtype, dbname), **kw) |  | ||||||
| 
 |  | ||||||
| def mark_changed(session): |  | ||||||
|     pass |  | ||||||
| 
 |  | ||||||
| def commit(conn): |  | ||||||
|     conn.commit() |  | ||||||
| 
 |  | ||||||
| IdType = Integer |  | ||||||
| JsonType = JSON |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class StorageFactory(object): | class StorageFactory(object): | ||||||
| 
 | 
 | ||||||
|     engine = Session = None |     def sessionFactory(self): | ||||||
|  |          return self.engine.connect | ||||||
| 
 | 
 | ||||||
|     sessionFactory = sessionFactory |     @staticmethod | ||||||
|     getEngine = getEngine |     def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): | ||||||
|     mark_changed = mark_changed |         return create_engine('%s:///%s' % (dbtype, dbname), **kw) | ||||||
|     commit = commit |  | ||||||
|     IdType = IdType |  | ||||||
|     JsonType = JsonType |  | ||||||
| 
 | 
 | ||||||
|     def __call__(self, schema=None): |     @staticmethod | ||||||
|         st = Storage(schema=schema) |     def mark_changed(session): | ||||||
|         st.setup(self) |         pass | ||||||
|         return st |  | ||||||
| 
 | 
 | ||||||
|     def setup(self, config): |     @staticmethod | ||||||
|  |     def commit(conn): | ||||||
|  |         conn.commit() | ||||||
|  | 
 | ||||||
|  |     IdType = Integer | ||||||
|  |     JsonType = JSON | ||||||
|  | 
 | ||||||
|  |     def __init__(self, config): | ||||||
|         self.engine = self.getEngine(config.dbengine, config.dbname,  |         self.engine = self.getEngine(config.dbengine, config.dbname,  | ||||||
|                                      config.dbuser, config.dbpassword)  |                                      config.dbuser, config.dbpassword)  | ||||||
|         self.Session = self.sessionFactory |         self.Session = self.sessionFactory() | ||||||
|  | 
 | ||||||
|  |     def __call__(self, schema=None): | ||||||
|  |         return Storage(self, schema=schema) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # you may put something like this in your code: | # you may put something like this in your code: | ||||||
| #scopes.storage.common.factory = StorageFactory(config) | #factory = StorageFactory(config) | ||||||
| # and then call at appropriate places: | # and then call at appropriate places: | ||||||
| #storage = scopes.storage.common.factory(schema=...) | #storage = scopes.storage.common.factory(schema=...) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| class Storage(object): | class Storage(object): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, schema=None): |     def __init__(self, db, schema=None): | ||||||
|         self.engine = engine |         self.db = db | ||||||
|         self.session = Session() |         self.engine = db.engine | ||||||
|  |         self.session = db.Session() | ||||||
|         self.schema = schema |         self.schema = schema | ||||||
|         self.metadata = MetaData(schema=schema) |         self.metadata = MetaData(schema=schema) | ||||||
|         self.containers = {} |         self.containers = {} | ||||||
|  |  | ||||||
|  | @ -9,26 +9,29 @@ from sqlalchemy.orm import scoped_session, sessionmaker | ||||||
| import transaction | import transaction | ||||||
| from zope.sqlalchemy import register, mark_changed | from zope.sqlalchemy import register, mark_changed | ||||||
| 
 | 
 | ||||||
|  | from scopes.storage.common import StorageFactory | ||||||
| 
 | 
 | ||||||
| def sessionFactory(engine): | 
 | ||||||
|     Session = scoped_session(sessionmaker(bind=engine, twophase=True)) | class StorageFactory(StorageFactory): | ||||||
|  | 
 | ||||||
|  |     def sessionFactory(self): | ||||||
|  |         Session = scoped_session(sessionmaker(bind=self.engine, twophase=True)) | ||||||
|         register(Session) |         register(Session) | ||||||
|         return Session |         return Session | ||||||
| 
 | 
 | ||||||
| def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): |     @staticmethod | ||||||
|  |     def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): | ||||||
|         return create_engine('%s://%s:%s@%s:%s/%s' % ( |         return create_engine('%s://%s:%s@%s:%s/%s' % ( | ||||||
|             dbtype, user, pw, host, port, dbname), **kw) |             dbtype, user, pw, host, port, dbname), **kw) | ||||||
| 
 | 
 | ||||||
| def commit(conn): |     @staticmethod | ||||||
|  |     def mark_changed(session): | ||||||
|  |         return mark_changed(session) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def commit(conn): | ||||||
|         transaction.commit() |         transaction.commit() | ||||||
| 
 | 
 | ||||||
| # patch `common` module |     IdType = BigInteger | ||||||
| import scopes.storage.common |     JsonType = JSONB | ||||||
| def init(): |  | ||||||
|     scopes.storage.common.IdType = BigInteger |  | ||||||
|     scopes.storage.common.JsonType = JSONB |  | ||||||
|     scopes.storage.common.sessionFactory = sessionFactory |  | ||||||
|     scopes.storage.common.getEngine = getEngine |  | ||||||
|     scopes.storage.common.mark_changed = mark_changed |  | ||||||
|     scopes.storage.common.commit = commit |  | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,7 +12,6 @@ from sqlalchemy import Table, Column, Index | ||||||
| from sqlalchemy import DateTime, Text, func | from sqlalchemy import DateTime, Text, func | ||||||
| from sqlalchemy import and_ | from sqlalchemy import and_ | ||||||
| 
 | 
 | ||||||
| from scopes.storage.common import commit, IdType, JsonType, mark_changed |  | ||||||
| from scopes.storage.common import registerContainerClass | from scopes.storage.common import registerContainerClass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -77,6 +76,7 @@ class Container(object): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, storage): |     def __init__(self, storage): | ||||||
|         self.storage = storage |         self.storage = storage | ||||||
|  |         self.db = storage.db | ||||||
|         self.session = storage.session |         self.session = storage.session | ||||||
|         self.table = self.getTable() |         self.table = self.getTable() | ||||||
| 
 | 
 | ||||||
|  | @ -113,7 +113,7 @@ class Container(object): | ||||||
|         values = self.setupValues(track, withTrackId) |         values = self.setupValues(track, withTrackId) | ||||||
|         stmt = t.insert().values(**values).returning(t.c.trackid) |         stmt = t.insert().values(**values).returning(t.c.trackid) | ||||||
|         trackId = self.session.execute(stmt).first()[0] |         trackId = self.session.execute(stmt).first()[0] | ||||||
|         mark_changed(self.session) |         self.db.mark_changed(self.session) | ||||||
|         return trackId |         return trackId | ||||||
| 
 | 
 | ||||||
|     def update(self, track): |     def update(self, track): | ||||||
|  | @ -124,7 +124,7 @@ class Container(object): | ||||||
|         stmt = t.update().values(**values).where(t.c.trackid == track.trackId) |         stmt = t.update().values(**values).where(t.c.trackid == track.trackId) | ||||||
|         n = self.session.execute(stmt).rowcount |         n = self.session.execute(stmt).rowcount | ||||||
|         if n > 0: |         if n > 0: | ||||||
|             mark_changed(self.session) |             self.db.mark_changed(self.session) | ||||||
|         return n |         return n | ||||||
| 
 | 
 | ||||||
|     def upsert(self, track): |     def upsert(self, track): | ||||||
|  | @ -142,7 +142,7 @@ class Container(object): | ||||||
|         stmt = self.table.delete().where(self.table.c.trackid == trackId) |         stmt = self.table.delete().where(self.table.c.trackid == trackId) | ||||||
|         n = self.session.execute(stmt).rowcount |         n = self.session.execute(stmt).rowcount | ||||||
|         if n > 0: |         if n > 0: | ||||||
|             mark_changed(self.session) |             self.db.mark_changed(self.session) | ||||||
|         return n |         return n | ||||||
| 
 | 
 | ||||||
|     def makeTrack(self, r): |     def makeTrack(self, r): | ||||||
|  | @ -175,7 +175,7 @@ class Container(object): | ||||||
| 
 | 
 | ||||||
| def createTable(storage, tableName, headcols, indexes=None): | def createTable(storage, tableName, headcols, indexes=None): | ||||||
|     metadata = storage.metadata |     metadata = storage.metadata | ||||||
|     cols = [Column('trackid', IdType, primary_key=True)] |     cols = [Column('trackid', storage.db.IdType, primary_key=True)] | ||||||
|     idxs = [] |     idxs = [] | ||||||
|     for ix, f in enumerate(headcols): |     for ix, f in enumerate(headcols): | ||||||
|         cols.append(Column(f.lower(), Text, nullable=False, server_default='')) |         cols.append(Column(f.lower(), Text, nullable=False, server_default='')) | ||||||
|  | @ -185,7 +185,7 @@ def createTable(storage, tableName, headcols, indexes=None): | ||||||
|         indexName = 'idx_%s_%d' % (tableName, (ix + 1)) |         indexName = 'idx_%s_%d' % (tableName, (ix + 1)) | ||||||
|         idxs.append(Index(indexName, *idef)) |         idxs.append(Index(indexName, *idef)) | ||||||
|     idxs.append(Index('idx_%s_ts' % tableName, 'timestamp')) |     idxs.append(Index('idx_%s_ts' % tableName, 'timestamp')) | ||||||
|     cols.append(Column('data', JsonType, nullable=False, server_default='{}')) |     cols.append(Column('data', storage.db.JsonType, nullable=False, server_default='{}')) | ||||||
|     table = Table(tableName, metadata, *(cols+idxs), extend_existing=True) |     table = Table(tableName, metadata, *(cols+idxs), extend_existing=True) | ||||||
|     metadata.create_all(storage.engine) |     metadata.create_all(storage.engine) | ||||||
|     return table |     return table | ||||||
|  |  | ||||||
|  | @ -12,21 +12,19 @@ config.dbpassword = 'secret' | ||||||
| config.dbschema = 'testing' | config.dbschema = 'testing' | ||||||
| 
 | 
 | ||||||
| # PostgreSQL-specific settings | # PostgreSQL-specific settings | ||||||
| from scopes.storage.db import postgres  | from scopes.storage.db.postgres import StorageFactory  | ||||||
| postgres.init() | factory = StorageFactory(config) | ||||||
|  | storage = factory(schema='testing') | ||||||
| 
 | 
 | ||||||
| import tlib | import tlib | ||||||
| tlib.init(config) |  | ||||||
| #factory = postgres.StorageFactory(config) |  | ||||||
| #storage = factory(schema='testing') |  | ||||||
| 
 | 
 | ||||||
| class Test(unittest.TestCase): | class Test(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|     def test_001_tracking(self): |     def test_001_tracking(self): | ||||||
|         tlib.test_tracking(self) |         tlib.test_tracking(self, storage) | ||||||
| 
 | 
 | ||||||
|     def test_002_folder(self): |     def test_002_folder(self): | ||||||
|         tlib.test_folder(self) |         tlib.test_folder(self, storage) | ||||||
| 
 | 
 | ||||||
| def suite(): | def suite(): | ||||||
|     return unittest.TestSuite(( |     return unittest.TestSuite(( | ||||||
|  |  | ||||||
|  | @ -5,19 +5,22 @@ | ||||||
| import unittest | import unittest | ||||||
| 
 | 
 | ||||||
| import config | import config | ||||||
|  | config.dbengine = 'sqlite' | ||||||
|  | config.dbname = 'var/test.db' | ||||||
|  | 
 | ||||||
|  | from scopes.storage.common import StorageFactory | ||||||
|  | factory = StorageFactory(config) | ||||||
|  | storage = factory(schema=None) | ||||||
|  | 
 | ||||||
| import tlib | import tlib | ||||||
| tlib.init(config) |  | ||||||
| #from scopes.storage.common import StorageFactory |  | ||||||
| #factory = StorageFactory(config) |  | ||||||
| #storage = factory(schema=None) |  | ||||||
| 
 | 
 | ||||||
| class Test(unittest.TestCase): | class Test(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|     def test_001_tracking(self): |     def test_001_tracking(self): | ||||||
|         tlib.test_tracking(self) |         tlib.test_tracking(self, storage) | ||||||
| 
 | 
 | ||||||
|     def test_002_folder(self): |     def test_002_folder(self): | ||||||
|         tlib.test_folder(self) |         tlib.test_folder(self, storage) | ||||||
| 
 | 
 | ||||||
| def suite(): | def suite(): | ||||||
|     return unittest.TestSuite(( |     return unittest.TestSuite(( | ||||||
|  |  | ||||||
|  | @ -3,18 +3,8 @@ | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from scopes.storage import folder, tracking | from scopes.storage import folder, tracking | ||||||
| 
 | 
 | ||||||
| import scopes.storage.common |  | ||||||
| from scopes.storage.common import commit, Storage, getEngine, sessionFactory |  | ||||||
| 
 | 
 | ||||||
| def init(config): | def test_tracking(self, storage): | ||||||
|     global storage |  | ||||||
|     engine = getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword)  |  | ||||||
|     scopes.storage.common.engine = engine |  | ||||||
|     scopes.storage.common.Session = sessionFactory(engine) |  | ||||||
|     storage = Storage(schema=config.dbschema) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_tracking(self): |  | ||||||
|         storage.dropTable('tracks') |         storage.dropTable('tracks') | ||||||
|         tracks = storage.create(tracking.Container) |         tracks = storage.create(tracking.Container) | ||||||
| 
 | 
 | ||||||
|  | @ -60,10 +50,10 @@ def test_tracking(self): | ||||||
|         self.assertEqual(n, 1) |         self.assertEqual(n, 1) | ||||||
|         self.assertEqual(tracks.get(31), None) |         self.assertEqual(tracks.get(31), None) | ||||||
| 
 | 
 | ||||||
|         commit(storage.session) |         storage.db.commit(storage.session) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_folder(self): | def test_folder(self, storage): | ||||||
|         storage.dropTable('folders') |         storage.dropTable('folders') | ||||||
|         root = folder.Root(storage) |         root = folder.Root(storage) | ||||||
|         self.assertEqual(list(root.keys()), []) |         self.assertEqual(list(root.keys()), []) | ||||||
|  | @ -76,5 +66,5 @@ def test_folder(self): | ||||||
|         self.assertEqual(ch1.parent, top.rid) |         self.assertEqual(ch1.parent, top.rid) | ||||||
|         self.assertEqual(list(top.keys()), ['child1']) |         self.assertEqual(list(top.keys()), ['child1']) | ||||||
| 
 | 
 | ||||||
|         commit(storage.session) |         storage.db.commit(storage.session) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue