use separate StorageFactory objects for different database / database types

This commit is contained in:
Helmut Merz 2024-03-09 08:57:49 +01:00
parent fbe8d99d74
commit e3bb5b2880
6 changed files with 70 additions and 88 deletions

View file

@ -9,58 +9,46 @@ from sqlalchemy.dialects.sqlite import JSON
import threading import threading
# predefined db-specific definitions, usable for SQLite; class StorageFactory(object):
# may be overriden by import of ``scopes.storage.db.<dbname>``
def sessionFactory(engine): def sessionFactory(self):
return engine.connect return self.engine.connect
@staticmethod
def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw):
return create_engine('%s:///%s' % (dbtype, dbname), **kw) return create_engine('%s:///%s' % (dbtype, dbname), **kw)
@staticmethod
def mark_changed(session): def mark_changed(session):
pass pass
@staticmethod
def commit(conn): def commit(conn):
conn.commit() conn.commit()
IdType = Integer IdType = Integer
JsonType = JSON JsonType = JSON
def __init__(self, config):
class StorageFactory(object):
engine = Session = None
sessionFactory = sessionFactory
getEngine = getEngine
mark_changed = mark_changed
commit = commit
IdType = IdType
JsonType = JsonType
def __call__(self, schema=None):
st = Storage(schema=schema)
st.setup(self)
return st
def setup(self, config):
self.engine = self.getEngine(config.dbengine, config.dbname, self.engine = self.getEngine(config.dbengine, config.dbname,
config.dbuser, config.dbpassword) config.dbuser, config.dbpassword)
self.Session = self.sessionFactory self.Session = self.sessionFactory()
def __call__(self, schema=None):
return Storage(self, schema=schema)
# you may put something like this in your code: # you may put something like this in your code:
#scopes.storage.common.factory = StorageFactory(config) #factory = StorageFactory(config)
# and then call at appropriate places: # and then call at appropriate places:
#storage = scopes.storage.common.factory(schema=...) #storage = scopes.storage.common.factory(schema=...)
class Storage(object): class Storage(object):
def __init__(self, schema=None): def __init__(self, db, schema=None):
self.engine = engine self.db = db
self.session = Session() self.engine = db.engine
self.session = db.Session()
self.schema = schema self.schema = schema
self.metadata = MetaData(schema=schema) self.metadata = MetaData(schema=schema)
self.containers = {} self.containers = {}

View file

@ -9,26 +9,29 @@ from sqlalchemy.orm import scoped_session, sessionmaker
import transaction import transaction
from zope.sqlalchemy import register, mark_changed from zope.sqlalchemy import register, mark_changed
from scopes.storage.common import StorageFactory
def sessionFactory(engine):
Session = scoped_session(sessionmaker(bind=engine, twophase=True)) class StorageFactory(StorageFactory):
def sessionFactory(self):
Session = scoped_session(sessionmaker(bind=self.engine, twophase=True))
register(Session) register(Session)
return Session return Session
@staticmethod
def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw): def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw):
return create_engine('%s://%s:%s@%s:%s/%s' % ( return create_engine('%s://%s:%s@%s:%s/%s' % (
dbtype, user, pw, host, port, dbname), **kw) dbtype, user, pw, host, port, dbname), **kw)
@staticmethod
def mark_changed(session):
return mark_changed(session)
@staticmethod
def commit(conn): def commit(conn):
transaction.commit() transaction.commit()
# patch `common` module IdType = BigInteger
import scopes.storage.common JsonType = JSONB
def init():
scopes.storage.common.IdType = BigInteger
scopes.storage.common.JsonType = JSONB
scopes.storage.common.sessionFactory = sessionFactory
scopes.storage.common.getEngine = getEngine
scopes.storage.common.mark_changed = mark_changed
scopes.storage.common.commit = commit

View file

