summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-07-29 15:09:16 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-08-21 17:51:18 -0400
commit5ab43eefa50186bd5c7520b341f7499e3fc0601a (patch)
tree9615d3dc4ce03ee9942565126a6bf8192f188365
parent468347515dbb7c0a693c9da61017555748ffb5cb (diff)
downloadoslo-db-5ab43eefa50186bd5c7520b341f7499e3fc0601a.tar.gz
Use dialect dispatch for engine initiailization.
This patch makes use of the new DialectFunctionDispatcher to remove all conditional logic from within session.create_engine(). The flags that result in various engine connect arguments as well as subsequent event listeners now fall under the realm of dispatched functions keyed to their target dialects. Additionally, this modernizes the SQLite foreign keys listener to use events. A test suite is added which via mocks asserts that arguments expected to be present as well as non-present for a variety of url situations is added. Change-Id: Iba5e77e1396eda46ac2c6d6cbca159cbfadf4d04
-rw-r--r--oslo/db/sqlalchemy/session.py116
-rw-r--r--tests/sqlalchemy/test_sqlalchemy.py160
2 files changed, 239 insertions, 37 deletions
diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py
index f47c462..6478384 100644
--- a/oslo/db/sqlalchemy/session.py
+++ b/oslo/db/sqlalchemy/session.py
@@ -286,9 +286,8 @@ import time
from oslo.utils import timeutils
import six
-from sqlalchemy.interfaces import PoolListener
import sqlalchemy.orm
-from sqlalchemy.pool import NullPool, StaticPool
+from sqlalchemy import pool
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql.expression import select
@@ -296,19 +295,19 @@ from oslo.db import exception
from oslo.db.openstack.common.gettextutils import _LW
from oslo.db import options
from oslo.db.sqlalchemy import exc_filters
+from oslo.db.sqlalchemy import utils
LOG = logging.getLogger(__name__)
-class SqliteForeignKeysListener(PoolListener):
+def _sqlite_foreign_keys_listener(dbapi_con, con_record):
"""Ensures that the foreign key constraints are enforced in SQLite.
The foreign key constraints are disabled by default in SQLite,
so the foreign key constraints will be enabled here for every
database connection
"""
- def connect(self, dbapi_con, con_record):
- dbapi_con.execute('pragma foreign_keys=ON')
+ dbapi_con.execute('pragma foreign_keys=ON')
def _synchronous_switch_listener(dbapi_conn, connection_rec):
@@ -424,7 +423,7 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
thread_checkin=True):
"""Return a new SQLAlchemy engine."""
- connection_dict = sqlalchemy.engine.url.make_url(sql_connection)
+ url = sqlalchemy.engine.url.make_url(sql_connection)
engine_args = {
"pool_recycle": idle_timeout,
@@ -441,54 +440,97 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
else:
logger.setLevel(logging.WARNING)
- if "sqlite" in connection_dict.drivername:
- if sqlite_fk:
- engine_args["listeners"] = [SqliteForeignKeysListener()]
- engine_args["poolclass"] = NullPool
+ _init_connection_args(
+ url, engine_args,
+ sqlite_fk=sqlite_fk,
+ max_pool_size=max_pool_size,
+ max_overflow=max_overflow,
+ pool_timeout=pool_timeout
+ )
+
+ engine = sqlalchemy.create_engine(url, **engine_args)
+
+ _init_events(
+ engine,
+ mysql_sql_mode=mysql_sql_mode,
+ sqlite_synchronous=sqlite_synchronous,
+ sqlite_fk=sqlite_fk,
+ thread_checkin=thread_checkin,
+ connection_trace=connection_trace
+ )
- if sql_connection == "sqlite://":
- engine_args["poolclass"] = StaticPool
- engine_args["connect_args"] = {'check_same_thread': False}
- else:
+ # register alternate exception handler
+ exc_filters.register_engine(engine)
+
+ # register on begin handler
+ sqlalchemy.event.listen(engine, "begin", _begin_ping_listener)
+
+ # initial connect + test
+ _test_connection(engine, max_retries, retry_interval)
+
+ return engine
+
+
+@utils.dispatch_for_dialect('*', multiple=True)
+def _init_connection_args(
+ url, engine_args,
+ max_pool_size=None, max_overflow=None, pool_timeout=None, **kw):
+
+ pool_class = url.get_dialect().get_pool_class(url)
+ if issubclass(pool_class, pool.QueuePool):
if max_pool_size is not None:
engine_args['pool_size'] = max_pool_size
if max_overflow is not None:
engine_args['max_overflow'] = max_overflow
if pool_timeout is not None:
engine_args['pool_timeout'] = pool_timeout
- if connection_dict.get_dialect().driver == 'mysqlconnector':
- # mysqlconnector engine (<1.0) incorrectly defaults to
- # raise_on_warnings=True
- # https://bitbucket.org/zzzeek/sqlalchemy/issue/2515
- engine_args['connect_args'] = {'raise_on_warnings': False}
- engine = sqlalchemy.create_engine(sql_connection, **engine_args)
- if thread_checkin:
- sqlalchemy.event.listen(engine, 'checkin', _thread_yield)
+@_init_connection_args.dispatch_for("sqlite")
+def _init_connection_args(url, engine_args, **kw):
+ pool_class = url.get_dialect().get_pool_class(url)
+ # singletonthreadpool is used for :memory: connections;
+ # replace it with StaticPool.
+ if issubclass(pool_class, pool.SingletonThreadPool):
+ engine_args["poolclass"] = pool.StaticPool
+ engine_args["connect_args"] = {'check_same_thread': False}
+
- if engine.name == 'mysql':
- if mysql_sql_mode is not None:
- _mysql_set_mode_callback(engine, mysql_sql_mode)
- elif 'sqlite' in connection_dict.drivername:
- if not sqlite_synchronous:
- sqlalchemy.event.listen(engine, 'connect',
- _synchronous_switch_listener)
- sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener)
+@_init_connection_args.dispatch_for("mysql+mysqlconnector")
+def _init_connection_args(url, engine_args, **kw):
+ # mysqlconnector engine (<1.0) incorrectly defaults to
+ # raise_on_warnings=True
+ # https://bitbucket.org/zzzeek/sqlalchemy/issue/2515
+ if "raise_on_warnings" not in url.query:
+ engine_args['connect_args'] = {'raise_on_warnings': False}
+
+@utils.dispatch_for_dialect('*', multiple=True)
+def _init_events(engine, thread_checkin=True, connection_trace=False, **kw):
if connection_trace:
_add_trace_comments(engine)
- # register alternate exception handler
- exc_filters.register_engine(engine)
+ if thread_checkin:
+ sqlalchemy.event.listen(engine, 'checkin', _thread_yield)
- # register on begin handler
- sqlalchemy.event.listen(engine, "begin", _begin_ping_listener)
- # initial connect + test
- _test_connection(engine, max_retries, retry_interval)
+@_init_events.dispatch_for("mysql")
+def _init_events(engine, mysql_sql_mode=None, **kw):
+ if mysql_sql_mode is not None:
+ _mysql_set_mode_callback(engine, mysql_sql_mode)
- return engine
+
+@_init_events.dispatch_for("sqlite")
+def _init_events(engine, sqlite_synchronous=True, sqlite_fk=False, **kw):
+ if not sqlite_synchronous:
+ sqlalchemy.event.listen(engine, 'connect',
+ _synchronous_switch_listener)
+ sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener)
+
+ if sqlite_fk:
+ sqlalchemy.event.listen(
+ engine, 'connect',
+ _sqlite_foreign_keys_listener)
def _test_connection(engine, max_retries, retry_interval):
diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py
index ca0388c..6c3caa0 100644
--- a/tests/sqlalchemy/test_sqlalchemy.py
+++ b/tests/sqlalchemy/test_sqlalchemy.py
@@ -16,6 +16,7 @@
# under the License.
"""Unit tests for SQLAlchemy specific code."""
+import contextlib
import logging
import fixtures
@@ -454,6 +455,165 @@ class MysqlSetCallbackTest(oslo_test.BaseTestCase):
self.assertEqual(exp_calls, engine._execs)
+class CreateEngineTest(oslo_test.BaseTestCase):
+ """Test that dialect-specific arguments/ listeners are set up correctly.
+
+ """
+
+ @contextlib.contextmanager
+ def _fixture(self):
+ # use a SQLite engine so that the tests don't rely on
+ # specific DBAPIs being installed...
+ real_engine = sqlalchemy.create_engine("sqlite://")
+
+ def muck_real_engine(url, **kw):
+ # ...but alter the URL on the engine so that the dispatch
+ # system goes off the fake URL we are using.
+ real_engine.url = url
+ return real_engine
+
+ with contextlib.nested(
+ mock.patch.object(
+ session.sqlalchemy, "create_engine",
+ mock.Mock(side_effect=muck_real_engine)),
+ mock.patch.object(session.sqlalchemy.event, "listen")
+ ) as (create_engine, listen_evt):
+ yield create_engine, listen_evt
+
+ def _assert_event_listened(self, listen_evt, evt_name, listen_fn):
+ call_set = set([
+ call[1][1:] for call in listen_evt.mock_calls
+ ])
+ self.assertTrue(
+ (evt_name, listen_fn) in call_set
+ )
+
+ def _assert_event_not_listened(self, listen_evt, evt_name, listen_fn):
+ call_set = set([
+ call[1][1:] for call in listen_evt.mock_calls
+ ])
+ self.assertTrue(
+ (evt_name, listen_fn) not in call_set
+ )
+
+ def test_queuepool_args(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine(
+ "mysql://u:p@host/test",
+ max_pool_size=10, max_overflow=10)
+ self.assertEqual(create_engine.mock_calls[0][2]['pool_size'], 10)
+ self.assertEqual(create_engine.mock_calls[0][2][
+ 'max_overflow'], 10)
+
+ def test_sqlite_memory_pool_args(self):
+ for url in ("sqlite://", "sqlite:///:memory:"):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine(
+ url,
+ max_pool_size=10, max_overflow=10)
+
+ # queuepool arguments are not peresnet
+ self.assertTrue(
+ 'pool_size' not in create_engine.mock_calls[0][2])
+ self.assertTrue(
+ 'max_overflow' not in create_engine.mock_calls[0][2])
+
+ self.assertEqual(
+ create_engine.mock_calls[0][2]['connect_args'],
+ {'check_same_thread': False}
+ )
+
+ # due to memory connection
+ self.assertTrue('poolclass' in create_engine.mock_calls[0][2])
+
+ def test_sqlite_file_pool_args(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine(
+ "sqlite:///somefile.db",
+ max_pool_size=10, max_overflow=10)
+
+ # queuepool arguments are not peresnet
+ self.assertTrue('pool_size' not in create_engine.mock_calls[0][2])
+ self.assertTrue(
+ 'max_overflow' not in create_engine.mock_calls[0][2])
+
+ self.assertTrue(
+ 'connect_args' not in create_engine.mock_calls[0][2]
+ )
+
+ # NullPool is the default for file based connections,
+ # no need to specify this
+ self.assertTrue('poolclass' not in create_engine.mock_calls[0][2])
+
+ def test_mysql_connect_args_default(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("mysql+mysqldb://u:p@host/test")
+ self.assertTrue(
+ 'connect_args' not in create_engine.mock_calls[0][2]
+ )
+
+ def test_mysqlconnector_raise_on_warnings_default(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("mysql+mysqlconnector://u:p@host/test")
+ self.assertEqual(
+ create_engine.mock_calls[0][2]['connect_args'],
+ {'raise_on_warnings': False}
+ )
+
+ def test_mysqlconnector_raise_on_warnings_override(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine(
+ "mysql+mysqlconnector://u:p@host/test?raise_on_warnings=true")
+ self.assertTrue(
+ 'connect_args' not in create_engine.mock_calls[0][2]
+ )
+
+ def test_thread_checkin(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("mysql+mysqldb://u:p@host/test")
+
+ self._assert_event_listened(
+ listen_evt, "checkin", session._thread_yield)
+
+ def test_sqlite_fk_listener(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("sqlite://", sqlite_fk=True)
+
+ self._assert_event_listened(
+ listen_evt, "connect", session._sqlite_foreign_keys_listener)
+
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("mysql://", sqlite_fk=True)
+
+ self._assert_event_not_listened(
+ listen_evt, "connect", session._sqlite_foreign_keys_listener)
+
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("sqlite://", sqlite_fk=False)
+
+ self._assert_event_not_listened(
+ listen_evt, "connect", session._sqlite_foreign_keys_listener)
+
+ def test_sqlite_synchronous_listener(self):
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("sqlite://")
+
+ self._assert_event_not_listened(
+ listen_evt, "connect", session._synchronous_switch_listener)
+
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("mysql://", sqlite_synchronous=True)
+
+ self._assert_event_not_listened(
+ listen_evt, "connect", session._synchronous_switch_listener)
+
+ with self._fixture() as (create_engine, listen_evt):
+ session.create_engine("sqlite://", sqlite_synchronous=False)
+
+ self._assert_event_listened(
+ listen_evt, "connect", session._synchronous_switch_listener)
+
+
class PatchStacktraceTest(test_base.DbTestCase):
def test_trace(self):