use separate StorageFactory objects for different database / database types
This commit is contained in:
parent
fbe8d99d74
commit
e3bb5b2880
6 changed files with 70 additions and 88 deletions
|
@ -9,58 +9,46 @@ from sqlalchemy.dialects.sqlite import JSON
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|
||||||
# predefined db-specific definitions, usable for SQLite;
|
class StorageFactory(object):
|
||||||
# may be overriden by import of ``scopes.storage.db.<dbname>``
|
|
||||||
|
|
||||||
def sessionFactory(engine):
|
def sessionFactory(self):
|
||||||
return engine.connect
|
return self.engine.connect
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw):
|
def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw):
|
||||||
return create_engine('%s:///%s' % (dbtype, dbname), **kw)
|
return create_engine('%s:///%s' % (dbtype, dbname), **kw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def mark_changed(session):
|
def mark_changed(session):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def commit(conn):
|
def commit(conn):
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
IdType = Integer
|
IdType = Integer
|
||||||
JsonType = JSON
|
JsonType = JSON
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
class StorageFactory(object):
|
|
||||||
|
|
||||||
engine = Session = None
|
|
||||||
|
|
||||||
sessionFactory = sessionFactory
|
|
||||||
getEngine = getEngine
|
|
||||||
mark_changed = mark_changed
|
|
||||||
commit = commit
|
|
||||||
IdType = IdType
|
|
||||||
JsonType = JsonType
|
|
||||||
|
|
||||||
def __call__(self, schema=None):
|
|
||||||
st = Storage(schema=schema)
|
|
||||||
st.setup(self)
|
|
||||||
return st
|
|
||||||
|
|
||||||
def setup(self, config):
|
|
||||||
self.engine = self.getEngine(config.dbengine, config.dbname,
|
self.engine = self.getEngine(config.dbengine, config.dbname,
|
||||||
config.dbuser, config.dbpassword)
|
config.dbuser, config.dbpassword)
|
||||||
self.Session = self.sessionFactory
|
self.Session = self.sessionFactory()
|
||||||
|
|
||||||
|
def __call__(self, schema=None):
|
||||||
|
return Storage(self, schema=schema)
|
||||||
|
|
||||||
|
|
||||||
# you may put something like this in your code:
|
# you may put something like this in your code:
|
||||||
#scopes.storage.common.factory = StorageFactory(config)
|
#factory = StorageFactory(config)
|
||||||
# and then call at appropriate places:
|
# and then call at appropriate places:
|
||||||
#storage = scopes.storage.common.factory(schema=...)
|
#storage = scopes.storage.common.factory(schema=...)
|
||||||
|
|
||||||
|
|
||||||
class Storage(object):
|
class Storage(object):
|
||||||
|
|
||||||
def __init__(self, schema=None):
|
def __init__(self, db, schema=None):
|
||||||
self.engine = engine
|
self.db = db
|
||||||
self.session = Session()
|
self.engine = db.engine
|
||||||
|
self.session = db.Session()
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.metadata = MetaData(schema=schema)
|
self.metadata = MetaData(schema=schema)
|
||||||
self.containers = {}
|
self.containers = {}
|
||||||
|
|
|
@ -9,26 +9,29 @@ from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
import transaction
|
import transaction
|
||||||
from zope.sqlalchemy import register, mark_changed
|
from zope.sqlalchemy import register, mark_changed
|
||||||
|
|
||||||
|
from scopes.storage.common import StorageFactory
|
||||||
|
|
||||||
def sessionFactory(engine):
|
|
||||||
Session = scoped_session(sessionmaker(bind=engine, twophase=True))
|
class StorageFactory(StorageFactory):
|
||||||
|
|
||||||
|
def sessionFactory(self):
|
||||||
|
Session = scoped_session(sessionmaker(bind=self.engine, twophase=True))
|
||||||
register(Session)
|
register(Session)
|
||||||
return Session
|
return Session
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw):
|
def getEngine(dbtype, dbname, user, pw, host='localhost', port=5432, **kw):
|
||||||
return create_engine('%s://%s:%s@%s:%s/%s' % (
|
return create_engine('%s://%s:%s@%s:%s/%s' % (
|
||||||
dbtype, user, pw, host, port, dbname), **kw)
|
dbtype, user, pw, host, port, dbname), **kw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mark_changed(session):
|
||||||
|
return mark_changed(session)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def commit(conn):
|
def commit(conn):
|
||||||
transaction.commit()
|
transaction.commit()
|
||||||
|
|
||||||
# patch `common` module
|
IdType = BigInteger
|
||||||
import scopes.storage.common
|
JsonType = JSONB
|
||||||
def init():
|
|
||||||
scopes.storage.common.IdType = BigInteger
|
|
||||||
scopes.storage.common.JsonType = JSONB
|
|
||||||
scopes.storage.common.sessionFactory = sessionFactory
|
|
||||||
scopes.storage.common.getEngine = getEngine
|
|
||||||
scopes.storage.common.mark_changed = mark_changed
|
|
||||||
scopes.storage.common.commit = commit
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ from sqlalchemy import Table, Column, Index
|
||||||
from sqlalchemy import DateTime, Text, func
|
from sqlalchemy import DateTime, Text, func
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
|
|
||||||
from scopes.storage.common import commit, IdType, JsonType, mark_changed
|
|
||||||
from scopes.storage.common import registerContainerClass
|
from scopes.storage.common import registerContainerClass
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,6 +76,7 @@ class Container(object):
|
||||||
|
|
||||||
def __init__(self, storage):
|
def __init__(self, storage):
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
|
self.db = storage.db
|
||||||
self.session = storage.session
|
self.session = storage.session
|
||||||
self.table = self.getTable()
|
self.table = self.getTable()
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ class Container(object):
|
||||||
values = self.setupValues(track, withTrackId)
|
values = self.setupValues(track, withTrackId)
|
||||||
stmt = t.insert().values(**values).returning(t.c.trackid)
|
stmt = t.insert().values(**values).returning(t.c.trackid)
|
||||||
trackId = self.session.execute(stmt).first()[0]
|
trackId = self.session.execute(stmt).first()[0]
|
||||||
mark_changed(self.session)
|
self.db.mark_changed(self.session)
|
||||||
return trackId
|
return trackId
|
||||||
|
|
||||||
def update(self, track):
|
def update(self, track):
|
||||||
|
@ -124,7 +124,7 @@ class Container(object):
|
||||||
stmt = t.update().values(**values).where(t.c.trackid == track.trackId)
|
stmt = t.update().values(**values).where(t.c.trackid == track.trackId)
|
||||||
n = self.session.execute(stmt).rowcount
|
n = self.session.execute(stmt).rowcount
|
||||||
if n > 0:
|
if n > 0:
|
||||||
mark_changed(self.session)
|
self.db.mark_changed(self.session)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def upsert(self, track):
|
def upsert(self, track):
|
||||||
|
@ -142,7 +142,7 @@ class Container(object):
|
||||||
stmt = self.table.delete().where(self.table.c.trackid == trackId)
|
stmt = self.table.delete().where(self.table.c.trackid == trackId)
|
||||||
n = self.session.execute(stmt).rowcount
|
n = self.session.execute(stmt).rowcount
|
||||||
if n > 0:
|
if n > 0:
|
||||||
mark_changed(self.session)
|
self.db.mark_changed(self.session)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def makeTrack(self, r):
|
def makeTrack(self, r):
|
||||||
|
@ -175,7 +175,7 @@ class Container(object):
|
||||||
|
|
||||||
def createTable(storage, tableName, headcols, indexes=None):
|
def createTable(storage, tableName, headcols, indexes=None):
|
||||||
metadata = storage.metadata
|
metadata = storage.metadata
|
||||||
cols = [Column('trackid', IdType, primary_key=True)]
|
cols = [Column('trackid', storage.db.IdType, primary_key=True)]
|
||||||
idxs = []
|
idxs = []
|
||||||
for ix, f in enumerate(headcols):
|
for ix, f in enumerate(headcols):
|
||||||
cols.append(Column(f.lower(), Text, nullable=False, server_default=''))
|
cols.append(Column(f.lower(), Text, nullable=False, server_default=''))
|
||||||
|
@ -185,7 +185,7 @@ def createTable(storage, tableName, headcols, indexes=None):
|
||||||
indexName = 'idx_%s_%d' % (tableName, (ix + 1))
|
indexName = 'idx_%s_%d' % (tableName, (ix + 1))
|
||||||
idxs.append(Index(indexName, *idef))
|
idxs.append(Index(indexName, *idef))
|
||||||
idxs.append(Index('idx_%s_ts' % tableName, 'timestamp'))
|
idxs.append(Index('idx_%s_ts' % tableName, 'timestamp'))
|
||||||
cols.append(Column('data', JsonType, nullable=False, server_default='{}'))
|
cols.append(Column('data', storage.db.JsonType, nullable=False, server_default='{}'))
|
||||||
table = Table(tableName, metadata, *(cols+idxs), extend_existing=True)
|
table = Table(tableName, metadata, *(cols+idxs), extend_existing=True)
|
||||||
metadata.create_all(storage.engine)
|
metadata.create_all(storage.engine)
|
||||||
return table
|
return table
|
||||||
|
|
|
@ -12,21 +12,19 @@ config.dbpassword = 'secret'
|
||||||
config.dbschema = 'testing'
|
config.dbschema = 'testing'
|
||||||
|
|
||||||
# PostgreSQL-specific settings
|
# PostgreSQL-specific settings
|
||||||
from scopes.storage.db import postgres
|
from scopes.storage.db.postgres import StorageFactory
|
||||||
postgres.init()
|
factory = StorageFactory(config)
|
||||||
|
storage = factory(schema='testing')
|
||||||
|
|
||||||
import tlib
|
import tlib
|
||||||
tlib.init(config)
|
|
||||||
#factory = postgres.StorageFactory(config)
|
|
||||||
#storage = factory(schema='testing')
|
|
||||||
|
|
||||||
class Test(unittest.TestCase):
|
class Test(unittest.TestCase):
|
||||||
|
|
||||||
def test_001_tracking(self):
|
def test_001_tracking(self):
|
||||||
tlib.test_tracking(self)
|
tlib.test_tracking(self, storage)
|
||||||
|
|
||||||
def test_002_folder(self):
|
def test_002_folder(self):
|
||||||
tlib.test_folder(self)
|
tlib.test_folder(self, storage)
|
||||||
|
|
||||||
def suite():
|
def suite():
|
||||||
return unittest.TestSuite((
|
return unittest.TestSuite((
|
||||||
|
|
|
@ -5,19 +5,22 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import config
|
import config
|
||||||
|
config.dbengine = 'sqlite'
|
||||||
|
config.dbname = 'var/test.db'
|
||||||
|
|
||||||
|
from scopes.storage.common import StorageFactory
|
||||||
|
factory = StorageFactory(config)
|
||||||
|
storage = factory(schema=None)
|
||||||
|
|
||||||
import tlib
|
import tlib
|
||||||
tlib.init(config)
|
|
||||||
#from scopes.storage.common import StorageFactory
|
|
||||||
#factory = StorageFactory(config)
|
|
||||||
#storage = factory(schema=None)
|
|
||||||
|
|
||||||
class Test(unittest.TestCase):
|
class Test(unittest.TestCase):
|
||||||
|
|
||||||
def test_001_tracking(self):
|
def test_001_tracking(self):
|
||||||
tlib.test_tracking(self)
|
tlib.test_tracking(self, storage)
|
||||||
|
|
||||||
def test_002_folder(self):
|
def test_002_folder(self):
|
||||||
tlib.test_folder(self)
|
tlib.test_folder(self, storage)
|
||||||
|
|
||||||
def suite():
|
def suite():
|
||||||
return unittest.TestSuite((
|
return unittest.TestSuite((
|
||||||
|
|
|
@ -3,18 +3,8 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from scopes.storage import folder, tracking
|
from scopes.storage import folder, tracking
|
||||||
|
|
||||||
import scopes.storage.common
|
|
||||||
from scopes.storage.common import commit, Storage, getEngine, sessionFactory
|
|
||||||
|
|
||||||
def init(config):
|
def test_tracking(self, storage):
|
||||||
global storage
|
|
||||||
engine = getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword)
|
|
||||||
scopes.storage.common.engine = engine
|
|
||||||
scopes.storage.common.Session = sessionFactory(engine)
|
|
||||||
storage = Storage(schema=config.dbschema)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tracking(self):
|
|
||||||
storage.dropTable('tracks')
|
storage.dropTable('tracks')
|
||||||
tracks = storage.create(tracking.Container)
|
tracks = storage.create(tracking.Container)
|
||||||
|
|
||||||
|
@ -60,10 +50,10 @@ def test_tracking(self):
|
||||||
self.assertEqual(n, 1)
|
self.assertEqual(n, 1)
|
||||||
self.assertEqual(tracks.get(31), None)
|
self.assertEqual(tracks.get(31), None)
|
||||||
|
|
||||||
commit(storage.session)
|
storage.db.commit(storage.session)
|
||||||
|
|
||||||
|
|
||||||
def test_folder(self):
|
def test_folder(self, storage):
|
||||||
storage.dropTable('folders')
|
storage.dropTable('folders')
|
||||||
root = folder.Root(storage)
|
root = folder.Root(storage)
|
||||||
self.assertEqual(list(root.keys()), [])
|
self.assertEqual(list(root.keys()), [])
|
||||||
|
@ -76,5 +66,5 @@ def test_folder(self):
|
||||||
self.assertEqual(ch1.parent, top.rid)
|
self.assertEqual(ch1.parent, top.rid)
|
||||||
self.assertEqual(list(top.keys()), ['child1'])
|
self.assertEqual(list(top.keys()), ['child1'])
|
||||||
|
|
||||||
commit(storage.session)
|
storage.db.commit(storage.session)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue