diff options
-rw-r--r-- | oslo/db/sqlalchemy/utils.py | 33 | ||||
-rw-r--r-- | tests/sqlalchemy/test_utils.py | 44 |
2 files changed, 77 insertions, 0 deletions
diff --git a/oslo/db/sqlalchemy/utils.py b/oslo/db/sqlalchemy/utils.py index 5abd4ea..a4a79cd 100644 --- a/oslo/db/sqlalchemy/utils.py +++ b/oslo/db/sqlalchemy/utils.py @@ -36,6 +36,7 @@ from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy.sql.expression import literal_column from sqlalchemy.sql.expression import UpdateBase +from sqlalchemy.sql import text from sqlalchemy import String from sqlalchemy import Table from sqlalchemy.types import NullType @@ -983,3 +984,35 @@ class DialectMultiFunctionDispatcher(DialectFunctionDispatcher): "multiple filtered function") dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect + + +def get_non_innodb_tables(connectable, skip_tables=('migrate_version', + 'alembic_version')): + """Get a list of tables which don't use InnoDB storage engine. + + :param connectable: a SQLAlchemy Engine or a Connection instance + :param skip_tables: a list of tables which might have a different + storage engine + """ + + query_str = """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = :database AND + engine != 'InnoDB' + """ + + params = {} + if skip_tables: + params = dict( + ('skip_%s' % i, table_name) + for i, table_name in enumerate(skip_tables) + ) + + placeholders = ', '.join(':' + p for p in params) + query_str += ' AND table_name NOT IN (%s)' % placeholders + + params['database'] = connectable.engine.url.database + query = text(query_str) + noninnodb = connectable.execute(query, **params) + return [i[0] for i in noninnodb] diff --git a/tests/sqlalchemy/test_utils.py b/tests/sqlalchemy/test_utils.py index 88c3817..a4cd3f2 100644 --- a/tests/sqlalchemy/test_utils.py +++ b/tests/sqlalchemy/test_utils.py @@ -1094,3 +1094,47 @@ class TestDialectFunctionDispatcher(test_base.BaseTestCase): "Return value not allowed for multiple filtered function", str(exc) ) + + +class TestGetInnoDBTables(db_test_base.MySQLOpportunisticTestCase): + + def test_all_tables_use_innodb(self): + self.engine.execute("CREATE TABLE customers " + "(a INT, b CHAR (20), INDEX (a)) ENGINE=InnoDB") + self.assertEqual([], utils.get_non_innodb_tables(self.engine)) + + def test_all_tables_use_innodb_false(self): + self.engine.execute("CREATE TABLE employee " + "(i INT) ENGINE=MEMORY") + self.assertEqual(['employee'], + utils.get_non_innodb_tables(self.engine)) + + def test_skip_tables_use_default_value(self): + self.engine.execute("CREATE TABLE migrate_version " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables(self.engine)) + + def test_skip_tables_use_passed_value(self): + self.engine.execute("CREATE TABLE some_table " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables( + self.engine, skip_tables=('some_table',))) + + def test_skip_tables_use_empty_list(self): + self.engine.execute("CREATE TABLE some_table_3 " + "(i INT) ENGINE=MEMORY") + self.assertEqual(['some_table_3'], + utils.get_non_innodb_tables( + self.engine, skip_tables=())) + + def test_skip_tables_use_several_values(self): + self.engine.execute("CREATE TABLE some_table_1 " + "(i INT) ENGINE=MEMORY") + self.engine.execute("CREATE TABLE some_table_2 " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables( + self.engine, + skip_tables=('some_table_1', 'some_table_2'))) |