From 404de362a3cf8fb88e827fe0b4bf6b2d5baea405 Mon Sep 17 00:00:00 2001 From: Roman Podoliaka Date: Tue, 13 May 2014 12:45:48 +0300 Subject: Add a base test case for DB schema comparison Add a base test case that allows to check if the DB schema obtained by applying of migration scripts is equal to the one produced from models definitions. It's very important that those two stay in sync to be able to use models definitions for generation of the initial DB schema. Though, due to the using of sqlalchemy-migrate (since we had to write migration scripts manually and alembic can generate them automatically) we have many divergences in projects right now, which must be detected and fixed. Fortunately, we can now rely on alembic + sqlalchemy implementation of schema comparison instead of writing our own tools. Blueprint: sync-models-with-migrations Change-Id: I2d2c35987426dacb1f566569d23a80eee3575a58 --- oslo/db/sqlalchemy/test_migrations.py | 181 ++++++++++++++++++++++++++++++++++ tests/sqlalchemy/test_migrations.py | 108 +++++++++++++++++++- 2 files changed, 287 insertions(+), 2 deletions(-) diff --git a/oslo/db/sqlalchemy/test_migrations.py b/oslo/db/sqlalchemy/test_migrations.py index 5972d03..564564d 100644 --- a/oslo/db/sqlalchemy/test_migrations.py +++ b/oslo/db/sqlalchemy/test_migrations.py @@ -14,17 +14,28 @@ # License for the specific language governing permissions and limitations # under the License. +import abc import functools import logging import os +import pprint import subprocess +import alembic +import alembic.autogenerate +import alembic.migration import lockfile from oslotest import base as test_base +import pkg_resources as pkg +import six from six import moves from six.moves.urllib import parse import sqlalchemy +from sqlalchemy.engine import reflection import sqlalchemy.exc +from sqlalchemy import schema +import sqlalchemy.sql.expression as expr +import sqlalchemy.types as types from oslo.db.openstack.common.gettextutils import _LE from oslo.db.sqlalchemy import utils @@ -268,3 +279,173 @@ class WalkVersionsMixin(object): "engine %(engine)s"), {'version': version, 'engine': engine}) raise + + +@six.add_metaclass(abc.ABCMeta) +class ModelsMigrationsSync(object): + """A helper class for comparison of DB migration scripts and models. + + It's intended to be inherited by test cases in target projects. They have + to provide implementations for methods used internally in the test (as + we have no way to implement them here). + + test_model_sync() will run migration scripts for the engine provided and + then compare the given metadata to the one reflected from the database. + The difference between MODELS and MIGRATION scripts will be printed and + the test will fail, if the difference is not empty. + + Method include_object() can be overridden to exclude some tables from + comparison (e.g. migrate_repo). + + """ + + @abc.abstractmethod + def db_sync(self, engine): + """Run migration scripts with the given engine instance. + + This method must be implemented in subclasses and run migration scripts + for a DB the given engine is connected to. + + """ + + @abc.abstractmethod + def get_engine(self): + """Return the engine instance to be used when running tests. + + This method must be implemented in subclasses and return an engine + instance to be used when running tests. + + """ + + @abc.abstractmethod + def get_metadata(self): + """Return the metadata instance to be used for schema comparison. + + This method must be implemented in subclasses and return the metadata + instance attached to the BASE model. + + """ + + def include_object(self, object_, name, type_, reflected, compare_to): + """Return True for objects that should be compared. + + :param object_: a SchemaItem object such as a Table or Column object + :param name: the name of the object + :param type_: a string describing the type of object (e.g. "table") + :param reflected: True if the given object was produced based on + table reflection, False if it's from a local + MetaData object + :param compare_to: the object being compared against, if available, + else None + + """ + + return True + + def compare_type(self, ctxt, insp_col, meta_col, insp_type, meta_type): + """Return True if types are different, False if not. + + Return None to allow the default implementation to compare these types. + + :param ctxt: alembic MigrationContext instance + :param insp_col: reflected column + :param meta_col: column from model + :param insp_type: reflected column type + :param meta_type: column type from model + + """ + + # some backends (e.g. mysql) don't provide native boolean type + BOOLEAN_METADATA = (types.BOOLEAN, types.Boolean) + BOOLEAN_SQL = BOOLEAN_METADATA + (types.INTEGER, types.Integer) + + if issubclass(type(meta_type), BOOLEAN_METADATA): + return not issubclass(type(insp_type), BOOLEAN_SQL) + + return None # tells alembic to use the default comparison method + + def compare_server_default(self, ctxt, ins_col, meta_col, + insp_def, meta_def, rendered_meta_def): + """Compare default values between model and db table. + + Return True if the defaults are different, False if not, or None to + allow the default implementation to compare these defaults. + + :param ctxt: alembic MigrationContext instance + :param insp_col: reflected column + :param meta_col: column from model + :param insp_def: reflected column default value + :param meta_def: column default value from model + :param rendered_meta_def: rendered column default value (from model) + + """ + + if (ctxt.dialect.name == 'mysql' and + issubclass(type(meta_col.type), sqlalchemy.Boolean)): + + if meta_def is None or insp_def is None: + return meta_def != insp_def + + return not ( + isinstance(meta_def.arg, expr.True_) and insp_def == "'1'" or + isinstance(meta_def.arg, expr.False_) and insp_def == "'0'" + ) + + return None # tells alembic to use the default comparison method + + def _cleanup(self): + engine = self.get_engine() + with engine.begin() as conn: + inspector = reflection.Inspector.from_engine(engine) + metadata = schema.MetaData() + tbs = [] + all_fks = [] + + for table_name in inspector.get_table_names(): + fks = [] + for fk in inspector.get_foreign_keys(table_name): + if not fk['name']: + continue + fks.append( + schema.ForeignKeyConstraint((), (), name=fk['name']) + ) + table = schema.Table(table_name, metadata, *fks) + tbs.append(table) + all_fks.extend(fks) + + for fkc in all_fks: + conn.execute(schema.DropConstraint(fkc)) + + for table in tbs: + conn.execute(schema.DropTable(table)) + + def test_models_sync(self): + # recent versions of sqlalchemy and alembic are needed for running of + # this test, but we already have them in requirements + try: + pkg.require('sqlalchemy>=0.8.4', 'alembic>=0.6.2') + except (pkg.VersionConflict, pkg.DistributionNotFound) as e: + self.skipTest('sqlalchemy>=0.8.4 and alembic>=0.6.3 are required' + ' for running of this test: %s' % e) + + # drop all tables after a test run + self.addCleanup(self._cleanup) + + # run migration scripts + self.db_sync(self.get_engine()) + + with self.get_engine().connect() as conn: + opts = { + 'include_object': self.include_object, + 'compare_type': self.compare_type, + 'compare_server_default': self.compare_server_default, + } + mc = alembic.migration.MigrationContext.configure(conn, opts=opts) + + # compare schemas and fail with diff, if it's not empty + diff = alembic.autogenerate.compare_metadata(mc, + self.get_metadata()) + if diff: + msg = pprint.pformat(diff, indent=2, width=20) + self.fail( + "Models and migration scripts aren't in sync:\n%s" % msg) diff --git a/tests/sqlalchemy/test_migrations.py b/tests/sqlalchemy/test_migrations.py index f39bb73..6f0d747 100644 --- a/tests/sqlalchemy/test_migrations.py +++ b/tests/sqlalchemy/test_migrations.py @@ -15,12 +15,16 @@ # under the License. import mock -from oslotest import base as test_base +from oslotest import base as test +import six +import sqlalchemy as sa +import sqlalchemy.ext.declarative as sa_decl +from oslo.db.sqlalchemy import test_base from oslo.db.sqlalchemy import test_migrations as migrate -class TestWalkVersions(test_base.BaseTestCase, migrate.WalkVersionsMixin): +class TestWalkVersions(test.BaseTestCase, migrate.WalkVersionsMixin): def setUp(self): super(TestWalkVersions, self).setUp() self.migration_api = mock.MagicMock() @@ -152,3 +156,103 @@ class TestWalkVersions(test_base.BaseTestCase, migrate.WalkVersionsMixin): mock.call(self.engine, v, with_data=True) for v in versions ] self.assertEqual(upgraded, self._migrate_up.call_args_list) + + +class ModelsMigrationSyncMixin(test.BaseTestCase): + + def setUp(self): + super(ModelsMigrationSyncMixin, self).setUp() + + self.metadata = sa.MetaData() + self.metadata_migrations = sa.MetaData() + + sa.Table( + 'testtbl', self.metadata_migrations, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('spam', sa.String(10), nullable=False), + sa.Column('eggs', sa.DateTime), + sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.true()), + sa.Column('bool_wo_default', sa.Boolean), + sa.Column('bar', sa.Numeric(10, 5)), + sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'), + ) + + BASE = sa_decl.declarative_base(metadata=self.metadata) + + class TestModel(BASE): + __tablename__ = 'testtbl' + __table_args__ = ( + sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'), + ) + + id = sa.Column('id', sa.Integer, primary_key=True) + spam = sa.Column('spam', sa.String(10), nullable=False) + eggs = sa.Column('eggs', sa.DateTime) + foo = sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.true()) + bool_wo_default = sa.Column('bool_wo_default', sa.Boolean) + bar = sa.Column('bar', sa.Numeric(10, 5)) + + class ModelThatShouldNotBeCompared(BASE): + __tablename__ = 'testtbl2' + + id = sa.Column('id', sa.Integer, primary_key=True) + spam = sa.Column('spam', sa.String(10), nullable=False) + + def get_metadata(self): + return self.metadata + + def get_engine(self): + return self.engine + + def db_sync(self, engine): + self.metadata_migrations.create_all(bind=engine) + + def include_object(self, object_, name, type_, reflected, compare_to): + return type_ == 'table' and name == 'testtbl' or type_ == 'column' + + def _test_models_not_sync(self): + self.metadata_migrations.clear() + sa.Table( + 'testtbl', self.metadata_migrations, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('spam', sa.String(8), nullable=True), + sa.Column('eggs', sa.DateTime), + sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.false()), + sa.Column('bool_wo_default', sa.Boolean, unique=True), + sa.Column('bar', sa.BigInteger), + sa.UniqueConstraint('spam', 'foo', name='uniq_cons'), + ) + + msg = six.text_type(self.assertRaises(AssertionError, + self.test_models_sync)) + # NOTE(I159): Check mentioning of the table and columns. + # The log is invalid json, so we can't parse it and check it for + # full compliance. We have no guarantee of the log items ordering, + # so we can't use regexp. + self.assertTrue(msg.startswith( + 'Models and migration scripts aren\'t in sync:')) + self.assertIn('testtbl', msg) + self.assertIn('spam', msg) + self.assertIn('eggs', msg) + self.assertIn('foo', msg) + self.assertIn('bar', msg) + self.assertIn('bool_wo_default', msg) + + +class ModelsMigrationsSyncMysql(ModelsMigrationSyncMixin, + migrate.ModelsMigrationsSync, + test_base.MySQLOpportunisticTestCase): + + def test_models_not_sync(self): + self._test_models_not_sync() + + +class ModelsMigrationsSyncPsql(ModelsMigrationSyncMixin, + migrate.ModelsMigrationsSync, + test_base.PostgreSQLOpportunisticTestCase): + + def test_models_not_sync(self): + self._test_models_not_sync() -- cgit v1.2.1