From e3bb5b288016b761883030c1eeb1eb6323168e6b Mon Sep 17 00:00:00 2001 From: Helmut Merz Date: Sat, 9 Mar 2024 08:57:49 +0100 Subject: [PATCH] use separate StorageFactory objects for different database / database types --- scopes/storage/common.py | 62 ++++++++++++++--------------------- scopes/storage/db/postgres.py | 39 ++++++++++++---------- scopes/storage/tracking.py | 12 +++---- tests/postgres.py | 12 +++---- tests/test_standard.py | 15 +++++---- tests/tlib.py | 18 +++------- 6 files changed, 70 insertions(+), 88 deletions(-) diff --git a/scopes/storage/common.py b/scopes/storage/common.py index cd8a70c..e26e4b1 100644 --- a/scopes/storage/common.py +++ b/scopes/storage/common.py @@ -9,58 +9,46 @@ from sqlalchemy.dialects.sqlite import JSON import threading -# predefined db-specific definitions, usable for SQLite; -# may be overriden by import of ``scopes.storage.db.`` - -def sessionFactory(engine): - return engine.connect - -def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): - return create_engine('%s:///%s' % (dbtype, dbname), **kw) - -def mark_changed(session): - pass - -def commit(conn): - conn.commit() - -IdType = Integer -JsonType = JSON - - class StorageFactory(object): - engine = Session = None + def sessionFactory(self): + return self.engine.connect - sessionFactory = sessionFactory - getEngine = getEngine - mark_changed = mark_changed - commit = commit - IdType = IdType - JsonType = JsonType + @staticmethod + def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): + return create_engine('%s:///%s' % (dbtype, dbname), **kw) - def __call__(self, schema=None): - st = Storage(schema=schema) - st.setup(self) - return st + @staticmethod + def mark_changed(session): + pass - def setup(self, config): + @staticmethod + def commit(conn): + conn.commit() + + IdType = Integer + JsonType = JSON + + def __init__(self, config): self.engine = self.getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword) - self.Session = self.sessionFactory + self.Session = self.sessionFactory() + + def __call__(self, schema=None): + return Storage(self, schema=schema) # you may put something like this in your code: -#scopes.storage.common.factory = StorageFactory(config) +#factory = StorageFactory(config) # and then call at appropriate places: #storage = scopes.storage.common.factory(schema=...) - class Storage(object): - def __init__(self, schema=None): - self.engine = engine - self.session = Session() + def __init__(self, db, schema=None): + self.db = db + self.engine = db.engine + self.session = db.Session() self.schema = schema self.metadata = MetaData(schema=schema) self.containers = {} diff --git a/scopes/storage/db/postgres.py b/scopes/storage/db/postgres.py index c8a6e5d..ccc835a 100644 --- a/scopes/storage/db/postgres.py +++ b/scopes/storage/db/postgres.py @@ -9,26 +9,29 @@ from sqlalchemy.orm import scoped_session, sessionmaker import transaction from zope.sqlalchemy import register, mark_changed +from scopes.storage.common import StorageFactory -def sessionFactory(engine): - Session = scoped_session(sessionmaker(bind=engine, twophase=True)) - register(Session) - return Session -def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): - return create_engine('%s://%s:%s@%s:%s/%s' % ( - dbtype, user, pw, host, port, dbname), **kw) +class StorageFactory(StorageFactory): -def commit(conn): - transaction.commit() + def sessionFactory(self): + Session = scoped_session(sessionmaker(bind=self.engine, twophase=True)) + register(Session) + return Session -# patch `common` module -import scopes.storage.common -def init(): - scopes.storage.common.IdType = BigInteger - scopes.storage.common.JsonType = JSONB - scopes.storage.common.sessionFactory = sessionFactory - scopes.storage.common.getEngine = getEngine - scopes.storage.common.mark_changed = mark_changed - scopes.storage.common.commit = commit + @staticmethod + def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): + return create_engine('%s://%s:%s@%s:%s/%s' % ( + dbtype, user, pw, host, port, dbname), **kw) + + @staticmethod + def mark_changed(session): + return mark_changed(session) + + @staticmethod + def commit(conn): + transaction.commit() + + IdType = BigInteger + JsonType = JSONB diff --git a/scopes/storage/tracking.py b/scopes/storage/tracking.py index 76b7cf5..8e07aa3 100644 --- a/scopes/storage/tracking.py +++ b/scopes/storage/tracking.py @@ -12,7 +12,6 @@ from sqlalchemy import Table, Column, Index from sqlalchemy import DateTime, Text, func from sqlalchemy import and_ -from scopes.storage.common import commit, IdType, JsonType, mark_changed from scopes.storage.common import registerContainerClass @@ -77,6 +76,7 @@ class Container(object): def __init__(self, storage): self.storage = storage + self.db = storage.db self.session = storage.session self.table = self.getTable() @@ -113,7 +113,7 @@ class Container(object): values = self.setupValues(track, withTrackId) stmt = t.insert().values(**values).returning(t.c.trackid) trackId = self.session.execute(stmt).first()[0] - mark_changed(self.session) + self.db.mark_changed(self.session) return trackId def update(self, track): @@ -124,7 +124,7 @@ class Container(object): stmt = t.update().values(**values).where(t.c.trackid == track.trackId) n = self.session.execute(stmt).rowcount if n > 0: - mark_changed(self.session) + self.db.mark_changed(self.session) return n def upsert(self, track): @@ -142,7 +142,7 @@ class Container(object): stmt = self.table.delete().where(self.table.c.trackid == trackId) n = self.session.execute(stmt).rowcount if n > 0: - mark_changed(self.session) + self.db.mark_changed(self.session) return n def makeTrack(self, r): @@ -175,7 +175,7 @@ class Container(object): def createTable(storage, tableName, headcols, indexes=None): metadata = storage.metadata - cols = [Column('trackid', IdType, primary_key=True)] + cols = [Column('trackid', storage.db.IdType, primary_key=True)] idxs = [] for ix, f in enumerate(headcols): cols.append(Column(f.lower(), Text, nullable=False, server_default='')) @@ -185,7 +185,7 @@ def createTable(storage, tableName, headcols, indexes=None): indexName = 'idx_%s_%d' % (tableName, (ix + 1)) idxs.append(Index(indexName, *idef)) idxs.append(Index('idx_%s_ts' % tableName, 'timestamp')) - cols.append(Column('data', JsonType, nullable=False, server_default='{}')) + cols.append(Column('data', storage.db.JsonType, nullable=False, server_default='{}')) table = Table(tableName, metadata, *(cols+idxs), extend_existing=True) metadata.create_all(storage.engine) return table diff --git a/tests/postgres.py b/tests/postgres.py index c6e261e..f3a6309 100644 --- a/tests/postgres.py +++ b/tests/postgres.py @@ -12,21 +12,19 @@ config.dbpassword = 'secret' config.dbschema = 'testing' # PostgreSQL-specific settings -from scopes.storage.db import postgres -postgres.init() +from scopes.storage.db.postgres import StorageFactory +factory = StorageFactory(config) +storage = factory(schema='testing') import tlib -tlib.init(config) -#factory = postgres.StorageFactory(config) -#storage = factory(schema='testing') class Test(unittest.TestCase): def test_001_tracking(self): - tlib.test_tracking(self) + tlib.test_tracking(self, storage) def test_002_folder(self): - tlib.test_folder(self) + tlib.test_folder(self, storage) def suite(): return unittest.TestSuite(( diff --git a/tests/test_standard.py b/tests/test_standard.py index cb8554a..6ff4a7e 100644 --- a/tests/test_standard.py +++ b/tests/test_standard.py @@ -5,19 +5,22 @@ import unittest import config +config.dbengine = 'sqlite' +config.dbname = 'var/test.db' + +from scopes.storage.common import StorageFactory +factory = StorageFactory(config) +storage = factory(schema=None) + import tlib -tlib.init(config) -#from scopes.storage.common import StorageFactory -#factory = StorageFactory(config) -#storage = factory(schema=None) class Test(unittest.TestCase): def test_001_tracking(self): - tlib.test_tracking(self) + tlib.test_tracking(self, storage) def test_002_folder(self): - tlib.test_folder(self) + tlib.test_folder(self, storage) def suite(): return unittest.TestSuite(( diff --git a/tests/tlib.py b/tests/tlib.py index ebd1bcd..692e070 100644 --- a/tests/tlib.py +++ b/tests/tlib.py @@ -3,18 +3,8 @@ from datetime import datetime from scopes.storage import folder, tracking -import scopes.storage.common -from scopes.storage.common import commit, Storage, getEngine, sessionFactory -def init(config): - global storage - engine = getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword) - scopes.storage.common.engine = engine - scopes.storage.common.Session = sessionFactory(engine) - storage = Storage(schema=config.dbschema) - - -def test_tracking(self): +def test_tracking(self, storage): storage.dropTable('tracks') tracks = storage.create(tracking.Container) @@ -60,10 +50,10 @@ def test_tracking(self): self.assertEqual(n, 1) self.assertEqual(tracks.get(31), None) - commit(storage.session) + storage.db.commit(storage.session) -def test_folder(self): +def test_folder(self, storage): storage.dropTable('folders') root = folder.Root(storage) self.assertEqual(list(root.keys()), []) @@ -76,5 +66,5 @@ def test_folder(self): self.assertEqual(ch1.parent, top.rid) self.assertEqual(list(top.keys()), ['child1']) - commit(storage.session) + storage.db.commit(storage.session)