diff --git a/scopes/storage/common.py b/scopes/storage/common.py index 9e6142e..cd8a70c 100644 --- a/scopes/storage/common.py +++ b/scopes/storage/common.py @@ -27,10 +27,33 @@ def commit(conn): IdType = Integer JsonType = JSON -# put something like this in code before first creating a Storage object -#engine = getEngine('postgresql+psycopg', 'testdb', 'testuser', 'secret') -#scopes.storage.common.engine = engine -#scopes.storage.common.Session = sessionFactory(engine) + +class StorageFactory(object): + + engine = Session = None + + sessionFactory = sessionFactory + getEngine = getEngine + mark_changed = mark_changed + commit = commit + IdType = IdType + JsonType = JsonType + + def __call__(self, schema=None): + st = Storage(schema=schema) + st.setup(self) + return st + + def setup(self, config): + self.engine = self.getEngine(config.dbengine, config.dbname, + config.dbuser, config.dbpassword) + self.Session = self.sessionFactory + + +# you may put something like this in your code: +#scopes.storage.common.factory = StorageFactory(config) +# and then call at appropriate places: +#storage = scopes.storage.common.factory(schema=...) class Storage(object): diff --git a/scopes/storage/db/postgres.py b/scopes/storage/db/postgres.py index beaa34a..c8a6e5d 100644 --- a/scopes/storage/db/postgres.py +++ b/scopes/storage/db/postgres.py @@ -24,10 +24,11 @@ def commit(conn): # patch `common` module import scopes.storage.common -scopes.storage.common.IdType = BigInteger -scopes.storage.common.JsonType = JSONB -scopes.storage.common.sessionFactory = sessionFactory -scopes.storage.common.getEngine = getEngine -scopes.storage.common.mark_changed = mark_changed -scopes.storage.common.commit = commit +def init(): + scopes.storage.common.IdType = BigInteger + scopes.storage.common.JsonType = JSONB + scopes.storage.common.sessionFactory = sessionFactory + scopes.storage.common.getEngine = getEngine + scopes.storage.common.mark_changed = mark_changed + scopes.storage.common.commit = commit diff --git a/tests/postgres.py b/tests/postgres.py index f7fe524..c6e261e 100644 --- a/tests/postgres.py +++ b/tests/postgres.py @@ -2,11 +2,8 @@ """Tests for the 'scopes.storage' package - using PostgreSQL.""" -from datetime import datetime import unittest -# PostgreSQL-specific settings -import scopes.storage.db.postgres import config config.dbengine = 'postgresql+psycopg' config.dbname = 'testdb' @@ -14,7 +11,14 @@ config.dbuser = 'testuser' config.dbpassword = 'secret' config.dbschema = 'testing' +# PostgreSQL-specific settings +from scopes.storage.db import postgres +postgres.init() + import tlib +tlib.init(config) +#factory = postgres.StorageFactory(config) +#storage = factory(schema='testing') class Test(unittest.TestCase): diff --git a/tests/standard.py b/tests/test_standard.py similarity index 74% rename from tests/standard.py rename to tests/test_standard.py index cffa9fd..cb8554a 100644 --- a/tests/standard.py +++ b/tests/test_standard.py @@ -2,10 +2,14 @@ """Tests for the 'scopes.storage' package.""" -from datetime import datetime import unittest +import config import tlib +tlib.init(config) +#from scopes.storage.common import StorageFactory +#factory = StorageFactory(config) +#storage = factory(schema=None) class Test(unittest.TestCase): diff --git a/tests/tlib.py b/tests/tlib.py index dac6db3..ebd1bcd 100644 --- a/tests/tlib.py +++ b/tests/tlib.py @@ -1,17 +1,17 @@ """The real test implementations""" -import config from datetime import datetime from scopes.storage import folder, tracking import scopes.storage.common from scopes.storage.common import commit, Storage, getEngine, sessionFactory -engine = getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword) -scopes.storage.common.engine = engine -scopes.storage.common.Session = sessionFactory(engine) - -storage = Storage(schema=config.dbschema) +def init(config): + global storage + engine = getEngine(config.dbengine, config.dbname, config.dbuser, config.dbpassword) + scopes.storage.common.engine = engine + scopes.storage.common.Session = sessionFactory(engine) + storage = Storage(schema=config.dbschema) def test_tracking(self):