summaryrefslogtreecommitdiff
path: root/oslo/db/sqlalchemy/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'oslo/db/sqlalchemy/utils.py')
-rw-r--r--oslo/db/sqlalchemy/utils.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/oslo/db/sqlalchemy/utils.py b/oslo/db/sqlalchemy/utils.py
index 179e416..5abd4ea 100644
--- a/oslo/db/sqlalchemy/utils.py
+++ b/oslo/db/sqlalchemy/utils.py
@@ -16,15 +16,19 @@
# License for the specific language governing permissions and limitations
# under the License.
+import collections
import logging
import re
from oslo.utils import timeutils
+import six
import sqlalchemy
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
+from sqlalchemy.engine import Connectable
from sqlalchemy.engine import reflection
+from sqlalchemy.engine import url as sa_url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import func
from sqlalchemy import Index
@@ -777,3 +781,205 @@ def column_exists(engine, table_name, column):
"""
t = get_table(engine, table_name)
return column in t.c
+
+
+class DialectFunctionDispatcher(object):
+ @classmethod
+ def dispatch_for_dialect(cls, expr, multiple=False):
+ """Provide dialect-specific functionality within distinct functions.
+
+ e.g.::
+
+ @dispatch_for_dialect("*")
+ def set_special_option(engine):
+ pass
+
+ @set_special_option.dispatch_for("sqlite")
+ def set_sqlite_special_option(engine):
+ return engine.execute("sqlite thing")
+
+ @set_special_option.dispatch_for("mysql+mysqldb")
+ def set_mysqldb_special_option(engine):
+ return engine.execute("mysqldb thing")
+
+ After the above registration, the ``set_special_option()`` function
+ is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``,
+ URL string, or ``sqlalchemy.engine.URL`` object::
+
+ eng = create_engine('...')
+ result = set_special_option(eng)
+
+ The filter system supports two modes, "multiple" and "single".
+ The default is "single", and requires that one and only one function
+ match for a given backend. In this mode, the function may also
+ have a return value, which will be returned by the top level
+ call.
+
+ "multiple" mode, on the other hand, does not support return
+ arguments, but allows for any number of matching functions, where
+ each function will be called::
+
+ # the initial call sets this up as a "multiple" dispatcher
+ @dispatch_for_dialect("*", multiple=True)
+ def set_options(engine):
+ # set options that apply to *all* engines
+
+ @set_options.dispatch_for("postgresql")
+ def set_postgresql_options(engine):
+ # set options that apply to all Postgresql engines
+
+ @set_options.dispatch_for("postgresql+psycopg2")
+ def set_postgresql_psycopg2_options(engine):
+ # set options that apply only to "postgresql+psycopg2"
+
+ @set_options.dispatch_for("*+pyodbc")
+ def set_pyodbc_options(engine):
+ # set options that apply to all pyodbc backends
+
+ Note that in both modes, any number of additional arguments can be
+ accepted by member functions. For example, to populate a dictionary of
+ options, it may be passed in::
+
+ @dispatch_for_dialect("*", multiple=True)
+ def set_engine_options(url, opts):
+ pass
+
+ @set_engine_options.dispatch_for("mysql+mysqldb")
+ def _mysql_set_default_charset_to_utf8(url, opts):
+ opts.setdefault('charset', 'utf-8')
+
+ @set_engine_options.dispatch_for("sqlite")
+ def _set_sqlite_in_memory_check_same_thread(url, opts):
+ if url.database in (None, 'memory'):
+ opts['check_same_thread'] = False
+
+ opts = {}
+ set_engine_options(url, opts)
+
+ The driver specifiers are of the form:
+ ``<database | *>[+<driver | *>]``. That is, database name or "*",
+ followed by an optional ``+`` sign with driver or "*". Omitting
+ the driver name implies all drivers for that database.
+
+ """
+ if multiple:
+ cls = DialectMultiFunctionDispatcher
+ else:
+ cls = DialectSingleFunctionDispatcher
+ return cls().dispatch_for(expr)
+
+ _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$')
+
+ def dispatch_for(self, expr):
+ def decorate(fn):
+ dbname, driver = self._parse_dispatch(expr)
+ self._register(expr, dbname, driver, fn)
+ return self
+ return decorate
+
+ def _parse_dispatch(self, text):
+ m = self._db_plus_driver_reg.match(text)
+ if not m:
+ raise ValueError("Couldn't parse database[+driver]: %r" % text)
+ return m.group(1) or '*', m.group(2) or '*'
+
+ def __call__(self, *arg, **kw):
+ target = arg[0]
+ return self._dispatch_on(
+ self._url_from_target(target), target, arg, kw)
+
+ def _url_from_target(self, target):
+ if isinstance(target, Connectable):
+ return target.engine.url
+ elif isinstance(target, six.string_types):
+ if "://" not in target:
+ target_url = sa_url.make_url("%s://" % target)
+ else:
+ target_url = sa_url.make_url(target)
+ return target_url
+ elif isinstance(target, sa_url.URL):
+ return target
+ else:
+ raise ValueError("Invalid target type: %r" % target)
+
+ def dispatch_on_drivername(self, drivername):
+ """Return a sub-dispatcher for the given drivername.
+
+ This provides a means of calling a different function, such as the
+ "*" function, for a given target object that normally refers
+ to a sub-function.
+
+ """
+ dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2)
+
+ def go(*arg, **kw):
+ return self._dispatch_on_db_driver(dbname, "*", arg, kw)
+
+ return go
+
+ def _dispatch_on(self, url, target, arg, kw):
+ dbname, driver = self._db_plus_driver_reg.match(
+ url.drivername).group(1, 2)
+ if not driver:
+ driver = url.get_dialect().driver
+
+ return self._dispatch_on_db_driver(dbname, driver, arg, kw)
+
+ def _invoke_fn(self, fn, arg, kw):
+ return fn(*arg, **kw)
+
+
+class DialectSingleFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(dict)
+
+ def _register(self, expr, dbname, driver, fn):
+ fn_dict = self.reg[dbname]
+ if driver in fn_dict:
+ raise TypeError("Multiple functions for expression %r" % expr)
+ fn_dict[driver] = fn
+
+ def _matches(self, dbname, driver):
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in (driver, '*'):
+ if drv in subdict:
+ return subdict[drv]
+ else:
+ raise ValueError(
+ "No default function found for driver: %r" %
+ ("%s+%s" % (dbname, driver)))
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ fn = self._matches(dbname, driver)
+ return self._invoke_fn(fn, arg, kw)
+
+
+class DialectMultiFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(
+ lambda: collections.defaultdict(list))
+
+ def _register(self, expr, dbname, driver, fn):
+ self.reg[dbname][driver].append(fn)
+
+ def _matches(self, dbname, driver):
+ if driver != '*':
+ drivers = (driver, '*')
+ else:
+ drivers = ('*', )
+
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in drivers:
+ for fn in subdict[drv]:
+ yield fn
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ for fn in self._matches(dbname, driver):
+ if self._invoke_fn(fn, arg, kw) is not None:
+ raise TypeError(
+ "Return value not allowed for "
+ "multiple filtered function")
+
+dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect