diff options
Diffstat (limited to 'oslo_db/tests/sqlalchemy')
-rw-r--r-- | oslo_db/tests/sqlalchemy/__init__.py | 0 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_engine_connect.py | 68 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_exc_filters.py | 833 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_handle_error.py | 194 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_migrate_cli.py | 222 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_migration_common.py | 212 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_migrations.py | 309 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_models.py | 146 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_options.py | 127 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_sqlalchemy.py | 655 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_utils.py | 1090 |
11 files changed, 3856 insertions, 0 deletions
diff --git a/oslo_db/tests/sqlalchemy/__init__.py b/oslo_db/tests/sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/__init__.py diff --git a/oslo_db/tests/sqlalchemy/test_engine_connect.py b/oslo_db/tests/sqlalchemy/test_engine_connect.py new file mode 100644 index 0000000..c75511f --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_engine_connect.py @@ -0,0 +1,68 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test the compatibility layer for the engine_connect() event. + +This event is added as of SQLAlchemy 0.9.0; oslo_db provides a compatibility +layer for prior SQLAlchemy versions. + +""" + +import mock +from oslotest import base as test_base +import sqlalchemy as sqla + +from oslo_db.sqlalchemy.compat import engine_connect + + +class EngineConnectTest(test_base.BaseTestCase): + + def setUp(self): + super(EngineConnectTest, self).setUp() + + self.engine = engine = sqla.create_engine("sqlite://") + self.addCleanup(engine.dispose) + + def test_connect_event(self): + engine = self.engine + + listener = mock.Mock() + engine_connect(engine, listener) + + conn = engine.connect() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False)] + ) + + conn.close() + + conn2 = engine.connect() + conn2.close() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False), mock.call(conn2, False)] + ) + + def test_branch(self): + engine = self.engine + + listener = mock.Mock() + engine_connect(engine, listener) + + conn = engine.connect() + branched = conn.connect() + conn.close() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False), mock.call(branched, True)] + ) diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py new file mode 100644 index 0000000..157a183 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py @@ -0,0 +1,833 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test exception filters applied to engines.""" + +import contextlib +import itertools + +import mock +from oslotest import base as oslo_test_base +import six +import sqlalchemy as sqla +from sqlalchemy.orm import mapper + +from oslo_db import exception +from oslo_db.sqlalchemy import compat +from oslo_db.sqlalchemy import exc_filters +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import test_base +from oslo_db.tests import utils as test_utils + +_TABLE_NAME = '__tmp__test__tmp__' + + +class _SQLAExceptionMatcher(object): + def assertInnerException( + self, + matched, exception_type, message, sql=None, params=None): + + exc = matched.inner_exception + self.assertSQLAException(exc, exception_type, message, sql, params) + + def assertSQLAException( + self, + exc, exception_type, message, sql=None, params=None): + if isinstance(exception_type, (type, tuple)): + self.assertTrue(issubclass(exc.__class__, exception_type)) + else: + self.assertEqual(exc.__class__.__name__, exception_type) + self.assertEqual(str(exc.orig).lower(), message.lower()) + if sql is not None: + self.assertEqual(exc.statement, sql) + if params is not None: + self.assertEqual(exc.params, params) + + +class TestsExceptionFilter(_SQLAExceptionMatcher, oslo_test_base.BaseTestCase): + + class Error(Exception): + """DBAPI base error. + + This exception and subclasses are used in a mock context + within these tests. + + """ + + class OperationalError(Error): + pass + + class InterfaceError(Error): + pass + + class InternalError(Error): + pass + + class IntegrityError(Error): + pass + + class ProgrammingError(Error): + pass + + class TransactionRollbackError(OperationalError): + """Special psycopg2-only error class. + + SQLAlchemy has an issue with this per issue #3075: + + https://bitbucket.org/zzzeek/sqlalchemy/issue/3075/ + + """ + + def setUp(self): + super(TestsExceptionFilter, self).setUp() + self.engine = sqla.create_engine("sqlite://") + exc_filters.register_engine(self.engine) + self.engine.connect().close() # initialize + + @contextlib.contextmanager + def _dbapi_fixture(self, dialect_name): + engine = self.engine + with test_utils.nested( + mock.patch.object(engine.dialect.dbapi, + "Error", + self.Error), + mock.patch.object(engine.dialect, "name", dialect_name), + ): + yield + + @contextlib.contextmanager + def _fixture(self, dialect_name, exception, is_disconnect=False): + + def do_execute(self, cursor, statement, parameters, **kw): + raise exception + + engine = self.engine + + # ensure the engine has done its initial checks against the + # DB as we are going to be removing its ability to execute a + # statement + self.engine.connect().close() + + with test_utils.nested( + mock.patch.object(engine.dialect, "do_execute", do_execute), + # replace the whole DBAPI rather than patching "Error" + # as some DBAPIs might not be patchable (?) + mock.patch.object(engine.dialect, + "dbapi", + mock.Mock(Error=self.Error)), + mock.patch.object(engine.dialect, "name", dialect_name), + mock.patch.object(engine.dialect, + "is_disconnect", + lambda *args: is_disconnect) + ): + yield + + def _run_test(self, dialect_name, statement, raises, expected, + is_disconnect=False, params=()): + with self._fixture(dialect_name, raises, is_disconnect=is_disconnect): + with self.engine.connect() as conn: + matched = self.assertRaises( + expected, conn.execute, statement, params + ) + return matched + + +class TestFallthroughsAndNonDBAPI(TestsExceptionFilter): + + def test_generic_dbapi(self): + matched = self._run_test( + "mysql", "select you_made_a_programming_error", + self.ProgrammingError("Error 123, you made a mistake"), + exception.DBError + ) + self.assertInnerException( + matched, + "ProgrammingError", + "Error 123, you made a mistake", + 'select you_made_a_programming_error', ()) + + def test_generic_dbapi_disconnect(self): + matched = self._run_test( + "mysql", "select the_db_disconnected", + self.InterfaceError("connection lost"), + exception.DBConnectionError, + is_disconnect=True + ) + self.assertInnerException( + matched, + "InterfaceError", "connection lost", + "select the_db_disconnected", ()), + + def test_operational_dbapi_disconnect(self): + matched = self._run_test( + "mysql", "select the_db_disconnected", + self.OperationalError("connection lost"), + exception.DBConnectionError, + is_disconnect=True + ) + self.assertInnerException( + matched, + "OperationalError", "connection lost", + "select the_db_disconnected", ()), + + def test_operational_error_asis(self): + """Test operational errors. + + test that SQLAlchemy OperationalErrors that aren't disconnects + are passed through without wrapping. + """ + + matched = self._run_test( + "mysql", "select some_operational_error", + self.OperationalError("some op error"), + sqla.exc.OperationalError + ) + self.assertSQLAException( + matched, + "OperationalError", "some op error" + ) + + def test_unicode_encode(self): + # intentionally generate a UnicodeEncodeError, as its + # constructor is quite complicated and seems to be non-public + # or at least not documented anywhere. + uee_ref = None + try: + six.u('\u2435').encode('ascii') + except UnicodeEncodeError as uee: + # Python3.x added new scoping rules here (sadly) + # http://legacy.python.org/dev/peps/pep-3110/#semantic-changes + uee_ref = uee + + self._run_test( + "postgresql", six.u('select \u2435'), + uee_ref, + exception.DBInvalidUnicodeParameter + ) + + def test_garden_variety(self): + matched = self._run_test( + "mysql", "select some_thing_that_breaks", + AttributeError("mysqldb has an attribute error"), + exception.DBError + ) + self.assertEqual("mysqldb has an attribute error", matched.args[0]) + + +class TestReferenceErrorSQLite(_SQLAExceptionMatcher, test_base.DbTestCase): + + def setUp(self): + super(TestReferenceErrorSQLite, self).setUp() + + meta = sqla.MetaData(bind=self.engine) + + table_1 = sqla.Table( + "resource_foo", meta, + sqla.Column("id", sqla.Integer, primary_key=True), + sqla.Column("foo", sqla.Integer), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) + table_1.create() + + self.table_2 = sqla.Table( + "resource_entity", meta, + sqla.Column("id", sqla.Integer, primary_key=True), + sqla.Column("foo_id", sqla.Integer, + sqla.ForeignKey("resource_foo.id", name="foo_fkey")), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) + self.table_2.create() + + def test_raise(self): + self.engine.execute("PRAGMA foreign_keys = ON;") + + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert({'id': 1, 'foo_id': 2}) + ) + + self.assertInnerException( + matched, + "IntegrityError", + "FOREIGN KEY constraint failed", + 'INSERT INTO resource_entity (id, foo_id) VALUES (?, ?)', + (1, 2) + ) + + self.assertIsNone(matched.table) + self.assertIsNone(matched.constraint) + self.assertIsNone(matched.key) + self.assertIsNone(matched.key_table) + + +class TestReferenceErrorPostgreSQL(TestReferenceErrorSQLite, + test_base.PostgreSQLOpportunisticTestCase): + def test_raise(self): + params = {'id': 1, 'foo_id': 2} + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert(params) + ) + self.assertInnerException( + matched, + "IntegrityError", + "insert or update on table \"resource_entity\" " + "violates foreign key constraint \"foo_fkey\"\nDETAIL: Key " + "(foo_id)=(2) is not present in table \"resource_foo\".\n", + "INSERT INTO resource_entity (id, foo_id) VALUES (%(id)s, " + "%(foo_id)s)", + params, + ) + + self.assertEqual("resource_entity", matched.table) + self.assertEqual("foo_fkey", matched.constraint) + self.assertEqual("foo_id", matched.key) + self.assertEqual("resource_foo", matched.key_table) + + +class TestReferenceErrorMySQL(TestReferenceErrorSQLite, + test_base.MySQLOpportunisticTestCase): + def test_raise(self): + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert({'id': 1, 'foo_id': 2}) + ) + + self.assertInnerException( + matched, + "IntegrityError", + "(1452, 'Cannot add or update a child row: a " + "foreign key constraint fails (`{0}`.`resource_entity`, " + "CONSTRAINT `foo_fkey` FOREIGN KEY (`foo_id`) REFERENCES " + "`resource_foo` (`id`))')".format(self.engine.url.database), + "INSERT INTO resource_entity (id, foo_id) VALUES (%s, %s)", + (1, 2) + ) + self.assertEqual("resource_entity", matched.table) + self.assertEqual("foo_fkey", matched.constraint) + self.assertEqual("foo_id", matched.key) + self.assertEqual("resource_foo", matched.key_table) + + def test_raise_ansi_quotes(self): + self.engine.execute("SET SESSION sql_mode = 'ANSI';") + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert({'id': 1, 'foo_id': 2}) + ) + + self.assertInnerException( + matched, + "IntegrityError", + '(1452, \'Cannot add or update a child row: a ' + 'foreign key constraint fails ("{0}"."resource_entity", ' + 'CONSTRAINT "foo_fkey" FOREIGN KEY ("foo_id") REFERENCES ' + '"resource_foo" ("id"))\')'.format(self.engine.url.database), + "INSERT INTO resource_entity (id, foo_id) VALUES (%s, %s)", + (1, 2) + ) + self.assertEqual("resource_entity", matched.table) + self.assertEqual("foo_fkey", matched.constraint) + self.assertEqual("foo_id", matched.key) + self.assertEqual("resource_foo", matched.key_table) + + +class TestDuplicate(TestsExceptionFilter): + + def _run_dupe_constraint_test(self, dialect_name, message, + expected_columns=['a', 'b'], + expected_value=None): + matched = self._run_test( + dialect_name, "insert into table some_values", + self.IntegrityError(message), + exception.DBDuplicateEntry + ) + self.assertEqual(expected_columns, matched.columns) + self.assertEqual(expected_value, matched.value) + + def _not_dupe_constraint_test(self, dialect_name, statement, message, + expected_cls): + matched = self._run_test( + dialect_name, statement, + self.IntegrityError(message), + expected_cls + ) + self.assertInnerException( + matched, + "IntegrityError", + str(self.IntegrityError(message)), + statement + ) + + def test_sqlite(self): + self._run_dupe_constraint_test("sqlite", 'column a, b are not unique') + + def test_sqlite_3_7_16_or_3_8_2_and_higher(self): + self._run_dupe_constraint_test( + "sqlite", + 'UNIQUE constraint failed: tbl.a, tbl.b') + + def test_sqlite_dupe_primary_key(self): + self._run_dupe_constraint_test( + "sqlite", + "PRIMARY KEY must be unique 'insert into t values(10)'", + expected_columns=[]) + + def test_mysql_mysqldb(self): + self._run_dupe_constraint_test( + "mysql", + '(1062, "Duplicate entry ' + '\'2-3\' for key \'uniq_tbl0a0b\'")', expected_value='2-3') + + def test_mysql_mysqlconnector(self): + self._run_dupe_constraint_test( + "mysql", + '1062 (23000): Duplicate entry ' + '\'2-3\' for key \'uniq_tbl0a0b\'")', expected_value='2-3') + + def test_postgresql(self): + self._run_dupe_constraint_test( + 'postgresql', + 'duplicate key value violates unique constraint' + '"uniq_tbl0a0b"' + '\nDETAIL: Key (a, b)=(2, 3) already exists.\n', + expected_value='2, 3' + ) + + def test_mysql_single(self): + self._run_dupe_constraint_test( + "mysql", + "1062 (23000): Duplicate entry '2' for key 'b'", + expected_columns=['b'], + expected_value='2' + ) + + def test_postgresql_single(self): + self._run_dupe_constraint_test( + 'postgresql', + 'duplicate key value violates unique constraint "uniq_tbl0b"\n' + 'DETAIL: Key (b)=(2) already exists.\n', + expected_columns=['b'], + expected_value='2' + ) + + def test_unsupported_backend(self): + self._not_dupe_constraint_test( + "nonexistent", "insert into table some_values", + self.IntegrityError("constraint violation"), + exception.DBError + ) + + def test_ibm_db_sa(self): + self._run_dupe_constraint_test( + 'ibm_db_sa', + 'SQL0803N One or more values in the INSERT statement, UPDATE ' + 'statement, or foreign key update caused by a DELETE statement are' + ' not valid because the primary key, unique constraint or unique ' + 'index identified by "2" constrains table "NOVA.KEY_PAIRS" from ' + 'having duplicate values for the index key.', + expected_columns=[] + ) + + def test_ibm_db_sa_notadupe(self): + self._not_dupe_constraint_test( + 'ibm_db_sa', + 'ALTER TABLE instance_types ADD CONSTRAINT ' + 'uniq_name_x_deleted UNIQUE (name, deleted)', + 'SQL0542N The column named "NAME" cannot be a column of a ' + 'primary key or unique key constraint because it can contain null ' + 'values.', + exception.DBError + ) + + +class TestDeadlock(TestsExceptionFilter): + statement = ('SELECT quota_usages.created_at AS ' + 'quota_usages_created_at FROM quota_usages ' + 'WHERE quota_usages.project_id = %(project_id_1)s ' + 'AND quota_usages.deleted = %(deleted_1)s FOR UPDATE') + params = { + 'project_id_1': '8891d4478bbf48ad992f050cdf55e9b5', + 'deleted_1': 0 + } + + def _run_deadlock_detect_test( + self, dialect_name, message, + orig_exception_cls=TestsExceptionFilter.OperationalError): + self._run_test( + dialect_name, self.statement, + orig_exception_cls(message), + exception.DBDeadlock, + params=self.params + ) + + def _not_deadlock_test( + self, dialect_name, message, + expected_cls, expected_dbapi_cls, + orig_exception_cls=TestsExceptionFilter.OperationalError): + + matched = self._run_test( + dialect_name, self.statement, + orig_exception_cls(message), + expected_cls, + params=self.params + ) + + if isinstance(matched, exception.DBError): + matched = matched.inner_exception + + self.assertEqual(matched.orig.__class__.__name__, expected_dbapi_cls) + + def test_mysql_mysqldb_deadlock(self): + self._run_deadlock_detect_test( + "mysql", + "(1213, 'Deadlock found when trying " + "to get lock; try restarting " + "transaction')" + ) + + def test_mysql_mysqldb_galera_deadlock(self): + self._run_deadlock_detect_test( + "mysql", + "(1205, 'Lock wait timeout exceeded; " + "try restarting transaction')" + ) + + def test_mysql_mysqlconnector_deadlock(self): + self._run_deadlock_detect_test( + "mysql", + "1213 (40001): Deadlock found when trying to get lock; try " + "restarting transaction", + orig_exception_cls=self.InternalError + ) + + def test_mysql_not_deadlock(self): + self._not_deadlock_test( + "mysql", + "(1005, 'some other error')", + sqla.exc.OperationalError, # note OperationalErrors are sent thru + "OperationalError", + ) + + def test_postgresql_deadlock(self): + self._run_deadlock_detect_test( + "postgresql", + "deadlock detected", + orig_exception_cls=self.TransactionRollbackError + ) + + def test_postgresql_not_deadlock(self): + self._not_deadlock_test( + "postgresql", + 'relation "fake" does not exist', + # can be either depending on #3075 + (exception.DBError, sqla.exc.OperationalError), + "TransactionRollbackError", + orig_exception_cls=self.TransactionRollbackError + ) + + def test_ibm_db_sa_deadlock(self): + self._run_deadlock_detect_test( + "ibm_db_sa", + "SQL0911N The current transaction has been " + "rolled back because of a deadlock or timeout", + # use the lowest class b.c. I don't know what actual error + # class DB2's driver would raise for this + orig_exception_cls=self.Error + ) + + def test_ibm_db_sa_not_deadlock(self): + self._not_deadlock_test( + "ibm_db_sa", + "SQL01234B Some other error.", + exception.DBError, + "Error", + orig_exception_cls=self.Error + ) + + +class IntegrationTest(test_base.DbTestCase): + """Test an actual error-raising round trips against the database.""" + + def setUp(self): + super(IntegrationTest, self).setUp() + meta = sqla.MetaData() + self.test_table = sqla.Table( + _TABLE_NAME, meta, + sqla.Column('id', sqla.Integer, + primary_key=True, nullable=False), + sqla.Column('counter', sqla.Integer, + nullable=False), + sqla.UniqueConstraint('counter', + name='uniq_counter')) + self.test_table.create(self.engine) + self.addCleanup(self.test_table.drop, self.engine) + + class Foo(object): + def __init__(self, counter): + self.counter = counter + mapper(Foo, self.test_table) + self.Foo = Foo + + def test_flush_wrapper_duplicate_entry(self): + """test a duplicate entry exception.""" + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=1) + _session.add(foo) + self.assertRaises(exception.DBDuplicateEntry, _session.flush) + + def test_autoflush_wrapper_duplicate_entry(self): + """Test a duplicate entry exception raised. + + test a duplicate entry exception raised via query.all()-> autoflush + """ + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=1) + _session.add(foo) + self.assertTrue(_session.autoflush) + self.assertRaises(exception.DBDuplicateEntry, + _session.query(self.Foo).all) + + def test_flush_wrapper_plain_integrity_error(self): + """test a plain integrity error wrapped as DBError.""" + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=None) + _session.add(foo) + self.assertRaises(exception.DBError, _session.flush) + + def test_flush_wrapper_operational_error(self): + """test an operational error from flush() raised as-is.""" + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=sqla.func.imfake(123)) + _session.add(foo) + matched = self.assertRaises(sqla.exc.OperationalError, _session.flush) + self.assertTrue("no such function" in str(matched)) + + def test_query_wrapper_operational_error(self): + """test an operational error from query.all() raised as-is.""" + + _session = self.sessionmaker() + + _session.begin() + self.addCleanup(_session.rollback) + q = _session.query(self.Foo).filter( + self.Foo.counter == sqla.func.imfake(123)) + matched = self.assertRaises(sqla.exc.OperationalError, q.all) + self.assertTrue("no such function" in str(matched)) + + +class TestDBDisconnected(TestsExceptionFilter): + + @contextlib.contextmanager + def _fixture( + self, + dialect_name, exception, num_disconnects, is_disconnect=True): + engine = self.engine + + compat.engine_connect(engine, session._connect_ping_listener) + + real_do_execute = engine.dialect.do_execute + counter = itertools.count(1) + + def fake_do_execute(self, *arg, **kw): + if next(counter) > num_disconnects: + return real_do_execute(self, *arg, **kw) + else: + raise exception + + with self._dbapi_fixture(dialect_name): + with test_utils.nested( + mock.patch.object(engine.dialect, + "do_execute", + fake_do_execute), + mock.patch.object(engine.dialect, + "is_disconnect", + mock.Mock(return_value=is_disconnect)) + ): + yield + + def _test_ping_listener_disconnected( + self, dialect_name, exc_obj, is_disconnect=True): + with self._fixture(dialect_name, exc_obj, 1, is_disconnect): + conn = self.engine.connect() + with conn.begin(): + self.assertEqual(conn.scalar(sqla.select([1])), 1) + self.assertFalse(conn.closed) + self.assertFalse(conn.invalidated) + self.assertTrue(conn.in_transaction()) + + with self._fixture(dialect_name, exc_obj, 2, is_disconnect): + self.assertRaises( + exception.DBConnectionError, + self.engine.connect + ) + + # test implicit execution + with self._fixture(dialect_name, exc_obj, 1): + self.assertEqual(self.engine.scalar(sqla.select([1])), 1) + + def test_mysql_ping_listener_disconnected(self): + for code in [2006, 2013, 2014, 2045, 2055]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code) + ) + + def test_mysql_ping_listener_disconnected_regex_only(self): + # intentionally set the is_disconnect flag to False + # in the "sqlalchemy" layer to make sure the regexp + # on _is_db_connection_error is catching + for code in [2002, 2003, 2006, 2013]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code), + is_disconnect=False + ) + + def test_db2_ping_listener_disconnected(self): + self._test_ping_listener_disconnected( + "ibm_db_sa", + self.OperationalError( + 'SQL30081N: DB2 Server connection is no longer active') + ) + + def test_db2_ping_listener_disconnected_regex_only(self): + self._test_ping_listener_disconnected( + "ibm_db_sa", + self.OperationalError( + 'SQL30081N: DB2 Server connection is no longer active'), + is_disconnect=False + ) + + +class TestDBConnectRetry(TestsExceptionFilter): + + def _run_test(self, dialect_name, exception, count, retries): + counter = itertools.count() + + engine = self.engine + + # empty out the connection pool + engine.dispose() + + connect_fn = engine.dialect.connect + + def cant_connect(*arg, **kw): + if next(counter) < count: + raise exception + else: + return connect_fn(*arg, **kw) + + with self._dbapi_fixture(dialect_name): + with mock.patch.object(engine.dialect, "connect", cant_connect): + return session._test_connection(engine, retries, .01) + + def test_connect_no_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 0 + ) + # didnt connect because nothing was tried + self.assertIsNone(conn) + + def test_connect_inifinite_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, -1 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_connect_retry_past_failure(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 3 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_connect_retry_not_candidate_exception(self): + self.assertRaises( + sqla.exc.OperationalError, # remember, we pass OperationalErrors + # through at the moment :) + self._run_test, + "mysql", + self.OperationalError("Error: (2015) I can't connect period"), + 2, 3 + ) + + def test_connect_retry_stops_infailure(self): + self.assertRaises( + exception.DBConnectionError, + self._run_test, + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 3, 2 + ) + + def test_db2_error_positive(self): + conn = self._run_test( + "ibm_db_sa", + self.OperationalError("blah blah -30081 blah blah"), + 2, -1 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_db2_error_negative(self): + self.assertRaises( + sqla.exc.OperationalError, + self._run_test, + "ibm_db_sa", + self.OperationalError("blah blah -39981 blah blah"), + 2, 3 + ) diff --git a/oslo_db/tests/sqlalchemy/test_handle_error.py b/oslo_db/tests/sqlalchemy/test_handle_error.py new file mode 100644 index 0000000..83322ef --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_handle_error.py @@ -0,0 +1,194 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test the compatibility layer for the handle_error() event. + +This event is added as of SQLAlchemy 0.9.7; oslo_db provides a compatibility +layer for prior SQLAlchemy versions. + +""" + +import mock +from oslotest import base as test_base +import sqlalchemy as sqla +from sqlalchemy.sql import column +from sqlalchemy.sql import literal +from sqlalchemy.sql import select +from sqlalchemy.types import Integer +from sqlalchemy.types import TypeDecorator + +from oslo_db.sqlalchemy.compat import handle_error +from oslo_db.sqlalchemy.compat import utils +from oslo_db.tests import utils as test_utils + + +class MyException(Exception): + pass + + +class ExceptionReraiseTest(test_base.BaseTestCase): + + def setUp(self): + super(ExceptionReraiseTest, self).setUp() + + self.engine = engine = sqla.create_engine("sqlite://") + self.addCleanup(engine.dispose) + + def _fixture(self): + engine = self.engine + + def err(context): + if "ERROR ONE" in str(context.statement): + raise MyException("my exception") + handle_error(engine, err) + + def test_exception_event_altered(self): + self._fixture() + + with mock.patch.object(self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception") as patched: + + matchee = self.assertRaises( + MyException, + self.engine.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + ) + self.assertEqual(1, patched.call_count) + self.assertEqual("my exception", matchee.args[0]) + + def test_exception_event_non_altered(self): + self._fixture() + + with mock.patch.object(self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception") as patched: + + self.assertRaises( + sqla.exc.DBAPIError, + self.engine.execute, "SELECT 'ERROR TWO' FROM I_DONT_EXIST" + ) + self.assertEqual(1, patched.call_count) + + def test_is_disconnect_not_interrupted(self): + self._fixture() + + with test_utils.nested( + mock.patch.object( + self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception" + ), + mock.patch.object( + self.engine.dialect, "is_disconnect", + lambda *args: True + ) + ) as (handle_dbapi_exception, is_disconnect): + with self.engine.connect() as conn: + self.assertRaises( + MyException, + conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + ) + self.assertEqual(1, handle_dbapi_exception.call_count) + self.assertTrue(conn.invalidated) + + def test_no_is_disconnect_not_invalidated(self): + self._fixture() + + with test_utils.nested( + mock.patch.object( + self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception" + ), + mock.patch.object( + self.engine.dialect, "is_disconnect", + lambda *args: False + ) + ) as (handle_dbapi_exception, is_disconnect): + with self.engine.connect() as conn: + self.assertRaises( + MyException, + conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + ) + self.assertEqual(1, handle_dbapi_exception.call_count) + self.assertFalse(conn.invalidated) + + def test_exception_event_ad_hoc_context(self): + engine = self.engine + + nope = MyException("nope") + + class MyType(TypeDecorator): + impl = Integer + + def process_bind_param(self, value, dialect): + raise nope + + listener = mock.Mock(return_value=None) + handle_error(engine, listener) + + self.assertRaises( + sqla.exc.StatementError, + engine.execute, + select([1]).where(column('foo') == literal('bar', MyType)) + ) + + ctx = listener.mock_calls[0][1][0] + self.assertTrue(ctx.statement.startswith("SELECT 1 ")) + self.assertIs(ctx.is_disconnect, False) + self.assertIs(ctx.original_exception, nope) + + def _test_alter_disconnect(self, orig_error, evt_value): + engine = self.engine + + def evt(ctx): + ctx.is_disconnect = evt_value + handle_error(engine, evt) + + # if we are under sqla 0.9.7, and we are expecting to take + # an "is disconnect" exception and make it not a disconnect, + # that isn't supported b.c. the wrapped handler has already + # done the invalidation. + expect_failure = not utils.sqla_097 and orig_error and not evt_value + + with mock.patch.object(engine.dialect, + "is_disconnect", + mock.Mock(return_value=orig_error)): + + with engine.connect() as c: + conn_rec = c.connection._connection_record + try: + c.execute("SELECT x FROM nonexistent") + assert False + except sqla.exc.StatementError as st: + self.assertFalse(expect_failure) + + # check the exception's invalidation flag + self.assertEqual(st.connection_invalidated, evt_value) + + # check the Connection object's invalidation flag + self.assertEqual(c.invalidated, evt_value) + + # this is the ConnectionRecord object; it's invalidated + # when its .connection member is None + self.assertEqual(conn_rec.connection is None, evt_value) + + except NotImplementedError as ne: + self.assertTrue(expect_failure) + self.assertEqual( + str(ne), + "Can't reset 'disconnect' status of exception once it " + "is set with this version of SQLAlchemy") + + def test_alter_disconnect_to_true(self): + self._test_alter_disconnect(False, True) + self._test_alter_disconnect(True, True) + + def test_alter_disconnect_to_false(self): + self._test_alter_disconnect(True, False) + self._test_alter_disconnect(False, False) diff --git a/oslo_db/tests/sqlalchemy/test_migrate_cli.py b/oslo_db/tests/sqlalchemy/test_migrate_cli.py new file mode 100644 index 0000000..c1ab53c --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_migrate_cli.py @@ -0,0 +1,222 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock +from oslotest import base as test_base + +from oslo_db.sqlalchemy.migration_cli import ext_alembic +from oslo_db.sqlalchemy.migration_cli import ext_migrate +from oslo_db.sqlalchemy.migration_cli import manager + + +class MockWithCmp(mock.MagicMock): + + order = 0 + + def __init__(self, *args, **kwargs): + super(MockWithCmp, self).__init__(*args, **kwargs) + + self.__lt__ = lambda self, other: self.order < other.order + + +@mock.patch(('oslo_db.sqlalchemy.migration_cli.' + 'ext_alembic.alembic.command')) +class TestAlembicExtension(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'alembic_ini_path': '.', + 'db_url': 'sqlite://'} + self.alembic = ext_alembic.AlembicExtension(self.migration_config) + super(TestAlembicExtension, self).setUp() + + def test_check_enabled_true(self, command): + """Check enabled returns True + + Verifies that enabled returns True on non empty + alembic_ini_path conf variable + """ + self.assertTrue(self.alembic.enabled) + + def test_check_enabled_false(self, command): + """Check enabled returns False + + Verifies enabled returns False on empty alembic_ini_path variable + """ + self.migration_config['alembic_ini_path'] = '' + alembic = ext_alembic.AlembicExtension(self.migration_config) + self.assertFalse(alembic.enabled) + + def test_upgrade_none(self, command): + self.alembic.upgrade(None) + command.upgrade.assert_called_once_with(self.alembic.config, 'head') + + def test_upgrade_normal(self, command): + self.alembic.upgrade('131daa') + command.upgrade.assert_called_once_with(self.alembic.config, '131daa') + + def test_downgrade_none(self, command): + self.alembic.downgrade(None) + command.downgrade.assert_called_once_with(self.alembic.config, 'base') + + def test_downgrade_int(self, command): + self.alembic.downgrade(111) + command.downgrade.assert_called_once_with(self.alembic.config, 'base') + + def test_downgrade_normal(self, command): + self.alembic.downgrade('131daa') + command.downgrade.assert_called_once_with( + self.alembic.config, '131daa') + + def test_revision(self, command): + self.alembic.revision(message='test', autogenerate=True) + command.revision.assert_called_once_with( + self.alembic.config, message='test', autogenerate=True) + + def test_stamp(self, command): + self.alembic.stamp('stamp') + command.stamp.assert_called_once_with( + self.alembic.config, revision='stamp') + + def test_version(self, command): + version = self.alembic.version() + self.assertIsNone(version) + + +@mock.patch(('oslo_db.sqlalchemy.migration_cli.' + 'ext_migrate.migration')) +class TestMigrateExtension(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'migration_repo_path': '.', + 'db_url': 'sqlite://'} + self.migrate = ext_migrate.MigrateExtension(self.migration_config) + super(TestMigrateExtension, self).setUp() + + def test_check_enabled_true(self, migration): + self.assertTrue(self.migrate.enabled) + + def test_check_enabled_false(self, migration): + self.migration_config['migration_repo_path'] = '' + migrate = ext_migrate.MigrateExtension(self.migration_config) + self.assertFalse(migrate.enabled) + + def test_upgrade_head(self, migration): + self.migrate.upgrade('head') + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, None, init_version=0) + + def test_upgrade_normal(self, migration): + self.migrate.upgrade(111) + migration.db_sync.assert_called_once_with( + mock.ANY, self.migrate.repository, 111, init_version=0) + + def test_downgrade_init_version_from_base(self, migration): + self.migrate.downgrade('base') + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, mock.ANY, + init_version=mock.ANY) + + def test_downgrade_init_version_from_none(self, migration): + self.migrate.downgrade(None) + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, mock.ANY, + init_version=mock.ANY) + + def test_downgrade_normal(self, migration): + self.migrate.downgrade(101) + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, 101, init_version=0) + + def test_version(self, migration): + self.migrate.version() + migration.db_version.assert_called_once_with( + self.migrate.engine, self.migrate.repository, init_version=0) + + def test_change_init_version(self, migration): + self.migration_config['init_version'] = 101 + migrate = ext_migrate.MigrateExtension(self.migration_config) + migrate.downgrade(None) + migration.db_sync.assert_called_once_with( + migrate.engine, + self.migrate.repository, + self.migration_config['init_version'], + init_version=self.migration_config['init_version']) + + +class TestMigrationManager(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'alembic_ini_path': '.', + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} + self.migration_manager = manager.MigrationManager( + self.migration_config) + self.ext = mock.Mock() + self.migration_manager._manager.extensions = [self.ext] + super(TestMigrationManager, self).setUp() + + def test_manager_update(self): + self.migration_manager.upgrade('head') + self.ext.obj.upgrade.assert_called_once_with('head') + + def test_manager_update_revision_none(self): + self.migration_manager.upgrade(None) + self.ext.obj.upgrade.assert_called_once_with(None) + + def test_downgrade_normal_revision(self): + self.migration_manager.downgrade('111abcd') + self.ext.obj.downgrade.assert_called_once_with('111abcd') + + def test_version(self): + self.migration_manager.version() + self.ext.obj.version.assert_called_once_with() + + def test_revision_message_autogenerate(self): + self.migration_manager.revision('test', True) + self.ext.obj.revision.assert_called_once_with('test', True) + + def test_revision_only_message(self): + self.migration_manager.revision('test', False) + self.ext.obj.revision.assert_called_once_with('test', False) + + def test_stamp(self): + self.migration_manager.stamp('stamp') + self.ext.obj.stamp.assert_called_once_with('stamp') + + +class TestMigrationRightOrder(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'alembic_ini_path': '.', + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} + self.migration_manager = manager.MigrationManager( + self.migration_config) + self.first_ext = MockWithCmp() + self.first_ext.obj.order = 1 + self.first_ext.obj.upgrade.return_value = 100 + self.first_ext.obj.downgrade.return_value = 0 + self.second_ext = MockWithCmp() + self.second_ext.obj.order = 2 + self.second_ext.obj.upgrade.return_value = 200 + self.second_ext.obj.downgrade.return_value = 100 + self.migration_manager._manager.extensions = [self.first_ext, + self.second_ext] + super(TestMigrationRightOrder, self).setUp() + + def test_upgrade_right_order(self): + results = self.migration_manager.upgrade(None) + self.assertEqual(results, [100, 200]) + + def test_downgrade_right_order(self): + results = self.migration_manager.downgrade(None) + self.assertEqual(results, [100, 0]) diff --git a/oslo_db/tests/sqlalchemy/test_migration_common.py b/oslo_db/tests/sqlalchemy/test_migration_common.py new file mode 100644 index 0000000..95efab1 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_migration_common.py @@ -0,0 +1,212 @@ +# Copyright 2013 Mirantis Inc. +# All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +import os +import tempfile + +from migrate import exceptions as migrate_exception +from migrate.versioning import api as versioning_api +import mock +import sqlalchemy + +from oslo_db import exception as db_exception +from oslo_db.sqlalchemy import migration +from oslo_db.sqlalchemy import test_base +from oslo_db.tests import utils as test_utils + + +class TestMigrationCommon(test_base.DbTestCase): + def setUp(self): + super(TestMigrationCommon, self).setUp() + + migration._REPOSITORY = None + self.path = tempfile.mkdtemp('test_migration') + self.path1 = tempfile.mkdtemp('test_migration') + self.return_value = '/home/openstack/migrations' + self.return_value1 = '/home/extension/migrations' + self.init_version = 1 + self.test_version = 123 + + self.patcher_repo = mock.patch.object(migration, 'Repository') + self.repository = self.patcher_repo.start() + self.repository.side_effect = [self.return_value, self.return_value1] + + self.mock_api_db = mock.patch.object(versioning_api, 'db_version') + self.mock_api_db_version = self.mock_api_db.start() + self.mock_api_db_version.return_value = self.test_version + + def tearDown(self): + os.rmdir(self.path) + self.mock_api_db.stop() + self.patcher_repo.stop() + super(TestMigrationCommon, self).tearDown() + + def test_find_migrate_repo_path_not_found(self): + self.assertRaises( + db_exception.DbMigrationError, + migration._find_migrate_repo, + "/foo/bar/", + ) + self.assertIsNone(migration._REPOSITORY) + + def test_find_migrate_repo_called_once(self): + my_repository = migration._find_migrate_repo(self.path) + self.repository.assert_called_once_with(self.path) + self.assertEqual(my_repository, self.return_value) + + def test_find_migrate_repo_called_few_times(self): + repo1 = migration._find_migrate_repo(self.path) + repo2 = migration._find_migrate_repo(self.path1) + self.assertNotEqual(repo1, repo2) + + def test_db_version_control(self): + with test_utils.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(versioning_api, 'version_control'), + ) as (mock_find_repo, mock_version_control): + mock_find_repo.return_value = self.return_value + + version = migration.db_version_control( + self.engine, self.path, self.test_version) + + self.assertEqual(version, self.test_version) + mock_version_control.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_version_return(self): + ret_val = migration.db_version(self.engine, self.path, + self.init_version) + self.assertEqual(ret_val, self.test_version) + + def test_db_version_raise_not_controlled_error_first(self): + with mock.patch.object(migration, 'db_version_control') as mock_ver: + + self.mock_api_db_version.side_effect = [ + migrate_exception.DatabaseNotControlledError('oups'), + self.test_version] + + ret_val = migration.db_version(self.engine, self.path, + self.init_version) + self.assertEqual(ret_val, self.test_version) + mock_ver.assert_called_once_with(self.engine, self.path, + version=self.init_version) + + def test_db_version_raise_not_controlled_error_tables(self): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = \ + migrate_exception.DatabaseNotControlledError('oups') + my_meta = mock.MagicMock() + my_meta.tables = {'a': 1, 'b': 2} + mock_meta.return_value = my_meta + + self.assertRaises( + db_exception.DbMigrationError, migration.db_version, + self.engine, self.path, self.init_version) + + @mock.patch.object(versioning_api, 'version_control') + def test_db_version_raise_not_controlled_error_no_tables(self, mock_vc): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = ( + migrate_exception.DatabaseNotControlledError('oups'), + self.init_version) + my_meta = mock.MagicMock() + my_meta.tables = {} + mock_meta.return_value = my_meta + migration.db_version(self.engine, self.path, self.init_version) + + mock_vc.assert_called_once_with(self.engine, self.return_value1, + self.init_version) + + def test_db_sync_wrong_version(self): + self.assertRaises(db_exception.DbMigrationError, + migration.db_sync, self.engine, self.path, 'foo') + + def test_db_sync_upgrade(self): + init_ver = 55 + with test_utils.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(versioning_api, 'upgrade') + ) as (mock_find_repo, mock_upgrade): + + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version - 1 + + migration.db_sync(self.engine, self.path, self.test_version, + init_ver) + + mock_upgrade.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_sync_downgrade(self): + with test_utils.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_downgrade): + + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version + 1 + + migration.db_sync(self.engine, self.path, self.test_version) + + mock_downgrade.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_sync_sanity_called(self): + with test_utils.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(migration, '_db_schema_sanity_check'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_sanity, mock_downgrade): + + mock_find_repo.return_value = self.return_value + migration.db_sync(self.engine, self.path, self.test_version) + + mock_sanity.assert_called_once_with(self.engine) + + def test_db_sync_sanity_skipped(self): + with test_utils.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(migration, '_db_schema_sanity_check'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_sanity, mock_downgrade): + + mock_find_repo.return_value = self.return_value + migration.db_sync(self.engine, self.path, self.test_version, + sanity_check=False) + + self.assertFalse(mock_sanity.called) + + def test_db_sanity_table_not_utf8(self): + with mock.patch.object(self, 'engine') as mock_eng: + type(mock_eng).name = mock.PropertyMock(return_value='mysql') + mock_eng.execute.return_value = [['table_A', 'latin1'], + ['table_B', 'latin1']] + + self.assertRaises(ValueError, migration._db_schema_sanity_check, + mock_eng) + + def test_db_sanity_table_not_utf8_exclude_migrate_tables(self): + with mock.patch.object(self, 'engine') as mock_eng: + type(mock_eng).name = mock.PropertyMock(return_value='mysql') + # NOTE(morganfainberg): Check both lower and upper case versions + # of the migration table names (validate case insensitivity in + # the sanity check. + mock_eng.execute.return_value = [['migrate_version', 'latin1'], + ['alembic_version', 'latin1'], + ['MIGRATE_VERSION', 'latin1'], + ['ALEMBIC_VERSION', 'latin1']] + + migration._db_schema_sanity_check(mock_eng) diff --git a/oslo_db/tests/sqlalchemy/test_migrations.py b/oslo_db/tests/sqlalchemy/test_migrations.py new file mode 100644 index 0000000..a372d8b --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_migrations.py @@ -0,0 +1,309 @@ +# Copyright 2010-2011 OpenStack Foundation +# Copyright 2012-2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import fixtures +import mock +from oslotest import base as test +import six +import sqlalchemy as sa +import sqlalchemy.ext.declarative as sa_decl + +from oslo_db import exception as exc +from oslo_db.sqlalchemy import test_base +from oslo_db.sqlalchemy import test_migrations as migrate + + +class TestWalkVersions(test.BaseTestCase, migrate.WalkVersionsMixin): + migration_api = mock.MagicMock() + REPOSITORY = mock.MagicMock() + engine = mock.MagicMock() + INIT_VERSION = 4 + + @property + def migrate_engine(self): + return self.engine + + def test_migrate_up(self): + self.migration_api.db_version.return_value = 141 + + self.migrate_up(141) + + self.migration_api.upgrade.assert_called_with( + self.engine, self.REPOSITORY, 141) + self.migration_api.db_version.assert_called_with( + self.engine, self.REPOSITORY) + + def test_migrate_up_fail(self): + version = 141 + self.migration_api.db_version.return_value = version + expected_output = (u"Failed to migrate to version %(version)s on " + "engine %(engine)s\n" % + {'version': version, 'engine': self.engine}) + + with mock.patch.object(self.migration_api, + 'upgrade', + side_effect=exc.DbMigrationError): + log = self.useFixture(fixtures.FakeLogger()) + self.assertRaises(exc.DbMigrationError, self.migrate_up, version) + self.assertEqual(expected_output, log.output) + + def test_migrate_up_with_data(self): + test_value = {"a": 1, "b": 2} + self.migration_api.db_version.return_value = 141 + self._pre_upgrade_141 = mock.MagicMock() + self._pre_upgrade_141.return_value = test_value + self._check_141 = mock.MagicMock() + + self.migrate_up(141, True) + + self._pre_upgrade_141.assert_called_with(self.engine) + self._check_141.assert_called_with(self.engine, test_value) + + def test_migrate_down(self): + self.migration_api.db_version.return_value = 42 + + self.assertTrue(self.migrate_down(42)) + self.migration_api.db_version.assert_called_with( + self.engine, self.REPOSITORY) + + def test_migrate_down_not_implemented(self): + with mock.patch.object(self.migration_api, + 'downgrade', + side_effect=NotImplementedError): + self.assertFalse(self.migrate_down(self.engine, 42)) + + def test_migrate_down_with_data(self): + self._post_downgrade_043 = mock.MagicMock() + self.migration_api.db_version.return_value = 42 + + self.migrate_down(42, True) + + self._post_downgrade_043.assert_called_with(self.engine) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_all_default(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions() + + self.migration_api.version_control.assert_called_with( + self.engine, self.REPOSITORY, self.INIT_VERSION) + self.migration_api.db_version.assert_called_with( + self.engine, self.REPOSITORY) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + upgraded = [mock.call(v, with_data=True) + for v in versions] + self.assertEqual(self.migrate_up.call_args_list, upgraded) + + downgraded = [mock.call(v - 1) for v in reversed(versions)] + self.assertEqual(self.migrate_down.call_args_list, downgraded) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_all_true(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions(snake_walk=True, downgrade=True) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + upgraded = [] + for v in versions: + upgraded.append(mock.call(v, with_data=True)) + upgraded.append(mock.call(v)) + upgraded.extend([mock.call(v) for v in reversed(versions)]) + self.assertEqual(upgraded, self.migrate_up.call_args_list) + + downgraded_1 = [mock.call(v - 1, with_data=True) for v in versions] + downgraded_2 = [] + for v in reversed(versions): + downgraded_2.append(mock.call(v - 1)) + downgraded_2.append(mock.call(v - 1)) + downgraded = downgraded_1 + downgraded_2 + self.assertEqual(self.migrate_down.call_args_list, downgraded) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_true_false(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions(snake_walk=True, downgrade=False) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + upgraded = [] + for v in versions: + upgraded.append(mock.call(v, with_data=True)) + upgraded.append(mock.call(v)) + self.assertEqual(upgraded, self.migrate_up.call_args_list) + + downgraded = [mock.call(v - 1, with_data=True) for v in versions] + self.assertEqual(self.migrate_down.call_args_list, downgraded) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_all_false(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions(snake_walk=False, downgrade=False) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + upgraded = [mock.call(v, with_data=True) for v in versions] + self.assertEqual(upgraded, self.migrate_up.call_args_list) + + +class ModelsMigrationSyncMixin(test.BaseTestCase): + + def setUp(self): + super(ModelsMigrationSyncMixin, self).setUp() + + self.metadata = sa.MetaData() + self.metadata_migrations = sa.MetaData() + + sa.Table( + 'testtbl', self.metadata_migrations, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('spam', sa.String(10), nullable=False), + sa.Column('eggs', sa.DateTime), + sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.true()), + sa.Column('bool_wo_default', sa.Boolean), + sa.Column('bar', sa.Numeric(10, 5)), + sa.Column('defaulttest', sa.Integer, server_default='5'), + sa.Column('defaulttest2', sa.String(8), server_default=''), + sa.Column('defaulttest3', sa.String(5), server_default="test"), + sa.Column('defaulttest4', sa.Enum('first', 'second', + name='testenum'), + server_default="first"), + sa.Column('fk_check', sa.String(36), nullable=False), + sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'), + ) + + BASE = sa_decl.declarative_base(metadata=self.metadata) + + class TestModel(BASE): + __tablename__ = 'testtbl' + __table_args__ = ( + sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'), + ) + + id = sa.Column('id', sa.Integer, primary_key=True) + spam = sa.Column('spam', sa.String(10), nullable=False) + eggs = sa.Column('eggs', sa.DateTime) + foo = sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.true()) + fk_check = sa.Column('fk_check', sa.String(36), nullable=False) + bool_wo_default = sa.Column('bool_wo_default', sa.Boolean) + defaulttest = sa.Column('defaulttest', + sa.Integer, server_default='5') + defaulttest2 = sa.Column('defaulttest2', sa.String(8), + server_default='') + defaulttest3 = sa.Column('defaulttest3', sa.String(5), + server_default="test") + defaulttest4 = sa.Column('defaulttest4', sa.Enum('first', 'second', + name='testenum'), + server_default="first") + bar = sa.Column('bar', sa.Numeric(10, 5)) + + class ModelThatShouldNotBeCompared(BASE): + __tablename__ = 'testtbl2' + + id = sa.Column('id', sa.Integer, primary_key=True) + spam = sa.Column('spam', sa.String(10), nullable=False) + + def get_metadata(self): + return self.metadata + + def get_engine(self): + return self.engine + + def db_sync(self, engine): + self.metadata_migrations.create_all(bind=engine) + + def include_object(self, object_, name, type_, reflected, compare_to): + if type_ == 'table': + return name == 'testtbl' + else: + return True + + def _test_models_not_sync(self): + self.metadata_migrations.clear() + sa.Table( + 'table', self.metadata_migrations, + sa.Column('fk_check', sa.String(36), nullable=False), + sa.PrimaryKeyConstraint('fk_check'), + mysql_engine='InnoDB' + ) + sa.Table( + 'testtbl', self.metadata_migrations, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('spam', sa.String(8), nullable=True), + sa.Column('eggs', sa.DateTime), + sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.false()), + sa.Column('bool_wo_default', sa.Boolean, unique=True), + sa.Column('bar', sa.BigInteger), + sa.Column('defaulttest', sa.Integer, server_default='7'), + sa.Column('defaulttest2', sa.String(8), server_default=''), + sa.Column('defaulttest3', sa.String(5), server_default="fake"), + sa.Column('defaulttest4', + sa.Enum('first', 'second', name='testenum'), + server_default="first"), + sa.Column('fk_check', sa.String(36), nullable=False), + sa.UniqueConstraint('spam', 'foo', name='uniq_cons'), + sa.ForeignKeyConstraint(['fk_check'], ['table.fk_check']), + mysql_engine='InnoDB' + ) + + msg = six.text_type(self.assertRaises(AssertionError, + self.test_models_sync)) + # NOTE(I159): Check mentioning of the table and columns. + # The log is invalid json, so we can't parse it and check it for + # full compliance. We have no guarantee of the log items ordering, + # so we can't use regexp. + self.assertTrue(msg.startswith( + 'Models and migration scripts aren\'t in sync:')) + self.assertIn('testtbl', msg) + self.assertIn('spam', msg) + self.assertIn('eggs', msg) # test that the unique constraint is added + self.assertIn('foo', msg) + self.assertIn('bar', msg) + self.assertIn('bool_wo_default', msg) + self.assertIn('defaulttest', msg) + self.assertIn('defaulttest3', msg) + self.assertIn('drop_key', msg) + + +class ModelsMigrationsSyncMysql(ModelsMigrationSyncMixin, + migrate.ModelsMigrationsSync, + test_base.MySQLOpportunisticTestCase): + + def test_models_not_sync(self): + self._test_models_not_sync() + + +class ModelsMigrationsSyncPsql(ModelsMigrationSyncMixin, + migrate.ModelsMigrationsSync, + test_base.PostgreSQLOpportunisticTestCase): + + def test_models_not_sync(self): + self._test_models_not_sync() diff --git a/oslo_db/tests/sqlalchemy/test_models.py b/oslo_db/tests/sqlalchemy/test_models.py new file mode 100644 index 0000000..4a45576 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_models.py @@ -0,0 +1,146 @@ +# Copyright 2012 Cloudscaling Group, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections + +from oslotest import base as oslo_test +from sqlalchemy import Column +from sqlalchemy import Integer, String +from sqlalchemy.ext.declarative import declarative_base + +from oslo_db.sqlalchemy import models +from oslo_db.sqlalchemy import test_base + + +BASE = declarative_base() + + +class ModelBaseTest(test_base.DbTestCase): + def setUp(self): + super(ModelBaseTest, self).setUp() + self.mb = models.ModelBase() + self.ekm = ExtraKeysModel() + + def test_modelbase_has_dict_methods(self): + dict_methods = ('__getitem__', + '__setitem__', + '__contains__', + 'get', + 'update', + 'save', + 'iteritems') + for method in dict_methods: + self.assertTrue(hasattr(models.ModelBase, method), + "Method %s() is not found" % method) + + def test_modelbase_is_iterable(self): + self.assertTrue(issubclass(models.ModelBase, collections.Iterable)) + + def test_modelbase_set(self): + self.mb['world'] = 'hello' + self.assertEqual(self.mb['world'], 'hello') + + def test_modelbase_update(self): + h = {'a': '1', 'b': '2'} + self.mb.update(h) + for key in h.keys(): + self.assertEqual(self.mb[key], h[key]) + + def test_modelbase_contains(self): + mb = models.ModelBase() + h = {'a': '1', 'b': '2'} + mb.update(h) + for key in h.keys(): + # Test 'in' syntax (instead of using .assertIn) + self.assertTrue(key in mb) + + self.assertFalse('non-existent-key' in mb) + + def test_modelbase_iteritems(self): + h = {'a': '1', 'b': '2'} + expected = { + 'id': None, + 'smth': None, + 'name': 'NAME', + 'a': '1', + 'b': '2', + } + self.ekm.update(h) + self.assertEqual(dict(self.ekm.iteritems()), expected) + + def test_modelbase_iter(self): + expected = { + 'id': None, + 'smth': None, + 'name': 'NAME', + } + i = iter(self.ekm) + found_items = 0 + while True: + r = next(i, None) + if r is None: + break + self.assertEqual(expected[r[0]], r[1]) + found_items += 1 + + self.assertEqual(len(expected), found_items) + + def test_modelbase_several_iters(self): + mb = ExtraKeysModel() + it1 = iter(mb) + it2 = iter(mb) + + self.assertFalse(it1 is it2) + self.assertEqual(dict(it1), dict(mb)) + self.assertEqual(dict(it2), dict(mb)) + + def test_extra_keys_empty(self): + """Test verifies that by default extra_keys return empty list.""" + self.assertEqual(self.mb._extra_keys, []) + + def test_extra_keys_defined(self): + """Property _extra_keys will return list with attributes names.""" + self.assertEqual(self.ekm._extra_keys, ['name']) + + def test_model_with_extra_keys(self): + data = dict(self.ekm) + self.assertEqual(data, {'smth': None, + 'id': None, + 'name': 'NAME'}) + + +class ExtraKeysModel(BASE, models.ModelBase): + __tablename__ = 'test_model' + + id = Column(Integer, primary_key=True) + smth = Column(String(255)) + + @property + def name(self): + return 'NAME' + + @property + def _extra_keys(self): + return ['name'] + + +class TimestampMixinTest(oslo_test.BaseTestCase): + + def test_timestampmixin_attr(self): + methods = ('created_at', + 'updated_at') + for method in methods: + self.assertTrue(hasattr(models.TimestampMixin, method), + "Method %s() is not found" % method) diff --git a/oslo_db/tests/sqlalchemy/test_options.py b/oslo_db/tests/sqlalchemy/test_options.py new file mode 100644 index 0000000..22a6e4f --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_options.py @@ -0,0 +1,127 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg +from oslo.config import fixture as config + +from oslo_db import options +from oslo_db.tests import utils as test_utils + + +class DbApiOptionsTestCase(test_utils.BaseTestCase): + def setUp(self): + super(DbApiOptionsTestCase, self).setUp() + + config_fixture = self.useFixture(config.Config()) + self.conf = config_fixture.conf + self.conf.register_opts(options.database_opts, group='database') + self.config = config_fixture.config + + def test_deprecated_session_parameters(self): + path = self.create_tempfiles([["tmp", b"""[DEFAULT] +sql_connection=x://y.z +sql_min_pool_size=10 +sql_max_pool_size=20 +sql_max_retries=30 +sql_retry_interval=40 +sql_max_overflow=50 +sql_connection_debug=60 +sql_connection_trace=True +"""]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'x://y.z') + self.assertEqual(self.conf.database.min_pool_size, 10) + self.assertEqual(self.conf.database.max_pool_size, 20) + self.assertEqual(self.conf.database.max_retries, 30) + self.assertEqual(self.conf.database.retry_interval, 40) + self.assertEqual(self.conf.database.max_overflow, 50) + self.assertEqual(self.conf.database.connection_debug, 60) + self.assertEqual(self.conf.database.connection_trace, True) + + def test_session_parameters(self): + path = self.create_tempfiles([["tmp", b"""[database] +connection=x://y.z +min_pool_size=10 +max_pool_size=20 +max_retries=30 +retry_interval=40 +max_overflow=50 +connection_debug=60 +connection_trace=True +pool_timeout=7 +"""]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'x://y.z') + self.assertEqual(self.conf.database.min_pool_size, 10) + self.assertEqual(self.conf.database.max_pool_size, 20) + self.assertEqual(self.conf.database.max_retries, 30) + self.assertEqual(self.conf.database.retry_interval, 40) + self.assertEqual(self.conf.database.max_overflow, 50) + self.assertEqual(self.conf.database.connection_debug, 60) + self.assertEqual(self.conf.database.connection_trace, True) + self.assertEqual(self.conf.database.pool_timeout, 7) + + def test_dbapi_database_deprecated_parameters(self): + path = self.create_tempfiles([['tmp', b'[DATABASE]\n' + b'sql_connection=fake_connection\n' + b'sql_idle_timeout=100\n' + b'sql_min_pool_size=99\n' + b'sql_max_pool_size=199\n' + b'sql_max_retries=22\n' + b'reconnect_interval=17\n' + b'sqlalchemy_max_overflow=101\n' + b'sqlalchemy_pool_timeout=5\n' + ]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'fake_connection') + self.assertEqual(self.conf.database.idle_timeout, 100) + self.assertEqual(self.conf.database.min_pool_size, 99) + self.assertEqual(self.conf.database.max_pool_size, 199) + self.assertEqual(self.conf.database.max_retries, 22) + self.assertEqual(self.conf.database.retry_interval, 17) + self.assertEqual(self.conf.database.max_overflow, 101) + self.assertEqual(self.conf.database.pool_timeout, 5) + + def test_dbapi_database_deprecated_parameters_sql(self): + path = self.create_tempfiles([['tmp', b'[sql]\n' + b'connection=test_sql_connection\n' + b'idle_timeout=99\n' + ]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'test_sql_connection') + self.assertEqual(self.conf.database.idle_timeout, 99) + + def test_deprecated_dbapi_parameters(self): + path = self.create_tempfiles([['tmp', b'[DEFAULT]\n' + b'db_backend=test_123\n' + ]])[0] + + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.backend, 'test_123') + + def test_dbapi_parameters(self): + path = self.create_tempfiles([['tmp', b'[database]\n' + b'backend=test_123\n' + ]])[0] + + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.backend, 'test_123') + + def test_set_defaults(self): + conf = cfg.ConfigOpts() + + options.set_defaults(conf, + connection='sqlite:///:memory:') + + self.assertTrue(len(conf.database.items()) > 1) + self.assertEqual('sqlite:///:memory:', conf.database.connection) diff --git a/oslo_db/tests/sqlalchemy/test_sqlalchemy.py b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py new file mode 100644 index 0000000..84cbbcf --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py @@ -0,0 +1,655 @@ +# coding=utf-8 + +# Copyright (c) 2012 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Unit tests for SQLAlchemy specific code.""" + +import logging + +import fixtures +import mock +from oslo.config import cfg +from oslotest import base as oslo_test +import sqlalchemy +from sqlalchemy import Column, MetaData, Table +from sqlalchemy.engine import url +from sqlalchemy import Integer, String +from sqlalchemy.ext.declarative import declarative_base + +from oslo_db import exception +from oslo_db import options as db_options +from oslo_db.sqlalchemy import models +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import test_base + + +BASE = declarative_base() +_TABLE_NAME = '__tmp__test__tmp__' + +_REGEXP_TABLE_NAME = _TABLE_NAME + "regexp" + + +class RegexpTable(BASE, models.ModelBase): + __tablename__ = _REGEXP_TABLE_NAME + id = Column(Integer, primary_key=True) + bar = Column(String(255)) + + +class RegexpFilterTestCase(test_base.DbTestCase): + + def setUp(self): + super(RegexpFilterTestCase, self).setUp() + meta = MetaData() + meta.bind = self.engine + test_table = Table(_REGEXP_TABLE_NAME, meta, + Column('id', Integer, primary_key=True, + nullable=False), + Column('bar', String(255))) + test_table.create() + self.addCleanup(test_table.drop) + + def _test_regexp_filter(self, regexp, expected): + _session = self.sessionmaker() + with _session.begin(): + for i in ['10', '20', u'♥']: + tbl = RegexpTable() + tbl.update({'bar': i}) + tbl.save(session=_session) + + regexp_op = RegexpTable.bar.op('REGEXP')(regexp) + result = _session.query(RegexpTable).filter(regexp_op).all() + self.assertEqual([r.bar for r in result], expected) + + def test_regexp_filter(self): + self._test_regexp_filter('10', ['10']) + + def test_regexp_filter_nomatch(self): + self._test_regexp_filter('11', []) + + def test_regexp_filter_unicode(self): + self._test_regexp_filter(u'♥', [u'♥']) + + def test_regexp_filter_unicode_nomatch(self): + self._test_regexp_filter(u'♦', []) + + +class SQLiteSavepointTest(test_base.DbTestCase): + def setUp(self): + super(SQLiteSavepointTest, self).setUp() + meta = MetaData() + self.test_table = Table( + "test_table", meta, + Column('id', Integer, primary_key=True), + Column('data', String(10))) + self.test_table.create(self.engine) + self.addCleanup(self.test_table.drop, self.engine) + + def test_plain_transaction(self): + conn = self.engine.connect() + trans = conn.begin() + conn.execute( + self.test_table.insert(), + {'data': 'data 1'} + ) + self.assertEqual( + [(1, 'data 1')], + self.engine.execute( + self.test_table.select(). + order_by(self.test_table.c.id) + ).fetchall() + ) + trans.rollback() + self.assertEqual( + 0, + self.engine.scalar(self.test_table.count()) + ) + + def test_savepoint_middle(self): + with self.engine.begin() as conn: + conn.execute( + self.test_table.insert(), + {'data': 'data 1'} + ) + + savepoint = conn.begin_nested() + conn.execute( + self.test_table.insert(), + {'data': 'data 2'} + ) + savepoint.rollback() + + conn.execute( + self.test_table.insert(), + {'data': 'data 3'} + ) + + self.assertEqual( + [(1, 'data 1'), (2, 'data 3')], + self.engine.execute( + self.test_table.select(). + order_by(self.test_table.c.id) + ).fetchall() + ) + + def test_savepoint_beginning(self): + with self.engine.begin() as conn: + savepoint = conn.begin_nested() + conn.execute( + self.test_table.insert(), + {'data': 'data 1'} + ) + savepoint.rollback() + + conn.execute( + self.test_table.insert(), + {'data': 'data 2'} + ) + + self.assertEqual( + [(1, 'data 2')], + self.engine.execute( + self.test_table.select(). + order_by(self.test_table.c.id) + ).fetchall() + ) + + +class FakeDBAPIConnection(): + def cursor(self): + return FakeCursor() + + +class FakeCursor(): + def execute(self, sql): + pass + + +class FakeConnectionProxy(): + pass + + +class FakeConnectionRec(): + pass + + +class OperationalError(Exception): + pass + + +class ProgrammingError(Exception): + pass + + +class FakeDB2Engine(object): + + class Dialect(): + + def is_disconnect(self, e, *args): + expected_error = ('SQL30081N: DB2 Server connection is no longer ' + 'active') + return (str(e) == expected_error) + + dialect = Dialect() + name = 'ibm_db_sa' + + def dispose(self): + pass + + +class MySQLModeTestCase(test_base.MySQLOpportunisticTestCase): + + def __init__(self, *args, **kwargs): + super(MySQLModeTestCase, self).__init__(*args, **kwargs) + # By default, run in empty SQL mode. + # Subclasses override this with specific modes. + self.mysql_mode = '' + + def setUp(self): + super(MySQLModeTestCase, self).setUp() + + self.engine = session.create_engine(self.engine.url, + mysql_sql_mode=self.mysql_mode) + self.connection = self.engine.connect() + + meta = MetaData() + meta.bind = self.engine + self.test_table = Table(_TABLE_NAME + "mode", meta, + Column('id', Integer, primary_key=True), + Column('bar', String(255))) + self.test_table.create() + + self.addCleanup(self.test_table.drop) + self.addCleanup(self.connection.close) + + def _test_string_too_long(self, value): + with self.connection.begin(): + self.connection.execute(self.test_table.insert(), + bar=value) + result = self.connection.execute(self.test_table.select()) + return result.fetchone()['bar'] + + def test_string_too_long(self): + value = 'a' * 512 + # String is too long. + # With no SQL mode set, this gets truncated. + self.assertNotEqual(value, + self._test_string_too_long(value)) + + +class MySQLStrictAllTablesModeTestCase(MySQLModeTestCase): + "Test data integrity enforcement in MySQL STRICT_ALL_TABLES mode." + + def __init__(self, *args, **kwargs): + super(MySQLStrictAllTablesModeTestCase, self).__init__(*args, **kwargs) + self.mysql_mode = 'STRICT_ALL_TABLES' + + def test_string_too_long(self): + value = 'a' * 512 + # String is too long. + # With STRICT_ALL_TABLES or TRADITIONAL mode set, this is an error. + self.assertRaises(exception.DBError, + self._test_string_too_long, value) + + +class MySQLTraditionalModeTestCase(MySQLStrictAllTablesModeTestCase): + """Test data integrity enforcement in MySQL TRADITIONAL mode. + + Since TRADITIONAL includes STRICT_ALL_TABLES, this inherits all + STRICT_ALL_TABLES mode tests. + """ + + def __init__(self, *args, **kwargs): + super(MySQLTraditionalModeTestCase, self).__init__(*args, **kwargs) + self.mysql_mode = 'TRADITIONAL' + + +class EngineFacadeTestCase(oslo_test.BaseTestCase): + def setUp(self): + super(EngineFacadeTestCase, self).setUp() + + self.facade = session.EngineFacade('sqlite://') + + def test_get_engine(self): + eng1 = self.facade.get_engine() + eng2 = self.facade.get_engine() + + self.assertIs(eng1, eng2) + + def test_get_session(self): + ses1 = self.facade.get_session() + ses2 = self.facade.get_session() + + self.assertIsNot(ses1, ses2) + + def test_get_session_arguments_override_default_settings(self): + ses = self.facade.get_session(autocommit=False, expire_on_commit=True) + + self.assertFalse(ses.autocommit) + self.assertTrue(ses.expire_on_commit) + + @mock.patch('oslo_db.sqlalchemy.session.get_maker') + @mock.patch('oslo_db.sqlalchemy.session.create_engine') + def test_creation_from_config(self, create_engine, get_maker): + conf = cfg.ConfigOpts() + conf.register_opts(db_options.database_opts, group='database') + + overrides = { + 'connection': 'sqlite:///:memory:', + 'slave_connection': None, + 'connection_debug': 100, + 'max_pool_size': 10, + 'mysql_sql_mode': 'TRADITIONAL', + } + for optname, optvalue in overrides.items(): + conf.set_override(optname, optvalue, group='database') + + session.EngineFacade.from_config(conf, + autocommit=False, + expire_on_commit=True) + + create_engine.assert_called_once_with( + sql_connection='sqlite:///:memory:', + connection_debug=100, + max_pool_size=10, + mysql_sql_mode='TRADITIONAL', + sqlite_fk=False, + idle_timeout=mock.ANY, + retry_interval=mock.ANY, + max_retries=mock.ANY, + max_overflow=mock.ANY, + connection_trace=mock.ANY, + sqlite_synchronous=mock.ANY, + pool_timeout=mock.ANY, + thread_checkin=mock.ANY, + ) + get_maker.assert_called_once_with(engine=create_engine(), + autocommit=False, + expire_on_commit=True) + + def test_slave_connection(self): + paths = self.create_tempfiles([('db.master', ''), ('db.slave', '')], + ext='') + master_path = 'sqlite:///' + paths[0] + slave_path = 'sqlite:///' + paths[1] + + facade = session.EngineFacade( + sql_connection=master_path, + slave_connection=slave_path + ) + + master = facade.get_engine() + self.assertEqual(master_path, str(master.url)) + slave = facade.get_engine(use_slave=True) + self.assertEqual(slave_path, str(slave.url)) + + master_session = facade.get_session() + self.assertEqual(master_path, str(master_session.bind.url)) + slave_session = facade.get_session(use_slave=True) + self.assertEqual(slave_path, str(slave_session.bind.url)) + + def test_slave_connection_string_not_provided(self): + master_path = 'sqlite:///' + self.create_tempfiles( + [('db.master', '')], ext='')[0] + + facade = session.EngineFacade(sql_connection=master_path) + + master = facade.get_engine() + slave = facade.get_engine(use_slave=True) + self.assertIs(master, slave) + self.assertEqual(master_path, str(master.url)) + + master_session = facade.get_session() + self.assertEqual(master_path, str(master_session.bind.url)) + slave_session = facade.get_session(use_slave=True) + self.assertEqual(master_path, str(slave_session.bind.url)) + + +class SQLiteConnectTest(oslo_test.BaseTestCase): + + def _fixture(self, **kw): + return session.create_engine("sqlite://", **kw) + + def test_sqlite_fk_listener(self): + engine = self._fixture(sqlite_fk=True) + self.assertEqual( + engine.scalar("pragma foreign_keys"), + 1 + ) + + engine = self._fixture(sqlite_fk=False) + + self.assertEqual( + engine.scalar("pragma foreign_keys"), + 0 + ) + + def test_sqlite_synchronous_listener(self): + engine = self._fixture() + + # "The default setting is synchronous=FULL." (e.g. 2) + # http://www.sqlite.org/pragma.html#pragma_synchronous + self.assertEqual( + engine.scalar("pragma synchronous"), + 2 + ) + + engine = self._fixture(sqlite_synchronous=False) + + self.assertEqual( + engine.scalar("pragma synchronous"), + 0 + ) + + +class MysqlConnectTest(test_base.MySQLOpportunisticTestCase): + + def _fixture(self, sql_mode): + return session.create_engine(self.engine.url, mysql_sql_mode=sql_mode) + + def _assert_sql_mode(self, engine, sql_mode_present, sql_mode_non_present): + mode = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + self.assertTrue( + sql_mode_present in mode + ) + if sql_mode_non_present: + self.assertTrue( + sql_mode_non_present not in mode + ) + + def test_set_mode_traditional(self): + engine = self._fixture(sql_mode='TRADITIONAL') + self._assert_sql_mode(engine, "TRADITIONAL", "ANSI") + + def test_set_mode_ansi(self): + engine = self._fixture(sql_mode='ANSI') + self._assert_sql_mode(engine, "ANSI", "TRADITIONAL") + + def test_set_mode_no_mode(self): + # If _mysql_set_mode_callback is called with sql_mode=None, then + # the SQL mode is NOT set on the connection. + + expected = self.engine.execute( + "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + + engine = self._fixture(sql_mode=None) + self._assert_sql_mode(engine, expected, None) + + def test_fail_detect_mode(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in no row, then + # we get a log indicating can't detect the mode. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + + mysql_conn = self.engine.raw_connection() + self.addCleanup(mysql_conn.close) + mysql_conn.detach() + mysql_cursor = mysql_conn.cursor() + + def execute(statement, parameters=()): + if "SHOW VARIABLES LIKE 'sql_mode'" in statement: + statement = "SHOW VARIABLES LIKE 'i_dont_exist'" + return mysql_cursor.execute(statement, parameters) + + test_engine = sqlalchemy.create_engine(self.engine.url, + _initialize=False) + + with mock.patch.object( + test_engine.pool, '_creator', + mock.Mock( + return_value=mock.Mock( + cursor=mock.Mock( + return_value=mock.Mock( + execute=execute, + fetchone=mysql_cursor.fetchone, + fetchall=mysql_cursor.fetchall + ) + ) + ) + ) + ): + session._init_events.dispatch_on_drivername("mysql")(test_engine) + + test_engine.raw_connection() + self.assertIn('Unable to detect effective SQL mode', + log.output) + + def test_logs_real_mode(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value, then + # we get a log with the value. + + log = self.useFixture(fixtures.FakeLogger(level=logging.DEBUG)) + + engine = self._fixture(sql_mode='TRADITIONAL') + + actual_mode = engine.execute( + "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + + self.assertIn('MySQL server mode set to %s' % actual_mode, + log.output) + + def test_warning_when_not_traditional(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that doesn't + # include 'TRADITIONAL', then a warning is logged. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='ANSI') + + self.assertIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", + log.output) + + def test_no_warning_when_traditional(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes + # 'TRADITIONAL', then no warning is logged. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='TRADITIONAL') + + self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", + log.output) + + def test_no_warning_when_strict_all_tables(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes + # 'STRICT_ALL_TABLES', then no warning is logged. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='TRADITIONAL') + + self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", + log.output) + + +class CreateEngineTest(oslo_test.BaseTestCase): + """Test that dialect-specific arguments/ listeners are set up correctly. + + """ + + def setUp(self): + super(CreateEngineTest, self).setUp() + self.args = {'connect_args': {}} + + def test_queuepool_args(self): + session._init_connection_args( + url.make_url("mysql://u:p@host/test"), self.args, + max_pool_size=10, max_overflow=10) + self.assertEqual(self.args['pool_size'], 10) + self.assertEqual(self.args['max_overflow'], 10) + + def test_sqlite_memory_pool_args(self): + for _url in ("sqlite://", "sqlite:///:memory:"): + session._init_connection_args( + url.make_url(_url), self.args, + max_pool_size=10, max_overflow=10) + + # queuepool arguments are not peresnet + self.assertTrue( + 'pool_size' not in self.args) + self.assertTrue( + 'max_overflow' not in self.args) + + self.assertEqual(self.args['connect_args']['check_same_thread'], + False) + + # due to memory connection + self.assertTrue('poolclass' in self.args) + + def test_sqlite_file_pool_args(self): + session._init_connection_args( + url.make_url("sqlite:///somefile.db"), self.args, + max_pool_size=10, max_overflow=10) + + # queuepool arguments are not peresnet + self.assertTrue('pool_size' not in self.args) + self.assertTrue( + 'max_overflow' not in self.args) + + self.assertFalse(self.args['connect_args']) + + # NullPool is the default for file based connections, + # no need to specify this + self.assertTrue('poolclass' not in self.args) + + def test_mysql_connect_args_default(self): + session._init_connection_args( + url.make_url("mysql://u:p@host/test"), self.args) + self.assertEqual(self.args['connect_args'], + {'charset': 'utf8', 'use_unicode': 0}) + + def test_mysql_oursql_connect_args_default(self): + session._init_connection_args( + url.make_url("mysql+oursql://u:p@host/test"), self.args) + self.assertEqual(self.args['connect_args'], + {'charset': 'utf8', 'use_unicode': 0}) + + def test_mysql_mysqldb_connect_args_default(self): + session._init_connection_args( + url.make_url("mysql+mysqldb://u:p@host/test"), self.args) + self.assertEqual(self.args['connect_args'], + {'charset': 'utf8', 'use_unicode': 0}) + + def test_postgresql_connect_args_default(self): + session._init_connection_args( + url.make_url("postgresql://u:p@host/test"), self.args) + self.assertEqual(self.args['client_encoding'], 'utf8') + self.assertFalse(self.args['connect_args']) + + def test_mysqlconnector_raise_on_warnings_default(self): + session._init_connection_args( + url.make_url("mysql+mysqlconnector://u:p@host/test"), + self.args) + self.assertEqual(self.args['connect_args']['raise_on_warnings'], False) + + def test_mysqlconnector_raise_on_warnings_override(self): + session._init_connection_args( + url.make_url( + "mysql+mysqlconnector://u:p@host/test" + "?raise_on_warnings=true"), + self.args + ) + + self.assertFalse('raise_on_warnings' in self.args['connect_args']) + + def test_thread_checkin(self): + with mock.patch("sqlalchemy.event.listens_for"): + with mock.patch("sqlalchemy.event.listen") as listen_evt: + session._init_events.dispatch_on_drivername( + "sqlite")(mock.Mock()) + + self.assertEqual( + listen_evt.mock_calls[0][1][-1], + session._thread_yield + ) + + +# NOTE(dhellmann): This test no longer works as written. The code in +# oslo_db.sqlalchemy.session filters out lines from modules under +# oslo_db, and now this test is under oslo_db, so the test filename +# does not appear in the context for the error message. LP #1405376 + +# class PatchStacktraceTest(test_base.DbTestCase): + +# def test_trace(self): +# engine = self.engine +# session._add_trace_comments(engine) +# conn = engine.connect() +# with mock.patch.object(engine.dialect, "do_execute") as mock_exec: + +# conn.execute("select * from table") + +# call = mock_exec.mock_calls[0] + +# # we're the caller, see that we're in there +# self.assertIn("oslo_db/tests/sqlalchemy/test_sqlalchemy.py", +# call[1][1]) diff --git a/oslo_db/tests/sqlalchemy/test_utils.py b/oslo_db/tests/sqlalchemy/test_utils.py new file mode 100644 index 0000000..509cf48 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_utils.py @@ -0,0 +1,1090 @@ +# Copyright (c) 2013 Boris Pavlovic (boris@pavlovic.me). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import uuid + +import fixtures +import mock +from oslotest import base as test_base +from oslotest import moxstubout +import six +from six.moves.urllib import parse +import sqlalchemy +from sqlalchemy.dialects import mysql +from sqlalchemy import Boolean, Index, Integer, DateTime, String, SmallInteger +from sqlalchemy import MetaData, Table, Column, ForeignKey +from sqlalchemy.engine import reflection +from sqlalchemy.engine import url as sa_url +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import select +from sqlalchemy.types import UserDefinedType, NullType + +from oslo_db import exception +from oslo_db.sqlalchemy import models +from oslo_db.sqlalchemy import provision +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import test_base as db_test_base +from oslo_db.sqlalchemy import utils +from oslo_db.tests import utils as test_utils + + +SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.'))) + + +class TestSanitizeDbUrl(test_base.BaseTestCase): + + def test_url_with_cred(self): + db_url = 'myproto://johndoe:secret@localhost/myschema' + expected = 'myproto://****:****@localhost/myschema' + actual = utils.sanitize_db_url(db_url) + self.assertEqual(expected, actual) + + def test_url_with_no_cred(self): + db_url = 'sqlite:///mysqlitefile' + actual = utils.sanitize_db_url(db_url) + self.assertEqual(db_url, actual) + + +class CustomType(UserDefinedType): + """Dummy column type for testing unsupported types.""" + def get_col_spec(self): + return "CustomType" + + +class FakeModel(object): + def __init__(self, values): + self.values = values + + def __getattr__(self, name): + try: + value = self.values[name] + except KeyError: + raise AttributeError(name) + return value + + def __getitem__(self, key): + if key in self.values: + return self.values[key] + else: + raise NotImplementedError() + + def __repr__(self): + return '<FakeModel: %s>' % self.values + + +class TestPaginateQuery(test_base.BaseTestCase): + def setUp(self): + super(TestPaginateQuery, self).setUp() + mox_fixture = self.useFixture(moxstubout.MoxStubout()) + self.mox = mox_fixture.mox + self.query = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(sqlalchemy, 'asc') + self.mox.StubOutWithMock(sqlalchemy, 'desc') + self.marker = FakeModel({ + 'user_id': 'user', + 'project_id': 'p', + 'snapshot_id': 's', + }) + self.model = FakeModel({ + 'user_id': 'user', + 'project_id': 'project', + 'snapshot_id': 'snapshot', + }) + + def test_paginate_query_no_pagination_no_sort_dirs(self): + sqlalchemy.asc('user').AndReturn('asc_3') + self.query.order_by('asc_3').AndReturn(self.query) + sqlalchemy.asc('project').AndReturn('asc_2') + self.query.order_by('asc_2').AndReturn(self.query) + sqlalchemy.asc('snapshot').AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id', 'snapshot_id']) + + def test_paginate_query_no_pagination(self): + sqlalchemy.asc('user').AndReturn('asc') + self.query.order_by('asc').AndReturn(self.query) + sqlalchemy.desc('project').AndReturn('desc') + self.query.order_by('desc').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id'], + sort_dirs=['asc', 'desc']) + + def test_paginate_query_attribute_error(self): + sqlalchemy.asc('user').AndReturn('asc') + self.query.order_by('asc').AndReturn(self.query) + self.mox.ReplayAll() + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['user_id', 'non-existent key']) + + def test_paginate_query_assertion_error(self): + self.mox.ReplayAll() + self.assertRaises(AssertionError, + utils.paginate_query, self.query, + self.model, 5, ['user_id'], + marker=self.marker, + sort_dir='asc', sort_dirs=['asc']) + + def test_paginate_query_assertion_error_2(self): + self.mox.ReplayAll() + self.assertRaises(AssertionError, + utils.paginate_query, self.query, + self.model, 5, ['user_id'], + marker=self.marker, + sort_dir=None, sort_dirs=['asc', 'desk']) + + def test_paginate_query(self): + sqlalchemy.asc('user').AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + sqlalchemy.desc('project').AndReturn('desc_1') + self.query.order_by('desc_1').AndReturn(self.query) + self.mox.StubOutWithMock(sqlalchemy.sql, 'and_') + sqlalchemy.sql.and_(False).AndReturn('some_crit') + sqlalchemy.sql.and_(True, False).AndReturn('another_crit') + self.mox.StubOutWithMock(sqlalchemy.sql, 'or_') + sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f') + self.query.filter('some_f').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id'], + marker=self.marker, + sort_dirs=['asc', 'desc']) + + def test_paginate_query_value_error(self): + sqlalchemy.asc('user').AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + self.mox.ReplayAll() + self.assertRaises(ValueError, utils.paginate_query, + self.query, self.model, 5, ['user_id', 'project_id'], + marker=self.marker, sort_dirs=['asc', 'mixed']) + + +class TestMigrationUtils(db_test_base.DbTestCase): + + """Class for testing utils that are used in db migrations.""" + + def setUp(self): + super(TestMigrationUtils, self).setUp() + self.meta = MetaData(bind=self.engine) + self.conn = self.engine.connect() + self.addCleanup(self.meta.drop_all) + self.addCleanup(self.conn.close) + + def _populate_db_for_drop_duplicate_entries(self, engine, meta, + table_name): + values = [ + {'id': 11, 'a': 3, 'b': 10, 'c': 'abcdef'}, + {'id': 12, 'a': 5, 'b': 10, 'c': 'abcdef'}, + {'id': 13, 'a': 6, 'b': 10, 'c': 'abcdef'}, + {'id': 14, 'a': 7, 'b': 10, 'c': 'abcdef'}, + {'id': 21, 'a': 1, 'b': 20, 'c': 'aa'}, + {'id': 31, 'a': 1, 'b': 20, 'c': 'bb'}, + {'id': 41, 'a': 1, 'b': 30, 'c': 'aef'}, + {'id': 42, 'a': 2, 'b': 30, 'c': 'aef'}, + {'id': 43, 'a': 3, 'b': 30, 'c': 'aef'} + ] + + test_table = Table(table_name, meta, + Column('id', Integer, primary_key=True, + nullable=False), + Column('a', Integer), + Column('b', Integer), + Column('c', String(255)), + Column('deleted', Integer, default=0), + Column('deleted_at', DateTime), + Column('updated_at', DateTime)) + + test_table.create() + engine.execute(test_table.insert(), values) + return test_table, values + + def test_drop_old_duplicate_entries_from_table(self): + table_name = "__test_tmp_table__" + + test_table, values = self._populate_db_for_drop_duplicate_entries( + self.engine, self.meta, table_name) + utils.drop_old_duplicate_entries_from_table( + self.engine, table_name, False, 'b', 'c') + + uniq_values = set() + expected_ids = [] + for value in sorted(values, key=lambda x: x['id'], reverse=True): + uniq_value = (('b', value['b']), ('c', value['c'])) + if uniq_value in uniq_values: + continue + uniq_values.add(uniq_value) + expected_ids.append(value['id']) + + real_ids = [row[0] for row in + self.engine.execute(select([test_table.c.id])).fetchall()] + + self.assertEqual(len(real_ids), len(expected_ids)) + for id_ in expected_ids: + self.assertTrue(id_ in real_ids) + + def test_drop_dup_entries_in_file_conn(self): + table_name = "__test_tmp_table__" + tmp_db_file = self.create_tempfiles([['name', '']], ext='.sql')[0] + in_file_engine = session.EngineFacade( + 'sqlite:///%s' % tmp_db_file).get_engine() + meta = MetaData() + meta.bind = in_file_engine + test_table, values = self._populate_db_for_drop_duplicate_entries( + in_file_engine, meta, table_name) + utils.drop_old_duplicate_entries_from_table( + in_file_engine, table_name, False, 'b', 'c') + + def test_drop_old_duplicate_entries_from_table_soft_delete(self): + table_name = "__test_tmp_table__" + + table, values = self._populate_db_for_drop_duplicate_entries( + self.engine, self.meta, table_name) + utils.drop_old_duplicate_entries_from_table(self.engine, table_name, + True, 'b', 'c') + uniq_values = set() + expected_values = [] + soft_deleted_values = [] + + for value in sorted(values, key=lambda x: x['id'], reverse=True): + uniq_value = (('b', value['b']), ('c', value['c'])) + if uniq_value in uniq_values: + soft_deleted_values.append(value) + continue + uniq_values.add(uniq_value) + expected_values.append(value) + + base_select = table.select() + + rows_select = base_select.where(table.c.deleted != table.c.id) + row_ids = [row['id'] for row in + self.engine.execute(rows_select).fetchall()] + self.assertEqual(len(row_ids), len(expected_values)) + for value in expected_values: + self.assertTrue(value['id'] in row_ids) + + deleted_rows_select = base_select.where( + table.c.deleted == table.c.id) + deleted_rows_ids = [row['id'] for row in + self.engine.execute( + deleted_rows_select).fetchall()] + self.assertEqual(len(deleted_rows_ids), + len(values) - len(row_ids)) + for value in soft_deleted_values: + self.assertTrue(value['id'] in deleted_rows_ids) + + def test_change_deleted_column_type_does_not_drop_index(self): + table_name = 'abc' + + indexes = { + 'idx_a_deleted': ['a', 'deleted'], + 'idx_b_deleted': ['b', 'deleted'], + 'idx_a': ['a'] + } + + index_instances = [Index(name, *columns) + for name, columns in six.iteritems(indexes)] + + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('a', String(255)), + Column('b', String(255)), + Column('deleted', Boolean), + *index_instances) + table.create() + utils.change_deleted_column_type_to_id_type(self.engine, table_name) + utils.change_deleted_column_type_to_boolean(self.engine, table_name) + + insp = reflection.Inspector.from_engine(self.engine) + real_indexes = insp.get_indexes(table_name) + self.assertEqual(len(real_indexes), 3) + for index in real_indexes: + name = index['name'] + self.assertIn(name, indexes) + self.assertEqual(set(index['column_names']), + set(indexes[name])) + + def test_change_deleted_column_type_to_id_type_integer(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Boolean)) + table.create() + utils.change_deleted_column_type_to_id_type(self.engine, table_name) + + table = utils.get_table(self.engine, table_name) + self.assertTrue(isinstance(table.c.deleted.type, Integer)) + + def test_change_deleted_column_type_to_id_type_string(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', String(255), primary_key=True), + Column('deleted', Boolean)) + table.create() + utils.change_deleted_column_type_to_id_type(self.engine, table_name) + + table = utils.get_table(self.engine, table_name) + self.assertTrue(isinstance(table.c.deleted.type, String)) + + @db_test_base.backend_specific('sqlite') + def test_change_deleted_column_type_to_id_type_custom(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('foo', CustomType), + Column('deleted', Boolean)) + table.create() + + # reflection of custom types has been fixed upstream + if SA_VERSION < (0, 9, 0): + self.assertRaises(exception.ColumnError, + utils.change_deleted_column_type_to_id_type, + self.engine, table_name) + + fooColumn = Column('foo', CustomType()) + utils.change_deleted_column_type_to_id_type(self.engine, table_name, + foo=fooColumn) + + table = utils.get_table(self.engine, table_name) + # NOTE(boris-42): There is no way to check has foo type CustomType. + # but sqlalchemy will set it to NullType. This has + # been fixed upstream in recent SA versions + if SA_VERSION < (0, 9, 0): + self.assertTrue(isinstance(table.c.foo.type, NullType)) + self.assertTrue(isinstance(table.c.deleted.type, Integer)) + + def test_change_deleted_column_type_to_boolean(self): + expected_types = {'mysql': mysql.TINYINT, + 'ibm_db_sa': SmallInteger} + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Integer)) + table.create() + + utils.change_deleted_column_type_to_boolean(self.engine, table_name) + + table = utils.get_table(self.engine, table_name) + self.assertIsInstance(table.c.deleted.type, + expected_types.get(self.engine.name, Boolean)) + + def test_change_deleted_column_type_to_boolean_with_fc(self): + expected_types = {'mysql': mysql.TINYINT, + 'ibm_db_sa': SmallInteger} + table_name_1 = 'abc' + table_name_2 = 'bcd' + + table_1 = Table(table_name_1, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Integer)) + table_1.create() + + table_2 = Table(table_name_2, self.meta, + Column('id', Integer, primary_key=True), + Column('foreign_id', Integer, + ForeignKey('%s.id' % table_name_1)), + Column('deleted', Integer)) + table_2.create() + + utils.change_deleted_column_type_to_boolean(self.engine, table_name_2) + + table = utils.get_table(self.engine, table_name_2) + self.assertIsInstance(table.c.deleted.type, + expected_types.get(self.engine.name, Boolean)) + + @db_test_base.backend_specific('sqlite') + def test_change_deleted_column_type_to_boolean_type_custom(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('foo', CustomType), + Column('deleted', Integer)) + table.create() + + # reflection of custom types has been fixed upstream + if SA_VERSION < (0, 9, 0): + self.assertRaises(exception.ColumnError, + utils.change_deleted_column_type_to_boolean, + self.engine, table_name) + + fooColumn = Column('foo', CustomType()) + utils.change_deleted_column_type_to_boolean(self.engine, table_name, + foo=fooColumn) + + table = utils.get_table(self.engine, table_name) + # NOTE(boris-42): There is no way to check has foo type CustomType. + # but sqlalchemy will set it to NullType. This has + # been fixed upstream in recent SA versions + if SA_VERSION < (0, 9, 0): + self.assertTrue(isinstance(table.c.foo.type, NullType)) + self.assertTrue(isinstance(table.c.deleted.type, Boolean)) + + @db_test_base.backend_specific('sqlite') + def test_change_deleted_column_type_sqlite_drops_check_constraint(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Boolean)) + table.create() + + utils._change_deleted_column_type_to_id_type_sqlite(self.engine, + table_name) + table = Table(table_name, self.meta, autoload=True) + # NOTE(I159): if the CHECK constraint has been dropped (expected + # behavior), any integer value can be inserted, otherwise only 1 or 0. + self.engine.execute(table.insert({'deleted': 10})) + + def test_insert_from_select(self): + insert_table_name = "__test_insert_to_table__" + select_table_name = "__test_select_from_table__" + uuidstrs = [] + for unused in range(10): + uuidstrs.append(uuid.uuid4().hex) + insert_table = Table( + insert_table_name, self.meta, + Column('id', Integer, primary_key=True, + nullable=False, autoincrement=True), + Column('uuid', String(36), nullable=False)) + select_table = Table( + select_table_name, self.meta, + Column('id', Integer, primary_key=True, + nullable=False, autoincrement=True), + Column('uuid', String(36), nullable=False)) + + insert_table.create() + select_table.create() + # Add 10 rows to select_table + for uuidstr in uuidstrs: + ins_stmt = select_table.insert().values(uuid=uuidstr) + self.conn.execute(ins_stmt) + + # Select 4 rows in one chunk from select_table + column = select_table.c.id + query_insert = select([select_table], + select_table.c.id < 5).order_by(column) + insert_statement = utils.InsertFromSelect(insert_table, + query_insert) + result_insert = self.conn.execute(insert_statement) + # Verify we insert 4 rows + self.assertEqual(result_insert.rowcount, 4) + + query_all = select([insert_table]).where( + insert_table.c.uuid.in_(uuidstrs)) + rows = self.conn.execute(query_all).fetchall() + # Verify we really have 4 rows in insert_table + self.assertEqual(len(rows), 4) + + +class PostgesqlTestMigrations(TestMigrationUtils, + db_test_base.PostgreSQLOpportunisticTestCase): + + """Test migrations on PostgreSQL.""" + pass + + +class MySQLTestMigrations(TestMigrationUtils, + db_test_base.MySQLOpportunisticTestCase): + + """Test migrations on MySQL.""" + pass + + +class TestConnectionUtils(test_utils.BaseTestCase): + + def setUp(self): + super(TestConnectionUtils, self).setUp() + + self.full_credentials = {'backend': 'postgresql', + 'database': 'test', + 'user': 'dude', + 'passwd': 'pass'} + + self.connect_string = 'postgresql://dude:pass@localhost/test' + + def test_connect_string(self): + connect_string = utils.get_connect_string(**self.full_credentials) + self.assertEqual(connect_string, self.connect_string) + + def test_connect_string_sqlite(self): + sqlite_credentials = {'backend': 'sqlite', 'database': 'test.db'} + connect_string = utils.get_connect_string(**sqlite_credentials) + self.assertEqual(connect_string, 'sqlite:///test.db') + + def test_is_backend_avail(self): + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + fake_connection = self.mox.CreateMockAnything() + fake_connection.close() + sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection) + self.mox.ReplayAll() + + self.assertTrue(utils.is_backend_avail(**self.full_credentials)) + + def test_is_backend_unavail(self): + log = self.useFixture(fixtures.FakeLogger()) + err = OperationalError("Can't connect to database", None, None) + error_msg = "The postgresql backend is unavailable: %s\n" % err + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + sqlalchemy.engine.base.Engine.connect().AndRaise(err) + self.mox.ReplayAll() + self.assertFalse(utils.is_backend_avail(**self.full_credentials)) + self.assertEqual(error_msg, log.output) + + def test_ensure_backend_available(self): + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + fake_connection = self.mox.CreateMockAnything() + fake_connection.close() + sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection) + self.mox.ReplayAll() + + eng = provision.Backend._ensure_backend_available(self.connect_string) + self.assertIsInstance(eng, sqlalchemy.engine.base.Engine) + self.assertEqual(self.connect_string, str(eng.url)) + + def test_ensure_backend_available_no_connection_raises(self): + log = self.useFixture(fixtures.FakeLogger()) + err = OperationalError("Can't connect to database", None, None) + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + sqlalchemy.engine.base.Engine.connect().AndRaise(err) + self.mox.ReplayAll() + + exc = self.assertRaises( + exception.BackendNotAvailable, + provision.Backend._ensure_backend_available, self.connect_string + ) + self.assertEqual("Could not connect", str(exc)) + self.assertEqual( + "The postgresql backend is unavailable: %s" % err, + log.output.strip()) + + def test_ensure_backend_available_no_dbapi_raises(self): + log = self.useFixture(fixtures.FakeLogger()) + self.mox.StubOutWithMock(sqlalchemy, 'create_engine') + sqlalchemy.create_engine( + sa_url.make_url(self.connect_string)).AndRaise( + ImportError("Can't import DBAPI module foobar")) + self.mox.ReplayAll() + + exc = self.assertRaises( + exception.BackendNotAvailable, + provision.Backend._ensure_backend_available, self.connect_string + ) + self.assertEqual("No DBAPI installed", str(exc)) + self.assertEqual( + "The postgresql backend is unavailable: Can't import " + "DBAPI module foobar", log.output.strip()) + + def test_get_db_connection_info(self): + conn_pieces = parse.urlparse(self.connect_string) + self.assertEqual(utils.get_db_connection_info(conn_pieces), + ('dude', 'pass', 'test', 'localhost')) + + def test_connect_string_host(self): + self.full_credentials['host'] = 'myhost' + connect_string = utils.get_connect_string(**self.full_credentials) + self.assertEqual(connect_string, 'postgresql://dude:pass@myhost/test') + + +class MyModelSoftDeletedProjectId(declarative_base(), models.ModelBase, + models.SoftDeleteMixin): + __tablename__ = 'soft_deleted_project_id_test_model' + id = Column(Integer, primary_key=True) + project_id = Column(Integer) + + +class MyModel(declarative_base(), models.ModelBase): + __tablename__ = 'test_model' + id = Column(Integer, primary_key=True) + + +class MyModelSoftDeleted(declarative_base(), models.ModelBase, + models.SoftDeleteMixin): + __tablename__ = 'soft_deleted_test_model' + id = Column(Integer, primary_key=True) + + +class TestModelQuery(test_base.BaseTestCase): + + def setUp(self): + super(TestModelQuery, self).setUp() + + self.session = mock.MagicMock() + self.session.query.return_value = self.session.query + self.session.query.filter.return_value = self.session.query + + def test_wrong_model(self): + self.assertRaises(TypeError, utils.model_query, + FakeModel, session=self.session) + + def test_no_soft_deleted(self): + self.assertRaises(ValueError, utils.model_query, + MyModel, session=self.session, deleted=True) + + def test_deleted_false(self): + mock_query = utils.model_query( + MyModelSoftDeleted, session=self.session, deleted=False) + + deleted_filter = mock_query.filter.call_args[0][0] + self.assertEqual(str(deleted_filter), + 'soft_deleted_test_model.deleted = :deleted_1') + self.assertEqual(deleted_filter.right.value, + MyModelSoftDeleted.__mapper__.c.deleted.default.arg) + + def test_deleted_true(self): + mock_query = utils.model_query( + MyModelSoftDeleted, session=self.session, deleted=True) + + deleted_filter = mock_query.filter.call_args[0][0] + self.assertEqual(str(deleted_filter), + 'soft_deleted_test_model.deleted != :deleted_1') + self.assertEqual(deleted_filter.right.value, + MyModelSoftDeleted.__mapper__.c.deleted.default.arg) + + @mock.patch.object(utils, "_read_deleted_filter") + def test_no_deleted_value(self, _read_deleted_filter): + utils.model_query(MyModelSoftDeleted, session=self.session) + self.assertEqual(_read_deleted_filter.call_count, 0) + + def test_project_filter(self): + project_id = 10 + + mock_query = utils.model_query( + MyModelSoftDeletedProjectId, session=self.session, + project_only=True, project_id=project_id) + + deleted_filter = mock_query.filter.call_args[0][0] + self.assertEqual( + str(deleted_filter), + 'soft_deleted_project_id_test_model.project_id = :project_id_1') + self.assertEqual(deleted_filter.right.value, project_id) + + def test_project_filter_wrong_model(self): + self.assertRaises(ValueError, utils.model_query, + MyModelSoftDeleted, session=self.session, + project_id=10) + + def test_project_filter_allow_none(self): + mock_query = utils.model_query( + MyModelSoftDeletedProjectId, + session=self.session, project_id=(10, None)) + + self.assertEqual( + str(mock_query.filter.call_args[0][0]), + 'soft_deleted_project_id_test_model.project_id' + ' IN (:project_id_1, NULL)' + ) + + def test_model_query_common(self): + utils.model_query(MyModel, args=(MyModel.id,), session=self.session) + self.session.query.assert_called_with(MyModel.id) + + +class TestUtils(db_test_base.DbTestCase): + def setUp(self): + super(TestUtils, self).setUp() + meta = MetaData(bind=self.engine) + self.test_table = Table( + 'test_table', + meta, + Column('a', Integer), + Column('b', Integer) + ) + self.test_table.create() + self.addCleanup(meta.drop_all) + + def test_index_exists(self): + self.assertFalse(utils.index_exists(self.engine, 'test_table', + 'new_index')) + Index('new_index', self.test_table.c.a).create(self.engine) + self.assertTrue(utils.index_exists(self.engine, 'test_table', + 'new_index')) + + def test_add_index(self): + self.assertFalse(utils.index_exists(self.engine, 'test_table', + 'new_index')) + utils.add_index(self.engine, 'test_table', 'new_index', ('a',)) + self.assertTrue(utils.index_exists(self.engine, 'test_table', + 'new_index')) + + def test_add_existing_index(self): + Index('new_index', self.test_table.c.a).create(self.engine) + self.assertRaises(ValueError, utils.add_index, self.engine, + 'test_table', 'new_index', ('a',)) + + def test_drop_index(self): + Index('new_index', self.test_table.c.a).create(self.engine) + utils.drop_index(self.engine, 'test_table', 'new_index') + self.assertFalse(utils.index_exists(self.engine, 'test_table', + 'new_index')) + + def test_drop_unexisting_index(self): + self.assertRaises(ValueError, utils.drop_index, self.engine, + 'test_table', 'new_index') + + @mock.patch('oslo_db.sqlalchemy.utils.drop_index') + @mock.patch('oslo_db.sqlalchemy.utils.add_index') + def test_change_index_columns(self, add_index, drop_index): + utils.change_index_columns(self.engine, 'test_table', 'a_index', + ('a',)) + utils.drop_index.assert_called_once_with(self.engine, 'test_table', + 'a_index') + utils.add_index.assert_called_once_with(self.engine, 'test_table', + 'a_index', ('a',)) + + def test_column_exists(self): + for col in ['a', 'b']: + self.assertTrue(utils.column_exists(self.engine, 'test_table', + col)) + self.assertFalse(utils.column_exists(self.engine, 'test_table', + 'fake_column')) + + +class TestUtilsMysqlOpportunistically( + TestUtils, db_test_base.MySQLOpportunisticTestCase): + pass + + +class TestUtilsPostgresqlOpportunistically( + TestUtils, db_test_base.PostgreSQLOpportunisticTestCase): + pass + + +class TestDialectFunctionDispatcher(test_base.BaseTestCase): + def _single_fixture(self): + callable_fn = mock.Mock() + + dispatcher = orig = utils.dispatch_for_dialect("*")( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql_mysqldb) + dispatcher = dispatcher.dispatch_for("postgresql")( + callable_fn.postgresql) + + self.assertTrue(dispatcher is orig) + + return dispatcher, callable_fn + + def _multiple_fixture(self): + callable_fn = mock.Mock() + + for targ in [ + callable_fn.default, + callable_fn.sqlite, + callable_fn.mysql_mysqldb, + callable_fn.postgresql, + callable_fn.postgresql_psycopg2, + callable_fn.pyodbc + ]: + targ.return_value = None + + dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql_mysqldb) + dispatcher = dispatcher.dispatch_for("postgresql+*")( + callable_fn.postgresql) + dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")( + callable_fn.postgresql_psycopg2) + dispatcher = dispatcher.dispatch_for("*+pyodbc")( + callable_fn.pyodbc) + + self.assertTrue(dispatcher is orig) + + return dispatcher, callable_fn + + def test_single(self): + + dispatcher, callable_fn = self._single_fixture() + dispatcher("sqlite://", 1) + dispatcher("postgresql+psycopg2://u:p@h/t", 2) + dispatcher("mysql://u:p@h/t", 3) + dispatcher("mysql+mysqlconnector://u:p@h/t", 4) + + self.assertEqual( + [ + mock.call.sqlite('sqlite://', 1), + mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2), + mock.call.mysql_mysqldb("mysql://u:p@h/t", 3), + mock.call.default("mysql+mysqlconnector://u:p@h/t", 4) + ], + callable_fn.mock_calls) + + def test_single_kwarg(self): + dispatcher, callable_fn = self._single_fixture() + dispatcher("sqlite://", foo='bar') + dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y') + + self.assertEqual( + [ + mock.call.sqlite('sqlite://', foo='bar'), + mock.call.postgresql( + "postgresql+psycopg2://u:p@h/t", + 1, x='y'), + ], + callable_fn.mock_calls) + + def test_dispatch_on_target(self): + callable_fn = mock.Mock() + + @utils.dispatch_for_dialect("*") + def default_fn(url, x, y): + callable_fn.default(url, x, y) + + @default_fn.dispatch_for("sqlite") + def sqlite_fn(url, x, y): + callable_fn.sqlite(url, x, y) + default_fn.dispatch_on_drivername("*")(url, x, y) + + default_fn("sqlite://", 4, 5) + self.assertEqual( + [ + mock.call.sqlite("sqlite://", 4, 5), + mock.call.default("sqlite://", 4, 5) + ], + callable_fn.mock_calls + ) + + def test_single_no_dispatcher(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql) + exc = self.assertRaises( + ValueError, + dispatcher, "postgresql://s:t@localhost/test" + ) + self.assertEqual( + "No default function found for driver: 'postgresql+psycopg2'", + str(exc) + ) + + def test_multiple_no_dispatcher(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)( + callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql) + dispatcher("postgresql://s:t@localhost/test") + self.assertEqual( + [], callable_fn.mock_calls + ) + + def test_multiple_no_driver(self): + callable_fn = mock.Mock( + default=mock.Mock(return_value=None), + sqlite=mock.Mock(return_value=None) + ) + + dispatcher = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")( + callable_fn.sqlite) + + dispatcher.dispatch_on_drivername("sqlite")("foo") + self.assertEqual( + [mock.call.sqlite("foo"), mock.call.default("foo")], + callable_fn.mock_calls + ) + + def test_multiple_nesting(self): + callable_fn = mock.Mock( + default=mock.Mock(return_value=None), + mysql=mock.Mock(return_value=None) + ) + + dispatcher = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + + dispatcher = dispatcher.dispatch_for("mysql+mysqlconnector")( + dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql + ) + ) + + mysqldb_url = sqlalchemy.engine.url.make_url("mysql+mysqldb://") + mysqlconnector_url = sqlalchemy.engine.url.make_url( + "mysql+mysqlconnector://") + sqlite_url = sqlalchemy.engine.url.make_url("sqlite://") + + dispatcher(mysqldb_url, 1) + dispatcher(mysqlconnector_url, 2) + dispatcher(sqlite_url, 3) + + self.assertEqual( + [ + mock.call.mysql(mysqldb_url, 1), + mock.call.default(mysqldb_url, 1), + mock.call.mysql(mysqlconnector_url, 2), + mock.call.default(mysqlconnector_url, 2), + mock.call.default(sqlite_url, 3) + ], + callable_fn.mock_calls + ) + + def test_single_retval(self): + dispatcher, callable_fn = self._single_fixture() + callable_fn.mysql_mysqldb.return_value = 5 + + self.assertEqual( + dispatcher("mysql://u:p@h/t", 3), 5 + ) + + def test_engine(self): + eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db") + dispatcher, callable_fn = self._single_fixture() + + dispatcher(eng) + self.assertEqual( + [mock.call.sqlite(eng)], + callable_fn.mock_calls + ) + + def test_url(self): + url = sqlalchemy.engine.url.make_url( + "mysql+mysqldb://scott:tiger@localhost/test") + dispatcher, callable_fn = self._single_fixture() + + dispatcher(url, 15) + self.assertEqual( + [mock.call.mysql_mysqldb(url, 15)], + callable_fn.mock_calls + ) + + def test_invalid_target(self): + dispatcher, callable_fn = self._single_fixture() + + exc = self.assertRaises( + ValueError, + dispatcher, 20 + ) + self.assertEqual("Invalid target type: 20", str(exc)) + + def test_invalid_dispatch(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default) + + exc = self.assertRaises( + ValueError, + dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc + ) + self.assertEqual( + "Couldn't parse database[+driver]: '+pyodbc'", + str(exc) + ) + + def test_single_only_one_target(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + + exc = self.assertRaises( + TypeError, + dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2 + ) + self.assertEqual( + "Multiple functions for expression 'sqlite'", str(exc) + ) + + def test_multiple(self): + dispatcher, callable_fn = self._multiple_fixture() + + dispatcher("postgresql+pyodbc://", 1) + dispatcher("mysql://", 2) + dispatcher("ibm_db_sa+db2://", 3) + dispatcher("postgresql+psycopg2://", 4) + + # TODO(zzzeek): there is a deterministic order here, but we might + # want to tweak it, or maybe provide options. default first? + # most specific first? is *+pyodbc or postgresql+* more specific? + self.assertEqual( + [ + mock.call.postgresql('postgresql+pyodbc://', 1), + mock.call.pyodbc('postgresql+pyodbc://', 1), + mock.call.default('postgresql+pyodbc://', 1), + mock.call.mysql_mysqldb('mysql://', 2), + mock.call.default('mysql://', 2), + mock.call.default('ibm_db_sa+db2://', 3), + mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4), + mock.call.postgresql('postgresql+psycopg2://', 4), + mock.call.default('postgresql+psycopg2://', 4), + ], + callable_fn.mock_calls + ) + + def test_multiple_no_return_value(self): + dispatcher, callable_fn = self._multiple_fixture() + callable_fn.sqlite.return_value = 5 + + exc = self.assertRaises( + TypeError, + dispatcher, "sqlite://" + ) + self.assertEqual( + "Return value not allowed for multiple filtered function", + str(exc) + ) + + +class TestGetInnoDBTables(db_test_base.MySQLOpportunisticTestCase): + + def test_all_tables_use_innodb(self): + self.engine.execute("CREATE TABLE customers " + "(a INT, b CHAR (20), INDEX (a)) ENGINE=InnoDB") + self.assertEqual([], utils.get_non_innodb_tables(self.engine)) + + def test_all_tables_use_innodb_false(self): + self.engine.execute("CREATE TABLE employee " + "(i INT) ENGINE=MEMORY") + self.assertEqual(['employee'], + utils.get_non_innodb_tables(self.engine)) + + def test_skip_tables_use_default_value(self): + self.engine.execute("CREATE TABLE migrate_version " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables(self.engine)) + + def test_skip_tables_use_passed_value(self): + self.engine.execute("CREATE TABLE some_table " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables( + self.engine, skip_tables=('some_table',))) + + def test_skip_tables_use_empty_list(self): + self.engine.execute("CREATE TABLE some_table_3 " + "(i INT) ENGINE=MEMORY") + self.assertEqual(['some_table_3'], + utils.get_non_innodb_tables( + self.engine, skip_tables=())) + + def test_skip_tables_use_several_values(self): + self.engine.execute("CREATE TABLE some_table_1 " + "(i INT) ENGINE=MEMORY") + self.engine.execute("CREATE TABLE some_table_2 " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables( + self.engine, + skip_tables=('some_table_1', 'some_table_2'))) |