diff options
-rw-r--r-- | oslo/db/sqlalchemy/utils.py | 206 | ||||
-rw-r--r-- | tests/sqlalchemy/test_utils.py | 245 |
2 files changed, 451 insertions, 0 deletions
diff --git a/oslo/db/sqlalchemy/utils.py b/oslo/db/sqlalchemy/utils.py index 179e416..5abd4ea 100644 --- a/oslo/db/sqlalchemy/utils.py +++ b/oslo/db/sqlalchemy/utils.py @@ -16,15 +16,19 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import logging import re from oslo.utils import timeutils +import six import sqlalchemy from sqlalchemy import Boolean from sqlalchemy import CheckConstraint from sqlalchemy import Column +from sqlalchemy.engine import Connectable from sqlalchemy.engine import reflection +from sqlalchemy.engine import url as sa_url from sqlalchemy.ext.compiler import compiles from sqlalchemy import func from sqlalchemy import Index @@ -777,3 +781,205 @@ def column_exists(engine, table_name, column): """ t = get_table(engine, table_name) return column in t.c + + +class DialectFunctionDispatcher(object): + @classmethod + def dispatch_for_dialect(cls, expr, multiple=False): + """Provide dialect-specific functionality within distinct functions. + + e.g.:: + + @dispatch_for_dialect("*") + def set_special_option(engine): + pass + + @set_special_option.dispatch_for("sqlite") + def set_sqlite_special_option(engine): + return engine.execute("sqlite thing") + + @set_special_option.dispatch_for("mysql+mysqldb") + def set_mysqldb_special_option(engine): + return engine.execute("mysqldb thing") + + After the above registration, the ``set_special_option()`` function + is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``, + URL string, or ``sqlalchemy.engine.URL`` object:: + + eng = create_engine('...') + result = set_special_option(eng) + + The filter system supports two modes, "multiple" and "single". + The default is "single", and requires that one and only one function + match for a given backend. In this mode, the function may also + have a return value, which will be returned by the top level + call. + + "multiple" mode, on the other hand, does not support return + arguments, but allows for any number of matching functions, where + each function will be called:: + + # the initial call sets this up as a "multiple" dispatcher + @dispatch_for_dialect("*", multiple=True) + def set_options(engine): + # set options that apply to *all* engines + + @set_options.dispatch_for("postgresql") + def set_postgresql_options(engine): + # set options that apply to all Postgresql engines + + @set_options.dispatch_for("postgresql+psycopg2") + def set_postgresql_psycopg2_options(engine): + # set options that apply only to "postgresql+psycopg2" + + @set_options.dispatch_for("*+pyodbc") + def set_pyodbc_options(engine): + # set options that apply to all pyodbc backends + + Note that in both modes, any number of additional arguments can be + accepted by member functions. For example, to populate a dictionary of + options, it may be passed in:: + + @dispatch_for_dialect("*", multiple=True) + def set_engine_options(url, opts): + pass + + @set_engine_options.dispatch_for("mysql+mysqldb") + def _mysql_set_default_charset_to_utf8(url, opts): + opts.setdefault('charset', 'utf-8') + + @set_engine_options.dispatch_for("sqlite") + def _set_sqlite_in_memory_check_same_thread(url, opts): + if url.database in (None, 'memory'): + opts['check_same_thread'] = False + + opts = {} + set_engine_options(url, opts) + + The driver specifiers are of the form: + ``<database | *>[+<driver | *>]``. That is, database name or "*", + followed by an optional ``+`` sign with driver or "*". Omitting + the driver name implies all drivers for that database. + + """ + if multiple: + cls = DialectMultiFunctionDispatcher + else: + cls = DialectSingleFunctionDispatcher + return cls().dispatch_for(expr) + + _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$') + + def dispatch_for(self, expr): + def decorate(fn): + dbname, driver = self._parse_dispatch(expr) + self._register(expr, dbname, driver, fn) + return self + return decorate + + def _parse_dispatch(self, text): + m = self._db_plus_driver_reg.match(text) + if not m: + raise ValueError("Couldn't parse database[+driver]: %r" % text) + return m.group(1) or '*', m.group(2) or '*' + + def __call__(self, *arg, **kw): + target = arg[0] + return self._dispatch_on( + self._url_from_target(target), target, arg, kw) + + def _url_from_target(self, target): + if isinstance(target, Connectable): + return target.engine.url + elif isinstance(target, six.string_types): + if "://" not in target: + target_url = sa_url.make_url("%s://" % target) + else: + target_url = sa_url.make_url(target) + return target_url + elif isinstance(target, sa_url.URL): + return target + else: + raise ValueError("Invalid target type: %r" % target) + + def dispatch_on_drivername(self, drivername): + """Return a sub-dispatcher for the given drivername. + + This provides a means of calling a different function, such as the + "*" function, for a given target object that normally refers + to a sub-function. + + """ + dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2) + + def go(*arg, **kw): + return self._dispatch_on_db_driver(dbname, "*", arg, kw) + + return go + + def _dispatch_on(self, url, target, arg, kw): + dbname, driver = self._db_plus_driver_reg.match( + url.drivername).group(1, 2) + if not driver: + driver = url.get_dialect().driver + + return self._dispatch_on_db_driver(dbname, driver, arg, kw) + + def _invoke_fn(self, fn, arg, kw): + return fn(*arg, **kw) + + +class DialectSingleFunctionDispatcher(DialectFunctionDispatcher): + def __init__(self): + self.reg = collections.defaultdict(dict) + + def _register(self, expr, dbname, driver, fn): + fn_dict = self.reg[dbname] + if driver in fn_dict: + raise TypeError("Multiple functions for expression %r" % expr) + fn_dict[driver] = fn + + def _matches(self, dbname, driver): + for db in (dbname, '*'): + subdict = self.reg[db] + for drv in (driver, '*'): + if drv in subdict: + return subdict[drv] + else: + raise ValueError( + "No default function found for driver: %r" % + ("%s+%s" % (dbname, driver))) + + def _dispatch_on_db_driver(self, dbname, driver, arg, kw): + fn = self._matches(dbname, driver) + return self._invoke_fn(fn, arg, kw) + + +class DialectMultiFunctionDispatcher(DialectFunctionDispatcher): + def __init__(self): + self.reg = collections.defaultdict( + lambda: collections.defaultdict(list)) + + def _register(self, expr, dbname, driver, fn): + self.reg[dbname][driver].append(fn) + + def _matches(self, dbname, driver): + if driver != '*': + drivers = (driver, '*') + else: + drivers = ('*', ) + + for db in (dbname, '*'): + subdict = self.reg[db] + for drv in drivers: + for fn in subdict[drv]: + yield fn + + def _dispatch_on_db_driver(self, dbname, driver, arg, kw): + for fn in self._matches(dbname, driver): + if self._invoke_fn(fn, arg, kw) is not None: + raise TypeError( + "Return value not allowed for " + "multiple filtered function") + +dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect diff --git a/tests/sqlalchemy/test_utils.py b/tests/sqlalchemy/test_utils.py index e1285f9..cdb456c 100644 --- a/tests/sqlalchemy/test_utils.py +++ b/tests/sqlalchemy/test_utils.py @@ -853,3 +853,248 @@ class TestUtilsMysqlOpportunistically( class TestUtilsPostgresqlOpportunistically( TestUtils, db_test_base.PostgreSQLOpportunisticTestCase): pass + + +class TestDialectFunctionDispatcher(test_base.BaseTestCase): + def _single_fixture(self): + callable_fn = mock.Mock() + + dispatcher = orig = utils.dispatch_for_dialect("*")( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql_mysqldb) + dispatcher = dispatcher.dispatch_for("postgresql")( + callable_fn.postgresql) + + self.assertTrue(dispatcher is orig) + + return dispatcher, callable_fn + + def _multiple_fixture(self): + callable_fn = mock.Mock() + + for targ in [ + callable_fn.default, + callable_fn.sqlite, + callable_fn.mysql_mysqldb, + callable_fn.postgresql, + callable_fn.postgresql_psycopg2, + callable_fn.pyodbc + ]: + targ.return_value = None + + dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql_mysqldb) + dispatcher = dispatcher.dispatch_for("postgresql+*")( + callable_fn.postgresql) + dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")( + callable_fn.postgresql_psycopg2) + dispatcher = dispatcher.dispatch_for("*+pyodbc")( + callable_fn.pyodbc) + + self.assertTrue(dispatcher is orig) + + return dispatcher, callable_fn + + def test_single(self): + + dispatcher, callable_fn = self._single_fixture() + dispatcher("sqlite://", 1) + dispatcher("postgresql+psycopg2://u:p@h/t", 2) + dispatcher("mysql://u:p@h/t", 3) + dispatcher("mysql+mysqlconnector://u:p@h/t", 4) + + self.assertEqual( + [ + mock.call.sqlite('sqlite://', 1), + mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2), + mock.call.mysql_mysqldb("mysql://u:p@h/t", 3), + mock.call.default("mysql+mysqlconnector://u:p@h/t", 4) + ], + callable_fn.mock_calls) + + def test_single_kwarg(self): + dispatcher, callable_fn = self._single_fixture() + dispatcher("sqlite://", foo='bar') + dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y') + + self.assertEqual( + [ + mock.call.sqlite('sqlite://', foo='bar'), + mock.call.postgresql( + "postgresql+psycopg2://u:p@h/t", + 1, x='y'), + ], + callable_fn.mock_calls) + + def test_dispatch_on_target(self): + callable_fn = mock.Mock() + + @utils.dispatch_for_dialect("*") + def default_fn(url, x, y): + callable_fn.default(url, x, y) + + @default_fn.dispatch_for("sqlite") + def sqlite_fn(url, x, y): + callable_fn.sqlite(url, x, y) + default_fn.dispatch_on_drivername("*")(url, x, y) + + default_fn("sqlite://", 4, 5) + self.assertEqual( + [ + mock.call.sqlite("sqlite://", 4, 5), + mock.call.default("sqlite://", 4, 5) + ], + callable_fn.mock_calls + ) + + def test_single_no_dispatcher(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql) + exc = self.assertRaises( + ValueError, + dispatcher, "postgresql://s:t@localhost/test" + ) + self.assertEqual( + "No default function found for driver: 'postgresql+psycopg2'", + str(exc) + ) + + def test_multiple_no_dispatcher(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)( + callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql) + dispatcher("postgresql://s:t@localhost/test") + self.assertEqual( + [], callable_fn.mock_calls + ) + + def test_multiple_no_driver(self): + callable_fn = mock.Mock( + default=mock.Mock(return_value=None), + sqlite=mock.Mock(return_value=None) + ) + + dispatcher = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")( + callable_fn.sqlite) + + dispatcher.dispatch_on_drivername("sqlite")("foo") + self.assertEqual( + [mock.call.sqlite("foo"), mock.call.default("foo")], + callable_fn.mock_calls + ) + + def test_single_retval(self): + dispatcher, callable_fn = self._single_fixture() + callable_fn.mysql_mysqldb.return_value = 5 + + self.assertEqual( + dispatcher("mysql://u:p@h/t", 3), 5 + ) + + def test_engine(self): + eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db") + dispatcher, callable_fn = self._single_fixture() + + dispatcher(eng) + self.assertEqual( + [mock.call.sqlite(eng)], + callable_fn.mock_calls + ) + + def test_url(self): + url = sqlalchemy.engine.url.make_url( + "mysql+mysqldb://scott:tiger@localhost/test") + dispatcher, callable_fn = self._single_fixture() + + dispatcher(url, 15) + self.assertEqual( + [mock.call.mysql_mysqldb(url, 15)], + callable_fn.mock_calls + ) + + def test_invalid_target(self): + dispatcher, callable_fn = self._single_fixture() + + exc = self.assertRaises( + ValueError, + dispatcher, 20 + ) + self.assertEqual("Invalid target type: 20", str(exc)) + + def test_invalid_dispatch(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default) + + exc = self.assertRaises( + ValueError, + dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc + ) + self.assertEqual( + "Couldn't parse database[+driver]: '+pyodbc'", + str(exc) + ) + + def test_single_only_one_target(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + + exc = self.assertRaises( + TypeError, + dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2 + ) + self.assertEqual( + "Multiple functions for expression 'sqlite'", str(exc) + ) + + def test_multiple(self): + dispatcher, callable_fn = self._multiple_fixture() + + dispatcher("postgresql+pyodbc://", 1) + dispatcher("mysql://", 2) + dispatcher("ibm_db_sa+db2://", 3) + dispatcher("postgresql+psycopg2://", 4) + + # TODO(zzzeek): there is a deterministic order here, but we might + # want to tweak it, or maybe provide options. default first? + # most specific first? is *+pyodbc or postgresql+* more specific? + self.assertEqual( + [ + mock.call.postgresql('postgresql+pyodbc://', 1), + mock.call.pyodbc('postgresql+pyodbc://', 1), + mock.call.default('postgresql+pyodbc://', 1), + mock.call.mysql_mysqldb('mysql://', 2), + mock.call.default('mysql://', 2), + mock.call.default('ibm_db_sa+db2://', 3), + mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4), + mock.call.postgresql('postgresql+psycopg2://', 4), + mock.call.default('postgresql+psycopg2://', 4), + ], + callable_fn.mock_calls + ) + + def test_multiple_no_return_value(self): + dispatcher, callable_fn = self._multiple_fixture() + callable_fn.sqlite.return_value = 5 + + exc = self.assertRaises( + TypeError, + dispatcher, "sqlite://" + ) + self.assertEqual( + "Return value not allowed for multiple filtered function", + str(exc) + ) |