summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-07-24 18:09:29 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-08-12 17:07:32 -0400
commitee176a82a4ca21a8a02478d219914f95028b798d (patch)
treeeccc4f10dfc554407bcc7d919cd454115ad3f97a
parente2adc4384ab4a9ae172bdca9cfbaa5ab0ed08c68 (diff)
downloadoslo-db-ee176a82a4ca21a8a02478d219914f95028b798d.tar.gz
Implement a dialect-level function dispatch system
A new utility feature dispatch_for_dialect() is added, which is intended to provide a single solution for the issue of many areas within oslo.db where conditional per-database and/or per-driver behavior is needed, such as that of setting up special event listeners, connection options, or invoking special features. The dispatch_for_dialect() function is essentially a decorator, which when invoked provides subsequent "builder" decorators in the style of the Python @property object and others, so that any number of functions can be registered against database/driver matcher strings. The system supports "single" or "multiple" function call modes, wildcards for both database and driver independently of each other, support for URL strings, URL objects, Engine and Connection objects as lead targets, and additional arguments and return values for single mode. The agenda of this feature is to provide the basis for the removal of all remaining "if db.name == 'MYDB'" conditionals throughout oslo.db and elsewhere. Change-Id: I75201487c1481ffc83329a2eaaaa07acf5aac76e
-rw-r--r--oslo/db/sqlalchemy/utils.py206
-rw-r--r--tests/sqlalchemy/test_utils.py245
2 files changed, 451 insertions, 0 deletions
diff --git a/oslo/db/sqlalchemy/utils.py b/oslo/db/sqlalchemy/utils.py
index 179e416..5abd4ea 100644
--- a/oslo/db/sqlalchemy/utils.py
+++ b/oslo/db/sqlalchemy/utils.py
@@ -16,15 +16,19 @@
# License for the specific language governing permissions and limitations
# under the License.
+import collections
import logging
import re
from oslo.utils import timeutils
+import six
import sqlalchemy
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
+from sqlalchemy.engine import Connectable
from sqlalchemy.engine import reflection
+from sqlalchemy.engine import url as sa_url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import func
from sqlalchemy import Index
@@ -777,3 +781,205 @@ def column_exists(engine, table_name, column):
"""
t = get_table(engine, table_name)
return column in t.c
+
+
+class DialectFunctionDispatcher(object):
+ @classmethod
+ def dispatch_for_dialect(cls, expr, multiple=False):
+ """Provide dialect-specific functionality within distinct functions.
+
+ e.g.::
+
+ @dispatch_for_dialect("*")
+ def set_special_option(engine):
+ pass
+
+ @set_special_option.dispatch_for("sqlite")
+ def set_sqlite_special_option(engine):
+ return engine.execute("sqlite thing")
+
+ @set_special_option.dispatch_for("mysql+mysqldb")
+ def set_mysqldb_special_option(engine):
+ return engine.execute("mysqldb thing")
+
+ After the above registration, the ``set_special_option()`` function
+ is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``,
+ URL string, or ``sqlalchemy.engine.URL`` object::
+
+ eng = create_engine('...')
+ result = set_special_option(eng)
+
+ The filter system supports two modes, "multiple" and "single".
+ The default is "single", and requires that one and only one function
+ match for a given backend. In this mode, the function may also
+ have a return value, which will be returned by the top level
+ call.
+
+ "multiple" mode, on the other hand, does not support return
+ arguments, but allows for any number of matching functions, where
+ each function will be called::
+
+ # the initial call sets this up as a "multiple" dispatcher
+ @dispatch_for_dialect("*", multiple=True)
+ def set_options(engine):
+ # set options that apply to *all* engines
+
+ @set_options.dispatch_for("postgresql")
+ def set_postgresql_options(engine):
+ # set options that apply to all Postgresql engines
+
+ @set_options.dispatch_for("postgresql+psycopg2")
+ def set_postgresql_psycopg2_options(engine):
+ # set options that apply only to "postgresql+psycopg2"
+
+ @set_options.dispatch_for("*+pyodbc")
+ def set_pyodbc_options(engine):
+ # set options that apply to all pyodbc backends
+
+ Note that in both modes, any number of additional arguments can be
+ accepted by member functions. For example, to populate a dictionary of
+ options, it may be passed in::
+
+ @dispatch_for_dialect("*", multiple=True)
+ def set_engine_options(url, opts):
+ pass
+
+ @set_engine_options.dispatch_for("mysql+mysqldb")
+ def _mysql_set_default_charset_to_utf8(url, opts):
+ opts.setdefault('charset', 'utf-8')
+
+ @set_engine_options.dispatch_for("sqlite")
+ def _set_sqlite_in_memory_check_same_thread(url, opts):
+ if url.database in (None, 'memory'):
+ opts['check_same_thread'] = False
+
+ opts = {}
+ set_engine_options(url, opts)
+
+ The driver specifiers are of the form:
+ ``<database | *>[+<driver | *>]``. That is, database name or "*",
+ followed by an optional ``+`` sign with driver or "*". Omitting
+ the driver name implies all drivers for that database.
+
+ """
+ if multiple:
+ cls = DialectMultiFunctionDispatcher
+ else:
+ cls = DialectSingleFunctionDispatcher
+ return cls().dispatch_for(expr)
+
+ _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$')
+
+ def dispatch_for(self, expr):
+ def decorate(fn):
+ dbname, driver = self._parse_dispatch(expr)
+ self._register(expr, dbname, driver, fn)
+ return self
+ return decorate
+
+ def _parse_dispatch(self, text):
+ m = self._db_plus_driver_reg.match(text)
+ if not m:
+ raise ValueError("Couldn't parse database[+driver]: %r" % text)
+ return m.group(1) or '*', m.group(2) or '*'
+
+ def __call__(self, *arg, **kw):
+ target = arg[0]
+ return self._dispatch_on(
+ self._url_from_target(target), target, arg, kw)
+
+ def _url_from_target(self, target):
+ if isinstance(target, Connectable):
+ return target.engine.url
+ elif isinstance(target, six.string_types):
+ if "://" not in target:
+ target_url = sa_url.make_url("%s://" % target)
+ else:
+ target_url = sa_url.make_url(target)
+ return target_url
+ elif isinstance(target, sa_url.URL):
+ return target
+ else:
+ raise ValueError("Invalid target type: %r" % target)
+
+ def dispatch_on_drivername(self, drivername):
+ """Return a sub-dispatcher for the given drivername.
+
+ This provides a means of calling a different function, such as the
+ "*" function, for a given target object that normally refers
+ to a sub-function.
+
+ """
+ dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2)
+
+ def go(*arg, **kw):
+ return self._dispatch_on_db_driver(dbname, "*", arg, kw)
+
+ return go
+
+ def _dispatch_on(self, url, target, arg, kw):
+ dbname, driver = self._db_plus_driver_reg.match(
+ url.drivername).group(1, 2)
+ if not driver:
+ driver = url.get_dialect().driver
+
+ return self._dispatch_on_db_driver(dbname, driver, arg, kw)
+
+ def _invoke_fn(self, fn, arg, kw):
+ return fn(*arg, **kw)
+
+
+class DialectSingleFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(dict)
+
+ def _register(self, expr, dbname, driver, fn):
+ fn_dict = self.reg[dbname]
+ if driver in fn_dict:
+ raise TypeError("Multiple functions for expression %r" % expr)
+ fn_dict[driver] = fn
+
+ def _matches(self, dbname, driver):
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in (driver, '*'):
+ if drv in subdict:
+ return subdict[drv]
+ else:
+ raise ValueError(
+ "No default function found for driver: %r" %
+ ("%s+%s" % (dbname, driver)))
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ fn = self._matches(dbname, driver)
+ return self._invoke_fn(fn, arg, kw)
+
+
+class DialectMultiFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(
+ lambda: collections.defaultdict(list))
+
+ def _register(self, expr, dbname, driver, fn):
+ self.reg[dbname][driver].append(fn)
+
+ def _matches(self, dbname, driver):
+ if driver != '*':
+ drivers = (driver, '*')
+ else:
+ drivers = ('*', )
+
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in drivers:
+ for fn in subdict[drv]:
+ yield fn
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ for fn in self._matches(dbname, driver):
+ if self._invoke_fn(fn, arg, kw) is not None:
+ raise TypeError(
+ "Return value not allowed for "
+ "multiple filtered function")
+
+dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect
diff --git a/tests/sqlalchemy/test_utils.py b/tests/sqlalchemy/test_utils.py
index e1285f9..cdb456c 100644
--- a/tests/sqlalchemy/test_utils.py
+++ b/tests/sqlalchemy/test_utils.py
@@ -853,3 +853,248 @@ class TestUtilsMysqlOpportunistically(
class TestUtilsPostgresqlOpportunistically(
TestUtils, db_test_base.PostgreSQLOpportunisticTestCase):
pass
+
+
+class TestDialectFunctionDispatcher(test_base.BaseTestCase):
+ def _single_fixture(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = orig = utils.dispatch_for_dialect("*")(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql_mysqldb)
+ dispatcher = dispatcher.dispatch_for("postgresql")(
+ callable_fn.postgresql)
+
+ self.assertTrue(dispatcher is orig)
+
+ return dispatcher, callable_fn
+
+ def _multiple_fixture(self):
+ callable_fn = mock.Mock()
+
+ for targ in [
+ callable_fn.default,
+ callable_fn.sqlite,
+ callable_fn.mysql_mysqldb,
+ callable_fn.postgresql,
+ callable_fn.postgresql_psycopg2,
+ callable_fn.pyodbc
+ ]:
+ targ.return_value = None
+
+ dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql_mysqldb)
+ dispatcher = dispatcher.dispatch_for("postgresql+*")(
+ callable_fn.postgresql)
+ dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")(
+ callable_fn.postgresql_psycopg2)
+ dispatcher = dispatcher.dispatch_for("*+pyodbc")(
+ callable_fn.pyodbc)
+
+ self.assertTrue(dispatcher is orig)
+
+ return dispatcher, callable_fn
+
+ def test_single(self):
+
+ dispatcher, callable_fn = self._single_fixture()
+ dispatcher("sqlite://", 1)
+ dispatcher("postgresql+psycopg2://u:p@h/t", 2)
+ dispatcher("mysql://u:p@h/t", 3)
+ dispatcher("mysql+mysqlconnector://u:p@h/t", 4)
+
+ self.assertEqual(
+ [
+ mock.call.sqlite('sqlite://', 1),
+ mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2),
+ mock.call.mysql_mysqldb("mysql://u:p@h/t", 3),
+ mock.call.default("mysql+mysqlconnector://u:p@h/t", 4)
+ ],
+ callable_fn.mock_calls)
+
+ def test_single_kwarg(self):
+ dispatcher, callable_fn = self._single_fixture()
+ dispatcher("sqlite://", foo='bar')
+ dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y')
+
+ self.assertEqual(
+ [
+ mock.call.sqlite('sqlite://', foo='bar'),
+ mock.call.postgresql(
+ "postgresql+psycopg2://u:p@h/t",
+ 1, x='y'),
+ ],
+ callable_fn.mock_calls)
+
+ def test_dispatch_on_target(self):
+ callable_fn = mock.Mock()
+
+ @utils.dispatch_for_dialect("*")
+ def default_fn(url, x, y):
+ callable_fn.default(url, x, y)
+
+ @default_fn.dispatch_for("sqlite")
+ def sqlite_fn(url, x, y):
+ callable_fn.sqlite(url, x, y)
+ default_fn.dispatch_on_drivername("*")(url, x, y)
+
+ default_fn("sqlite://", 4, 5)
+ self.assertEqual(
+ [
+ mock.call.sqlite("sqlite://", 4, 5),
+ mock.call.default("sqlite://", 4, 5)
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_single_no_dispatcher(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher, "postgresql://s:t@localhost/test"
+ )
+ self.assertEqual(
+ "No default function found for driver: 'postgresql+psycopg2'",
+ str(exc)
+ )
+
+ def test_multiple_no_dispatcher(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)(
+ callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
+ dispatcher("postgresql://s:t@localhost/test")
+ self.assertEqual(
+ [], callable_fn.mock_calls
+ )
+
+ def test_multiple_no_driver(self):
+ callable_fn = mock.Mock(
+ default=mock.Mock(return_value=None),
+ sqlite=mock.Mock(return_value=None)
+ )
+
+ dispatcher = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(
+ callable_fn.sqlite)
+
+ dispatcher.dispatch_on_drivername("sqlite")("foo")
+ self.assertEqual(
+ [mock.call.sqlite("foo"), mock.call.default("foo")],
+ callable_fn.mock_calls
+ )
+
+ def test_single_retval(self):
+ dispatcher, callable_fn = self._single_fixture()
+ callable_fn.mysql_mysqldb.return_value = 5
+
+ self.assertEqual(
+ dispatcher("mysql://u:p@h/t", 3), 5
+ )
+
+ def test_engine(self):
+ eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db")
+ dispatcher, callable_fn = self._single_fixture()
+
+ dispatcher(eng)
+ self.assertEqual(
+ [mock.call.sqlite(eng)],
+ callable_fn.mock_calls
+ )
+
+ def test_url(self):
+ url = sqlalchemy.engine.url.make_url(
+ "mysql+mysqldb://scott:tiger@localhost/test")
+ dispatcher, callable_fn = self._single_fixture()
+
+ dispatcher(url, 15)
+ self.assertEqual(
+ [mock.call.mysql_mysqldb(url, 15)],
+ callable_fn.mock_calls
+ )
+
+ def test_invalid_target(self):
+ dispatcher, callable_fn = self._single_fixture()
+
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher, 20
+ )
+ self.assertEqual("Invalid target type: 20", str(exc))
+
+ def test_invalid_dispatch(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
+
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc
+ )
+ self.assertEqual(
+ "Couldn't parse database[+driver]: '+pyodbc'",
+ str(exc)
+ )
+
+ def test_single_only_one_target(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+
+ exc = self.assertRaises(
+ TypeError,
+ dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2
+ )
+ self.assertEqual(
+ "Multiple functions for expression 'sqlite'", str(exc)
+ )
+
+ def test_multiple(self):
+ dispatcher, callable_fn = self._multiple_fixture()
+
+ dispatcher("postgresql+pyodbc://", 1)
+ dispatcher("mysql://", 2)
+ dispatcher("ibm_db_sa+db2://", 3)
+ dispatcher("postgresql+psycopg2://", 4)
+
+ # TODO(zzzeek): there is a deterministic order here, but we might
+ # want to tweak it, or maybe provide options. default first?
+ # most specific first? is *+pyodbc or postgresql+* more specific?
+ self.assertEqual(
+ [
+ mock.call.postgresql('postgresql+pyodbc://', 1),
+ mock.call.pyodbc('postgresql+pyodbc://', 1),
+ mock.call.default('postgresql+pyodbc://', 1),
+ mock.call.mysql_mysqldb('mysql://', 2),
+ mock.call.default('mysql://', 2),
+ mock.call.default('ibm_db_sa+db2://', 3),
+ mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4),
+ mock.call.postgresql('postgresql+psycopg2://', 4),
+ mock.call.default('postgresql+psycopg2://', 4),
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_multiple_no_return_value(self):
+ dispatcher, callable_fn = self._multiple_fixture()
+ callable_fn.sqlite.return_value = 5
+
+ exc = self.assertRaises(
+ TypeError,
+ dispatcher, "sqlite://"
+ )
+ self.assertEqual(
+ "Return value not allowed for multiple filtered function",
+ str(exc)
+ )