diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/testing/config.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/plugin_base.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/provision.py | 166 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 24 |
4 files changed, 181 insertions, 42 deletions
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 84344eb31..b24483bb7 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -12,7 +12,8 @@ db = None db_url = None db_opts = None file_config = None - +test_schema = None +test_schema_2 = None _current = None @@ -22,12 +23,14 @@ class Config(object): self.db_opts = db_opts self.options = options self.file_config = file_config + self.test_schema = "test_schema" + self.test_schema_2 = "test_schema_2" _stack = collections.deque() _configs = {} @classmethod - def register(cls, db, db_opts, options, file_config, namespace): + def register(cls, db, db_opts, options, file_config): """add a config as one of the global configs. If there are no configs set up yet, this config also @@ -35,18 +38,18 @@ class Config(object): """ cfg = Config(db, db_opts, options, file_config) - global _current - if not _current: - cls.set_as_current(cfg, namespace) cls._configs[cfg.db.name] = cfg cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg cls._configs[cfg.db] = cfg + return cfg @classmethod def set_as_current(cls, config, namespace): - global db, _current, db_url + global db, _current, db_url, test_schema, test_schema_2 _current = config db_url = config.db.url + test_schema = config.test_schema + test_schema_2 = config.test_schema_2 namespace.db = db = config.db @classmethod diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index f16a0828f..095e3f369 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -103,7 +103,7 @@ def setup_options(make_option): def configure_follower(follower_ident): global FOLLOWER_IDENT - FOLLOWER_IDENT = "test_%s" % follower_ident + FOLLOWER_IDENT = follower_ident def read_config(): @@ -221,18 +221,20 @@ def _engine_uri(options, file_config): if not db_urls: db_urls.append(file_config.get('db', 'default')) + from . import provision + for db_url in db_urls: - if FOLLOWER_IDENT: - from sqlalchemy.engine import url - db_url = url.make_url(db_url) - db_url.database = FOLLOWER_IDENT - eng = engines.testing_engine(db_url, db_opts) - eng.connect().close() - config.Config.register(eng, db_opts, options, file_config, testing) + cfg = provision.setup_config( + db_url, db_opts, options, file_config, FOLLOWER_IDENT) + + if not config._current: + cfg.set_as_current(cfg, testing) config.db_opts = db_opts + + @post def _engine_pool(options, file_config): if options.mockpool: diff --git a/lib/sqlalchemy/testing/plugin/provision.py b/lib/sqlalchemy/testing/plugin/provision.py index e6790f877..7c54cd643 100644 --- a/lib/sqlalchemy/testing/plugin/provision.py +++ b/lib/sqlalchemy/testing/plugin/provision.py @@ -1,11 +1,73 @@ from sqlalchemy.engine import url as sa_url +from sqlalchemy import text +from sqlalchemy.util import compat +from .. import config, engines +import os + + +class register(object): + def __init__(self): + self.fns = {} + + @classmethod + def init(cls, fn): + return register().for_db("*")(fn) + + def for_db(self, dbname): + def decorate(fn): + self.fns[dbname] = fn + return self + return decorate + + def __call__(self, cfg, *arg): + if isinstance(cfg, compat.string_types): + url = sa_url.make_url(cfg) + elif isinstance(cfg, sa_url.URL): + url = cfg + else: + url = cfg.db.url + backend = url.get_backend_name() + if backend in self.fns: + return self.fns[backend](cfg, *arg) + else: + return self.fns['*'](cfg, *arg) def create_follower_db(follower_ident): - from .. import config, engines - follower_ident = "test_%s" % follower_ident + for cfg in _configs_for_db_operation(): + url = cfg.db.url + backend = url.get_backend_name() + _create_db(cfg, cfg.db, follower_ident) + + new_url = sa_url.make_url(str(url)) + + new_url.database = follower_ident + + +def configure_follower(follower_ident): + for cfg in config.Config.all_configs(): + _configure_follower(cfg, follower_ident) + + +def setup_config(db_url, db_opts, options, file_config, follower_ident): + if follower_ident: + db_url = _follower_url_from_main(db_url, follower_ident) + eng = engines.testing_engine(db_url, db_opts) + eng.connect().close() + cfg = config.Config.register(eng, db_opts, options, file_config) + if follower_ident: + _configure_follower(cfg, follower_ident) + return cfg + + +def drop_follower_db(follower_ident): + for cfg in _configs_for_db_operation(): + url = cfg.db.url + _drop_db(cfg, cfg.db, follower_ident) + +def _configs_for_db_operation(): hosts = set() for cfg in config.Config.all_configs(): @@ -19,47 +81,109 @@ def create_follower_db(follower_ident): url.username, url.host, url.database) if host_conf not in hosts: - if backend.startswith("postgresql"): - _pg_create_db(cfg.db, follower_ident) - elif backend.startswith("mysql"): - _mysql_create_db(cfg.db, follower_ident) + yield cfg + hosts.add(host_conf) - new_url = sa_url.make_url(str(url)) + for cfg in config.Config.all_configs(): + cfg.db.dispose() - new_url.database = follower_ident - eng = engines.testing_engine(new_url, cfg.db_opts) - if backend.startswith("postgresql"): - _pg_init_db(eng) - elif backend.startswith("mysql"): - _mysql_init_db(eng) +@register.init +def _create_db(cfg, eng, ident): + raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url) - hosts.add(host_conf) + +@register.init +def _drop_db(cfg, eng, ident): + raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url) + + +@register.init +def _configure_follower(cfg, ident): + pass + + +@register.init +def _follower_url_from_main(url, ident): + url = sa_url.make_url(url) + url.database = ident + return url + + +@_follower_url_from_main.for_db("sqlite") +def _sqlite_follower_url_from_main(url, ident): + return sa_url.make_url("sqlite:///%s.db" % ident) -def _pg_create_db(eng, ident): +@_create_db.for_db("postgresql") +def _pg_create_db(cfg, eng, ident): with eng.connect().execution_options( isolation_level="AUTOCOMMIT") as conn: try: - conn.execute("DROP DATABASE %s" % ident) + _pg_drop_db(cfg, conn, ident) except: pass currentdb = conn.scalar("select current_database()") conn.execute("CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb)) -def _pg_init_db(eng): +@_create_db.for_db("mysql") +def _mysql_create_db(cfg, eng, ident): + with eng.connect() as conn: + try: + _mysql_drop_db(cfg, conn, ident) + except: + pass + conn.execute("CREATE DATABASE %s" % ident) + conn.execute("CREATE DATABASE %s_test_schema" % ident) + conn.execute("CREATE DATABASE %s_test_schema_2" % ident) + + +@_configure_follower.for_db("mysql") +def _mysql_configure_follower(config, ident): + config.test_schema = "%s_test_schema" % ident + config.test_schema_2 = "%s_test_schema_2" % ident + + +@_create_db.for_db("sqlite") +def _sqlite_create_db(cfg, eng, ident): pass -def _mysql_create_db(eng, ident): +@_drop_db.for_db("postgresql") +def _pg_drop_db(cfg, eng, ident): + with eng.connect().execution_options( + isolation_level="AUTOCOMMIT") as conn: + conn.execute( + text( + "select pg_terminate_backend(pid) from pg_stat_activity " + "where usename=current_user and pid != pg_backend_pid() " + "and datname=:dname" + ), dname=ident) + conn.execute("DROP DATABASE %s" % ident) + + +@_drop_db.for_db("sqlite") +def _sqlite_drop_db(cfg, eng, ident): + os.remove("%s.db" % ident) + + +@_drop_db.for_db("mysql") +def _mysql_drop_db(cfg, eng, ident): with eng.connect() as conn: try: + conn.execute("DROP DATABASE %s_test_schema" % ident) + except: + pass + try: + conn.execute("DROP DATABASE %s_test_schema_2" % ident) + except: + pass + try: conn.execute("DROP DATABASE %s" % ident) except: pass - conn.execute("CREATE DATABASE %s" % ident) -def _mysql_init_db(eng): - pass + + diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 7bef644d9..7671c800c 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -5,6 +5,12 @@ from . import plugin_base import collections import itertools +try: + import xdist + has_xdist = True +except ImportError: + has_xdist = False + def pytest_addoption(parser): group = parser.getgroup("sqlalchemy") @@ -37,15 +43,19 @@ def pytest_configure(config): plugin_base.post_begin() -_follower_count = itertools.count(1) +if has_xdist: + _follower_count = itertools.count(1) + def pytest_configure_node(node): + # the master for each node fills slaveinput dictionary + # which pytest-xdist will transfer to the subprocess + node.slaveinput["follower_ident"] = "test_%s" % next(_follower_count) + from . import provision + provision.create_follower_db(node.slaveinput["follower_ident"]) -def pytest_configure_node(node): - # the master for each node fills slaveinput dictionary - # which pytest-xdist will transfer to the subprocess - node.slaveinput["follower_ident"] = next(_follower_count) - from . import provision - provision.create_follower_db(node.slaveinput["follower_ident"]) + def pytest_testnodedown(node, error): + from . import provision + provision.drop_follower_db(node.slaveinput["follower_ident"]) def pytest_collection_modifyitems(session, config, items): |