summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2014-08-20 00:47:37 +0000
committerGerrit Code Review <review@openstack.org>2014-08-20 00:47:37 +0000
commit68c0b62972ef6a8c5d8fe92a75894e34d3947932 (patch)
tree428e947c70aacfee7da0d2b6b8bd623026dcc8f7
parentec05831e3d7566d1ee3795b047bc790a61a1059e (diff)
parentee176a82a4ca21a8a02478d219914f95028b798d (diff)
downloadoslo-db-68c0b62972ef6a8c5d8fe92a75894e34d3947932.tar.gz
Merge "Implement a dialect-level function dispatch system"0.4.0
-rw-r--r--oslo/db/sqlalchemy/utils.py206
-rw-r--r--tests/sqlalchemy/test_utils.py245
2 files changed, 451 insertions, 0 deletions
diff --git a/oslo/db/sqlalchemy/utils.py b/oslo/db/sqlalchemy/utils.py
index 179e416..5abd4ea 100644
--- a/oslo/db/sqlalchemy/utils.py
+++ b/oslo/db/sqlalchemy/utils.py
@@ -16,15 +16,19 @@
# License for the specific language governing permissions and limitations
# under the License.
+import collections
import logging
import re
from oslo.utils import timeutils
+import six
import sqlalchemy
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
+from sqlalchemy.engine import Connectable
from sqlalchemy.engine import reflection
+from sqlalchemy.engine import url as sa_url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import func
from sqlalchemy import Index
@@ -777,3 +781,205 @@ def column_exists(engine, table_name, column):
"""
t = get_table(engine, table_name)
return column in t.c
+
+
+class DialectFunctionDispatcher(object):
+ @classmethod
+ def dispatch_for_dialect(cls, expr, multiple=False):
+ """Provide dialect-specific functionality within distinct functions.
+
+ e.g.::
+
+ @dispatch_for_dialect("*")
+ def set_special_option(engine):
+ pass
+
+ @set_special_option.dispatch_for("sqlite")
+ def set_sqlite_special_option(engine):
+ return engine.execute("sqlite thing")
+
+ @set_special_option.dispatch_for("mysql+mysqldb")
+ def set_mysqldb_special_option(engine):
+ return engine.execute("mysqldb thing")
+
+ After the above registration, the ``set_special_option()`` function
+ is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``,
+ URL string, or ``sqlalchemy.engine.URL`` object::
+
+ eng = create_engine('...')
+ result = set_special_option(eng)
+
+ The filter system supports two modes, "multiple" and "single".
+ The default is "single", and requires that one and only one function
+ match for a given backend. In this mode, the function may also
+ have a return value, which will be returned by the top level
+ call.
+
+ "multiple" mode, on the other hand, does not support return
+ arguments, but allows for any number of matching functions, where
+ each function will be called::
+
+ # the initial call sets this up as a "multiple" dispatcher
+ @dispatch_for_dialect("*", multiple=True)
+ def set_options(engine):
+ # set options that apply to *all* engines
+
+ @set_options.dispatch_for("postgresql")
+ def set_postgresql_options(engine):
+ # set options that apply to all Postgresql engines
+
+ @set_options.dispatch_for("postgresql+psycopg2")
+ def set_postgresql_psycopg2_options(engine):
+ # set options that apply only to "postgresql+psycopg2"
+
+ @set_options.dispatch_for("*+pyodbc")
+ def set_pyodbc_options(engine):
+ # set options that apply to all pyodbc backends
+
+ Note that in both modes, any number of additional arguments can be
+ accepted by member functions. For example, to populate a dictionary of
+ options, it may be passed in::
+
+ @dispatch_for_dialect("*", multiple=True)
+ def set_engine_options(url, opts):
+ pass
+
+ @set_engine_options.dispatch_for("mysql+mysqldb")
+ def _mysql_set_default_charset_to_utf8(url, opts):
+ opts.setdefault('charset', 'utf-8')
+
+ @set_engine_options.dispatch_for("sqlite")
+ def _set_sqlite_in_memory_check_same_thread(url, opts):
+ if url.database in (None, 'memory'):
+ opts['check_same_thread'] = False
+
+ opts = {}
+ set_engine_options(url, opts)
+
+ The driver specifiers are of the form:
+ ``<database | *>[+<driver | *>]``. That is, database name or "*",
+ followed by an optional ``+`` sign with driver or "*". Omitting
+ the driver name implies all drivers for that database.
+
+ """
+ if multiple:
+ cls = DialectMultiFunctionDispatcher
+ else:
+ cls = DialectSingleFunctionDispatcher
+ return cls().dispatch_for(expr)
+
+ _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$')
+
+ def dispatch_for(self, expr):
+ def decorate(fn):
+ dbname, driver = self._parse_dispatch(expr)
+ self._register(expr, dbname, driver, fn)
+ return self
+ return decorate
+
+ def _parse_dispatch(self, text):
+ m = self._db_plus_driver_reg.match(text)
+ if not m:
+ raise ValueError("Couldn't parse database[+driver]: %r" % text)
+ return m.group(1) or '*', m.group(2) or '*'
+
+ def __call__(self, *arg, **kw):
+ target = arg[0]
+ return self._dispatch_on(
+ self._url_from_target(target), target, arg, kw)
+
+ def _url_from_target(self, target):
+ if isinstance(target, Connectable):
+ return target.engine.url
+ elif isinstance(target, six.string_types):
+ if "://" not in target:
+ target_url = sa_url.make_url("%s://" % target)
+ else:
+ target_url = sa_url.make_url(target)
+ return target_url
+ elif isinstance(target, sa_url.URL):
+ return target
+ else:
+ raise ValueError("Invalid target type: %r" % target)
+
+ def dispatch_on_drivername(self, drivername):
+ """Return a sub-dispatcher for the given drivername.
+
+ This provides a means of calling a different function, such as the
+ "*" function, for a given target object that normally refers
+ to a sub-function.
+
+ """
+ dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2)
+
+ def go(*arg, **kw):
+ return self._dispatch_on_db_driver(dbname, "*", arg, kw)
+
+ return go
+
+ def _dispatch_on(self, url, target, arg, kw):
+ dbname, driver = self._db_plus_driver_reg.match(
+ url.drivername).group(1, 2)
+ if not driver:
+ driver = url.get_dialect().driver
+
+ return self._dispatch_on_db_driver(dbname, driver, arg, kw)
+
+ def _invoke_fn(self, fn, arg, kw):
+ return fn(*arg, **kw)
+
+
+class DialectSingleFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(dict)
+
+ def _register(self, expr, dbname, driver, fn):
+ fn_dict = self.reg[dbname]
+ if driver in fn_dict:
+ raise TypeError("Multiple functions for expression %r" % expr)
+ fn_dict[driver] = fn
+
+ def _matches(self, dbname, driver):
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in (driver, '*'):
+ if drv in subdict:
+ return subdict[drv]
+ else:
+ raise ValueError(
+ "No default function found for driver: %r" %
+ ("%s+%s" % (dbname, driver)))
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ fn = self._matches(dbname, driver)
+ return self._invoke_fn(fn, arg, kw)
+
+
+class DialectMultiFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(
+ lambda: collections.defaultdict(list))
+
+ def _register(self, expr, dbname, driver, fn):
+ self.reg[dbname][driver].append(fn)
+
+ def _matches(self, dbname, driver):
+ if driver != '*':
+ drivers = (driver, '*')
+ else:
+ drivers = ('*', )
+
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in drivers:
+ for fn in subdict[drv]:
+ yield fn
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ for fn in self._matches(dbname, driver):
+ if self._invoke_fn(fn, arg, kw) is not None:
+ raise TypeError(
+ "Return value not allowed for "
+ "multiple filtered function")
+
+dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect
diff --git a/tests/sqlalchemy/test_utils.py b/tests/sqlalchemy/test_utils.py
index cdb132b..88c3817 100644
--- a/tests/sqlalchemy/test_utils.py
+++ b/tests/sqlalchemy/test_utils.py
@@ -849,3 +849,248 @@ class TestUtilsMysqlOpportunistically(
class TestUtilsPostgresqlOpportunistically(
TestUtils, db_test_base.PostgreSQLOpportunisticTestCase):
pass
+
+
+class TestDialectFunctionDispatcher(test_base.BaseTestCase):
+ def _single_fixture(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = orig = utils.dispatch_for_dialect("*")(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql_mysqldb)
+ dispatcher = dispatcher.dispatch_for("postgresql")(
+ callable_fn.postgresql)
+
+ self.assertTrue(dispatcher is orig)
+
+ return dispatcher, callable_fn
+
+ def _multiple_fixture(self):
+ callable_fn = mock.Mock()
+
+ for targ in [
+ callable_fn.default,
+ callable_fn.sqlite,
+ callable_fn.mysql_mysqldb,
+ callable_fn.postgresql,
+ callable_fn.postgresql_psycopg2,
+ callable_fn.pyodbc
+ ]:
+ targ.return_value = None
+
+ dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql_mysqldb)
+ dispatcher = dispatcher.dispatch_for("postgresql+*")(
+ callable_fn.postgresql)
+ dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")(
+ callable_fn.postgresql_psycopg2)
+ dispatcher = dispatcher.dispatch_for("*+pyodbc")(
+ callable_fn.pyodbc)
+
+ self.assertTrue(dispatcher is orig)
+
+ return dispatcher, callable_fn
+
+ def test_single(self):
+
+ dispatcher, callable_fn = self._single_fixture()
+ dispatcher("sqlite://", 1)
+ dispatcher("postgresql+psycopg2://u:p@h/t", 2)
+ dispatcher("mysql://u:p@h/t", 3)
+ dispatcher("mysql+mysqlconnector://u:p@h/t", 4)
+
+ self.assertEqual(
+ [
+ mock.call.sqlite('sqlite://', 1),
+ mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2),
+ mock.call.mysql_mysqldb("mysql://u:p@h/t", 3),
+ mock.call.default("mysql+mysqlconnector://u:p@h/t", 4)
+ ],
+ callable_fn.mock_calls)
+
+ def test_single_kwarg(self):
+ dispatcher, callable_fn = self._single_fixture()
+ dispatcher("sqlite://", foo='bar')
+ dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y')
+
+ self.assertEqual(
+ [
+ mock.call.sqlite('sqlite://', foo='bar'),
+ mock.call.postgresql(
+ "postgresql+psycopg2://u:p@h/t",
+ 1, x='y'),
+ ],
+ callable_fn.mock_calls)
+
+ def test_dispatch_on_target(self):
+ callable_fn = mock.Mock()
+
+ @utils.dispatch_for_dialect("*")
+ def default_fn(url, x, y):
+ callable_fn.default(url, x, y)
+
+ @default_fn.dispatch_for("sqlite")
+ def sqlite_fn(url, x, y):
+ callable_fn.sqlite(url, x, y)
+ default_fn.dispatch_on_drivername("*")(url, x, y)
+
+ default_fn("sqlite://", 4, 5)
+ self.assertEqual(
+ [
+ mock.call.sqlite("sqlite://", 4, 5),
+ mock.call.default("sqlite://", 4, 5)
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_single_no_dispatcher(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher, "postgresql://s:t@localhost/test"
+ )
+ self.assertEqual(
+ "No default function found for driver: 'postgresql+psycopg2'",
+ str(exc)
+ )
+
+ def test_multiple_no_dispatcher(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)(
+ callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
+ dispatcher("postgresql://s:t@localhost/test")
+ self.assertEqual(
+ [], callable_fn.mock_calls
+ )
+
+ def test_multiple_no_driver(self):
+ callable_fn = mock.Mock(
+ default=mock.Mock(return_value=None),
+ sqlite=mock.Mock(return_value=None)
+ )
+
+ dispatcher = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(
+ callable_fn.sqlite)
+
+ dispatcher.dispatch_on_drivername("sqlite")("foo")
+ self.assertEqual(
+ [mock.call.sqlite("foo"), mock.call.default("foo")],
+ callable_fn.mock_calls
+ )
+
+ def test_single_retval(self):
+ dispatcher, callable_fn = self._single_fixture()
+ callable_fn.mysql_mysqldb.return_value = 5
+
+ self.assertEqual(
+ dispatcher("mysql://u:p@h/t", 3), 5
+ )
+
+ def test_engine(self):
+ eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db")
+ dispatcher, callable_fn = self._single_fixture()
+
+ dispatcher(eng)
+ self.assertEqual(
+ [mock.call.sqlite(eng)],
+ callable_fn.mock_calls
+ )
+
+ def test_url(self):
+ url = sqlalchemy.engine.url.make_url(
+ "mysql+mysqldb://scott:tiger@localhost/test")
+ dispatcher, callable_fn = self._single_fixture()
+
+ dispatcher(url, 15)
+ self.assertEqual(
+ [mock.call.mysql_mysqldb(url, 15)],
+ callable_fn.mock_calls
+ )
+
+ def test_invalid_target(self):
+ dispatcher, callable_fn = self._single_fixture()
+
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher, 20
+ )
+ self.assertEqual("Invalid target type: 20", str(exc))
+
+ def test_invalid_dispatch(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
+
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc
+ )
+ self.assertEqual(
+ "Couldn't parse database[+driver]: '+pyodbc'",
+ str(exc)
+ )
+
+ def test_single_only_one_target(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+
+ exc = self.assertRaises(
+ TypeError,
+ dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2
+ )
+ self.assertEqual(
+ "Multiple functions for expression 'sqlite'", str(exc)
+ )
+
+ def test_multiple(self):
+ dispatcher, callable_fn = self._multiple_fixture()
+
+ dispatcher("postgresql+pyodbc://", 1)
+ dispatcher("mysql://", 2)
+ dispatcher("ibm_db_sa+db2://", 3)
+ dispatcher("postgresql+psycopg2://", 4)
+
+ # TODO(zzzeek): there is a deterministic order here, but we might
+ # want to tweak it, or maybe provide options. default first?
+ # most specific first? is *+pyodbc or postgresql+* more specific?
+ self.assertEqual(
+ [
+ mock.call.postgresql('postgresql+pyodbc://', 1),
+ mock.call.pyodbc('postgresql+pyodbc://', 1),
+ mock.call.default('postgresql+pyodbc://', 1),
+ mock.call.mysql_mysqldb('mysql://', 2),
+ mock.call.default('mysql://', 2),
+ mock.call.default('ibm_db_sa+db2://', 3),
+ mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4),
+ mock.call.postgresql('postgresql+psycopg2://', 4),
+ mock.call.default('postgresql+psycopg2://', 4),
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_multiple_no_return_value(self):
+ dispatcher, callable_fn = self._multiple_fixture()
+ callable_fn.sqlite.return_value = 5
+
+ exc = self.assertRaises(
+ TypeError,
+ dispatcher, "sqlite://"
+ )
+ self.assertEqual(
+ "Return value not allowed for multiple filtered function",
+ str(exc)
+ )