@ -12,7 +12,6 @@ from sqlalchemy import Table, Column, Index
from sqlalchemy import DateTime, Text, func from sqlalchemy import DateTime, Text, func
from sqlalchemy import and_ from sqlalchemy import and_
from scopes.storage.common import commit, IdType, JsonType, mark_changed
from scopes.storage.common import registerContainerClass from scopes.storage.common import registerContainerClass
@ -77,6 +76,7 @@ class Container(object):
def __init__(self, storage): def __init__(self, storage):
self.storage = storage self.storage = storage
self.db = storage.db
self.session = storage.session self.session = storage.session
self.table = self.getTable() self.table = self.getTable()
@ -113,7 +113,7 @@ class Container(object):
values = self.setupValues(track, withTrackId) values = self.setupValues(track, withTrackId)
stmt = t.insert().values(**values).returning(t.c.trackid) stmt = t.insert().values(**values).returning(t.c.trackid)
trackId = self.session.execute(stmt).first()[0] trackId = self.session.execute(stmt).first()[0]
mark_changed(self.session) self.db.mark_changed(self.session)
return trackId return trackId
def update(self, track): def update(self, track):
@ -124,7 +124,7 @@ class Container(object):
stmt = t.update().values(**values).where(t.c.trackid == track.trackId) stmt = t.update().values(**values).where(t.c.trackid == track.trackId)
n = self.session.execute(stmt).rowcount n = self.session.execute(stmt).rowcount
if n > 0: if n > 0:
mark_changed(self.session) self.db.mark_changed(self.session)
return n return n
def upsert(self, track): def upsert(self, track):
@ -142,7 +142,7 @@ class Container(object):
stmt = self.table.delete().where(self.table.c.trackid == trackId) stmt = self.table.delete().where(self.table.c.trackid == trackId)
n = self.session.execute(stmt).rowcount n = self.session.execute(stmt).rowcount
if n > 0: if n > 0:
mark_changed(self.session) self.db.mark_changed(self.session)
return n return n
def makeTrack(self, r): def makeTrack(self, r):
@ -175,7 +175,7 @@ class Container(object):
def createTable(storage, tableName, headcols, indexes=None): def createTable(storage, tableName, headcols, indexes=None):
metadata = storage.metadata metadata = storage.metadata
cols = [Column('trackid', IdType, primary_key=True)] cols = [Column('trackid', storage.db.IdType, primary_key=True)]
idxs = [] idxs = []
for ix, f in enumerate(headcols): for ix, f in enumerate(headcols):
cols.append(Column(f.lower(), Text, nullable=False, server_default='')) cols.append(Column(f.lower(), Text, nullable=False, server_default=''))
@ -185,7 +185,7 @@ def createTable(storage, tableName, headcols, indexes=None):
indexName = 'idx_%s_%d' % (tableName, (ix + 1)) indexName = 'idx_%s_%d' % (tableName, (ix + 1))
idxs.append(Index(indexName, *idef)) idxs.append(Index(indexName, *idef))
idxs.append(Index('idx_%s_ts' % tableName, 'timestamp')) idxs.append(Index('idx_%s_ts' % tableName, 'timestamp'))
cols.append(Column('data', JsonType, nullable=False, server_default='{}')) cols.append(Column('data', storage.db.JsonType, nullable=False, server_default='{}'))
table = Table(tableName, metadata, *(cols+idxs), extend_existing=True) table = Table(tableName, metadata, *(cols+idxs), extend_existing=True)
metadata.create_all(storage.engine) metadata.create_all(storage.engine)
return table return table

View file

@ -12,21 +12,19 @@ config.dbpassword = 'secret'
config.dbschema = 'testing' config.dbschema = 'testing'
# PostgreSQL-specific settings # PostgreSQL-specific settings
from scopes.storage.db import postgres from scopes.storage.db.postgres import StorageFactory
postgres.init() factory = StorageFactory(config)
storage = factory(schema='testing')
import tlib import tlib
tlib.init(config)
#factory = postgres.StorageFactory(config)
#storage = factory(schema='testing')
class Test(unittest.TestCase): class Test(unittest.TestCase):
def test_001_tracking(self): def test_001_tracking(self):
tlib.test_tracking(self) tlib.test_tracking(self, storage)
def test_002_folder(self): def test_002_folder(self):
tlib.test_folder(self) tlib.test_folder(self, storage)
def suite(): def suite():
return unittest.TestSuite(( return unittest.TestSuite((

View file

@ -5,19 +5,22 @@
import unittest import unittest
import config import config
config.dbengine = 'sqlite'
config.dbname = 'var/test.db'
from scopes.storage.common import StorageFactory
factory = StorageFactory(config)
storage = factory(schema=None)
import tlib import tlib
tlib.init(config)
#from scopes.storage.common import StorageFactory
#factory = StorageFactory(config)
#storage = factory(schema=None)
class Test(unittest.TestCase): class Test(unittest.TestCase):
def test_001_tracking(self): def test_001_tracking(self):
tlib.test_tracking(self) tlib.test_tracking(self, storage)
def test_002_folder(self): def test_002_folder(self):
tlib.test_folder(self) tlib.test_folder(self, storage)
def suite(): def suite():
return unittest.TestSuite(( return unittest.TestSuite((

View file

@ -3,18 +3,8 @@
from datetime import datetime from datetime import datetime
from scopes.storage import folder, tracking from scopes.storage import folder, tracking
import scopes.storage.common
from scopes.storage.common import commit, Storage, getEngine, sessionFactory
def init(config): def test_tracking(self, storage):
global storage
engine = getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword)
scopes.storage.common.engine = engine
scopes.storage.common.Session = sessionFactory(engine)
storage = Storage(schema=config.dbschema)
def test_tracking(self):
storage.dropTable('tracks') storage.dropTable('tracks')
tracks = storage.create(tracking.Container) tracks = storage.create(tracking.Container)
@ -60,10 +50,10 @@ def test_tracking(self):
self.assertEqual(n, 1) self.assertEqual(n, 1)
self.assertEqual(tracks.get(31), None) self.assertEqual(tracks.get(31), None)
commit(storage.session) storage.db.commit(storage.session)
def test_folder(self): def test_folder(self, storage):
storage.dropTable('folders') storage.dropTable('folders')
root = folder.Root(storage) root = folder.Root(storage)
self.assertEqual(list(root.keys()), []) self.assertEqual(list(root.keys()), [])
@ -76,5 +66,5 @@ def test_folder(self):
self.assertEqual(ch1.parent, top.rid) self.assertEqual(ch1.parent, top.rid)
self.assertEqual(list(top.keys()), ['child1']) self.assertEqual(list(top.keys()), ['child1'])
commit(storage.session) storage.db.commit(storage.session)