From 626d762cbdd9c0898bd8f1a0d73c11335501187a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 10 Aug 2014 15:46:18 -0400 Subject: Consolidate sqlite and mysql event listeners This patch refactors the SQLite and MySQL event listener setup in session.py to occur inline within the new _init_events() hooks for each of these backends. This is to improve code readability and reduce unneeded boilerplate. In particular, _mysql_set_mode_callback() was implemented using four functions, all private to session.py, unused anywhere else and also not consistently prefixed with "_mysql". These functions have been consolidated into a single event listen plus a quick check of SQL mode. In particular, the system now emits the logging check for "sql_mode" in all cases. On the SQLite side, three functions, again not all prefixed with _sqlite, are consolidated into a single event listening function wihtin _init_events(). The test suite is altered in several ways. The tests for connect arguments no longer inject a sqlite engine into session.create_engine(); now that MySQL does a particular statement every time, this was no longer feasible, and in general other setup functions might emit SQL also, so the CreateEngineTest tests now test directly into session._init_connection_args. The tests for MySQL modes and SQLite connection settings move into being plain round-trip tests, reducing lots of mocking logic and ultimately doing a more complete test against a real backend. Change-Id: I04ca3fbb1f46a64ed6552496ae995f77aa0b2315 --- oslo/db/sqlalchemy/session.py | 140 +++++-------- tests/sqlalchemy/test_sqlalchemy.py | 383 ++++++++++++++---------------------- 2 files changed, 195 insertions(+), 328 deletions(-) diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py index 6478384..ea381e5 100644 --- a/oslo/db/sqlalchemy/session.py +++ b/oslo/db/sqlalchemy/session.py @@ -278,7 +278,6 @@ Efficient use of soft deletes: """ -import functools import itertools import logging import re @@ -300,30 +299,6 @@ from oslo.db.sqlalchemy import utils LOG = logging.getLogger(__name__) -def _sqlite_foreign_keys_listener(dbapi_con, con_record): - """Ensures that the foreign key constraints are enforced in SQLite. - - The foreign key constraints are disabled by default in SQLite, - so the foreign key constraints will be enabled here for every - database connection - """ - dbapi_con.execute('pragma foreign_keys=ON') - - -def _synchronous_switch_listener(dbapi_conn, connection_rec): - """Switch sqlite connections to non-synchronous mode.""" - dbapi_conn.execute("PRAGMA synchronous = OFF") - - -def _add_regexp_listener(dbapi_con, con_record): - """Add REGEXP function to sqlite connections.""" - - def regexp(expr, item): - reg = re.compile(expr) - return reg.search(six.text_type(item)) is not None - dbapi_con.create_function('regexp', 2, regexp) - - def _thread_yield(dbapi_con, con_record): """Ensure other greenthreads get a chance to be executed. @@ -356,65 +331,6 @@ def _begin_ping_listener(connection): connection.scalar(select([1])) -def _set_session_sql_mode(dbapi_con, connection_rec, sql_mode=None): - """Set the sql_mode session variable. - - MySQL supports several server modes. The default is None, but sessions - may choose to enable server modes like TRADITIONAL, ANSI, - several STRICT_* modes and others. - - Note: passing in '' (empty string) for sql_mode clears - the SQL mode for the session, overriding a potentially set - server default. - """ - - cursor = dbapi_con.cursor() - cursor.execute("SET SESSION sql_mode = %s", [sql_mode]) - - -def _mysql_get_effective_sql_mode(engine): - """Returns the effective SQL mode for connections from the engine pool. - - Returns ``None`` if the mode isn't available, otherwise returns the mode. - - """ - # Get the real effective SQL mode. Even when unset by - # our own config, the server may still be operating in a specific - # SQL mode as set by the server configuration. - # Also note that the checkout listener will be called on execute to - # set the mode if it's registered. - row = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone() - if row is None: - return - return row[1] - - -def _mysql_check_effective_sql_mode(engine): - """Logs a message based on the effective SQL mode for MySQL connections.""" - realmode = _mysql_get_effective_sql_mode(engine) - - if realmode is None: - LOG.warning(_LW('Unable to detect effective SQL mode')) - return - - LOG.debug('MySQL server mode set to %s', realmode) - # 'TRADITIONAL' mode enables several other modes, so - # we need a substring match here - if not ('TRADITIONAL' in realmode.upper() or - 'STRICT_ALL_TABLES' in realmode.upper()): - LOG.warning(_LW("MySQL SQL mode is '%s', " - "consider enabling TRADITIONAL or STRICT_ALL_TABLES"), - realmode) - - -def _mysql_set_mode_callback(engine, sql_mode): - if sql_mode is not None: - mode_callback = functools.partial(_set_session_sql_mode, - sql_mode=sql_mode) - sqlalchemy.event.listen(engine, 'connect', mode_callback) - _mysql_check_effective_sql_mode(engine) - - def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, idle_timeout=3600, connection_debug=0, max_pool_size=None, max_overflow=None, @@ -507,6 +423,8 @@ def _init_connection_args(url, engine_args, **kw): @utils.dispatch_for_dialect('*', multiple=True) def _init_events(engine, thread_checkin=True, connection_trace=False, **kw): + """Set up event listeners for all database backends.""" + if connection_trace: _add_trace_comments(engine) @@ -516,21 +434,55 @@ def _init_events(engine, thread_checkin=True, connection_trace=False, **kw): @_init_events.dispatch_for("mysql") def _init_events(engine, mysql_sql_mode=None, **kw): + """Set up event listeners for MySQL.""" + if mysql_sql_mode is not None: - _mysql_set_mode_callback(engine, mysql_sql_mode) + @sqlalchemy.event.listens_for(engine, "connect") + def _set_session_sql_mode(dbapi_con, connection_rec): + cursor = dbapi_con.cursor() + cursor.execute("SET SESSION sql_mode = %s", [mysql_sql_mode]) + + realmode = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone() + if realmode is None: + LOG.warning(_LW('Unable to detect effective SQL mode')) + else: + realmode = realmode[1] + LOG.debug('MySQL server mode set to %s', realmode) + if 'TRADITIONAL' not in realmode.upper() and \ + 'STRICT_ALL_TABLES' not in realmode.upper(): + LOG.warning( + _LW( + "MySQL SQL mode is '%s', " + "consider enabling TRADITIONAL or STRICT_ALL_TABLES"), + realmode) @_init_events.dispatch_for("sqlite") def _init_events(engine, sqlite_synchronous=True, sqlite_fk=False, **kw): - if not sqlite_synchronous: - sqlalchemy.event.listen(engine, 'connect', - _synchronous_switch_listener) - sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener) - - if sqlite_fk: - sqlalchemy.event.listen( - engine, 'connect', - _sqlite_foreign_keys_listener) + """Set up event listeners for SQLite. + + This includes several settings made on connections as they are + created, as well as transactional control extensions. + + """ + + def regexp(expr, item): + reg = re.compile(expr) + return reg.search(six.text_type(item)) is not None + + @sqlalchemy.event.listens_for(engine, "connect") + def _sqlite_connect_events(dbapi_con, con_record): + + # Add REGEXP functionality on SQLite connections + dbapi_con.create_function('regexp', 2, regexp) + + if not sqlite_synchronous: + # Switch sqlite connections to non-synchronous mode + dbapi_con.execute("PRAGMA synchronous = OFF") + + if sqlite_fk: + # Ensures that the foreign key constraints are enforced in SQLite. + dbapi_con.execute('pragma foreign_keys=ON') def _test_connection(engine, max_retries, retry_interval): diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 6c3caa0..88ec605 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -16,7 +16,7 @@ # under the License. """Unit tests for SQLAlchemy specific code.""" -import contextlib + import logging import fixtures @@ -25,6 +25,7 @@ from oslo.config import cfg from oslotest import base as oslo_test import sqlalchemy from sqlalchemy import Column, MetaData, Table +from sqlalchemy.engine import url from sqlalchemy import Integer, String from sqlalchemy.ext.declarative import declarative_base @@ -295,164 +296,141 @@ class EngineFacadeTestCase(oslo_test.BaseTestCase): self.assertEqual(master_path, str(slave_session.bind.url)) -class MysqlSetCallbackTest(oslo_test.BaseTestCase): - - class FakeCursor(object): - def __init__(self, execs): - self._execs = execs - - def execute(self, sql, arg): - self._execs.append(sql % arg) +class SQLiteConnectTest(oslo_test.BaseTestCase): - class FakeDbapiCon(object): - def __init__(self, execs): - self._execs = execs + def _fixture(self, **kw): + return session.create_engine("sqlite://", **kw) - def cursor(self): - return MysqlSetCallbackTest.FakeCursor(self._execs) + def test_sqlite_fk_listener(self): + engine = self._fixture(sqlite_fk=True) + self.assertEqual( + engine.scalar("pragma foreign_keys"), + 1 + ) - class FakeResultSet(object): - def __init__(self, realmode): - self._realmode = realmode + engine = self._fixture(sqlite_fk=False) - def fetchone(self): - return ['ignored', self._realmode] + self.assertEqual( + engine.scalar("pragma foreign_keys"), + 0 + ) - class FakeEngine(object): - def __init__(self, realmode=None): - self._cbs = {} - self._execs = [] - self._realmode = realmode - self._connected = False + def test_sqlite_synchronous_listener(self): + engine = self._fixture() - def set_callback(self, name, cb): - self._cbs[name] = cb + # "The default setting is synchronous=FULL." (e.g. 2) + # http://www.sqlite.org/pragma.html#pragma_synchronous + self.assertEqual( + engine.scalar("pragma synchronous"), + 2 + ) - def connect(self, **kwargs): - cb = self._cbs.get('connect', lambda *x, **y: None) - dbapi_con = MysqlSetCallbackTest.FakeDbapiCon(self._execs) - connection_rec = None # Not used. - cb(dbapi_con, connection_rec) + engine = self._fixture(sqlite_synchronous=False) - def execute(self, sql): - if not self._connected: - self.connect() - self._connected = True - self._execs.append(sql) - return MysqlSetCallbackTest.FakeResultSet(self._realmode) + self.assertEqual( + engine.scalar("pragma synchronous"), + 0 + ) - def stub_listen(engine, name, cb): - engine.set_callback(name, cb) - @mock.patch.object(sqlalchemy.event, 'listen', side_effect=stub_listen) - def _call_set_callback(self, listen_mock, sql_mode=None, realmode=None): - engine = self.FakeEngine(realmode=realmode) +class MysqlConnectTest(test_base.MySQLOpportunisticTestCase): - self.stream = self.useFixture(fixtures.FakeLogger( - format="%(levelname)8s [%(name)s] %(message)s", - level=logging.DEBUG, - nuke_handlers=True - )) + def _fixture(self, sql_mode): + return session.create_engine(self.engine.url, mysql_sql_mode=sql_mode) - session._mysql_set_mode_callback(engine, sql_mode=sql_mode) - return engine + def _assert_sql_mode(self, engine, sql_mode_present, sql_mode_non_present): + mode = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + self.assertTrue( + sql_mode_present in mode + ) + if sql_mode_non_present: + self.assertTrue( + sql_mode_non_present not in mode + ) def test_set_mode_traditional(self): - # If _mysql_set_mode_callback is called with an sql_mode, then the SQL - # mode is set on the connection. - - engine = self._call_set_callback(sql_mode='TRADITIONAL') - - exp_calls = [ - "SET SESSION sql_mode = ['TRADITIONAL']", - "SHOW VARIABLES LIKE 'sql_mode'" - ] - self.assertEqual(exp_calls, engine._execs) + engine = self._fixture(sql_mode='TRADITIONAL') + self._assert_sql_mode(engine, "TRADITIONAL", "ANSI") def test_set_mode_ansi(self): - # If _mysql_set_mode_callback is called with an sql_mode, then the SQL - # mode is set on the connection. - - engine = self._call_set_callback(sql_mode='ANSI') - - exp_calls = [ - "SET SESSION sql_mode = ['ANSI']", - "SHOW VARIABLES LIKE 'sql_mode'" - ] - self.assertEqual(exp_calls, engine._execs) + engine = self._fixture(sql_mode='ANSI') + self._assert_sql_mode(engine, "ANSI", "TRADITIONAL") def test_set_mode_no_mode(self): # If _mysql_set_mode_callback is called with sql_mode=None, then # the SQL mode is NOT set on the connection. - engine = self._call_set_callback() + expected = self.engine.execute( + "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] - exp_calls = [ - "SHOW VARIABLES LIKE 'sql_mode'" - ] - self.assertEqual(exp_calls, engine._execs) + engine = self._fixture(sql_mode=None) + self._assert_sql_mode(engine, expected, None) def test_fail_detect_mode(self): # If "SHOW VARIABLES LIKE 'sql_mode'" results in no row, then # we get a log indicating can't detect the mode. - self._call_set_callback() + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + + engine = self._fixture(sql_mode=None) + + @sqlalchemy.event.listens_for( + engine, "before_cursor_execute", retval=True) + def replace_stmt( + conn, cursor, statement, parameters, + context, executemany): + if "SHOW VARIABLES LIKE 'sql_mode'" in statement: + statement = "SHOW VARIABLES LIKE 'i_dont_exist'" + return statement, parameters + + session._init_events.dispatch_on_drivername("mysql")(engine) self.assertIn('Unable to detect effective SQL mode', - self.stream.output) + log.output) def test_logs_real_mode(self): # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value, then # we get a log with the value. - self._call_set_callback(realmode='SOMETHING') + log = self.useFixture(fixtures.FakeLogger(level=logging.DEBUG)) + + engine = self._fixture(sql_mode='TRADITIONAL') - self.assertIn('MySQL server mode set to SOMETHING', - self.stream.output) + actual_mode = engine.execute( + "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + + self.assertIn('MySQL server mode set to %s' % actual_mode, + log.output) def test_warning_when_not_traditional(self): # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that doesn't # include 'TRADITIONAL', then a warning is logged. - self._call_set_callback(realmode='NOT_TRADIT') + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='ANSI') self.assertIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", - self.stream.output) + log.output) def test_no_warning_when_traditional(self): # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes # 'TRADITIONAL', then no warning is logged. - self._call_set_callback(realmode='TRADITIONAL') + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='TRADITIONAL') self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", - self.stream.output) + log.output) def test_no_warning_when_strict_all_tables(self): # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes # 'STRICT_ALL_TABLES', then no warning is logged. - self._call_set_callback(realmode='STRICT_ALL_TABLES') + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='TRADITIONAL') self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", - self.stream.output) - - def test_multiple_executes(self): - # We should only set the sql_mode on a connection once. - - engine = self._call_set_callback(sql_mode='TRADITIONAL', - realmode='TRADITIONAL') - - engine.execute('SELECT * FROM foo') - engine.execute('SELECT * FROM bar') - - exp_calls = [ - "SET SESSION sql_mode = ['TRADITIONAL']", - "SHOW VARIABLES LIKE 'sql_mode'", - "SELECT * FROM foo", - "SELECT * FROM bar", - ] - self.assertEqual(exp_calls, engine._execs) + log.output) class CreateEngineTest(oslo_test.BaseTestCase): @@ -460,158 +438,95 @@ class CreateEngineTest(oslo_test.BaseTestCase): """ - @contextlib.contextmanager - def _fixture(self): - # use a SQLite engine so that the tests don't rely on - # specific DBAPIs being installed... - real_engine = sqlalchemy.create_engine("sqlite://") - - def muck_real_engine(url, **kw): - # ...but alter the URL on the engine so that the dispatch - # system goes off the fake URL we are using. - real_engine.url = url - return real_engine - - with contextlib.nested( - mock.patch.object( - session.sqlalchemy, "create_engine", - mock.Mock(side_effect=muck_real_engine)), - mock.patch.object(session.sqlalchemy.event, "listen") - ) as (create_engine, listen_evt): - yield create_engine, listen_evt - - def _assert_event_listened(self, listen_evt, evt_name, listen_fn): - call_set = set([ - call[1][1:] for call in listen_evt.mock_calls - ]) - self.assertTrue( - (evt_name, listen_fn) in call_set - ) - - def _assert_event_not_listened(self, listen_evt, evt_name, listen_fn): - call_set = set([ - call[1][1:] for call in listen_evt.mock_calls - ]) - self.assertTrue( - (evt_name, listen_fn) not in call_set - ) - def test_queuepool_args(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine( - "mysql://u:p@host/test", - max_pool_size=10, max_overflow=10) - self.assertEqual(create_engine.mock_calls[0][2]['pool_size'], 10) - self.assertEqual(create_engine.mock_calls[0][2][ - 'max_overflow'], 10) + args = {} + session._init_connection_args( + url.make_url("mysql://u:p@host/test"), args, + max_pool_size=10, max_overflow=10) + self.assertEqual(args['pool_size'], 10) + self.assertEqual(args['max_overflow'], 10) def test_sqlite_memory_pool_args(self): - for url in ("sqlite://", "sqlite:///:memory:"): - with self._fixture() as (create_engine, listen_evt): - session.create_engine( - url, - max_pool_size=10, max_overflow=10) - - # queuepool arguments are not peresnet - self.assertTrue( - 'pool_size' not in create_engine.mock_calls[0][2]) - self.assertTrue( - 'max_overflow' not in create_engine.mock_calls[0][2]) - - self.assertEqual( - create_engine.mock_calls[0][2]['connect_args'], - {'check_same_thread': False} - ) - - # due to memory connection - self.assertTrue('poolclass' in create_engine.mock_calls[0][2]) - - def test_sqlite_file_pool_args(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine( - "sqlite:///somefile.db", + for _url in ("sqlite://", "sqlite:///:memory:"): + args = {} + session._init_connection_args( + url.make_url(_url), args, max_pool_size=10, max_overflow=10) # queuepool arguments are not peresnet - self.assertTrue('pool_size' not in create_engine.mock_calls[0][2]) self.assertTrue( - 'max_overflow' not in create_engine.mock_calls[0][2]) - + 'pool_size' not in args) self.assertTrue( - 'connect_args' not in create_engine.mock_calls[0][2] - ) - - # NullPool is the default for file based connections, - # no need to specify this - self.assertTrue('poolclass' not in create_engine.mock_calls[0][2]) + 'max_overflow' not in args) - def test_mysql_connect_args_default(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine("mysql+mysqldb://u:p@host/test") - self.assertTrue( - 'connect_args' not in create_engine.mock_calls[0][2] - ) - - def test_mysqlconnector_raise_on_warnings_default(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine("mysql+mysqlconnector://u:p@host/test") self.assertEqual( - create_engine.mock_calls[0][2]['connect_args'], - {'raise_on_warnings': False} + args['connect_args'], + {'check_same_thread': False} ) - def test_mysqlconnector_raise_on_warnings_override(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine( - "mysql+mysqlconnector://u:p@host/test?raise_on_warnings=true") - self.assertTrue( - 'connect_args' not in create_engine.mock_calls[0][2] - ) - - def test_thread_checkin(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine("mysql+mysqldb://u:p@host/test") - - self._assert_event_listened( - listen_evt, "checkin", session._thread_yield) - - def test_sqlite_fk_listener(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine("sqlite://", sqlite_fk=True) - - self._assert_event_listened( - listen_evt, "connect", session._sqlite_foreign_keys_listener) + # due to memory connection + self.assertTrue('poolclass' in args) - with self._fixture() as (create_engine, listen_evt): - session.create_engine("mysql://", sqlite_fk=True) + def test_sqlite_file_pool_args(self): + args = {} + session._init_connection_args( + url.make_url("sqlite:///somefile.db"), args, + max_pool_size=10, max_overflow=10) - self._assert_event_not_listened( - listen_evt, "connect", session._sqlite_foreign_keys_listener) + # queuepool arguments are not peresnet + self.assertTrue('pool_size' not in args) + self.assertTrue( + 'max_overflow' not in args) - with self._fixture() as (create_engine, listen_evt): - session.create_engine("sqlite://", sqlite_fk=False) + self.assertTrue( + 'connect_args' not in args + ) - self._assert_event_not_listened( - listen_evt, "connect", session._sqlite_foreign_keys_listener) + # NullPool is the default for file based connections, + # no need to specify this + self.assertTrue('poolclass' not in args) - def test_sqlite_synchronous_listener(self): - with self._fixture() as (create_engine, listen_evt): - session.create_engine("sqlite://") - - self._assert_event_not_listened( - listen_evt, "connect", session._synchronous_switch_listener) + def test_mysql_connect_args_default(self): + args = {} + session._init_connection_args( + url.make_url("mysql+mysqldb://u:p@host/test"), args) + self.assertTrue( + 'connect_args' not in args + ) - with self._fixture() as (create_engine, listen_evt): - session.create_engine("mysql://", sqlite_synchronous=True) + def test_mysqlconnector_raise_on_warnings_default(self): + args = {} + session._init_connection_args( + url.make_url("mysql+mysqlconnector://u:p@host/test"), + args) + self.assertEqual( + args, + {'connect_args': {'raise_on_warnings': False}} + ) - self._assert_event_not_listened( - listen_evt, "connect", session._synchronous_switch_listener) + def test_mysqlconnector_raise_on_warnings_override(self): + args = {} + session._init_connection_args( + url.make_url( + "mysql+mysqlconnector://u:p@host/test" + "?raise_on_warnings=true"), + args + ) - with self._fixture() as (create_engine, listen_evt): - session.create_engine("sqlite://", sqlite_synchronous=False) + self.assertTrue( + 'connect_args' not in args + ) - self._assert_event_listened( - listen_evt, "connect", session._synchronous_switch_listener) + def test_thread_checkin(self): + with mock.patch("sqlalchemy.event.listens_for"): + with mock.patch("sqlalchemy.event.listen") as listen_evt: + session._init_events.dispatch_on_drivername( + "sqlite")(mock.Mock()) + + self.assertEqual( + listen_evt.mock_calls[0][1][-1], + session._thread_yield + ) class PatchStacktraceTest(test_base.DbTestCase): -- cgit v1.2.1