summaryrefslogtreecommitdiff
path: root/oslo_db/tests/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'oslo_db/tests/sqlalchemy')
-rw-r--r--oslo_db/tests/sqlalchemy/__init__.py0
-rw-r--r--oslo_db/tests/sqlalchemy/test_engine_connect.py68
-rw-r--r--oslo_db/tests/sqlalchemy/test_exc_filters.py833
-rw-r--r--oslo_db/tests/sqlalchemy/test_handle_error.py194
-rw-r--r--oslo_db/tests/sqlalchemy/test_migrate_cli.py222
-rw-r--r--oslo_db/tests/sqlalchemy/test_migration_common.py212
-rw-r--r--oslo_db/tests/sqlalchemy/test_migrations.py309
-rw-r--r--oslo_db/tests/sqlalchemy/test_models.py146
-rw-r--r--oslo_db/tests/sqlalchemy/test_options.py127
-rw-r--r--oslo_db/tests/sqlalchemy/test_sqlalchemy.py655
-rw-r--r--oslo_db/tests/sqlalchemy/test_utils.py1090
11 files changed, 3856 insertions, 0 deletions
diff --git a/oslo_db/tests/sqlalchemy/__init__.py b/oslo_db/tests/sqlalchemy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/__init__.py
diff --git a/oslo_db/tests/sqlalchemy/test_engine_connect.py b/oslo_db/tests/sqlalchemy/test_engine_connect.py
new file mode 100644
index 0000000..c75511f
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_engine_connect.py
@@ -0,0 +1,68 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Test the compatibility layer for the engine_connect() event.
+
+This event is added as of SQLAlchemy 0.9.0; oslo_db provides a compatibility
+layer for prior SQLAlchemy versions.
+
+"""
+
+import mock
+from oslotest import base as test_base
+import sqlalchemy as sqla
+
+from oslo_db.sqlalchemy.compat import engine_connect
+
+
+class EngineConnectTest(test_base.BaseTestCase):
+
+ def setUp(self):
+ super(EngineConnectTest, self).setUp()
+
+ self.engine = engine = sqla.create_engine("sqlite://")
+ self.addCleanup(engine.dispose)
+
+ def test_connect_event(self):
+ engine = self.engine
+
+ listener = mock.Mock()
+ engine_connect(engine, listener)
+
+ conn = engine.connect()
+ self.assertEqual(
+ listener.mock_calls,
+ [mock.call(conn, False)]
+ )
+
+ conn.close()
+
+ conn2 = engine.connect()
+ conn2.close()
+ self.assertEqual(
+ listener.mock_calls,
+ [mock.call(conn, False), mock.call(conn2, False)]
+ )
+
+ def test_branch(self):
+ engine = self.engine
+
+ listener = mock.Mock()
+ engine_connect(engine, listener)
+
+ conn = engine.connect()
+ branched = conn.connect()
+ conn.close()
+ self.assertEqual(
+ listener.mock_calls,
+ [mock.call(conn, False), mock.call(branched, True)]
+ )
diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py
new file mode 100644
index 0000000..157a183
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py
@@ -0,0 +1,833 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Test exception filters applied to engines."""
+
+import contextlib
+import itertools
+
+import mock
+from oslotest import base as oslo_test_base
+import six
+import sqlalchemy as sqla
+from sqlalchemy.orm import mapper
+
+from oslo_db import exception
+from oslo_db.sqlalchemy import compat
+from oslo_db.sqlalchemy import exc_filters
+from oslo_db.sqlalchemy import session
+from oslo_db.sqlalchemy import test_base
+from oslo_db.tests import utils as test_utils
+
+_TABLE_NAME = '__tmp__test__tmp__'
+
+
+class _SQLAExceptionMatcher(object):
+ def assertInnerException(
+ self,
+ matched, exception_type, message, sql=None, params=None):
+
+ exc = matched.inner_exception
+ self.assertSQLAException(exc, exception_type, message, sql, params)
+
+ def assertSQLAException(
+ self,
+ exc, exception_type, message, sql=None, params=None):
+ if isinstance(exception_type, (type, tuple)):
+ self.assertTrue(issubclass(exc.__class__, exception_type))
+ else:
+ self.assertEqual(exc.__class__.__name__, exception_type)
+ self.assertEqual(str(exc.orig).lower(), message.lower())
+ if sql is not None:
+ self.assertEqual(exc.statement, sql)
+ if params is not None:
+ self.assertEqual(exc.params, params)
+
+
+class TestsExceptionFilter(_SQLAExceptionMatcher, oslo_test_base.BaseTestCase):
+
+ class Error(Exception):
+ """DBAPI base error.
+
+ This exception and subclasses are used in a mock context
+ within these tests.
+
+ """
+
+ class OperationalError(Error):
+ pass
+
+ class InterfaceError(Error):
+ pass
+
+ class InternalError(Error):
+ pass
+
+ class IntegrityError(Error):
+ pass
+
+ class ProgrammingError(Error):
+ pass
+
+ class TransactionRollbackError(OperationalError):
+ """Special psycopg2-only error class.
+
+ SQLAlchemy has an issue with this per issue #3075:
+
+ https://bitbucket.org/zzzeek/sqlalchemy/issue/3075/
+
+ """
+
+ def setUp(self):
+ super(TestsExceptionFilter, self).setUp()
+ self.engine = sqla.create_engine("sqlite://")
+ exc_filters.register_engine(self.engine)
+ self.engine.connect().close() # initialize
+
+ @contextlib.contextmanager
+ def _dbapi_fixture(self, dialect_name):
+ engine = self.engine
+ with test_utils.nested(
+ mock.patch.object(engine.dialect.dbapi,
+ "Error",
+ self.Error),
+ mock.patch.object(engine.dialect, "name", dialect_name),
+ ):
+ yield
+
+ @contextlib.contextmanager
+ def _fixture(self, dialect_name, exception, is_disconnect=False):
+
+ def do_execute(self, cursor, statement, parameters, **kw):
+ raise exception
+
+ engine = self.engine
+
+ # ensure the engine has done its initial checks against the
+ # DB as we are going to be removing its ability to execute a
+ # statement
+ self.engine.connect().close()
+
+ with test_utils.nested(
+ mock.patch.object(engine.dialect, "do_execute", do_execute),
+ # replace the whole DBAPI rather than patching "Error"
+ # as some DBAPIs might not be patchable (?)
+ mock.patch.object(engine.dialect,
+ "dbapi",
+ mock.Mock(Error=self.Error)),
+ mock.patch.object(engine.dialect, "name", dialect_name),
+ mock.patch.object(engine.dialect,
+ "is_disconnect",
+ lambda *args: is_disconnect)
+ ):
+ yield
+
+ def _run_test(self, dialect_name, statement, raises, expected,
+ is_disconnect=False, params=()):
+ with self._fixture(dialect_name, raises, is_disconnect=is_disconnect):
+ with self.engine.connect() as conn:
+ matched = self.assertRaises(
+ expected, conn.execute, statement, params
+ )
+ return matched
+
+
+class TestFallthroughsAndNonDBAPI(TestsExceptionFilter):
+
+ def test_generic_dbapi(self):
+ matched = self._run_test(
+ "mysql", "select you_made_a_programming_error",
+ self.ProgrammingError("Error 123, you made a mistake"),
+ exception.DBError
+ )
+ self.assertInnerException(
+ matched,
+ "ProgrammingError",
+ "Error 123, you made a mistake",
+ 'select you_made_a_programming_error', ())
+
+ def test_generic_dbapi_disconnect(self):
+ matched = self._run_test(
+ "mysql", "select the_db_disconnected",
+ self.InterfaceError("connection lost"),
+ exception.DBConnectionError,
+ is_disconnect=True
+ )
+ self.assertInnerException(
+ matched,
+ "InterfaceError", "connection lost",
+ "select the_db_disconnected", ()),
+
+ def test_operational_dbapi_disconnect(self):
+ matched = self._run_test(
+ "mysql", "select the_db_disconnected",
+ self.OperationalError("connection lost"),
+ exception.DBConnectionError,
+ is_disconnect=True
+ )
+ self.assertInnerException(
+ matched,
+ "OperationalError", "connection lost",
+ "select the_db_disconnected", ()),
+
+ def test_operational_error_asis(self):
+ """Test operational errors.
+
+ test that SQLAlchemy OperationalErrors that aren't disconnects
+ are passed through without wrapping.
+ """
+
+ matched = self._run_test(
+ "mysql", "select some_operational_error",
+ self.OperationalError("some op error"),
+ sqla.exc.OperationalError
+ )
+ self.assertSQLAException(
+ matched,
+ "OperationalError", "some op error"
+ )
+
+ def test_unicode_encode(self):
+ # intentionally generate a UnicodeEncodeError, as its
+ # constructor is quite complicated and seems to be non-public
+ # or at least not documented anywhere.
+ uee_ref = None
+ try:
+ six.u('\u2435').encode('ascii')
+ except UnicodeEncodeError as uee:
+ # Python3.x added new scoping rules here (sadly)
+ # http://legacy.python.org/dev/peps/pep-3110/#semantic-changes
+ uee_ref = uee
+
+ self._run_test(
+ "postgresql", six.u('select \u2435'),
+ uee_ref,
+ exception.DBInvalidUnicodeParameter
+ )
+
+ def test_garden_variety(self):
+ matched = self._run_test(
+ "mysql", "select some_thing_that_breaks",
+ AttributeError("mysqldb has an attribute error"),
+ exception.DBError
+ )
+ self.assertEqual("mysqldb has an attribute error", matched.args[0])
+
+
+class TestReferenceErrorSQLite(_SQLAExceptionMatcher, test_base.DbTestCase):
+
+ def setUp(self):
+ super(TestReferenceErrorSQLite, self).setUp()
+
+ meta = sqla.MetaData(bind=self.engine)
+
+ table_1 = sqla.Table(
+ "resource_foo", meta,
+ sqla.Column("id", sqla.Integer, primary_key=True),
+ sqla.Column("foo", sqla.Integer),
+ mysql_engine='InnoDB',
+ mysql_charset='utf8',
+ )
+ table_1.create()
+
+ self.table_2 = sqla.Table(
+ "resource_entity", meta,
+ sqla.Column("id", sqla.Integer, primary_key=True),
+ sqla.Column("foo_id", sqla.Integer,
+ sqla.ForeignKey("resource_foo.id", name="foo_fkey")),
+ mysql_engine='InnoDB',
+ mysql_charset='utf8',
+ )
+ self.table_2.create()
+
+ def test_raise(self):
+ self.engine.execute("PRAGMA foreign_keys = ON;")
+
+ matched = self.assertRaises(
+ exception.DBReferenceError,
+ self.engine.execute,
+ self.table_2.insert({'id': 1, 'foo_id': 2})
+ )
+
+ self.assertInnerException(
+ matched,
+ "IntegrityError",
+ "FOREIGN KEY constraint failed",
+ 'INSERT INTO resource_entity (id, foo_id) VALUES (?, ?)',
+ (1, 2)
+ )
+
+ self.assertIsNone(matched.table)
+ self.assertIsNone(matched.constraint)
+ self.assertIsNone(matched.key)
+ self.assertIsNone(matched.key_table)
+
+
+class TestReferenceErrorPostgreSQL(TestReferenceErrorSQLite,
+ test_base.PostgreSQLOpportunisticTestCase):
+ def test_raise(self):
+ params = {'id': 1, 'foo_id': 2}
+ matched = self.assertRaises(
+ exception.DBReferenceError,
+ self.engine.execute,
+ self.table_2.insert(params)
+ )
+ self.assertInnerException(
+ matched,
+ "IntegrityError",
+ "insert or update on table \"resource_entity\" "
+ "violates foreign key constraint \"foo_fkey\"\nDETAIL: Key "
+ "(foo_id)=(2) is not present in table \"resource_foo\".\n",
+ "INSERT INTO resource_entity (id, foo_id) VALUES (%(id)s, "
+ "%(foo_id)s)",
+ params,
+ )
+
+ self.assertEqual("resource_entity", matched.table)
+ self.assertEqual("foo_fkey", matched.constraint)
+ self.assertEqual("foo_id", matched.key)
+ self.assertEqual("resource_foo", matched.key_table)
+
+
+class TestReferenceErrorMySQL(TestReferenceErrorSQLite,
+ test_base.MySQLOpportunisticTestCase):
+ def test_raise(self):
+ matched = self.assertRaises(
+ exception.DBReferenceError,
+ self.engine.execute,
+ self.table_2.insert({'id': 1, 'foo_id': 2})
+ )
+
+ self.assertInnerException(
+ matched,
+ "IntegrityError",
+ "(1452, 'Cannot add or update a child row: a "
+ "foreign key constraint fails (`{0}`.`resource_entity`, "
+ "CONSTRAINT `foo_fkey` FOREIGN KEY (`foo_id`) REFERENCES "
+ "`resource_foo` (`id`))')".format(self.engine.url.database),
+ "INSERT INTO resource_entity (id, foo_id) VALUES (%s, %s)",
+ (1, 2)
+ )
+ self.assertEqual("resource_entity", matched.table)
+ self.assertEqual("foo_fkey", matched.constraint)
+ self.assertEqual("foo_id", matched.key)
+ self.assertEqual("resource_foo", matched.key_table)
+
+ def test_raise_ansi_quotes(self):
+ self.engine.execute("SET SESSION sql_mode = 'ANSI';")
+ matched = self.assertRaises(
+ exception.DBReferenceError,
+ self.engine.execute,
+ self.table_2.insert({'id': 1, 'foo_id': 2})
+ )
+
+ self.assertInnerException(
+ matched,
+ "IntegrityError",
+ '(1452, \'Cannot add or update a child row: a '
+ 'foreign key constraint fails ("{0}"."resource_entity", '
+ 'CONSTRAINT "foo_fkey" FOREIGN KEY ("foo_id") REFERENCES '
+ '"resource_foo" ("id"))\')'.format(self.engine.url.database),
+ "INSERT INTO resource_entity (id, foo_id) VALUES (%s, %s)",
+ (1, 2)
+ )
+ self.assertEqual("resource_entity", matched.table)
+ self.assertEqual("foo_fkey", matched.constraint)
+ self.assertEqual("foo_id", matched.key)
+ self.assertEqual("resource_foo", matched.key_table)
+
+
+class TestDuplicate(TestsExceptionFilter):
+
+ def _run_dupe_constraint_test(self, dialect_name, message,
+ expected_columns=['a', 'b'],
+ expected_value=None):
+ matched = self._run_test(
+ dialect_name, "insert into table some_values",
+ self.IntegrityError(message),
+ exception.DBDuplicateEntry
+ )
+ self.assertEqual(expected_columns, matched.columns)
+ self.assertEqual(expected_value, matched.value)
+
+ def _not_dupe_constraint_test(self, dialect_name, statement, message,
+ expected_cls):
+ matched = self._run_test(
+ dialect_name, statement,
+ self.IntegrityError(message),
+ expected_cls
+ )
+ self.assertInnerException(
+ matched,
+ "IntegrityError",
+ str(self.IntegrityError(message)),
+ statement
+ )
+
+ def test_sqlite(self):
+ self._run_dupe_constraint_test("sqlite", 'column a, b are not unique')
+
+ def test_sqlite_3_7_16_or_3_8_2_and_higher(self):
+ self._run_dupe_constraint_test(
+ "sqlite",
+ 'UNIQUE constraint failed: tbl.a, tbl.b')
+
+ def test_sqlite_dupe_primary_key(self):
+ self._run_dupe_constraint_test(
+ "sqlite",
+ "PRIMARY KEY must be unique 'insert into t values(10)'",
+ expected_columns=[])
+
+ def test_mysql_mysqldb(self):
+ self._run_dupe_constraint_test(
+ "mysql",
+ '(1062, "Duplicate entry '
+ '\'2-3\' for key \'uniq_tbl0a0b\'")', expected_value='2-3')
+
+ def test_mysql_mysqlconnector(self):
+ self._run_dupe_constraint_test(
+ "mysql",
+ '1062 (23000): Duplicate entry '
+ '\'2-3\' for key \'uniq_tbl0a0b\'")', expected_value='2-3')
+
+ def test_postgresql(self):
+ self._run_dupe_constraint_test(
+ 'postgresql',
+ 'duplicate key value violates unique constraint'
+ '"uniq_tbl0a0b"'
+ '\nDETAIL: Key (a, b)=(2, 3) already exists.\n',
+ expected_value='2, 3'
+ )
+
+ def test_mysql_single(self):
+ self._run_dupe_constraint_test(
+ "mysql",
+ "1062 (23000): Duplicate entry '2' for key 'b'",
+ expected_columns=['b'],
+ expected_value='2'
+ )
+
+ def test_postgresql_single(self):
+ self._run_dupe_constraint_test(
+ 'postgresql',
+ 'duplicate key value violates unique constraint "uniq_tbl0b"\n'
+ 'DETAIL: Key (b)=(2) already exists.\n',
+ expected_columns=['b'],
+ expected_value='2'
+ )
+
+ def test_unsupported_backend(self):
+ self._not_dupe_constraint_test(
+ "nonexistent", "insert into table some_values",
+ self.IntegrityError("constraint violation"),
+ exception.DBError
+ )
+
+ def test_ibm_db_sa(self):
+ self._run_dupe_constraint_test(
+ 'ibm_db_sa',
+ 'SQL0803N One or more values in the INSERT statement, UPDATE '
+ 'statement, or foreign key update caused by a DELETE statement are'
+ ' not valid because the primary key, unique constraint or unique '
+ 'index identified by "2" constrains table "NOVA.KEY_PAIRS" from '
+ 'having duplicate values for the index key.',
+ expected_columns=[]
+ )
+
+ def test_ibm_db_sa_notadupe(self):
+ self._not_dupe_constraint_test(
+ 'ibm_db_sa',
+ 'ALTER TABLE instance_types ADD CONSTRAINT '
+ 'uniq_name_x_deleted UNIQUE (name, deleted)',
+ 'SQL0542N The column named "NAME" cannot be a column of a '
+ 'primary key or unique key constraint because it can contain null '
+ 'values.',
+ exception.DBError
+ )
+
+
+class TestDeadlock(TestsExceptionFilter):
+ statement = ('SELECT quota_usages.created_at AS '
+ 'quota_usages_created_at FROM quota_usages '
+ 'WHERE quota_usages.project_id = %(project_id_1)s '
+ 'AND quota_usages.deleted = %(deleted_1)s FOR UPDATE')
+ params = {
+ 'project_id_1': '8891d4478bbf48ad992f050cdf55e9b5',
+ 'deleted_1': 0
+ }
+
+ def _run_deadlock_detect_test(
+ self, dialect_name, message,
+ orig_exception_cls=TestsExceptionFilter.OperationalError):
+ self._run_test(
+ dialect_name, self.statement,
+ orig_exception_cls(message),
+ exception.DBDeadlock,
+ params=self.params
+ )
+
+ def _not_deadlock_test(
+ self, dialect_name, message,
+ expected_cls, expected_dbapi_cls,
+ orig_exception_cls=TestsExceptionFilter.OperationalError):
+
+ matched = self._run_test(
+ dialect_name, self.statement,
+ orig_exception_cls(message),
+ expected_cls,
+ params=self.params
+ )
+
+ if isinstance(matched, exception.DBError):
+ matched = matched.inner_exception
+
+ self.assertEqual(matched.orig.__class__.__name__, expected_dbapi_cls)
+
+ def test_mysql_mysqldb_deadlock(self):
+ self._run_deadlock_detect_test(
+ "mysql",
+ "(1213, 'Deadlock found when trying "
+ "to get lock; try restarting "
+ "transaction')"
+ )
+
+ def test_mysql_mysqldb_galera_deadlock(self):
+ self._run_deadlock_detect_test(
+ "mysql",
+ "(1205, 'Lock wait timeout exceeded; "
+ "try restarting transaction')"
+ )
+
+ def test_mysql_mysqlconnector_deadlock(self):
+ self._run_deadlock_detect_test(
+ "mysql",
+ "1213 (40001): Deadlock found when trying to get lock; try "
+ "restarting transaction",
+ orig_exception_cls=self.InternalError
+ )
+
+ def test_mysql_not_deadlock(self):
+ self._not_deadlock_test(
+ "mysql",
+ "(1005, 'some other error')",
+ sqla.exc.OperationalError, # note OperationalErrors are sent thru
+ "OperationalError",
+ )
+
+ def test_postgresql_deadlock(self):
+ self._run_deadlock_detect_test(
+ "postgresql",
+ "deadlock detected",
+ orig_exception_cls=self.TransactionRollbackError
+ )
+
+ def test_postgresql_not_deadlock(self):
+ self._not_deadlock_test(
+ "postgresql",
+ 'relation "fake" does not exist',
+ # can be either depending on #3075
+ (exception.DBError, sqla.exc.OperationalError),
+ "TransactionRollbackError",
+ orig_exception_cls=self.TransactionRollbackError
+ )
+
+ def test_ibm_db_sa_deadlock(self):
+ self._run_deadlock_detect_test(
+ "ibm_db_sa",
+ "SQL0911N The current transaction has been "
+ "rolled back because of a deadlock or timeout",
+ # use the lowest class b.c. I don't know what actual error
+ # class DB2's driver would raise for this
+ orig_exception_cls=self.Error
+ )
+
+ def test_ibm_db_sa_not_deadlock(self):
+ self._not_deadlock_test(
+ "ibm_db_sa",
+ "SQL01234B Some other error.",
+ exception.DBError,
+ "Error",
+ orig_exception_cls=self.Error
+ )
+
+
+class IntegrationTest(test_base.DbTestCase):
+ """Test an actual error-raising round trips against the database."""
+
+ def setUp(self):
+ super(IntegrationTest, self).setUp()
+ meta = sqla.MetaData()
+ self.test_table = sqla.Table(
+ _TABLE_NAME, meta,
+ sqla.Column('id', sqla.Integer,
+ primary_key=True, nullable=False),
+ sqla.Column('counter', sqla.Integer,
+ nullable=False),
+ sqla.UniqueConstraint('counter',
+ name='uniq_counter'))
+ self.test_table.create(self.engine)
+ self.addCleanup(self.test_table.drop, self.engine)
+
+ class Foo(object):
+ def __init__(self, counter):
+ self.counter = counter
+ mapper(Foo, self.test_table)
+ self.Foo = Foo
+
+ def test_flush_wrapper_duplicate_entry(self):
+ """test a duplicate entry exception."""
+
+ _session = self.sessionmaker()
+
+ with _session.begin():
+ foo = self.Foo(counter=1)
+ _session.add(foo)
+
+ _session.begin()
+ self.addCleanup(_session.rollback)
+ foo = self.Foo(counter=1)
+ _session.add(foo)
+ self.assertRaises(exception.DBDuplicateEntry, _session.flush)
+
+ def test_autoflush_wrapper_duplicate_entry(self):
+ """Test a duplicate entry exception raised.
+
+ test a duplicate entry exception raised via query.all()-> autoflush
+ """
+
+ _session = self.sessionmaker()
+
+ with _session.begin():
+ foo = self.Foo(counter=1)
+ _session.add(foo)
+
+ _session.begin()
+ self.addCleanup(_session.rollback)
+ foo = self.Foo(counter=1)
+ _session.add(foo)
+ self.assertTrue(_session.autoflush)
+ self.assertRaises(exception.DBDuplicateEntry,
+ _session.query(self.Foo).all)
+
+ def test_flush_wrapper_plain_integrity_error(self):
+ """test a plain integrity error wrapped as DBError."""
+
+ _session = self.sessionmaker()
+
+ with _session.begin():
+ foo = self.Foo(counter=1)
+ _session.add(foo)
+
+ _session.begin()
+ self.addCleanup(_session.rollback)
+ foo = self.Foo(counter=None)
+ _session.add(foo)
+ self.assertRaises(exception.DBError, _session.flush)
+
+ def test_flush_wrapper_operational_error(self):
+ """test an operational error from flush() raised as-is."""
+
+ _session = self.sessionmaker()
+
+ with _session.begin():
+ foo = self.Foo(counter=1)
+ _session.add(foo)
+
+ _session.begin()
+ self.addCleanup(_session.rollback)
+ foo = self.Foo(counter=sqla.func.imfake(123))
+ _session.add(foo)
+ matched = self.assertRaises(sqla.exc.OperationalError, _session.flush)
+ self.assertTrue("no such function" in str(matched))
+
+ def test_query_wrapper_operational_error(self):
+ """test an operational error from query.all() raised as-is."""
+
+ _session = self.sessionmaker()
+
+ _session.begin()
+ self.addCleanup(_session.rollback)
+ q = _session.query(self.Foo).filter(
+ self.Foo.counter == sqla.func.imfake(123))
+ matched = self.assertRaises(sqla.exc.OperationalError, q.all)
+ self.assertTrue("no such function" in str(matched))
+
+
+class TestDBDisconnected(TestsExceptionFilter):
+
+ @contextlib.contextmanager
+ def _fixture(
+ self,
+ dialect_name, exception, num_disconnects, is_disconnect=True):
+ engine = self.engine
+
+ compat.engine_connect(engine, session._connect_ping_listener)
+
+ real_do_execute = engine.dialect.do_execute
+ counter = itertools.count(1)
+
+ def fake_do_execute(self, *arg, **kw):
+ if next(counter) > num_disconnects:
+ return real_do_execute(self, *arg, **kw)
+ else:
+ raise exception
+
+ with self._dbapi_fixture(dialect_name):
+ with test_utils.nested(
+ mock.patch.object(engine.dialect,
+ "do_execute",
+ fake_do_execute),
+ mock.patch.object(engine.dialect,
+ "is_disconnect",
+ mock.Mock(return_value=is_disconnect))
+ ):
+ yield
+
+ def _test_ping_listener_disconnected(
+ self, dialect_name, exc_obj, is_disconnect=True):
+ with self._fixture(dialect_name, exc_obj, 1, is_disconnect):
+ conn = self.engine.connect()
+ with conn.begin():
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+ self.assertFalse(conn.closed)
+ self.assertFalse(conn.invalidated)
+ self.assertTrue(conn.in_transaction())
+
+ with self._fixture(dialect_name, exc_obj, 2, is_disconnect):
+ self.assertRaises(
+ exception.DBConnectionError,
+ self.engine.connect
+ )
+
+ # test implicit execution
+ with self._fixture(dialect_name, exc_obj, 1):
+ self.assertEqual(self.engine.scalar(sqla.select([1])), 1)
+
+ def test_mysql_ping_listener_disconnected(self):
+ for code in [2006, 2013, 2014, 2045, 2055]:
+ self._test_ping_listener_disconnected(
+ "mysql",
+ self.OperationalError('%d MySQL server has gone away' % code)
+ )
+
+ def test_mysql_ping_listener_disconnected_regex_only(self):
+ # intentionally set the is_disconnect flag to False
+ # in the "sqlalchemy" layer to make sure the regexp
+ # on _is_db_connection_error is catching
+ for code in [2002, 2003, 2006, 2013]:
+ self._test_ping_listener_disconnected(
+ "mysql",
+ self.OperationalError('%d MySQL server has gone away' % code),
+ is_disconnect=False
+ )
+
+ def test_db2_ping_listener_disconnected(self):
+ self._test_ping_listener_disconnected(
+ "ibm_db_sa",
+ self.OperationalError(
+ 'SQL30081N: DB2 Server connection is no longer active')
+ )
+
+ def test_db2_ping_listener_disconnected_regex_only(self):
+ self._test_ping_listener_disconnected(
+ "ibm_db_sa",
+ self.OperationalError(
+ 'SQL30081N: DB2 Server connection is no longer active'),
+ is_disconnect=False
+ )
+
+
+class TestDBConnectRetry(TestsExceptionFilter):
+
+ def _run_test(self, dialect_name, exception, count, retries):
+ counter = itertools.count()
+
+ engine = self.engine
+
+ # empty out the connection pool
+ engine.dispose()
+
+ connect_fn = engine.dialect.connect
+
+ def cant_connect(*arg, **kw):
+ if next(counter) < count:
+ raise exception
+ else:
+ return connect_fn(*arg, **kw)
+
+ with self._dbapi_fixture(dialect_name):
+ with mock.patch.object(engine.dialect, "connect", cant_connect):
+ return session._test_connection(engine, retries, .01)
+
+ def test_connect_no_retries(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, 0
+ )
+ # didnt connect because nothing was tried
+ self.assertIsNone(conn)
+
+ def test_connect_inifinite_retries(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, -1
+ )
+ # conn is good
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+
+ def test_connect_retry_past_failure(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, 3
+ )
+ # conn is good
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+
+ def test_connect_retry_not_candidate_exception(self):
+ self.assertRaises(
+ sqla.exc.OperationalError, # remember, we pass OperationalErrors
+ # through at the moment :)
+ self._run_test,
+ "mysql",
+ self.OperationalError("Error: (2015) I can't connect period"),
+ 2, 3
+ )
+
+ def test_connect_retry_stops_infailure(self):
+ self.assertRaises(
+ exception.DBConnectionError,
+ self._run_test,
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 3, 2
+ )
+
+ def test_db2_error_positive(self):
+ conn = self._run_test(
+ "ibm_db_sa",
+ self.OperationalError("blah blah -30081 blah blah"),
+ 2, -1
+ )
+ # conn is good
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+
+ def test_db2_error_negative(self):
+ self.assertRaises(
+ sqla.exc.OperationalError,
+ self._run_test,
+ "ibm_db_sa",
+ self.OperationalError("blah blah -39981 blah blah"),
+ 2, 3
+ )
diff --git a/oslo_db/tests/sqlalchemy/test_handle_error.py b/oslo_db/tests/sqlalchemy/test_handle_error.py
new file mode 100644
index 0000000..83322ef
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_handle_error.py
@@ -0,0 +1,194 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Test the compatibility layer for the handle_error() event.
+
+This event is added as of SQLAlchemy 0.9.7; oslo_db provides a compatibility
+layer for prior SQLAlchemy versions.
+
+"""
+
+import mock
+from oslotest import base as test_base
+import sqlalchemy as sqla
+from sqlalchemy.sql import column
+from sqlalchemy.sql import literal
+from sqlalchemy.sql import select
+from sqlalchemy.types import Integer
+from sqlalchemy.types import TypeDecorator
+
+from oslo_db.sqlalchemy.compat import handle_error
+from oslo_db.sqlalchemy.compat import utils
+from oslo_db.tests import utils as test_utils
+
+
+class MyException(Exception):
+ pass
+
+
+class ExceptionReraiseTest(test_base.BaseTestCase):
+
+ def setUp(self):
+ super(ExceptionReraiseTest, self).setUp()
+
+ self.engine = engine = sqla.create_engine("sqlite://")
+ self.addCleanup(engine.dispose)
+
+ def _fixture(self):
+ engine = self.engine
+
+ def err(context):
+ if "ERROR ONE" in str(context.statement):
+ raise MyException("my exception")
+ handle_error(engine, err)
+
+ def test_exception_event_altered(self):
+ self._fixture()
+
+ with mock.patch.object(self.engine.dialect.execution_ctx_cls,
+ "handle_dbapi_exception") as patched:
+
+ matchee = self.assertRaises(
+ MyException,
+ self.engine.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST"
+ )
+ self.assertEqual(1, patched.call_count)
+ self.assertEqual("my exception", matchee.args[0])
+
+ def test_exception_event_non_altered(self):
+ self._fixture()
+
+ with mock.patch.object(self.engine.dialect.execution_ctx_cls,
+ "handle_dbapi_exception") as patched:
+
+ self.assertRaises(
+ sqla.exc.DBAPIError,
+ self.engine.execute, "SELECT 'ERROR TWO' FROM I_DONT_EXIST"
+ )
+ self.assertEqual(1, patched.call_count)
+
+ def test_is_disconnect_not_interrupted(self):
+ self._fixture()
+
+ with test_utils.nested(
+ mock.patch.object(
+ self.engine.dialect.execution_ctx_cls,
+ "handle_dbapi_exception"
+ ),
+ mock.patch.object(
+ self.engine.dialect, "is_disconnect",
+ lambda *args: True
+ )
+ ) as (handle_dbapi_exception, is_disconnect):
+ with self.engine.connect() as conn:
+ self.assertRaises(
+ MyException,
+ conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST"
+ )
+ self.assertEqual(1, handle_dbapi_exception.call_count)
+ self.assertTrue(conn.invalidated)
+
+ def test_no_is_disconnect_not_invalidated(self):
+ self._fixture()
+
+ with test_utils.nested(
+ mock.patch.object(
+ self.engine.dialect.execution_ctx_cls,
+ "handle_dbapi_exception"
+ ),
+ mock.patch.object(
+ self.engine.dialect, "is_disconnect",
+ lambda *args: False
+ )
+ ) as (handle_dbapi_exception, is_disconnect):
+ with self.engine.connect() as conn:
+ self.assertRaises(
+ MyException,
+ conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST"
+ )
+ self.assertEqual(1, handle_dbapi_exception.call_count)
+ self.assertFalse(conn.invalidated)
+
+ def test_exception_event_ad_hoc_context(self):
+ engine = self.engine
+
+ nope = MyException("nope")
+
+ class MyType(TypeDecorator):
+ impl = Integer
+
+ def process_bind_param(self, value, dialect):
+ raise nope
+
+ listener = mock.Mock(return_value=None)
+ handle_error(engine, listener)
+
+ self.assertRaises(
+ sqla.exc.StatementError,
+ engine.execute,
+ select([1]).where(column('foo') == literal('bar', MyType))
+ )
+
+ ctx = listener.mock_calls[0][1][0]
+ self.assertTrue(ctx.statement.startswith("SELECT 1 "))
+ self.assertIs(ctx.is_disconnect, False)
+ self.assertIs(ctx.original_exception, nope)
+
+ def _test_alter_disconnect(self, orig_error, evt_value):
+ engine = self.engine
+
+ def evt(ctx):
+ ctx.is_disconnect = evt_value
+ handle_error(engine, evt)
+
+ # if we are under sqla 0.9.7, and we are expecting to take
+ # an "is disconnect" exception and make it not a disconnect,
+ # that isn't supported b.c. the wrapped handler has already
+ # done the invalidation.
+ expect_failure = not utils.sqla_097 and orig_error and not evt_value
+
+ with mock.patch.object(engine.dialect,
+ "is_disconnect",
+ mock.Mock(return_value=orig_error)):
+
+ with engine.connect() as c:
+ conn_rec = c.connection._connection_record
+ try:
+ c.execute("SELECT x FROM nonexistent")
+ assert False
+ except sqla.exc.StatementError as st:
+ self.assertFalse(expect_failure)
+
+ # check the exception's invalidation flag
+ self.assertEqual(st.connection_invalidated, evt_value)
+
+ # check the Connection object's invalidation flag
+ self.assertEqual(c.invalidated, evt_value)
+
+ # this is the ConnectionRecord object; it's invalidated
+ # when its .connection member is None
+ self.assertEqual(conn_rec.connection is None, evt_value)
+
+ except NotImplementedError as ne:
+ self.assertTrue(expect_failure)
+ self.assertEqual(
+ str(ne),
+ "Can't reset 'disconnect' status of exception once it "
+ "is set with this version of SQLAlchemy")
+
+ def test_alter_disconnect_to_true(self):
+ self._test_alter_disconnect(False, True)
+ self._test_alter_disconnect(True, True)
+
+ def test_alter_disconnect_to_false(self):
+ self._test_alter_disconnect(True, False)
+ self._test_alter_disconnect(False, False)
diff --git a/oslo_db/tests/sqlalchemy/test_migrate_cli.py b/oslo_db/tests/sqlalchemy/test_migrate_cli.py
new file mode 100644
index 0000000..c1ab53c
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_migrate_cli.py
@@ -0,0 +1,222 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import mock
+from oslotest import base as test_base
+
+from oslo_db.sqlalchemy.migration_cli import ext_alembic
+from oslo_db.sqlalchemy.migration_cli import ext_migrate
+from oslo_db.sqlalchemy.migration_cli import manager
+
+
+class MockWithCmp(mock.MagicMock):
+
+ order = 0
+
+ def __init__(self, *args, **kwargs):
+ super(MockWithCmp, self).__init__(*args, **kwargs)
+
+ self.__lt__ = lambda self, other: self.order < other.order
+
+
+@mock.patch(('oslo_db.sqlalchemy.migration_cli.'
+ 'ext_alembic.alembic.command'))
+class TestAlembicExtension(test_base.BaseTestCase):
+
+ def setUp(self):
+ self.migration_config = {'alembic_ini_path': '.',
+ 'db_url': 'sqlite://'}
+ self.alembic = ext_alembic.AlembicExtension(self.migration_config)
+ super(TestAlembicExtension, self).setUp()
+
+ def test_check_enabled_true(self, command):
+ """Check enabled returns True
+
+ Verifies that enabled returns True on non empty
+ alembic_ini_path conf variable
+ """
+ self.assertTrue(self.alembic.enabled)
+
+ def test_check_enabled_false(self, command):
+ """Check enabled returns False
+
+ Verifies enabled returns False on empty alembic_ini_path variable
+ """
+ self.migration_config['alembic_ini_path'] = ''
+ alembic = ext_alembic.AlembicExtension(self.migration_config)
+ self.assertFalse(alembic.enabled)
+
+ def test_upgrade_none(self, command):
+ self.alembic.upgrade(None)
+ command.upgrade.assert_called_once_with(self.alembic.config, 'head')
+
+ def test_upgrade_normal(self, command):
+ self.alembic.upgrade('131daa')
+ command.upgrade.assert_called_once_with(self.alembic.config, '131daa')
+
+ def test_downgrade_none(self, command):
+ self.alembic.downgrade(None)
+ command.downgrade.assert_called_once_with(self.alembic.config, 'base')
+
+ def test_downgrade_int(self, command):
+ self.alembic.downgrade(111)
+ command.downgrade.assert_called_once_with(self.alembic.config, 'base')
+
+ def test_downgrade_normal(self, command):
+ self.alembic.downgrade('131daa')
+ command.downgrade.assert_called_once_with(
+ self.alembic.config, '131daa')
+
+ def test_revision(self, command):
+ self.alembic.revision(message='test', autogenerate=True)
+ command.revision.assert_called_once_with(
+ self.alembic.config, message='test', autogenerate=True)
+
+ def test_stamp(self, command):
+ self.alembic.stamp('stamp')
+ command.stamp.assert_called_once_with(
+ self.alembic.config, revision='stamp')
+
+ def test_version(self, command):
+ version = self.alembic.version()
+ self.assertIsNone(version)
+
+
+@mock.patch(('oslo_db.sqlalchemy.migration_cli.'
+ 'ext_migrate.migration'))
+class TestMigrateExtension(test_base.BaseTestCase):
+
+ def setUp(self):
+ self.migration_config = {'migration_repo_path': '.',
+ 'db_url': 'sqlite://'}
+ self.migrate = ext_migrate.MigrateExtension(self.migration_config)
+ super(TestMigrateExtension, self).setUp()
+
+ def test_check_enabled_true(self, migration):
+ self.assertTrue(self.migrate.enabled)
+
+ def test_check_enabled_false(self, migration):
+ self.migration_config['migration_repo_path'] = ''
+ migrate = ext_migrate.MigrateExtension(self.migration_config)
+ self.assertFalse(migrate.enabled)
+
+ def test_upgrade_head(self, migration):
+ self.migrate.upgrade('head')
+ migration.db_sync.assert_called_once_with(
+ self.migrate.engine, self.migrate.repository, None, init_version=0)
+
+ def test_upgrade_normal(self, migration):
+ self.migrate.upgrade(111)
+ migration.db_sync.assert_called_once_with(
+ mock.ANY, self.migrate.repository, 111, init_version=0)
+
+ def test_downgrade_init_version_from_base(self, migration):
+ self.migrate.downgrade('base')
+ migration.db_sync.assert_called_once_with(
+ self.migrate.engine, self.migrate.repository, mock.ANY,
+ init_version=mock.ANY)
+
+ def test_downgrade_init_version_from_none(self, migration):
+ self.migrate.downgrade(None)
+ migration.db_sync.assert_called_once_with(
+ self.migrate.engine, self.migrate.repository, mock.ANY,
+ init_version=mock.ANY)
+
+ def test_downgrade_normal(self, migration):
+ self.migrate.downgrade(101)
+ migration.db_sync.assert_called_once_with(
+ self.migrate.engine, self.migrate.repository, 101, init_version=0)
+
+ def test_version(self, migration):
+ self.migrate.version()
+ migration.db_version.assert_called_once_with(
+ self.migrate.engine, self.migrate.repository, init_version=0)
+
+ def test_change_init_version(self, migration):
+ self.migration_config['init_version'] = 101
+ migrate = ext_migrate.MigrateExtension(self.migration_config)
+ migrate.downgrade(None)
+ migration.db_sync.assert_called_once_with(
+ migrate.engine,
+ self.migrate.repository,
+ self.migration_config['init_version'],
+ init_version=self.migration_config['init_version'])
+
+
+class TestMigrationManager(test_base.BaseTestCase):
+
+ def setUp(self):
+ self.migration_config = {'alembic_ini_path': '.',
+ 'migrate_repo_path': '.',
+ 'db_url': 'sqlite://'}
+ self.migration_manager = manager.MigrationManager(
+ self.migration_config)
+ self.ext = mock.Mock()
+ self.migration_manager._manager.extensions = [self.ext]
+ super(TestMigrationManager, self).setUp()
+
+ def test_manager_update(self):
+ self.migration_manager.upgrade('head')
+ self.ext.obj.upgrade.assert_called_once_with('head')
+
+ def test_manager_update_revision_none(self):
+ self.migration_manager.upgrade(None)
+ self.ext.obj.upgrade.assert_called_once_with(None)
+
+ def test_downgrade_normal_revision(self):
+ self.migration_manager.downgrade('111abcd')
+ self.ext.obj.downgrade.assert_called_once_with('111abcd')
+
+ def test_version(self):
+ self.migration_manager.version()
+ self.ext.obj.version.assert_called_once_with()
+
+ def test_revision_message_autogenerate(self):
+ self.migration_manager.revision('test', True)
+ self.ext.obj.revision.assert_called_once_with('test', True)
+
+ def test_revision_only_message(self):
+ self.migration_manager.revision('test', False)
+ self.ext.obj.revision.assert_called_once_with('test', False)
+
+ def test_stamp(self):
+ self.migration_manager.stamp('stamp')
+ self.ext.obj.stamp.assert_called_once_with('stamp')
+
+
+class TestMigrationRightOrder(test_base.BaseTestCase):
+
+ def setUp(self):
+ self.migration_config = {'alembic_ini_path': '.',
+ 'migrate_repo_path': '.',
+ 'db_url': 'sqlite://'}
+ self.migration_manager = manager.MigrationManager(
+ self.migration_config)
+ self.first_ext = MockWithCmp()
+ self.first_ext.obj.order = 1
+ self.first_ext.obj.upgrade.return_value = 100
+ self.first_ext.obj.downgrade.return_value = 0
+ self.second_ext = MockWithCmp()
+ self.second_ext.obj.order = 2
+ self.second_ext.obj.upgrade.return_value = 200
+ self.second_ext.obj.downgrade.return_value = 100
+ self.migration_manager._manager.extensions = [self.first_ext,
+ self.second_ext]
+ super(TestMigrationRightOrder, self).setUp()
+
+ def test_upgrade_right_order(self):
+ results = self.migration_manager.upgrade(None)
+ self.assertEqual(results, [100, 200])
+
+ def test_downgrade_right_order(self):
+ results = self.migration_manager.downgrade(None)
+ self.assertEqual(results, [100, 0])
diff --git a/oslo_db/tests/sqlalchemy/test_migration_common.py b/oslo_db/tests/sqlalchemy/test_migration_common.py
new file mode 100644
index 0000000..95efab1
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_migration_common.py
@@ -0,0 +1,212 @@
+# Copyright 2013 Mirantis Inc.
+# All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import tempfile
+
+from migrate import exceptions as migrate_exception
+from migrate.versioning import api as versioning_api
+import mock
+import sqlalchemy
+
+from oslo_db import exception as db_exception
+from oslo_db.sqlalchemy import migration
+from oslo_db.sqlalchemy import test_base
+from oslo_db.tests import utils as test_utils
+
+
+class TestMigrationCommon(test_base.DbTestCase):
+ def setUp(self):
+ super(TestMigrationCommon, self).setUp()
+
+ migration._REPOSITORY = None
+ self.path = tempfile.mkdtemp('test_migration')
+ self.path1 = tempfile.mkdtemp('test_migration')
+ self.return_value = '/home/openstack/migrations'
+ self.return_value1 = '/home/extension/migrations'
+ self.init_version = 1
+ self.test_version = 123
+
+ self.patcher_repo = mock.patch.object(migration, 'Repository')
+ self.repository = self.patcher_repo.start()
+ self.repository.side_effect = [self.return_value, self.return_value1]
+
+ self.mock_api_db = mock.patch.object(versioning_api, 'db_version')
+ self.mock_api_db_version = self.mock_api_db.start()
+ self.mock_api_db_version.return_value = self.test_version
+
+ def tearDown(self):
+ os.rmdir(self.path)
+ self.mock_api_db.stop()
+ self.patcher_repo.stop()
+ super(TestMigrationCommon, self).tearDown()
+
+ def test_find_migrate_repo_path_not_found(self):
+ self.assertRaises(
+ db_exception.DbMigrationError,
+ migration._find_migrate_repo,
+ "/foo/bar/",
+ )
+ self.assertIsNone(migration._REPOSITORY)
+
+ def test_find_migrate_repo_called_once(self):
+ my_repository = migration._find_migrate_repo(self.path)
+ self.repository.assert_called_once_with(self.path)
+ self.assertEqual(my_repository, self.return_value)
+
+ def test_find_migrate_repo_called_few_times(self):
+ repo1 = migration._find_migrate_repo(self.path)
+ repo2 = migration._find_migrate_repo(self.path1)
+ self.assertNotEqual(repo1, repo2)
+
+ def test_db_version_control(self):
+ with test_utils.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(versioning_api, 'version_control'),
+ ) as (mock_find_repo, mock_version_control):
+ mock_find_repo.return_value = self.return_value
+
+ version = migration.db_version_control(
+ self.engine, self.path, self.test_version)
+
+ self.assertEqual(version, self.test_version)
+ mock_version_control.assert_called_once_with(
+ self.engine, self.return_value, self.test_version)
+
+ def test_db_version_return(self):
+ ret_val = migration.db_version(self.engine, self.path,
+ self.init_version)
+ self.assertEqual(ret_val, self.test_version)
+
+ def test_db_version_raise_not_controlled_error_first(self):
+ with mock.patch.object(migration, 'db_version_control') as mock_ver:
+
+ self.mock_api_db_version.side_effect = [
+ migrate_exception.DatabaseNotControlledError('oups'),
+ self.test_version]
+
+ ret_val = migration.db_version(self.engine, self.path,
+ self.init_version)
+ self.assertEqual(ret_val, self.test_version)
+ mock_ver.assert_called_once_with(self.engine, self.path,
+ version=self.init_version)
+
+ def test_db_version_raise_not_controlled_error_tables(self):
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = \
+ migrate_exception.DatabaseNotControlledError('oups')
+ my_meta = mock.MagicMock()
+ my_meta.tables = {'a': 1, 'b': 2}
+ mock_meta.return_value = my_meta
+
+ self.assertRaises(
+ db_exception.DbMigrationError, migration.db_version,
+ self.engine, self.path, self.init_version)
+
+ @mock.patch.object(versioning_api, 'version_control')
+ def test_db_version_raise_not_controlled_error_no_tables(self, mock_vc):
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = (
+ migrate_exception.DatabaseNotControlledError('oups'),
+ self.init_version)
+ my_meta = mock.MagicMock()
+ my_meta.tables = {}
+ mock_meta.return_value = my_meta
+ migration.db_version(self.engine, self.path, self.init_version)
+
+ mock_vc.assert_called_once_with(self.engine, self.return_value1,
+ self.init_version)
+
+ def test_db_sync_wrong_version(self):
+ self.assertRaises(db_exception.DbMigrationError,
+ migration.db_sync, self.engine, self.path, 'foo')
+
+ def test_db_sync_upgrade(self):
+ init_ver = 55
+ with test_utils.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(versioning_api, 'upgrade')
+ ) as (mock_find_repo, mock_upgrade):
+
+ mock_find_repo.return_value = self.return_value
+ self.mock_api_db_version.return_value = self.test_version - 1
+
+ migration.db_sync(self.engine, self.path, self.test_version,
+ init_ver)
+
+ mock_upgrade.assert_called_once_with(
+ self.engine, self.return_value, self.test_version)
+
+ def test_db_sync_downgrade(self):
+ with test_utils.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(versioning_api, 'downgrade')
+ ) as (mock_find_repo, mock_downgrade):
+
+ mock_find_repo.return_value = self.return_value
+ self.mock_api_db_version.return_value = self.test_version + 1
+
+ migration.db_sync(self.engine, self.path, self.test_version)
+
+ mock_downgrade.assert_called_once_with(
+ self.engine, self.return_value, self.test_version)
+
+ def test_db_sync_sanity_called(self):
+ with test_utils.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(migration, '_db_schema_sanity_check'),
+ mock.patch.object(versioning_api, 'downgrade')
+ ) as (mock_find_repo, mock_sanity, mock_downgrade):
+
+ mock_find_repo.return_value = self.return_value
+ migration.db_sync(self.engine, self.path, self.test_version)
+
+ mock_sanity.assert_called_once_with(self.engine)
+
+ def test_db_sync_sanity_skipped(self):
+ with test_utils.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(migration, '_db_schema_sanity_check'),
+ mock.patch.object(versioning_api, 'downgrade')
+ ) as (mock_find_repo, mock_sanity, mock_downgrade):
+
+ mock_find_repo.return_value = self.return_value
+ migration.db_sync(self.engine, self.path, self.test_version,
+ sanity_check=False)
+
+ self.assertFalse(mock_sanity.called)
+
+ def test_db_sanity_table_not_utf8(self):
+ with mock.patch.object(self, 'engine') as mock_eng:
+ type(mock_eng).name = mock.PropertyMock(return_value='mysql')
+ mock_eng.execute.return_value = [['table_A', 'latin1'],
+ ['table_B', 'latin1']]
+
+ self.assertRaises(ValueError, migration._db_schema_sanity_check,
+ mock_eng)
+
+ def test_db_sanity_table_not_utf8_exclude_migrate_tables(self):
+ with mock.patch.object(self, 'engine') as mock_eng:
+ type(mock_eng).name = mock.PropertyMock(return_value='mysql')
+ # NOTE(morganfainberg): Check both lower and upper case versions
+ # of the migration table names (validate case insensitivity in
+ # the sanity check.
+ mock_eng.execute.return_value = [['migrate_version', 'latin1'],
+ ['alembic_version', 'latin1'],
+ ['MIGRATE_VERSION', 'latin1'],
+ ['ALEMBIC_VERSION', 'latin1']]
+
+ migration._db_schema_sanity_check(mock_eng)
diff --git a/oslo_db/tests/sqlalchemy/test_migrations.py b/oslo_db/tests/sqlalchemy/test_migrations.py
new file mode 100644
index 0000000..a372d8b
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_migrations.py
@@ -0,0 +1,309 @@
+# Copyright 2010-2011 OpenStack Foundation
+# Copyright 2012-2013 IBM Corp.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import fixtures
+import mock
+from oslotest import base as test
+import six
+import sqlalchemy as sa
+import sqlalchemy.ext.declarative as sa_decl
+
+from oslo_db import exception as exc
+from oslo_db.sqlalchemy import test_base
+from oslo_db.sqlalchemy import test_migrations as migrate
+
+
+class TestWalkVersions(test.BaseTestCase, migrate.WalkVersionsMixin):
+ migration_api = mock.MagicMock()
+ REPOSITORY = mock.MagicMock()
+ engine = mock.MagicMock()
+ INIT_VERSION = 4
+
+ @property
+ def migrate_engine(self):
+ return self.engine
+
+ def test_migrate_up(self):
+ self.migration_api.db_version.return_value = 141
+
+ self.migrate_up(141)
+
+ self.migration_api.upgrade.assert_called_with(
+ self.engine, self.REPOSITORY, 141)
+ self.migration_api.db_version.assert_called_with(
+ self.engine, self.REPOSITORY)
+
+ def test_migrate_up_fail(self):
+ version = 141
+ self.migration_api.db_version.return_value = version
+ expected_output = (u"Failed to migrate to version %(version)s on "
+ "engine %(engine)s\n" %
+ {'version': version, 'engine': self.engine})
+
+ with mock.patch.object(self.migration_api,
+ 'upgrade',
+ side_effect=exc.DbMigrationError):
+ log = self.useFixture(fixtures.FakeLogger())
+ self.assertRaises(exc.DbMigrationError, self.migrate_up, version)
+ self.assertEqual(expected_output, log.output)
+
+ def test_migrate_up_with_data(self):
+ test_value = {"a": 1, "b": 2}
+ self.migration_api.db_version.return_value = 141
+ self._pre_upgrade_141 = mock.MagicMock()
+ self._pre_upgrade_141.return_value = test_value
+ self._check_141 = mock.MagicMock()
+
+ self.migrate_up(141, True)
+
+ self._pre_upgrade_141.assert_called_with(self.engine)
+ self._check_141.assert_called_with(self.engine, test_value)
+
+ def test_migrate_down(self):
+ self.migration_api.db_version.return_value = 42
+
+ self.assertTrue(self.migrate_down(42))
+ self.migration_api.db_version.assert_called_with(
+ self.engine, self.REPOSITORY)
+
+ def test_migrate_down_not_implemented(self):
+ with mock.patch.object(self.migration_api,
+ 'downgrade',
+ side_effect=NotImplementedError):
+ self.assertFalse(self.migrate_down(self.engine, 42))
+
+ def test_migrate_down_with_data(self):
+ self._post_downgrade_043 = mock.MagicMock()
+ self.migration_api.db_version.return_value = 42
+
+ self.migrate_down(42, True)
+
+ self._post_downgrade_043.assert_called_with(self.engine)
+
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up')
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down')
+ def test_walk_versions_all_default(self, migrate_up, migrate_down):
+ self.REPOSITORY.latest = 20
+ self.migration_api.db_version.return_value = self.INIT_VERSION
+
+ self.walk_versions()
+
+ self.migration_api.version_control.assert_called_with(
+ self.engine, self.REPOSITORY, self.INIT_VERSION)
+ self.migration_api.db_version.assert_called_with(
+ self.engine, self.REPOSITORY)
+
+ versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
+ upgraded = [mock.call(v, with_data=True)
+ for v in versions]
+ self.assertEqual(self.migrate_up.call_args_list, upgraded)
+
+ downgraded = [mock.call(v - 1) for v in reversed(versions)]
+ self.assertEqual(self.migrate_down.call_args_list, downgraded)
+
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up')
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down')
+ def test_walk_versions_all_true(self, migrate_up, migrate_down):
+ self.REPOSITORY.latest = 20
+ self.migration_api.db_version.return_value = self.INIT_VERSION
+
+ self.walk_versions(snake_walk=True, downgrade=True)
+
+ versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
+ upgraded = []
+ for v in versions:
+ upgraded.append(mock.call(v, with_data=True))
+ upgraded.append(mock.call(v))
+ upgraded.extend([mock.call(v) for v in reversed(versions)])
+ self.assertEqual(upgraded, self.migrate_up.call_args_list)
+
+ downgraded_1 = [mock.call(v - 1, with_data=True) for v in versions]
+ downgraded_2 = []
+ for v in reversed(versions):
+ downgraded_2.append(mock.call(v - 1))
+ downgraded_2.append(mock.call(v - 1))
+ downgraded = downgraded_1 + downgraded_2
+ self.assertEqual(self.migrate_down.call_args_list, downgraded)
+
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up')
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down')
+ def test_walk_versions_true_false(self, migrate_up, migrate_down):
+ self.REPOSITORY.latest = 20
+ self.migration_api.db_version.return_value = self.INIT_VERSION
+
+ self.walk_versions(snake_walk=True, downgrade=False)
+
+ versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
+
+ upgraded = []
+ for v in versions:
+ upgraded.append(mock.call(v, with_data=True))
+ upgraded.append(mock.call(v))
+ self.assertEqual(upgraded, self.migrate_up.call_args_list)
+
+ downgraded = [mock.call(v - 1, with_data=True) for v in versions]
+ self.assertEqual(self.migrate_down.call_args_list, downgraded)
+
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up')
+ @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down')
+ def test_walk_versions_all_false(self, migrate_up, migrate_down):
+ self.REPOSITORY.latest = 20
+ self.migration_api.db_version.return_value = self.INIT_VERSION
+
+ self.walk_versions(snake_walk=False, downgrade=False)
+
+ versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
+
+ upgraded = [mock.call(v, with_data=True) for v in versions]
+ self.assertEqual(upgraded, self.migrate_up.call_args_list)
+
+
+class ModelsMigrationSyncMixin(test.BaseTestCase):
+
+ def setUp(self):
+ super(ModelsMigrationSyncMixin, self).setUp()
+
+ self.metadata = sa.MetaData()
+ self.metadata_migrations = sa.MetaData()
+
+ sa.Table(
+ 'testtbl', self.metadata_migrations,
+ sa.Column('id', sa.Integer, primary_key=True),
+ sa.Column('spam', sa.String(10), nullable=False),
+ sa.Column('eggs', sa.DateTime),
+ sa.Column('foo', sa.Boolean,
+ server_default=sa.sql.expression.true()),
+ sa.Column('bool_wo_default', sa.Boolean),
+ sa.Column('bar', sa.Numeric(10, 5)),
+ sa.Column('defaulttest', sa.Integer, server_default='5'),
+ sa.Column('defaulttest2', sa.String(8), server_default=''),
+ sa.Column('defaulttest3', sa.String(5), server_default="test"),
+ sa.Column('defaulttest4', sa.Enum('first', 'second',
+ name='testenum'),
+ server_default="first"),
+ sa.Column('fk_check', sa.String(36), nullable=False),
+ sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'),
+ )
+
+ BASE = sa_decl.declarative_base(metadata=self.metadata)
+
+ class TestModel(BASE):
+ __tablename__ = 'testtbl'
+ __table_args__ = (
+ sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'),
+ )
+
+ id = sa.Column('id', sa.Integer, primary_key=True)
+ spam = sa.Column('spam', sa.String(10), nullable=False)
+ eggs = sa.Column('eggs', sa.DateTime)
+ foo = sa.Column('foo', sa.Boolean,
+ server_default=sa.sql.expression.true())
+ fk_check = sa.Column('fk_check', sa.String(36), nullable=False)
+ bool_wo_default = sa.Column('bool_wo_default', sa.Boolean)
+ defaulttest = sa.Column('defaulttest',
+ sa.Integer, server_default='5')
+ defaulttest2 = sa.Column('defaulttest2', sa.String(8),
+ server_default='')
+ defaulttest3 = sa.Column('defaulttest3', sa.String(5),
+ server_default="test")
+ defaulttest4 = sa.Column('defaulttest4', sa.Enum('first', 'second',
+ name='testenum'),
+ server_default="first")
+ bar = sa.Column('bar', sa.Numeric(10, 5))
+
+ class ModelThatShouldNotBeCompared(BASE):
+ __tablename__ = 'testtbl2'
+
+ id = sa.Column('id', sa.Integer, primary_key=True)
+ spam = sa.Column('spam', sa.String(10), nullable=False)
+
+ def get_metadata(self):
+ return self.metadata
+
+ def get_engine(self):
+ return self.engine
+
+ def db_sync(self, engine):
+ self.metadata_migrations.create_all(bind=engine)
+
+ def include_object(self, object_, name, type_, reflected, compare_to):
+ if type_ == 'table':
+ return name == 'testtbl'
+ else:
+ return True
+
+ def _test_models_not_sync(self):
+ self.metadata_migrations.clear()
+ sa.Table(
+ 'table', self.metadata_migrations,
+ sa.Column('fk_check', sa.String(36), nullable=False),
+ sa.PrimaryKeyConstraint('fk_check'),
+ mysql_engine='InnoDB'
+ )
+ sa.Table(
+ 'testtbl', self.metadata_migrations,
+ sa.Column('id', sa.Integer, primary_key=True),
+ sa.Column('spam', sa.String(8), nullable=True),
+ sa.Column('eggs', sa.DateTime),
+ sa.Column('foo', sa.Boolean,
+ server_default=sa.sql.expression.false()),
+ sa.Column('bool_wo_default', sa.Boolean, unique=True),
+ sa.Column('bar', sa.BigInteger),
+ sa.Column('defaulttest', sa.Integer, server_default='7'),
+ sa.Column('defaulttest2', sa.String(8), server_default=''),
+ sa.Column('defaulttest3', sa.String(5), server_default="fake"),
+ sa.Column('defaulttest4',
+ sa.Enum('first', 'second', name='testenum'),
+ server_default="first"),
+ sa.Column('fk_check', sa.String(36), nullable=False),
+ sa.UniqueConstraint('spam', 'foo', name='uniq_cons'),
+ sa.ForeignKeyConstraint(['fk_check'], ['table.fk_check']),
+ mysql_engine='InnoDB'
+ )
+
+ msg = six.text_type(self.assertRaises(AssertionError,
+ self.test_models_sync))
+ # NOTE(I159): Check mentioning of the table and columns.
+ # The log is invalid json, so we can't parse it and check it for
+ # full compliance. We have no guarantee of the log items ordering,
+ # so we can't use regexp.
+ self.assertTrue(msg.startswith(
+ 'Models and migration scripts aren\'t in sync:'))
+ self.assertIn('testtbl', msg)
+ self.assertIn('spam', msg)
+ self.assertIn('eggs', msg) # test that the unique constraint is added
+ self.assertIn('foo', msg)
+ self.assertIn('bar', msg)
+ self.assertIn('bool_wo_default', msg)
+ self.assertIn('defaulttest', msg)
+ self.assertIn('defaulttest3', msg)
+ self.assertIn('drop_key', msg)
+
+
+class ModelsMigrationsSyncMysql(ModelsMigrationSyncMixin,
+ migrate.ModelsMigrationsSync,
+ test_base.MySQLOpportunisticTestCase):
+
+ def test_models_not_sync(self):
+ self._test_models_not_sync()
+
+
+class ModelsMigrationsSyncPsql(ModelsMigrationSyncMixin,
+ migrate.ModelsMigrationsSync,
+ test_base.PostgreSQLOpportunisticTestCase):
+
+ def test_models_not_sync(self):
+ self._test_models_not_sync()
diff --git a/oslo_db/tests/sqlalchemy/test_models.py b/oslo_db/tests/sqlalchemy/test_models.py
new file mode 100644
index 0000000..4a45576
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_models.py
@@ -0,0 +1,146 @@
+# Copyright 2012 Cloudscaling Group, Inc.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import collections
+
+from oslotest import base as oslo_test
+from sqlalchemy import Column
+from sqlalchemy import Integer, String
+from sqlalchemy.ext.declarative import declarative_base
+
+from oslo_db.sqlalchemy import models
+from oslo_db.sqlalchemy import test_base
+
+
+BASE = declarative_base()
+
+
+class ModelBaseTest(test_base.DbTestCase):
+ def setUp(self):
+ super(ModelBaseTest, self).setUp()
+ self.mb = models.ModelBase()
+ self.ekm = ExtraKeysModel()
+
+ def test_modelbase_has_dict_methods(self):
+ dict_methods = ('__getitem__',
+ '__setitem__',
+ '__contains__',
+ 'get',
+ 'update',
+ 'save',
+ 'iteritems')
+ for method in dict_methods:
+ self.assertTrue(hasattr(models.ModelBase, method),
+ "Method %s() is not found" % method)
+
+ def test_modelbase_is_iterable(self):
+ self.assertTrue(issubclass(models.ModelBase, collections.Iterable))
+
+ def test_modelbase_set(self):
+ self.mb['world'] = 'hello'
+ self.assertEqual(self.mb['world'], 'hello')
+
+ def test_modelbase_update(self):
+ h = {'a': '1', 'b': '2'}
+ self.mb.update(h)
+ for key in h.keys():
+ self.assertEqual(self.mb[key], h[key])
+
+ def test_modelbase_contains(self):
+ mb = models.ModelBase()
+ h = {'a': '1', 'b': '2'}
+ mb.update(h)
+ for key in h.keys():
+ # Test 'in' syntax (instead of using .assertIn)
+ self.assertTrue(key in mb)
+
+ self.assertFalse('non-existent-key' in mb)
+
+ def test_modelbase_iteritems(self):
+ h = {'a': '1', 'b': '2'}
+ expected = {
+ 'id': None,
+ 'smth': None,
+ 'name': 'NAME',
+ 'a': '1',
+ 'b': '2',
+ }
+ self.ekm.update(h)
+ self.assertEqual(dict(self.ekm.iteritems()), expected)
+
+ def test_modelbase_iter(self):
+ expected = {
+ 'id': None,
+ 'smth': None,
+ 'name': 'NAME',
+ }
+ i = iter(self.ekm)
+ found_items = 0
+ while True:
+ r = next(i, None)
+ if r is None:
+ break
+ self.assertEqual(expected[r[0]], r[1])
+ found_items += 1
+
+ self.assertEqual(len(expected), found_items)
+
+ def test_modelbase_several_iters(self):
+ mb = ExtraKeysModel()
+ it1 = iter(mb)
+ it2 = iter(mb)
+
+ self.assertFalse(it1 is it2)
+ self.assertEqual(dict(it1), dict(mb))
+ self.assertEqual(dict(it2), dict(mb))
+
+ def test_extra_keys_empty(self):
+ """Test verifies that by default extra_keys return empty list."""
+ self.assertEqual(self.mb._extra_keys, [])
+
+ def test_extra_keys_defined(self):
+ """Property _extra_keys will return list with attributes names."""
+ self.assertEqual(self.ekm._extra_keys, ['name'])
+
+ def test_model_with_extra_keys(self):
+ data = dict(self.ekm)
+ self.assertEqual(data, {'smth': None,
+ 'id': None,
+ 'name': 'NAME'})
+
+
+class ExtraKeysModel(BASE, models.ModelBase):
+ __tablename__ = 'test_model'
+
+ id = Column(Integer, primary_key=True)
+ smth = Column(String(255))
+
+ @property
+ def name(self):
+ return 'NAME'
+
+ @property
+ def _extra_keys(self):
+ return ['name']
+
+
+class TimestampMixinTest(oslo_test.BaseTestCase):
+
+ def test_timestampmixin_attr(self):
+ methods = ('created_at',
+ 'updated_at')
+ for method in methods:
+ self.assertTrue(hasattr(models.TimestampMixin, method),
+ "Method %s() is not found" % method)
diff --git a/oslo_db/tests/sqlalchemy/test_options.py b/oslo_db/tests/sqlalchemy/test_options.py
new file mode 100644
index 0000000..22a6e4f
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_options.py
@@ -0,0 +1,127 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from oslo.config import cfg
+from oslo.config import fixture as config
+
+from oslo_db import options
+from oslo_db.tests import utils as test_utils
+
+
+class DbApiOptionsTestCase(test_utils.BaseTestCase):
+ def setUp(self):
+ super(DbApiOptionsTestCase, self).setUp()
+
+ config_fixture = self.useFixture(config.Config())
+ self.conf = config_fixture.conf
+ self.conf.register_opts(options.database_opts, group='database')
+ self.config = config_fixture.config
+
+ def test_deprecated_session_parameters(self):
+ path = self.create_tempfiles([["tmp", b"""[DEFAULT]
+sql_connection=x://y.z
+sql_min_pool_size=10
+sql_max_pool_size=20
+sql_max_retries=30
+sql_retry_interval=40
+sql_max_overflow=50
+sql_connection_debug=60
+sql_connection_trace=True
+"""]])[0]
+ self.conf(['--config-file', path])
+ self.assertEqual(self.conf.database.connection, 'x://y.z')
+ self.assertEqual(self.conf.database.min_pool_size, 10)
+ self.assertEqual(self.conf.database.max_pool_size, 20)
+ self.assertEqual(self.conf.database.max_retries, 30)
+ self.assertEqual(self.conf.database.retry_interval, 40)
+ self.assertEqual(self.conf.database.max_overflow, 50)
+ self.assertEqual(self.conf.database.connection_debug, 60)
+ self.assertEqual(self.conf.database.connection_trace, True)
+
+ def test_session_parameters(self):
+ path = self.create_tempfiles([["tmp", b"""[database]
+connection=x://y.z
+min_pool_size=10
+max_pool_size=20
+max_retries=30
+retry_interval=40
+max_overflow=50
+connection_debug=60
+connection_trace=True
+pool_timeout=7
+"""]])[0]
+ self.conf(['--config-file', path])
+ self.assertEqual(self.conf.database.connection, 'x://y.z')
+ self.assertEqual(self.conf.database.min_pool_size, 10)
+ self.assertEqual(self.conf.database.max_pool_size, 20)
+ self.assertEqual(self.conf.database.max_retries, 30)
+ self.assertEqual(self.conf.database.retry_interval, 40)
+ self.assertEqual(self.conf.database.max_overflow, 50)
+ self.assertEqual(self.conf.database.connection_debug, 60)
+ self.assertEqual(self.conf.database.connection_trace, True)
+ self.assertEqual(self.conf.database.pool_timeout, 7)
+
+ def test_dbapi_database_deprecated_parameters(self):
+ path = self.create_tempfiles([['tmp', b'[DATABASE]\n'
+ b'sql_connection=fake_connection\n'
+ b'sql_idle_timeout=100\n'
+ b'sql_min_pool_size=99\n'
+ b'sql_max_pool_size=199\n'
+ b'sql_max_retries=22\n'
+ b'reconnect_interval=17\n'
+ b'sqlalchemy_max_overflow=101\n'
+ b'sqlalchemy_pool_timeout=5\n'
+ ]])[0]
+ self.conf(['--config-file', path])
+ self.assertEqual(self.conf.database.connection, 'fake_connection')
+ self.assertEqual(self.conf.database.idle_timeout, 100)
+ self.assertEqual(self.conf.database.min_pool_size, 99)
+ self.assertEqual(self.conf.database.max_pool_size, 199)
+ self.assertEqual(self.conf.database.max_retries, 22)
+ self.assertEqual(self.conf.database.retry_interval, 17)
+ self.assertEqual(self.conf.database.max_overflow, 101)
+ self.assertEqual(self.conf.database.pool_timeout, 5)
+
+ def test_dbapi_database_deprecated_parameters_sql(self):
+ path = self.create_tempfiles([['tmp', b'[sql]\n'
+ b'connection=test_sql_connection\n'
+ b'idle_timeout=99\n'
+ ]])[0]
+ self.conf(['--config-file', path])
+ self.assertEqual(self.conf.database.connection, 'test_sql_connection')
+ self.assertEqual(self.conf.database.idle_timeout, 99)
+
+ def test_deprecated_dbapi_parameters(self):
+ path = self.create_tempfiles([['tmp', b'[DEFAULT]\n'
+ b'db_backend=test_123\n'
+ ]])[0]
+
+ self.conf(['--config-file', path])
+ self.assertEqual(self.conf.database.backend, 'test_123')
+
+ def test_dbapi_parameters(self):
+ path = self.create_tempfiles([['tmp', b'[database]\n'
+ b'backend=test_123\n'
+ ]])[0]
+
+ self.conf(['--config-file', path])
+ self.assertEqual(self.conf.database.backend, 'test_123')
+
+ def test_set_defaults(self):
+ conf = cfg.ConfigOpts()
+
+ options.set_defaults(conf,
+ connection='sqlite:///:memory:')
+
+ self.assertTrue(len(conf.database.items()) > 1)
+ self.assertEqual('sqlite:///:memory:', conf.database.connection)
diff --git a/oslo_db/tests/sqlalchemy/test_sqlalchemy.py b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py
new file mode 100644
index 0000000..84cbbcf
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py
@@ -0,0 +1,655 @@
+# coding=utf-8
+
+# Copyright (c) 2012 Rackspace Hosting
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for SQLAlchemy specific code."""
+
+import logging
+
+import fixtures
+import mock
+from oslo.config import cfg
+from oslotest import base as oslo_test
+import sqlalchemy
+from sqlalchemy import Column, MetaData, Table
+from sqlalchemy.engine import url
+from sqlalchemy import Integer, String
+from sqlalchemy.ext.declarative import declarative_base
+
+from oslo_db import exception
+from oslo_db import options as db_options
+from oslo_db.sqlalchemy import models
+from oslo_db.sqlalchemy import session
+from oslo_db.sqlalchemy import test_base
+
+
+BASE = declarative_base()
+_TABLE_NAME = '__tmp__test__tmp__'
+
+_REGEXP_TABLE_NAME = _TABLE_NAME + "regexp"
+
+
+class RegexpTable(BASE, models.ModelBase):
+ __tablename__ = _REGEXP_TABLE_NAME
+ id = Column(Integer, primary_key=True)
+ bar = Column(String(255))
+
+
+class RegexpFilterTestCase(test_base.DbTestCase):
+
+ def setUp(self):
+ super(RegexpFilterTestCase, self).setUp()
+ meta = MetaData()
+ meta.bind = self.engine
+ test_table = Table(_REGEXP_TABLE_NAME, meta,
+ Column('id', Integer, primary_key=True,
+ nullable=False),
+ Column('bar', String(255)))
+ test_table.create()
+ self.addCleanup(test_table.drop)
+
+ def _test_regexp_filter(self, regexp, expected):
+ _session = self.sessionmaker()
+ with _session.begin():
+ for i in ['10', '20', u'♥']:
+ tbl = RegexpTable()
+ tbl.update({'bar': i})
+ tbl.save(session=_session)
+
+ regexp_op = RegexpTable.bar.op('REGEXP')(regexp)
+ result = _session.query(RegexpTable).filter(regexp_op).all()
+ self.assertEqual([r.bar for r in result], expected)
+
+ def test_regexp_filter(self):
+ self._test_regexp_filter('10', ['10'])
+
+ def test_regexp_filter_nomatch(self):
+ self._test_regexp_filter('11', [])
+
+ def test_regexp_filter_unicode(self):
+ self._test_regexp_filter(u'♥', [u'♥'])
+
+ def test_regexp_filter_unicode_nomatch(self):
+ self._test_regexp_filter(u'♦', [])
+
+
+class SQLiteSavepointTest(test_base.DbTestCase):
+ def setUp(self):
+ super(SQLiteSavepointTest, self).setUp()
+ meta = MetaData()
+ self.test_table = Table(
+ "test_table", meta,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(10)))
+ self.test_table.create(self.engine)
+ self.addCleanup(self.test_table.drop, self.engine)
+
+ def test_plain_transaction(self):
+ conn = self.engine.connect()
+ trans = conn.begin()
+ conn.execute(
+ self.test_table.insert(),
+ {'data': 'data 1'}
+ )
+ self.assertEqual(
+ [(1, 'data 1')],
+ self.engine.execute(
+ self.test_table.select().
+ order_by(self.test_table.c.id)
+ ).fetchall()
+ )
+ trans.rollback()
+ self.assertEqual(
+ 0,
+ self.engine.scalar(self.test_table.count())
+ )
+
+ def test_savepoint_middle(self):
+ with self.engine.begin() as conn:
+ conn.execute(
+ self.test_table.insert(),
+ {'data': 'data 1'}
+ )
+
+ savepoint = conn.begin_nested()
+ conn.execute(
+ self.test_table.insert(),
+ {'data': 'data 2'}
+ )
+ savepoint.rollback()
+
+ conn.execute(
+ self.test_table.insert(),
+ {'data': 'data 3'}
+ )
+
+ self.assertEqual(
+ [(1, 'data 1'), (2, 'data 3')],
+ self.engine.execute(
+ self.test_table.select().
+ order_by(self.test_table.c.id)
+ ).fetchall()
+ )
+
+ def test_savepoint_beginning(self):
+ with self.engine.begin() as conn:
+ savepoint = conn.begin_nested()
+ conn.execute(
+ self.test_table.insert(),
+ {'data': 'data 1'}
+ )
+ savepoint.rollback()
+
+ conn.execute(
+ self.test_table.insert(),
+ {'data': 'data 2'}
+ )
+
+ self.assertEqual(
+ [(1, 'data 2')],
+ self.engine.execute(
+ self.test_table.select().
+ order_by(self.test_table.c.id)
+ ).fetchall()
+ )
+
+
+class FakeDBAPIConnection():
+ def cursor(self):
+ return FakeCursor()
+
+
+class FakeCursor():
+ def execute(self, sql):
+ pass
+
+
+class FakeConnectionProxy():
+ pass
+
+
+class FakeConnectionRec():
+ pass
+
+
+class OperationalError(Exception):
+ pass
+
+
+class ProgrammingError(Exception):
+ pass
+
+
+class FakeDB2Engine(object):
+
+ class Dialect():
+
+ def is_disconnect(self, e, *args):
+ expected_error = ('SQL30081N: DB2 Server connection is no longer '
+ 'active')
+ return (str(e) == expected_error)
+
+ dialect = Dialect()
+ name = 'ibm_db_sa'
+
+ def dispose(self):
+ pass
+
+
+class MySQLModeTestCase(test_base.MySQLOpportunisticTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(MySQLModeTestCase, self).__init__(*args, **kwargs)
+ # By default, run in empty SQL mode.
+ # Subclasses override this with specific modes.
+ self.mysql_mode = ''
+
+ def setUp(self):
+ super(MySQLModeTestCase, self).setUp()
+
+ self.engine = session.create_engine(self.engine.url,
+ mysql_sql_mode=self.mysql_mode)
+ self.connection = self.engine.connect()
+
+ meta = MetaData()
+ meta.bind = self.engine
+ self.test_table = Table(_TABLE_NAME + "mode", meta,
+ Column('id', Integer, primary_key=True),
+ Column('bar', String(255)))
+ self.test_table.create()
+
+ self.addCleanup(self.test_table.drop)
+ self.addCleanup(self.connection.close)
+
+ def _test_string_too_long(self, value):
+ with self.connection.begin():
+ self.connection.execute(self.test_table.insert(),
+ bar=value)
+ result = self.connection.execute(self.test_table.select())
+ return result.fetchone()['bar']
+
+ def test_string_too_long(self):
+ value = 'a' * 512
+ # String is too long.
+ # With no SQL mode set, this gets truncated.
+ self.assertNotEqual(value,
+ self._test_string_too_long(value))
+
+
+class MySQLStrictAllTablesModeTestCase(MySQLModeTestCase):
+ "Test data integrity enforcement in MySQL STRICT_ALL_TABLES mode."
+
+ def __init__(self, *args, **kwargs):
+ super(MySQLStrictAllTablesModeTestCase, self).__init__(*args, **kwargs)
+ self.mysql_mode = 'STRICT_ALL_TABLES'
+
+ def test_string_too_long(self):
+ value = 'a' * 512
+ # String is too long.
+ # With STRICT_ALL_TABLES or TRADITIONAL mode set, this is an error.
+ self.assertRaises(exception.DBError,
+ self._test_string_too_long, value)
+
+
+class MySQLTraditionalModeTestCase(MySQLStrictAllTablesModeTestCase):
+ """Test data integrity enforcement in MySQL TRADITIONAL mode.
+
+ Since TRADITIONAL includes STRICT_ALL_TABLES, this inherits all
+ STRICT_ALL_TABLES mode tests.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(MySQLTraditionalModeTestCase, self).__init__(*args, **kwargs)
+ self.mysql_mode = 'TRADITIONAL'
+
+
+class EngineFacadeTestCase(oslo_test.BaseTestCase):
+ def setUp(self):
+ super(EngineFacadeTestCase, self).setUp()
+
+ self.facade = session.EngineFacade('sqlite://')
+
+ def test_get_engine(self):
+ eng1 = self.facade.get_engine()
+ eng2 = self.facade.get_engine()
+
+ self.assertIs(eng1, eng2)
+
+ def test_get_session(self):
+ ses1 = self.facade.get_session()
+ ses2 = self.facade.get_session()
+
+ self.assertIsNot(ses1, ses2)
+
+ def test_get_session_arguments_override_default_settings(self):
+ ses = self.facade.get_session(autocommit=False, expire_on_commit=True)
+
+ self.assertFalse(ses.autocommit)
+ self.assertTrue(ses.expire_on_commit)
+
+ @mock.patch('oslo_db.sqlalchemy.session.get_maker')
+ @mock.patch('oslo_db.sqlalchemy.session.create_engine')
+ def test_creation_from_config(self, create_engine, get_maker):
+ conf = cfg.ConfigOpts()
+ conf.register_opts(db_options.database_opts, group='database')
+
+ overrides = {
+ 'connection': 'sqlite:///:memory:',
+ 'slave_connection': None,
+ 'connection_debug': 100,
+ 'max_pool_size': 10,
+ 'mysql_sql_mode': 'TRADITIONAL',
+ }
+ for optname, optvalue in overrides.items():
+ conf.set_override(optname, optvalue, group='database')
+
+ session.EngineFacade.from_config(conf,
+ autocommit=False,
+ expire_on_commit=True)
+
+ create_engine.assert_called_once_with(
+ sql_connection='sqlite:///:memory:',
+ connection_debug=100,
+ max_pool_size=10,
+ mysql_sql_mode='TRADITIONAL',
+ sqlite_fk=False,
+ idle_timeout=mock.ANY,
+ retry_interval=mock.ANY,
+ max_retries=mock.ANY,
+ max_overflow=mock.ANY,
+ connection_trace=mock.ANY,
+ sqlite_synchronous=mock.ANY,
+ pool_timeout=mock.ANY,
+ thread_checkin=mock.ANY,
+ )
+ get_maker.assert_called_once_with(engine=create_engine(),
+ autocommit=False,
+ expire_on_commit=True)
+
+ def test_slave_connection(self):
+ paths = self.create_tempfiles([('db.master', ''), ('db.slave', '')],
+ ext='')
+ master_path = 'sqlite:///' + paths[0]
+ slave_path = 'sqlite:///' + paths[1]
+
+ facade = session.EngineFacade(
+ sql_connection=master_path,
+ slave_connection=slave_path
+ )
+
+ master = facade.get_engine()
+ self.assertEqual(master_path, str(master.url))
+ slave = facade.get_engine(use_slave=True)
+ self.assertEqual(slave_path, str(slave.url))
+
+ master_session = facade.get_session()
+ self.assertEqual(master_path, str(master_session.bind.url))
+ slave_session = facade.get_session(use_slave=True)
+ self.assertEqual(slave_path, str(slave_session.bind.url))
+
+ def test_slave_connection_string_not_provided(self):
+ master_path = 'sqlite:///' + self.create_tempfiles(
+ [('db.master', '')], ext='')[0]
+
+ facade = session.EngineFacade(sql_connection=master_path)
+
+ master = facade.get_engine()
+ slave = facade.get_engine(use_slave=True)
+ self.assertIs(master, slave)
+ self.assertEqual(master_path, str(master.url))
+
+ master_session = facade.get_session()
+ self.assertEqual(master_path, str(master_session.bind.url))
+ slave_session = facade.get_session(use_slave=True)
+ self.assertEqual(master_path, str(slave_session.bind.url))
+
+
+class SQLiteConnectTest(oslo_test.BaseTestCase):
+
+ def _fixture(self, **kw):
+ return session.create_engine("sqlite://", **kw)
+
+ def test_sqlite_fk_listener(self):
+ engine = self._fixture(sqlite_fk=True)
+ self.assertEqual(
+ engine.scalar("pragma foreign_keys"),
+ 1
+ )
+
+ engine = self._fixture(sqlite_fk=False)
+
+ self.assertEqual(
+ engine.scalar("pragma foreign_keys"),
+ 0
+ )
+
+ def test_sqlite_synchronous_listener(self):
+ engine = self._fixture()
+
+ # "The default setting is synchronous=FULL." (e.g. 2)
+ # http://www.sqlite.org/pragma.html#pragma_synchronous
+ self.assertEqual(
+ engine.scalar("pragma synchronous"),
+ 2
+ )
+
+ engine = self._fixture(sqlite_synchronous=False)
+
+ self.assertEqual(
+ engine.scalar("pragma synchronous"),
+ 0
+ )
+
+
+class MysqlConnectTest(test_base.MySQLOpportunisticTestCase):
+
+ def _fixture(self, sql_mode):
+ return session.create_engine(self.engine.url, mysql_sql_mode=sql_mode)
+
+ def _assert_sql_mode(self, engine, sql_mode_present, sql_mode_non_present):
+ mode = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]
+ self.assertTrue(
+ sql_mode_present in mode
+ )
+ if sql_mode_non_present:
+ self.assertTrue(
+ sql_mode_non_present not in mode
+ )
+
+ def test_set_mode_traditional(self):
+ engine = self._fixture(sql_mode='TRADITIONAL')
+ self._assert_sql_mode(engine, "TRADITIONAL", "ANSI")
+
+ def test_set_mode_ansi(self):
+ engine = self._fixture(sql_mode='ANSI')
+ self._assert_sql_mode(engine, "ANSI", "TRADITIONAL")
+
+ def test_set_mode_no_mode(self):
+ # If _mysql_set_mode_callback is called with sql_mode=None, then
+ # the SQL mode is NOT set on the connection.
+
+ expected = self.engine.execute(
+ "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]
+
+ engine = self._fixture(sql_mode=None)
+ self._assert_sql_mode(engine, expected, None)
+
+ def test_fail_detect_mode(self):
+ # If "SHOW VARIABLES LIKE 'sql_mode'" results in no row, then
+ # we get a log indicating can't detect the mode.
+
+ log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
+
+ mysql_conn = self.engine.raw_connection()
+ self.addCleanup(mysql_conn.close)
+ mysql_conn.detach()
+ mysql_cursor = mysql_conn.cursor()
+
+ def execute(statement, parameters=()):
+ if "SHOW VARIABLES LIKE 'sql_mode'" in statement:
+ statement = "SHOW VARIABLES LIKE 'i_dont_exist'"
+ return mysql_cursor.execute(statement, parameters)
+
+ test_engine = sqlalchemy.create_engine(self.engine.url,
+ _initialize=False)
+
+ with mock.patch.object(
+ test_engine.pool, '_creator',
+ mock.Mock(
+ return_value=mock.Mock(
+ cursor=mock.Mock(
+ return_value=mock.Mock(
+ execute=execute,
+ fetchone=mysql_cursor.fetchone,
+ fetchall=mysql_cursor.fetchall
+ )
+ )
+ )
+ )
+ ):
+ session._init_events.dispatch_on_drivername("mysql")(test_engine)
+
+ test_engine.raw_connection()
+ self.assertIn('Unable to detect effective SQL mode',
+ log.output)
+
+ def test_logs_real_mode(self):
+ # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value, then
+ # we get a log with the value.
+
+ log = self.useFixture(fixtures.FakeLogger(level=logging.DEBUG))
+
+ engine = self._fixture(sql_mode='TRADITIONAL')
+
+ actual_mode = engine.execute(
+ "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]
+
+ self.assertIn('MySQL server mode set to %s' % actual_mode,
+ log.output)
+
+ def test_warning_when_not_traditional(self):
+ # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that doesn't
+ # include 'TRADITIONAL', then a warning is logged.
+
+ log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
+ self._fixture(sql_mode='ANSI')
+
+ self.assertIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES",
+ log.output)
+
+ def test_no_warning_when_traditional(self):
+ # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes
+ # 'TRADITIONAL', then no warning is logged.
+
+ log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
+ self._fixture(sql_mode='TRADITIONAL')
+
+ self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES",
+ log.output)
+
+ def test_no_warning_when_strict_all_tables(self):
+ # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes
+ # 'STRICT_ALL_TABLES', then no warning is logged.
+
+ log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
+ self._fixture(sql_mode='TRADITIONAL')
+
+ self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES",
+ log.output)
+
+
+class CreateEngineTest(oslo_test.BaseTestCase):
+ """Test that dialect-specific arguments/ listeners are set up correctly.
+
+ """
+
+ def setUp(self):
+ super(CreateEngineTest, self).setUp()
+ self.args = {'connect_args': {}}
+
+ def test_queuepool_args(self):
+ session._init_connection_args(
+ url.make_url("mysql://u:p@host/test"), self.args,
+ max_pool_size=10, max_overflow=10)
+ self.assertEqual(self.args['pool_size'], 10)
+ self.assertEqual(self.args['max_overflow'], 10)
+
+ def test_sqlite_memory_pool_args(self):
+ for _url in ("sqlite://", "sqlite:///:memory:"):
+ session._init_connection_args(
+ url.make_url(_url), self.args,
+ max_pool_size=10, max_overflow=10)
+
+ # queuepool arguments are not peresnet
+ self.assertTrue(
+ 'pool_size' not in self.args)
+ self.assertTrue(
+ 'max_overflow' not in self.args)
+
+ self.assertEqual(self.args['connect_args']['check_same_thread'],
+ False)
+
+ # due to memory connection
+ self.assertTrue('poolclass' in self.args)
+
+ def test_sqlite_file_pool_args(self):
+ session._init_connection_args(
+ url.make_url("sqlite:///somefile.db"), self.args,
+ max_pool_size=10, max_overflow=10)
+
+ # queuepool arguments are not peresnet
+ self.assertTrue('pool_size' not in self.args)
+ self.assertTrue(
+ 'max_overflow' not in self.args)
+
+ self.assertFalse(self.args['connect_args'])
+
+ # NullPool is the default for file based connections,
+ # no need to specify this
+ self.assertTrue('poolclass' not in self.args)
+
+ def test_mysql_connect_args_default(self):
+ session._init_connection_args(
+ url.make_url("mysql://u:p@host/test"), self.args)
+ self.assertEqual(self.args['connect_args'],
+ {'charset': 'utf8', 'use_unicode': 0})
+
+ def test_mysql_oursql_connect_args_default(self):
+ session._init_connection_args(
+ url.make_url("mysql+oursql://u:p@host/test"), self.args)
+ self.assertEqual(self.args['connect_args'],
+ {'charset': 'utf8', 'use_unicode': 0})
+
+ def test_mysql_mysqldb_connect_args_default(self):
+ session._init_connection_args(
+ url.make_url("mysql+mysqldb://u:p@host/test"), self.args)
+ self.assertEqual(self.args['connect_args'],
+ {'charset': 'utf8', 'use_unicode': 0})
+
+ def test_postgresql_connect_args_default(self):
+ session._init_connection_args(
+ url.make_url("postgresql://u:p@host/test"), self.args)
+ self.assertEqual(self.args['client_encoding'], 'utf8')
+ self.assertFalse(self.args['connect_args'])
+
+ def test_mysqlconnector_raise_on_warnings_default(self):
+ session._init_connection_args(
+ url.make_url("mysql+mysqlconnector://u:p@host/test"),
+ self.args)
+ self.assertEqual(self.args['connect_args']['raise_on_warnings'], False)
+
+ def test_mysqlconnector_raise_on_warnings_override(self):
+ session._init_connection_args(
+ url.make_url(
+ "mysql+mysqlconnector://u:p@host/test"
+ "?raise_on_warnings=true"),
+ self.args
+ )
+
+ self.assertFalse('raise_on_warnings' in self.args['connect_args'])
+
+ def test_thread_checkin(self):
+ with mock.patch("sqlalchemy.event.listens_for"):
+ with mock.patch("sqlalchemy.event.listen") as listen_evt:
+ session._init_events.dispatch_on_drivername(
+ "sqlite")(mock.Mock())
+
+ self.assertEqual(
+ listen_evt.mock_calls[0][1][-1],
+ session._thread_yield
+ )
+
+
+# NOTE(dhellmann): This test no longer works as written. The code in
+# oslo_db.sqlalchemy.session filters out lines from modules under
+# oslo_db, and now this test is under oslo_db, so the test filename
+# does not appear in the context for the error message. LP #1405376
+
+# class PatchStacktraceTest(test_base.DbTestCase):
+
+# def test_trace(self):
+# engine = self.engine
+# session._add_trace_comments(engine)
+# conn = engine.connect()
+# with mock.patch.object(engine.dialect, "do_execute") as mock_exec:
+
+# conn.execute("select * from table")
+
+# call = mock_exec.mock_calls[0]
+
+# # we're the caller, see that we're in there
+# self.assertIn("oslo_db/tests/sqlalchemy/test_sqlalchemy.py",
+# call[1][1])
diff --git a/oslo_db/tests/sqlalchemy/test_utils.py b/oslo_db/tests/sqlalchemy/test_utils.py
new file mode 100644
index 0000000..509cf48
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_utils.py
@@ -0,0 +1,1090 @@
+# Copyright (c) 2013 Boris Pavlovic (boris@pavlovic.me).
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import uuid
+
+import fixtures
+import mock
+from oslotest import base as test_base
+from oslotest import moxstubout
+import six
+from six.moves.urllib import parse
+import sqlalchemy
+from sqlalchemy.dialects import mysql
+from sqlalchemy import Boolean, Index, Integer, DateTime, String, SmallInteger
+from sqlalchemy import MetaData, Table, Column, ForeignKey
+from sqlalchemy.engine import reflection
+from sqlalchemy.engine import url as sa_url
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.sql import select
+from sqlalchemy.types import UserDefinedType, NullType
+
+from oslo_db import exception
+from oslo_db.sqlalchemy import models
+from oslo_db.sqlalchemy import provision
+from oslo_db.sqlalchemy import session
+from oslo_db.sqlalchemy import test_base as db_test_base
+from oslo_db.sqlalchemy import utils
+from oslo_db.tests import utils as test_utils
+
+
+SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.')))
+
+
+class TestSanitizeDbUrl(test_base.BaseTestCase):
+
+ def test_url_with_cred(self):
+ db_url = 'myproto://johndoe:secret@localhost/myschema'
+ expected = 'myproto://****:****@localhost/myschema'
+ actual = utils.sanitize_db_url(db_url)
+ self.assertEqual(expected, actual)
+
+ def test_url_with_no_cred(self):
+ db_url = 'sqlite:///mysqlitefile'
+ actual = utils.sanitize_db_url(db_url)
+ self.assertEqual(db_url, actual)
+
+
+class CustomType(UserDefinedType):
+ """Dummy column type for testing unsupported types."""
+ def get_col_spec(self):
+ return "CustomType"
+
+
+class FakeModel(object):
+ def __init__(self, values):
+ self.values = values
+
+ def __getattr__(self, name):
+ try:
+ value = self.values[name]
+ except KeyError:
+ raise AttributeError(name)
+ return value
+
+ def __getitem__(self, key):
+ if key in self.values:
+ return self.values[key]
+ else:
+ raise NotImplementedError()
+
+ def __repr__(self):
+ return '<FakeModel: %s>' % self.values
+
+
+class TestPaginateQuery(test_base.BaseTestCase):
+ def setUp(self):
+ super(TestPaginateQuery, self).setUp()
+ mox_fixture = self.useFixture(moxstubout.MoxStubout())
+ self.mox = mox_fixture.mox
+ self.query = self.mox.CreateMockAnything()
+ self.mox.StubOutWithMock(sqlalchemy, 'asc')
+ self.mox.StubOutWithMock(sqlalchemy, 'desc')
+ self.marker = FakeModel({
+ 'user_id': 'user',
+ 'project_id': 'p',
+ 'snapshot_id': 's',
+ })
+ self.model = FakeModel({
+ 'user_id': 'user',
+ 'project_id': 'project',
+ 'snapshot_id': 'snapshot',
+ })
+
+ def test_paginate_query_no_pagination_no_sort_dirs(self):
+ sqlalchemy.asc('user').AndReturn('asc_3')
+ self.query.order_by('asc_3').AndReturn(self.query)
+ sqlalchemy.asc('project').AndReturn('asc_2')
+ self.query.order_by('asc_2').AndReturn(self.query)
+ sqlalchemy.asc('snapshot').AndReturn('asc_1')
+ self.query.order_by('asc_1').AndReturn(self.query)
+ self.query.limit(5).AndReturn(self.query)
+ self.mox.ReplayAll()
+ utils.paginate_query(self.query, self.model, 5,
+ ['user_id', 'project_id', 'snapshot_id'])
+
+ def test_paginate_query_no_pagination(self):
+ sqlalchemy.asc('user').AndReturn('asc')
+ self.query.order_by('asc').AndReturn(self.query)
+ sqlalchemy.desc('project').AndReturn('desc')
+ self.query.order_by('desc').AndReturn(self.query)
+ self.query.limit(5).AndReturn(self.query)
+ self.mox.ReplayAll()
+ utils.paginate_query(self.query, self.model, 5,
+ ['user_id', 'project_id'],
+ sort_dirs=['asc', 'desc'])
+
+ def test_paginate_query_attribute_error(self):
+ sqlalchemy.asc('user').AndReturn('asc')
+ self.query.order_by('asc').AndReturn(self.query)
+ self.mox.ReplayAll()
+ self.assertRaises(exception.InvalidSortKey,
+ utils.paginate_query, self.query,
+ self.model, 5, ['user_id', 'non-existent key'])
+
+ def test_paginate_query_assertion_error(self):
+ self.mox.ReplayAll()
+ self.assertRaises(AssertionError,
+ utils.paginate_query, self.query,
+ self.model, 5, ['user_id'],
+ marker=self.marker,
+ sort_dir='asc', sort_dirs=['asc'])
+
+ def test_paginate_query_assertion_error_2(self):
+ self.mox.ReplayAll()
+ self.assertRaises(AssertionError,
+ utils.paginate_query, self.query,
+ self.model, 5, ['user_id'],
+ marker=self.marker,
+ sort_dir=None, sort_dirs=['asc', 'desk'])
+
+ def test_paginate_query(self):
+ sqlalchemy.asc('user').AndReturn('asc_1')
+ self.query.order_by('asc_1').AndReturn(self.query)
+ sqlalchemy.desc('project').AndReturn('desc_1')
+ self.query.order_by('desc_1').AndReturn(self.query)
+ self.mox.StubOutWithMock(sqlalchemy.sql, 'and_')
+ sqlalchemy.sql.and_(False).AndReturn('some_crit')
+ sqlalchemy.sql.and_(True, False).AndReturn('another_crit')
+ self.mox.StubOutWithMock(sqlalchemy.sql, 'or_')
+ sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f')
+ self.query.filter('some_f').AndReturn(self.query)
+ self.query.limit(5).AndReturn(self.query)
+ self.mox.ReplayAll()
+ utils.paginate_query(self.query, self.model, 5,
+ ['user_id', 'project_id'],
+ marker=self.marker,
+ sort_dirs=['asc', 'desc'])
+
+ def test_paginate_query_value_error(self):
+ sqlalchemy.asc('user').AndReturn('asc_1')
+ self.query.order_by('asc_1').AndReturn(self.query)
+ self.mox.ReplayAll()
+ self.assertRaises(ValueError, utils.paginate_query,
+ self.query, self.model, 5, ['user_id', 'project_id'],
+ marker=self.marker, sort_dirs=['asc', 'mixed'])
+
+
+class TestMigrationUtils(db_test_base.DbTestCase):
+
+ """Class for testing utils that are used in db migrations."""
+
+ def setUp(self):
+ super(TestMigrationUtils, self).setUp()
+ self.meta = MetaData(bind=self.engine)
+ self.conn = self.engine.connect()
+ self.addCleanup(self.meta.drop_all)
+ self.addCleanup(self.conn.close)
+
+ def _populate_db_for_drop_duplicate_entries(self, engine, meta,
+ table_name):
+ values = [
+ {'id': 11, 'a': 3, 'b': 10, 'c': 'abcdef'},
+ {'id': 12, 'a': 5, 'b': 10, 'c': 'abcdef'},
+ {'id': 13, 'a': 6, 'b': 10, 'c': 'abcdef'},
+ {'id': 14, 'a': 7, 'b': 10, 'c': 'abcdef'},
+ {'id': 21, 'a': 1, 'b': 20, 'c': 'aa'},
+ {'id': 31, 'a': 1, 'b': 20, 'c': 'bb'},
+ {'id': 41, 'a': 1, 'b': 30, 'c': 'aef'},
+ {'id': 42, 'a': 2, 'b': 30, 'c': 'aef'},
+ {'id': 43, 'a': 3, 'b': 30, 'c': 'aef'}
+ ]
+
+ test_table = Table(table_name, meta,
+ Column('id', Integer, primary_key=True,
+ nullable=False),
+ Column('a', Integer),
+ Column('b', Integer),
+ Column('c', String(255)),
+ Column('deleted', Integer, default=0),
+ Column('deleted_at', DateTime),
+ Column('updated_at', DateTime))
+
+ test_table.create()
+ engine.execute(test_table.insert(), values)
+ return test_table, values
+
+ def test_drop_old_duplicate_entries_from_table(self):
+ table_name = "__test_tmp_table__"
+
+ test_table, values = self._populate_db_for_drop_duplicate_entries(
+ self.engine, self.meta, table_name)
+ utils.drop_old_duplicate_entries_from_table(
+ self.engine, table_name, False, 'b', 'c')
+
+ uniq_values = set()
+ expected_ids = []
+ for value in sorted(values, key=lambda x: x['id'], reverse=True):
+ uniq_value = (('b', value['b']), ('c', value['c']))
+ if uniq_value in uniq_values:
+ continue
+ uniq_values.add(uniq_value)
+ expected_ids.append(value['id'])
+
+ real_ids = [row[0] for row in
+ self.engine.execute(select([test_table.c.id])).fetchall()]
+
+ self.assertEqual(len(real_ids), len(expected_ids))
+ for id_ in expected_ids:
+ self.assertTrue(id_ in real_ids)
+
+ def test_drop_dup_entries_in_file_conn(self):
+ table_name = "__test_tmp_table__"
+ tmp_db_file = self.create_tempfiles([['name', '']], ext='.sql')[0]
+ in_file_engine = session.EngineFacade(
+ 'sqlite:///%s' % tmp_db_file).get_engine()
+ meta = MetaData()
+ meta.bind = in_file_engine
+ test_table, values = self._populate_db_for_drop_duplicate_entries(
+ in_file_engine, meta, table_name)
+ utils.drop_old_duplicate_entries_from_table(
+ in_file_engine, table_name, False, 'b', 'c')
+
+ def test_drop_old_duplicate_entries_from_table_soft_delete(self):
+ table_name = "__test_tmp_table__"
+
+ table, values = self._populate_db_for_drop_duplicate_entries(
+ self.engine, self.meta, table_name)
+ utils.drop_old_duplicate_entries_from_table(self.engine, table_name,
+ True, 'b', 'c')
+ uniq_values = set()
+ expected_values = []
+ soft_deleted_values = []
+
+ for value in sorted(values, key=lambda x: x['id'], reverse=True):
+ uniq_value = (('b', value['b']), ('c', value['c']))
+ if uniq_value in uniq_values:
+ soft_deleted_values.append(value)
+ continue
+ uniq_values.add(uniq_value)
+ expected_values.append(value)
+
+ base_select = table.select()
+
+ rows_select = base_select.where(table.c.deleted != table.c.id)
+ row_ids = [row['id'] for row in
+ self.engine.execute(rows_select).fetchall()]
+ self.assertEqual(len(row_ids), len(expected_values))
+ for value in expected_values:
+ self.assertTrue(value['id'] in row_ids)
+
+ deleted_rows_select = base_select.where(
+ table.c.deleted == table.c.id)
+ deleted_rows_ids = [row['id'] for row in
+ self.engine.execute(
+ deleted_rows_select).fetchall()]
+ self.assertEqual(len(deleted_rows_ids),
+ len(values) - len(row_ids))
+ for value in soft_deleted_values:
+ self.assertTrue(value['id'] in deleted_rows_ids)
+
+ def test_change_deleted_column_type_does_not_drop_index(self):
+ table_name = 'abc'
+
+ indexes = {
+ 'idx_a_deleted': ['a', 'deleted'],
+ 'idx_b_deleted': ['b', 'deleted'],
+ 'idx_a': ['a']
+ }
+
+ index_instances = [Index(name, *columns)
+ for name, columns in six.iteritems(indexes)]
+
+ table = Table(table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('a', String(255)),
+ Column('b', String(255)),
+ Column('deleted', Boolean),
+ *index_instances)
+ table.create()
+ utils.change_deleted_column_type_to_id_type(self.engine, table_name)
+ utils.change_deleted_column_type_to_boolean(self.engine, table_name)
+
+ insp = reflection.Inspector.from_engine(self.engine)
+ real_indexes = insp.get_indexes(table_name)
+ self.assertEqual(len(real_indexes), 3)
+ for index in real_indexes:
+ name = index['name']
+ self.assertIn(name, indexes)
+ self.assertEqual(set(index['column_names']),
+ set(indexes[name]))
+
+ def test_change_deleted_column_type_to_id_type_integer(self):
+ table_name = 'abc'
+ table = Table(table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('deleted', Boolean))
+ table.create()
+ utils.change_deleted_column_type_to_id_type(self.engine, table_name)
+
+ table = utils.get_table(self.engine, table_name)
+ self.assertTrue(isinstance(table.c.deleted.type, Integer))
+
+ def test_change_deleted_column_type_to_id_type_string(self):
+ table_name = 'abc'
+ table = Table(table_name, self.meta,
+ Column('id', String(255), primary_key=True),
+ Column('deleted', Boolean))
+ table.create()
+ utils.change_deleted_column_type_to_id_type(self.engine, table_name)
+
+ table = utils.get_table(self.engine, table_name)
+ self.assertTrue(isinstance(table.c.deleted.type, String))
+
+ @db_test_base.backend_specific('sqlite')
+ def test_change_deleted_column_type_to_id_type_custom(self):
+ table_name = 'abc'
+ table = Table(table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('foo', CustomType),
+ Column('deleted', Boolean))
+ table.create()
+
+ # reflection of custom types has been fixed upstream
+ if SA_VERSION < (0, 9, 0):
+ self.assertRaises(exception.ColumnError,
+ utils.change_deleted_column_type_to_id_type,
+ self.engine, table_name)
+
+ fooColumn = Column('foo', CustomType())
+ utils.change_deleted_column_type_to_id_type(self.engine, table_name,
+ foo=fooColumn)
+
+ table = utils.get_table(self.engine, table_name)
+ # NOTE(boris-42): There is no way to check has foo type CustomType.
+ # but sqlalchemy will set it to NullType. This has
+ # been fixed upstream in recent SA versions
+ if SA_VERSION < (0, 9, 0):
+ self.assertTrue(isinstance(table.c.foo.type, NullType))
+ self.assertTrue(isinstance(table.c.deleted.type, Integer))
+
+ def test_change_deleted_column_type_to_boolean(self):
+ expected_types = {'mysql': mysql.TINYINT,
+ 'ibm_db_sa': SmallInteger}
+ table_name = 'abc'
+ table = Table(table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('deleted', Integer))
+ table.create()
+
+ utils.change_deleted_column_type_to_boolean(self.engine, table_name)
+
+ table = utils.get_table(self.engine, table_name)
+ self.assertIsInstance(table.c.deleted.type,
+ expected_types.get(self.engine.name, Boolean))
+
+ def test_change_deleted_column_type_to_boolean_with_fc(self):
+ expected_types = {'mysql': mysql.TINYINT,
+ 'ibm_db_sa': SmallInteger}
+ table_name_1 = 'abc'
+ table_name_2 = 'bcd'
+
+ table_1 = Table(table_name_1, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('deleted', Integer))
+ table_1.create()
+
+ table_2 = Table(table_name_2, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('foreign_id', Integer,
+ ForeignKey('%s.id' % table_name_1)),
+ Column('deleted', Integer))
+ table_2.create()
+
+ utils.change_deleted_column_type_to_boolean(self.engine, table_name_2)
+
+ table = utils.get_table(self.engine, table_name_2)
+ self.assertIsInstance(table.c.deleted.type,
+ expected_types.get(self.engine.name, Boolean))
+
+ @db_test_base.backend_specific('sqlite')
+ def test_change_deleted_column_type_to_boolean_type_custom(self):
+ table_name = 'abc'
+ table = Table(table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('foo', CustomType),
+ Column('deleted', Integer))
+ table.create()
+
+ # reflection of custom types has been fixed upstream
+ if SA_VERSION < (0, 9, 0):
+ self.assertRaises(exception.ColumnError,
+ utils.change_deleted_column_type_to_boolean,
+ self.engine, table_name)
+
+ fooColumn = Column('foo', CustomType())
+ utils.change_deleted_column_type_to_boolean(self.engine, table_name,
+ foo=fooColumn)
+
+ table = utils.get_table(self.engine, table_name)
+ # NOTE(boris-42): There is no way to check has foo type CustomType.
+ # but sqlalchemy will set it to NullType. This has
+ # been fixed upstream in recent SA versions
+ if SA_VERSION < (0, 9, 0):
+ self.assertTrue(isinstance(table.c.foo.type, NullType))
+ self.assertTrue(isinstance(table.c.deleted.type, Boolean))
+
+ @db_test_base.backend_specific('sqlite')
+ def test_change_deleted_column_type_sqlite_drops_check_constraint(self):
+ table_name = 'abc'
+ table = Table(table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('deleted', Boolean))
+ table.create()
+
+ utils._change_deleted_column_type_to_id_type_sqlite(self.engine,
+ table_name)
+ table = Table(table_name, self.meta, autoload=True)
+ # NOTE(I159): if the CHECK constraint has been dropped (expected
+ # behavior), any integer value can be inserted, otherwise only 1 or 0.
+ self.engine.execute(table.insert({'deleted': 10}))
+
+ def test_insert_from_select(self):
+ insert_table_name = "__test_insert_to_table__"
+ select_table_name = "__test_select_from_table__"
+ uuidstrs = []
+ for unused in range(10):
+ uuidstrs.append(uuid.uuid4().hex)
+ insert_table = Table(
+ insert_table_name, self.meta,
+ Column('id', Integer, primary_key=True,
+ nullable=False, autoincrement=True),
+ Column('uuid', String(36), nullable=False))
+ select_table = Table(
+ select_table_name, self.meta,
+ Column('id', Integer, primary_key=True,
+ nullable=False, autoincrement=True),
+ Column('uuid', String(36), nullable=False))
+
+ insert_table.create()
+ select_table.create()
+ # Add 10 rows to select_table
+ for uuidstr in uuidstrs:
+ ins_stmt = select_table.insert().values(uuid=uuidstr)
+ self.conn.execute(ins_stmt)
+
+ # Select 4 rows in one chunk from select_table
+ column = select_table.c.id
+ query_insert = select([select_table],
+ select_table.c.id < 5).order_by(column)
+ insert_statement = utils.InsertFromSelect(insert_table,
+ query_insert)
+ result_insert = self.conn.execute(insert_statement)
+ # Verify we insert 4 rows
+ self.assertEqual(result_insert.rowcount, 4)
+
+ query_all = select([insert_table]).where(
+ insert_table.c.uuid.in_(uuidstrs))
+ rows = self.conn.execute(query_all).fetchall()
+ # Verify we really have 4 rows in insert_table
+ self.assertEqual(len(rows), 4)
+
+
+class PostgesqlTestMigrations(TestMigrationUtils,
+ db_test_base.PostgreSQLOpportunisticTestCase):
+
+ """Test migrations on PostgreSQL."""
+ pass
+
+
+class MySQLTestMigrations(TestMigrationUtils,
+ db_test_base.MySQLOpportunisticTestCase):
+
+ """Test migrations on MySQL."""
+ pass
+
+
+class TestConnectionUtils(test_utils.BaseTestCase):
+
+ def setUp(self):
+ super(TestConnectionUtils, self).setUp()
+
+ self.full_credentials = {'backend': 'postgresql',
+ 'database': 'test',
+ 'user': 'dude',
+ 'passwd': 'pass'}
+
+ self.connect_string = 'postgresql://dude:pass@localhost/test'
+
+ def test_connect_string(self):
+ connect_string = utils.get_connect_string(**self.full_credentials)
+ self.assertEqual(connect_string, self.connect_string)
+
+ def test_connect_string_sqlite(self):
+ sqlite_credentials = {'backend': 'sqlite', 'database': 'test.db'}
+ connect_string = utils.get_connect_string(**sqlite_credentials)
+ self.assertEqual(connect_string, 'sqlite:///test.db')
+
+ def test_is_backend_avail(self):
+ self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
+ fake_connection = self.mox.CreateMockAnything()
+ fake_connection.close()
+ sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection)
+ self.mox.ReplayAll()
+
+ self.assertTrue(utils.is_backend_avail(**self.full_credentials))
+
+ def test_is_backend_unavail(self):
+ log = self.useFixture(fixtures.FakeLogger())
+ err = OperationalError("Can't connect to database", None, None)
+ error_msg = "The postgresql backend is unavailable: %s\n" % err
+ self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
+ sqlalchemy.engine.base.Engine.connect().AndRaise(err)
+ self.mox.ReplayAll()
+ self.assertFalse(utils.is_backend_avail(**self.full_credentials))
+ self.assertEqual(error_msg, log.output)
+
+ def test_ensure_backend_available(self):
+ self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
+ fake_connection = self.mox.CreateMockAnything()
+ fake_connection.close()
+ sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection)
+ self.mox.ReplayAll()
+
+ eng = provision.Backend._ensure_backend_available(self.connect_string)
+ self.assertIsInstance(eng, sqlalchemy.engine.base.Engine)
+ self.assertEqual(self.connect_string, str(eng.url))
+
+ def test_ensure_backend_available_no_connection_raises(self):
+ log = self.useFixture(fixtures.FakeLogger())
+ err = OperationalError("Can't connect to database", None, None)
+ self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
+ sqlalchemy.engine.base.Engine.connect().AndRaise(err)
+ self.mox.ReplayAll()
+
+ exc = self.assertRaises(
+ exception.BackendNotAvailable,
+ provision.Backend._ensure_backend_available, self.connect_string
+ )
+ self.assertEqual("Could not connect", str(exc))
+ self.assertEqual(
+ "The postgresql backend is unavailable: %s" % err,
+ log.output.strip())
+
+ def test_ensure_backend_available_no_dbapi_raises(self):
+ log = self.useFixture(fixtures.FakeLogger())
+ self.mox.StubOutWithMock(sqlalchemy, 'create_engine')
+ sqlalchemy.create_engine(
+ sa_url.make_url(self.connect_string)).AndRaise(
+ ImportError("Can't import DBAPI module foobar"))
+ self.mox.ReplayAll()
+
+ exc = self.assertRaises(
+ exception.BackendNotAvailable,
+ provision.Backend._ensure_backend_available, self.connect_string
+ )
+ self.assertEqual("No DBAPI installed", str(exc))
+ self.assertEqual(
+ "The postgresql backend is unavailable: Can't import "
+ "DBAPI module foobar", log.output.strip())
+
+ def test_get_db_connection_info(self):
+ conn_pieces = parse.urlparse(self.connect_string)
+ self.assertEqual(utils.get_db_connection_info(conn_pieces),
+ ('dude', 'pass', 'test', 'localhost'))
+
+ def test_connect_string_host(self):
+ self.full_credentials['host'] = 'myhost'
+ connect_string = utils.get_connect_string(**self.full_credentials)
+ self.assertEqual(connect_string, 'postgresql://dude:pass@myhost/test')
+
+
+class MyModelSoftDeletedProjectId(declarative_base(), models.ModelBase,
+ models.SoftDeleteMixin):
+ __tablename__ = 'soft_deleted_project_id_test_model'
+ id = Column(Integer, primary_key=True)
+ project_id = Column(Integer)
+
+
+class MyModel(declarative_base(), models.ModelBase):
+ __tablename__ = 'test_model'
+ id = Column(Integer, primary_key=True)
+
+
+class MyModelSoftDeleted(declarative_base(), models.ModelBase,
+ models.SoftDeleteMixin):
+ __tablename__ = 'soft_deleted_test_model'
+ id = Column(Integer, primary_key=True)
+
+
+class TestModelQuery(test_base.BaseTestCase):
+
+ def setUp(self):
+ super(TestModelQuery, self).setUp()
+
+ self.session = mock.MagicMock()
+ self.session.query.return_value = self.session.query
+ self.session.query.filter.return_value = self.session.query
+
+ def test_wrong_model(self):
+ self.assertRaises(TypeError, utils.model_query,
+ FakeModel, session=self.session)
+
+ def test_no_soft_deleted(self):
+ self.assertRaises(ValueError, utils.model_query,
+ MyModel, session=self.session, deleted=True)
+
+ def test_deleted_false(self):
+ mock_query = utils.model_query(
+ MyModelSoftDeleted, session=self.session, deleted=False)
+
+ deleted_filter = mock_query.filter.call_args[0][0]
+ self.assertEqual(str(deleted_filter),
+ 'soft_deleted_test_model.deleted = :deleted_1')
+ self.assertEqual(deleted_filter.right.value,
+ MyModelSoftDeleted.__mapper__.c.deleted.default.arg)
+
+ def test_deleted_true(self):
+ mock_query = utils.model_query(
+ MyModelSoftDeleted, session=self.session, deleted=True)
+
+ deleted_filter = mock_query.filter.call_args[0][0]
+ self.assertEqual(str(deleted_filter),
+ 'soft_deleted_test_model.deleted != :deleted_1')
+ self.assertEqual(deleted_filter.right.value,
+ MyModelSoftDeleted.__mapper__.c.deleted.default.arg)
+
+ @mock.patch.object(utils, "_read_deleted_filter")
+ def test_no_deleted_value(self, _read_deleted_filter):
+ utils.model_query(MyModelSoftDeleted, session=self.session)
+ self.assertEqual(_read_deleted_filter.call_count, 0)
+
+ def test_project_filter(self):
+ project_id = 10
+
+ mock_query = utils.model_query(
+ MyModelSoftDeletedProjectId, session=self.session,
+ project_only=True, project_id=project_id)
+
+ deleted_filter = mock_query.filter.call_args[0][0]
+ self.assertEqual(
+ str(deleted_filter),
+ 'soft_deleted_project_id_test_model.project_id = :project_id_1')
+ self.assertEqual(deleted_filter.right.value, project_id)
+
+ def test_project_filter_wrong_model(self):
+ self.assertRaises(ValueError, utils.model_query,
+ MyModelSoftDeleted, session=self.session,
+ project_id=10)
+
+ def test_project_filter_allow_none(self):
+ mock_query = utils.model_query(
+ MyModelSoftDeletedProjectId,
+ session=self.session, project_id=(10, None))
+
+ self.assertEqual(
+ str(mock_query.filter.call_args[0][0]),
+ 'soft_deleted_project_id_test_model.project_id'
+ ' IN (:project_id_1, NULL)'
+ )
+
+ def test_model_query_common(self):
+ utils.model_query(MyModel, args=(MyModel.id,), session=self.session)
+ self.session.query.assert_called_with(MyModel.id)
+
+
+class TestUtils(db_test_base.DbTestCase):
+ def setUp(self):
+ super(TestUtils, self).setUp()
+ meta = MetaData(bind=self.engine)
+ self.test_table = Table(
+ 'test_table',
+ meta,
+ Column('a', Integer),
+ Column('b', Integer)
+ )
+ self.test_table.create()
+ self.addCleanup(meta.drop_all)
+
+ def test_index_exists(self):
+ self.assertFalse(utils.index_exists(self.engine, 'test_table',
+ 'new_index'))
+ Index('new_index', self.test_table.c.a).create(self.engine)
+ self.assertTrue(utils.index_exists(self.engine, 'test_table',
+ 'new_index'))
+
+ def test_add_index(self):
+ self.assertFalse(utils.index_exists(self.engine, 'test_table',
+ 'new_index'))
+ utils.add_index(self.engine, 'test_table', 'new_index', ('a',))
+ self.assertTrue(utils.index_exists(self.engine, 'test_table',
+ 'new_index'))
+
+ def test_add_existing_index(self):
+ Index('new_index', self.test_table.c.a).create(self.engine)
+ self.assertRaises(ValueError, utils.add_index, self.engine,
+ 'test_table', 'new_index', ('a',))
+
+ def test_drop_index(self):
+ Index('new_index', self.test_table.c.a).create(self.engine)
+ utils.drop_index(self.engine, 'test_table', 'new_index')
+ self.assertFalse(utils.index_exists(self.engine, 'test_table',
+ 'new_index'))
+
+ def test_drop_unexisting_index(self):
+ self.assertRaises(ValueError, utils.drop_index, self.engine,
+ 'test_table', 'new_index')
+
+ @mock.patch('oslo_db.sqlalchemy.utils.drop_index')
+ @mock.patch('oslo_db.sqlalchemy.utils.add_index')
+ def test_change_index_columns(self, add_index, drop_index):
+ utils.change_index_columns(self.engine, 'test_table', 'a_index',
+ ('a',))
+ utils.drop_index.assert_called_once_with(self.engine, 'test_table',
+ 'a_index')
+ utils.add_index.assert_called_once_with(self.engine, 'test_table',
+ 'a_index', ('a',))
+
+ def test_column_exists(self):
+ for col in ['a', 'b']:
+ self.assertTrue(utils.column_exists(self.engine, 'test_table',
+ col))
+ self.assertFalse(utils.column_exists(self.engine, 'test_table',
+ 'fake_column'))
+
+
+class TestUtilsMysqlOpportunistically(
+ TestUtils, db_test_base.MySQLOpportunisticTestCase):
+ pass
+
+
+class TestUtilsPostgresqlOpportunistically(
+ TestUtils, db_test_base.PostgreSQLOpportunisticTestCase):
+ pass
+
+
+class TestDialectFunctionDispatcher(test_base.BaseTestCase):
+ def _single_fixture(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = orig = utils.dispatch_for_dialect("*")(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql_mysqldb)
+ dispatcher = dispatcher.dispatch_for("postgresql")(
+ callable_fn.postgresql)
+
+ self.assertTrue(dispatcher is orig)
+
+ return dispatcher, callable_fn
+
+ def _multiple_fixture(self):
+ callable_fn = mock.Mock()
+
+ for targ in [
+ callable_fn.default,
+ callable_fn.sqlite,
+ callable_fn.mysql_mysqldb,
+ callable_fn.postgresql,
+ callable_fn.postgresql_psycopg2,
+ callable_fn.pyodbc
+ ]:
+ targ.return_value = None
+
+ dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql_mysqldb)
+ dispatcher = dispatcher.dispatch_for("postgresql+*")(
+ callable_fn.postgresql)
+ dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")(
+ callable_fn.postgresql_psycopg2)
+ dispatcher = dispatcher.dispatch_for("*+pyodbc")(
+ callable_fn.pyodbc)
+
+ self.assertTrue(dispatcher is orig)
+
+ return dispatcher, callable_fn
+
+ def test_single(self):
+
+ dispatcher, callable_fn = self._single_fixture()
+ dispatcher("sqlite://", 1)
+ dispatcher("postgresql+psycopg2://u:p@h/t", 2)
+ dispatcher("mysql://u:p@h/t", 3)
+ dispatcher("mysql+mysqlconnector://u:p@h/t", 4)
+
+ self.assertEqual(
+ [
+ mock.call.sqlite('sqlite://', 1),
+ mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2),
+ mock.call.mysql_mysqldb("mysql://u:p@h/t", 3),
+ mock.call.default("mysql+mysqlconnector://u:p@h/t", 4)
+ ],
+ callable_fn.mock_calls)
+
+ def test_single_kwarg(self):
+ dispatcher, callable_fn = self._single_fixture()
+ dispatcher("sqlite://", foo='bar')
+ dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y')
+
+ self.assertEqual(
+ [
+ mock.call.sqlite('sqlite://', foo='bar'),
+ mock.call.postgresql(
+ "postgresql+psycopg2://u:p@h/t",
+ 1, x='y'),
+ ],
+ callable_fn.mock_calls)
+
+ def test_dispatch_on_target(self):
+ callable_fn = mock.Mock()
+
+ @utils.dispatch_for_dialect("*")
+ def default_fn(url, x, y):
+ callable_fn.default(url, x, y)
+
+ @default_fn.dispatch_for("sqlite")
+ def sqlite_fn(url, x, y):
+ callable_fn.sqlite(url, x, y)
+ default_fn.dispatch_on_drivername("*")(url, x, y)
+
+ default_fn("sqlite://", 4, 5)
+ self.assertEqual(
+ [
+ mock.call.sqlite("sqlite://", 4, 5),
+ mock.call.default("sqlite://", 4, 5)
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_single_no_dispatcher(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher, "postgresql://s:t@localhost/test"
+ )
+ self.assertEqual(
+ "No default function found for driver: 'postgresql+psycopg2'",
+ str(exc)
+ )
+
+ def test_multiple_no_dispatcher(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)(
+ callable_fn.sqlite)
+ dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
+ dispatcher("postgresql://s:t@localhost/test")
+ self.assertEqual(
+ [], callable_fn.mock_calls
+ )
+
+ def test_multiple_no_driver(self):
+ callable_fn = mock.Mock(
+ default=mock.Mock(return_value=None),
+ sqlite=mock.Mock(return_value=None)
+ )
+
+ dispatcher = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(
+ callable_fn.sqlite)
+
+ dispatcher.dispatch_on_drivername("sqlite")("foo")
+ self.assertEqual(
+ [mock.call.sqlite("foo"), mock.call.default("foo")],
+ callable_fn.mock_calls
+ )
+
+ def test_multiple_nesting(self):
+ callable_fn = mock.Mock(
+ default=mock.Mock(return_value=None),
+ mysql=mock.Mock(return_value=None)
+ )
+
+ dispatcher = utils.dispatch_for_dialect("*", multiple=True)(
+ callable_fn.default)
+
+ dispatcher = dispatcher.dispatch_for("mysql+mysqlconnector")(
+ dispatcher.dispatch_for("mysql+mysqldb")(
+ callable_fn.mysql
+ )
+ )
+
+ mysqldb_url = sqlalchemy.engine.url.make_url("mysql+mysqldb://")
+ mysqlconnector_url = sqlalchemy.engine.url.make_url(
+ "mysql+mysqlconnector://")
+ sqlite_url = sqlalchemy.engine.url.make_url("sqlite://")
+
+ dispatcher(mysqldb_url, 1)
+ dispatcher(mysqlconnector_url, 2)
+ dispatcher(sqlite_url, 3)
+
+ self.assertEqual(
+ [
+ mock.call.mysql(mysqldb_url, 1),
+ mock.call.default(mysqldb_url, 1),
+ mock.call.mysql(mysqlconnector_url, 2),
+ mock.call.default(mysqlconnector_url, 2),
+ mock.call.default(sqlite_url, 3)
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_single_retval(self):
+ dispatcher, callable_fn = self._single_fixture()
+ callable_fn.mysql_mysqldb.return_value = 5
+
+ self.assertEqual(
+ dispatcher("mysql://u:p@h/t", 3), 5
+ )
+
+ def test_engine(self):
+ eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db")
+ dispatcher, callable_fn = self._single_fixture()
+
+ dispatcher(eng)
+ self.assertEqual(
+ [mock.call.sqlite(eng)],
+ callable_fn.mock_calls
+ )
+
+ def test_url(self):
+ url = sqlalchemy.engine.url.make_url(
+ "mysql+mysqldb://scott:tiger@localhost/test")
+ dispatcher, callable_fn = self._single_fixture()
+
+ dispatcher(url, 15)
+ self.assertEqual(
+ [mock.call.mysql_mysqldb(url, 15)],
+ callable_fn.mock_calls
+ )
+
+ def test_invalid_target(self):
+ dispatcher, callable_fn = self._single_fixture()
+
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher, 20
+ )
+ self.assertEqual("Invalid target type: 20", str(exc))
+
+ def test_invalid_dispatch(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
+
+ exc = self.assertRaises(
+ ValueError,
+ dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc
+ )
+ self.assertEqual(
+ "Couldn't parse database[+driver]: '+pyodbc'",
+ str(exc)
+ )
+
+ def test_single_only_one_target(self):
+ callable_fn = mock.Mock()
+
+ dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
+ dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
+
+ exc = self.assertRaises(
+ TypeError,
+ dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2
+ )
+ self.assertEqual(
+ "Multiple functions for expression 'sqlite'", str(exc)
+ )
+
+ def test_multiple(self):
+ dispatcher, callable_fn = self._multiple_fixture()
+
+ dispatcher("postgresql+pyodbc://", 1)
+ dispatcher("mysql://", 2)
+ dispatcher("ibm_db_sa+db2://", 3)
+ dispatcher("postgresql+psycopg2://", 4)
+
+ # TODO(zzzeek): there is a deterministic order here, but we might
+ # want to tweak it, or maybe provide options. default first?
+ # most specific first? is *+pyodbc or postgresql+* more specific?
+ self.assertEqual(
+ [
+ mock.call.postgresql('postgresql+pyodbc://', 1),
+ mock.call.pyodbc('postgresql+pyodbc://', 1),
+ mock.call.default('postgresql+pyodbc://', 1),
+ mock.call.mysql_mysqldb('mysql://', 2),
+ mock.call.default('mysql://', 2),
+ mock.call.default('ibm_db_sa+db2://', 3),
+ mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4),
+ mock.call.postgresql('postgresql+psycopg2://', 4),
+ mock.call.default('postgresql+psycopg2://', 4),
+ ],
+ callable_fn.mock_calls
+ )
+
+ def test_multiple_no_return_value(self):
+ dispatcher, callable_fn = self._multiple_fixture()
+ callable_fn.sqlite.return_value = 5
+
+ exc = self.assertRaises(
+ TypeError,
+ dispatcher, "sqlite://"
+ )
+ self.assertEqual(
+ "Return value not allowed for multiple filtered function",
+ str(exc)
+ )
+
+
+class TestGetInnoDBTables(db_test_base.MySQLOpportunisticTestCase):
+
+ def test_all_tables_use_innodb(self):
+ self.engine.execute("CREATE TABLE customers "
+ "(a INT, b CHAR (20), INDEX (a)) ENGINE=InnoDB")
+ self.assertEqual([], utils.get_non_innodb_tables(self.engine))
+
+ def test_all_tables_use_innodb_false(self):
+ self.engine.execute("CREATE TABLE employee "
+ "(i INT) ENGINE=MEMORY")
+ self.assertEqual(['employee'],
+ utils.get_non_innodb_tables(self.engine))
+
+ def test_skip_tables_use_default_value(self):
+ self.engine.execute("CREATE TABLE migrate_version "
+ "(i INT) ENGINE=MEMORY")
+ self.assertEqual([],
+ utils.get_non_innodb_tables(self.engine))
+
+ def test_skip_tables_use_passed_value(self):
+ self.engine.execute("CREATE TABLE some_table "
+ "(i INT) ENGINE=MEMORY")
+ self.assertEqual([],
+ utils.get_non_innodb_tables(
+ self.engine, skip_tables=('some_table',)))
+
+ def test_skip_tables_use_empty_list(self):
+ self.engine.execute("CREATE TABLE some_table_3 "
+ "(i INT) ENGINE=MEMORY")
+ self.assertEqual(['some_table_3'],
+ utils.get_non_innodb_tables(
+ self.engine, skip_tables=()))
+
+ def test_skip_tables_use_several_values(self):
+ self.engine.execute("CREATE TABLE some_table_1 "
+ "(i INT) ENGINE=MEMORY")
+ self.engine.execute("CREATE TABLE some_table_2 "
+ "(i INT) ENGINE=MEMORY")
+ self.assertEqual([],
+ utils.get_non_innodb_tables(
+ self.engine,
+ skip_tables=('some_table_1', 'some_table_2')))