summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-06-24 17:35:07 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-06-24 17:35:07 -0400
commit94e2e3c6f15e6a84bedf0780e1502dfc7550bda7 (patch)
treea7fcbba2ee2b0ef5db4424cb25d91a63a3e06694
parent69588e424a8e0fa2b367851219f2ed1f634e8ba2 (diff)
downloadalembic-94e2e3c6f15e6a84bedf0780e1502dfc7550bda7.tar.gz
- get the new autogen tests entirely passing and cleaned up
-rw-r--r--alembic/autogenerate/api.py6
-rw-r--r--alembic/operations/base.py7
-rw-r--r--alembic/operations/ops.py61
-rw-r--r--alembic/operations/toimpl.py12
-rw-r--r--tests/_autogen_fixtures.py251
-rw-r--r--tests/test_autogen_composition.py160
-rw-r--r--tests/test_autogen_diffs.py240
-rw-r--r--tests/test_autogen_fks.py4
-rw-r--r--tests/test_autogen_indexes.py2
9 files changed, 328 insertions, 415 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index c4f5be6..7f0b089 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -116,11 +116,9 @@ def compare_metadata(context, metadata):
return diffs
-def _render_migration_diffs(
- context, template_args, imports, include_symbol=None,
- include_object=None, include_schemas=False):
+def _render_migration_diffs(context, template_args, imports):
- autogen_context, connection = _autogen_context(context, imports)
+ autogen_context = _autogen_context(context, imports)
diffs = []
_produce_net_changes(autogen_context, diffs)
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index 5162101..e577497 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -806,7 +806,7 @@ class Operations(object):
self.impl.create_table(table)
return table
- def drop_table(self, name, **kw):
+ def drop_table(self, name, schema=None, **kw):
"""Issue a "drop table" instruction using the current
migration context.
@@ -828,9 +828,10 @@ class Operations(object):
:class:`sqlalchemy.schema.Table` object created for the command.
"""
- self.impl.drop_table(
- self.schema_obj.table(name, **kw)
+ op = ops.DropTableOp(
+ name, schema=schema, table_kw=kw
)
+ self.invoke(op)
def create_index(self, name, table_name, columns, schema=None,
unique=False, quote=None, **kw):
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index e34ff09..3567352 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -7,44 +7,18 @@ to_impl = util.Dispatcher()
class MigrateOperation(object):
"""base class for migration command and organization objects."""
- def dispatch_for(self, handler):
- raise NotImplementedError()
-
class AddConstraintOp(MigrateOperation):
pass
class DropConstraintOp(MigrateOperation):
- pass
-
-
-class DropConstraintByNameOp(DropConstraintOp):
def __init__(self, name, table_name, type_=None, schema=None):
self.name = name
self.table_name = table_name
self.type_ = type_
self.schema = schema
- def dispatch_for(self, handler):
- return handler.drop_constraint
-
-
-class AddConstraintObjOp(AddConstraintOp):
- def __init__(self, constraint):
- self.constraint = constraint
-
- def dispatch_for(self, handler):
- return handler.add_constraint_obj
-
-
-class DropConstraintObjOp(DropConstraintOp):
- def __init__(self, constraint):
- self.constraint = constraint
-
- def dispatch_for(self, handler):
- return handler.drop_constraint_obj
-
class CreateUniqueConstraintOp(AddConstraintOp):
def __init__(self, name, local_cols, **kw):
@@ -52,9 +26,6 @@ class CreateUniqueConstraintOp(AddConstraintOp):
self.local_cols = local_cols
self.kw = kw
- def dispatch_for(self, handler):
- return handler.create_unique_constraint
-
class CreateCheckConstraintOp(AddConstraintOp):
def __init__(
@@ -65,9 +36,6 @@ class CreateCheckConstraintOp(AddConstraintOp):
self.schema = schema
self.kw = kw
- def dispatch_for(self, handler):
- return handler.create_check_constraint
-
class CreateIndexOp(MigrateOperation):
def __init__(
@@ -81,9 +49,6 @@ class CreateIndexOp(MigrateOperation):
self.quote = quote
self.kw = kw
- def dispatch_for(self, handler):
- return handler.create_index
-
class DropIndexOp(MigrateOperation):
def __init__(self, name, table_name=None, schema=None):
@@ -91,9 +56,6 @@ class DropIndexOp(MigrateOperation):
self.table_name = table_name
self.schema = schema
- def dispatch_for(self, handler):
- return handler.drop_index
-
class CreateTableOp(MigrateOperation):
def __init__(self, name, *columns, **kw):
@@ -101,17 +63,12 @@ class CreateTableOp(MigrateOperation):
self.columns = columns
self.kw = kw
- def dispatch_for(self, handler):
- return handler.create_table
-
class DropTableOp(MigrateOperation):
- def __init__(self, name, **kw):
+ def __init__(self, name, schema=None, table_kw=None):
self.name = name
- self.kw = kw
-
- def dispatch_for(self, handler):
- return handler.drop_table
+ self.schema = schema
+ self.table_kw = table_kw or {}
class AlterTableOp(MigrateOperation):
@@ -120,9 +77,6 @@ class AlterTableOp(MigrateOperation):
self.table_name = table_name
self.schema = schema
- def dispatch_for(self, handler):
- return handler.alter_table
-
class RenameTableOp(AlterTableOp):
@@ -130,9 +84,6 @@ class RenameTableOp(AlterTableOp):
super(RenameTableOp, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
- def dispatch_for(self, handler):
- return handler.rename_table
-
class AlterColumnOp(AlterTableOp):
@@ -154,9 +105,6 @@ class AlterColumnOp(AlterTableOp):
modify_type = None
kw = None
- def dispatch_for(self, handler):
- return handler.alter_column
-
class AddColumnOp(AlterTableOp):
@@ -178,9 +126,6 @@ class BulkInsertOp(MigrateOperation):
self.rows = rows
self.multiinsert = multiinsert
- def dispatch_for(self, handler):
- return handler.bulk_insert
-
class OpContainer(MigrateOperation):
def __init__(self, ops):
diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py
index b3688ff..ed50f30 100644
--- a/alembic/operations/toimpl.py
+++ b/alembic/operations/toimpl.py
@@ -2,7 +2,7 @@ from . import ops
from sqlalchemy import schema as sa_schema
-@ops.to_impl.dispatch_for(ops.AlterColumnOp, 'default')
+@ops.to_impl.dispatch_for(ops.AlterColumnOp)
def alter_column(operations, operation):
compiler = operations.impl.dialect.statement_compiler(
@@ -60,3 +60,13 @@ def alter_column(operations, operation):
for constraint in t.constraints:
if _count_constraint(constraint):
operations.impl.add_constraint(constraint)
+
+
+@ops.to_impl.dispatch_for(ops.DropTableOp)
+def drop_table(operations, operation):
+ operations.impl.drop_table(
+ operations.schema_obj.table(
+ operation.name,
+ schema=operation.schema,
+ **operation.table_kw)
+ )
diff --git a/tests/_autogen_fixtures.py b/tests/_autogen_fixtures.py
new file mode 100644
index 0000000..7ef6cbf
--- /dev/null
+++ b/tests/_autogen_fixtures.py
@@ -0,0 +1,251 @@
+from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
+ Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text
+from sqlalchemy.engine.reflection import Inspector
+
+from alembic import autogenerate
+from alembic.migration import MigrationContext
+from alembic.testing import config
+from alembic.testing.env import staging_env, clear_staging_env
+from alembic.testing import eq_
+from alembic.ddl.base import _fk_spec
+
+names_in_this_test = set()
+
+from sqlalchemy import event
+
+
+@event.listens_for(Table, "after_parent_attach")
+def new_table(table, parent):
+ names_in_this_test.add(table.name)
+
+
+def _default_include_object(obj, name, type_, reflected, compare_to):
+ if type_ == "table":
+ return name in names_in_this_test
+ else:
+ return True
+
+_default_object_filters = [
+ _default_include_object
+]
+
+
+class ModelOne(object):
+ __requires__ = ('unique_constraint_reflection', )
+
+ schema = None
+
+ @classmethod
+ def _get_db_schema(cls):
+ schema = cls.schema
+
+ m = MetaData(schema=schema)
+
+ Table('user', m,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('a1', Text),
+ Column("pw", String(50)),
+ Index('pw_idx', 'pw')
+ )
+
+ Table('address', m,
+ Column('id', Integer, primary_key=True),
+ Column('email_address', String(100), nullable=False),
+ )
+
+ Table('order', m,
+ Column('order_id', Integer, primary_key=True),
+ Column("amount", Numeric(8, 2), nullable=False,
+ server_default=text("0")),
+ CheckConstraint('amount >= 0', name='ck_order_amount')
+ )
+
+ Table('extra', m,
+ Column("x", CHAR),
+ Column('uid', Integer, ForeignKey('user.id'))
+ )
+
+ return m
+
+ @classmethod
+ def _get_model_schema(cls):
+ schema = cls.schema
+
+ m = MetaData(schema=schema)
+
+ Table('user', m,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50), nullable=False),
+ Column('a1', Text, server_default="x")
+ )
+
+ Table('address', m,
+ Column('id', Integer, primary_key=True),
+ Column('email_address', String(100), nullable=False),
+ Column('street', String(50)),
+ UniqueConstraint('email_address', name="uq_email")
+ )
+
+ Table('order', m,
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True,
+ server_default=text("0")),
+ Column('user_id', Integer, ForeignKey('user.id')),
+ CheckConstraint('amount > -1', name='ck_order_amount'),
+ )
+
+ Table('item', m,
+ Column('id', Integer, primary_key=True),
+ Column('description', String(100)),
+ Column('order_id', Integer, ForeignKey('order.order_id')),
+ CheckConstraint('len(description) > 5')
+ )
+ return m
+
+
+class _ComparesFKs(object):
+ def _assert_fk_diff(
+ self, diff, type_, source_table, source_columns,
+ target_table, target_columns, name=None, conditional_name=None,
+ source_schema=None):
+ # the public API for ForeignKeyConstraint was not very rich
+ # in 0.7, 0.8, so here we use the well-known but slightly
+ # private API to get at its elements
+ (fk_source_schema, fk_source_table,
+ fk_source_columns, fk_target_schema, fk_target_table,
+ fk_target_columns) = _fk_spec(diff[1])
+
+ eq_(diff[0], type_)
+ eq_(fk_source_table, source_table)
+ eq_(fk_source_columns, source_columns)
+ eq_(fk_target_table, target_table)
+ eq_(fk_source_schema, source_schema)
+
+ eq_([elem.column.name for elem in diff[1].elements],
+ target_columns)
+ if conditional_name is not None:
+ if config.requirements.no_fk_names.enabled:
+ eq_(diff[1].name, None)
+ elif conditional_name == 'servergenerated':
+ fks = Inspector.from_engine(self.bind).\
+ get_foreign_keys(source_table)
+ server_fk_name = fks[0]['name']
+ eq_(diff[1].name, server_fk_name)
+ else:
+ eq_(diff[1].name, conditional_name)
+ else:
+ eq_(diff[1].name, name)
+
+
+class AutogenTest(_ComparesFKs):
+
+ def _flatten_diffs(self, diffs):
+ for d in diffs:
+ if isinstance(d, list):
+ for fd in self._flatten_diffs(d):
+ yield fd
+ else:
+ yield d
+
+ @classmethod
+ def _get_bind(cls):
+ return config.db
+
+ configure_opts = {}
+
+ @classmethod
+ def setup_class(cls):
+ staging_env()
+ cls.bind = cls._get_bind()
+ cls.m1 = cls._get_db_schema()
+ cls.m1.create_all(cls.bind)
+ cls.m2 = cls._get_model_schema()
+
+ @classmethod
+ def teardown_class(cls):
+ cls.m1.drop_all(cls.bind)
+ clear_staging_env()
+
+ def setUp(self):
+ self.conn = conn = self.bind.connect()
+ ctx_opts = {
+ 'compare_type': True,
+ 'compare_server_default': True,
+ 'target_metadata': self.m2,
+ 'upgrade_token': "upgrades",
+ 'downgrade_token': "downgrades",
+ 'alembic_module_prefix': 'op.',
+ 'sqlalchemy_module_prefix': 'sa.',
+ }
+ if self.configure_opts:
+ ctx_opts.update(self.configure_opts)
+ self.context = context = MigrationContext.configure(
+ connection=conn,
+ opts=ctx_opts
+ )
+
+ connection = context.bind
+ self.autogen_context = {
+ 'imports': set(),
+ 'connection': connection,
+ 'dialect': connection.dialect,
+ 'context': context
+ }
+
+ def tearDown(self):
+ self.conn.close()
+
+
+class AutogenFixtureTest(_ComparesFKs):
+
+ def _fixture(
+ self, m1, m2, include_schemas=False,
+ opts=None, object_filters=_default_object_filters):
+ self.metadata, model_metadata = m1, m2
+ self.metadata.create_all(self.bind)
+
+ with self.bind.connect() as conn:
+ ctx_opts = {
+ 'compare_type': True,
+ 'compare_server_default': True,
+ 'target_metadata': model_metadata,
+ 'upgrade_token': "upgrades",
+ 'downgrade_token': "downgrades",
+ 'alembic_module_prefix': 'op.',
+ 'sqlalchemy_module_prefix': 'sa.',
+ }
+ if opts:
+ ctx_opts.update(opts)
+ self.context = context = MigrationContext.configure(
+ connection=conn,
+ opts=ctx_opts
+ )
+
+ connection = context.bind
+ autogen_context = {
+ 'imports': set(),
+ 'connection': connection,
+ 'dialect': connection.dialect,
+ 'context': context,
+ 'metadata': model_metadata,
+ 'object_filters': object_filters,
+ 'include_schemas': include_schemas
+ }
+ diffs = []
+ autogenerate._produce_net_changes(
+ autogen_context, diffs
+ )
+ return diffs
+
+ reports_unnamed_constraints = False
+
+ def setUp(self):
+ staging_env()
+ self.bind = config.db
+
+ def tearDown(self):
+ if hasattr(self, 'metadata'):
+ self.metadata.drop_all(self.bind)
+ clear_staging_env()
+
diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py
index 7806049..b1717ab 100644
--- a/tests/test_autogen_composition.py
+++ b/tests/test_autogen_composition.py
@@ -1,140 +1,11 @@
import re
-from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
- Numeric, CHAR, ForeignKey, INTEGER, Index, UniqueConstraint, \
- TypeDecorator, CheckConstraint, text, PrimaryKeyConstraint
-
from alembic import autogenerate
from alembic.migration import MigrationContext
from alembic.testing import TestBase
-from alembic.testing import config
-from alembic.testing.env import staging_env, clear_staging_env
from alembic.testing import eq_
-
-class ModelOne(object):
- __requires__ = ('unique_constraint_reflection', )
-
- schema = None
-
- @classmethod
- def _get_db_schema(cls):
- schema = cls.schema
-
- m = MetaData(schema=schema)
-
- Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50)),
- Column('a1', Text),
- Column("pw", String(50)),
- Index('pw_idx', 'pw')
- )
-
- Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- )
-
- Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column("amount", Numeric(8, 2), nullable=False,
- server_default=text("0")),
- CheckConstraint('amount >= 0', name='ck_order_amount')
- )
-
- Table('extra', m,
- Column("x", CHAR),
- Column('uid', Integer, ForeignKey('user.id'))
- )
-
- return m
-
- @classmethod
- def _get_model_schema(cls):
- schema = cls.schema
-
- m = MetaData(schema=schema)
-
- Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', Text, server_default="x")
- )
-
- Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('street', String(50)),
- UniqueConstraint('email_address', name="uq_email")
- )
-
- Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True,
- server_default=text("0")),
- Column('user_id', Integer, ForeignKey('user.id')),
- CheckConstraint('amount > -1', name='ck_order_amount'),
- )
-
- Table('item', m,
- Column('id', Integer, primary_key=True),
- Column('description', String(100)),
- Column('order_id', Integer, ForeignKey('order.order_id')),
- CheckConstraint('len(description) > 5')
- )
- return m
-
-
-class AutogenTest(object):
-
- @classmethod
- def _get_bind(cls):
- return config.db
-
- configure_opts = {}
-
- @classmethod
- def setup_class(cls):
- staging_env()
- cls.bind = cls._get_bind()
- cls.m1 = cls._get_db_schema()
- cls.m1.create_all(cls.bind)
- cls.m2 = cls._get_model_schema()
-
- @classmethod
- def teardown_class(cls):
- cls.m1.drop_all(cls.bind)
- clear_staging_env()
-
- def setUp(self):
- self.conn = conn = self.bind.connect()
- ctx_opts = {
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m2,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- }
- if self.configure_opts:
- ctx_opts.update(self.configure_opts)
- self.context = context = MigrationContext.configure(
- connection=conn,
- opts=ctx_opts
- )
-
- connection = context.bind
- self.autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context
- }
-
- def tearDown(self):
- self.conn.close()
+from ._autogen_fixtures import AutogenTest, ModelOne, _default_include_object
class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
@@ -152,7 +23,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._produce_migration_diffs(context, template_args, set())
+ autogenerate._render_migration_diffs(context, template_args, set())
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -174,13 +45,14 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
'downgrade_token': "downgrades",
'alembic_module_prefix': 'op.',
'sqlalchemy_module_prefix': 'sa.',
- 'render_as_batch': True
+ 'render_as_batch': True,
+ 'include_symbol': lambda name, schema: False
}
)
template_args = {}
- autogenerate._produce_migration_diffs(
+ autogenerate._render_migration_diffs(
context, template_args, set(),
- include_symbol=lambda name, schema: False
+
)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -195,7 +67,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
"""test a full render including indentation"""
template_args = {}
- autogenerate._produce_migration_diffs(
+ autogenerate._render_migration_diffs(
self.context, template_args, set())
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -263,7 +135,7 @@ nullable=True))
template_args = {}
self.context.opts['render_as_batch'] = True
- autogenerate._produce_migration_diffs(
+ autogenerate._render_migration_diffs(
self.context, template_args, set())
eq_(re.sub(r"u'", "'", template_args['upgrades']),
@@ -338,7 +210,6 @@ nullable=True))
### end Alembic commands ###""")
-
class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
__only_on__ = 'postgresql'
schema = "test_schema"
@@ -354,12 +225,13 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
'downgrade_token': "downgrades",
'alembic_module_prefix': 'op.',
'sqlalchemy_module_prefix': 'sa.',
+ 'include_symbol': lambda name, schema: False
}
)
template_args = {}
- autogenerate._produce_migration_diffs(
+ autogenerate._render_migration_diffs(
context, template_args, set(),
- include_symbol=lambda name, schema: False
+
)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -374,10 +246,12 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
"""test a full render including indentation (include and schema)"""
template_args = {}
- autogenerate._produce_migration_diffs(
- self.context, template_args, set(),
- include_object=_default_include_object,
- include_schemas=True
+ self.context.opts.update({
+ 'include_object': _default_include_object,
+ 'include_schemas': True
+ })
+ autogenerate._render_migration_diffs(
+ self.context, template_args, set()
)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py
index 467c624..f32fd84 100644
--- a/tests/test_autogen_diffs.py
+++ b/tests/test_autogen_diffs.py
@@ -12,178 +12,13 @@ from alembic.testing import TestBase
from alembic.testing import config
from alembic.testing import assert_raises_message
from alembic.testing.mock import Mock
-from alembic.testing.env import staging_env, clear_staging_env
from alembic.testing import eq_
-from alembic.ddl.base import _fk_spec
from alembic.util import CommandError
+from ._autogen_fixtures import \
+ AutogenTest, AutogenFixtureTest, _default_object_filters
py3k = sys.version_info >= (3, )
-names_in_this_test = set()
-
-
-def _default_include_object(obj, name, type_, reflected, compare_to):
- if type_ == "table":
- return name in names_in_this_test
- else:
- return True
-
-_default_object_filters = [
- _default_include_object
-]
-from sqlalchemy import event
-
-
-@event.listens_for(Table, "after_parent_attach")
-def new_table(table, parent):
- names_in_this_test.add(table.name)
-
-
-class _ComparesFKs(object):
- def _assert_fk_diff(
- self, diff, type_, source_table, source_columns,
- target_table, target_columns, name=None, conditional_name=None,
- source_schema=None):
- # the public API for ForeignKeyConstraint was not very rich
- # in 0.7, 0.8, so here we use the well-known but slightly
- # private API to get at its elements
- (fk_source_schema, fk_source_table,
- fk_source_columns, fk_target_schema, fk_target_table,
- fk_target_columns) = _fk_spec(diff[1])
-
- eq_(diff[0], type_)
- eq_(fk_source_table, source_table)
- eq_(fk_source_columns, source_columns)
- eq_(fk_target_table, target_table)
- eq_(fk_source_schema, source_schema)
-
- eq_([elem.column.name for elem in diff[1].elements],
- target_columns)
- if conditional_name is not None:
- if config.requirements.no_fk_names.enabled:
- eq_(diff[1].name, None)
- elif conditional_name == 'servergenerated':
- fks = Inspector.from_engine(self.bind).\
- get_foreign_keys(source_table)
- server_fk_name = fks[0]['name']
- eq_(diff[1].name, server_fk_name)
- else:
- eq_(diff[1].name, conditional_name)
- else:
- eq_(diff[1].name, name)
-
-
-class AutogenTest(_ComparesFKs):
-
- def _flatten_diffs(self, diffs):
- for d in diffs:
- if isinstance(d, list):
- for fd in self._flatten_diffs(d):
- yield fd
- else:
- yield d
-
- @classmethod
- def _get_bind(cls):
- return config.db
-
- configure_opts = {}
-
- @classmethod
- def setup_class(cls):
- staging_env()
- cls.bind = cls._get_bind()
- cls.m1 = cls._get_db_schema()
- cls.m1.create_all(cls.bind)
- cls.m2 = cls._get_model_schema()
-
- @classmethod
- def teardown_class(cls):
- cls.m1.drop_all(cls.bind)
- clear_staging_env()
-
- def setUp(self):
- self.conn = conn = self.bind.connect()
- ctx_opts = {
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m2,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- }
- if self.configure_opts:
- ctx_opts.update(self.configure_opts)
- self.context = context = MigrationContext.configure(
- connection=conn,
- opts=ctx_opts
- )
-
- connection = context.bind
- self.autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context
- }
-
- def tearDown(self):
- self.conn.close()
-
-
-class AutogenFixtureTest(_ComparesFKs):
-
- def _fixture(
- self, m1, m2, include_schemas=False,
- opts=None, object_filters=_default_object_filters):
- self.metadata, model_metadata = m1, m2
- self.metadata.create_all(self.bind)
-
- with self.bind.connect() as conn:
- ctx_opts = {
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': model_metadata,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- }
- if opts:
- ctx_opts.update(opts)
- self.context = context = MigrationContext.configure(
- connection=conn,
- opts=ctx_opts
- )
-
- connection = context.bind
- autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context,
- 'metadata': model_metadata,
- 'object_filters': object_filters,
- 'include_schemas': include_schemas
- }
- diffs = []
- autogenerate._produce_net_changes(
- autogen_context, diffs
- )
- return diffs
-
- reports_unnamed_constraints = False
-
- def setUp(self):
- staging_env()
- self.bind = config.db
-
- def tearDown(self):
- if hasattr(self, 'metadata'):
- self.metadata.drop_all(self.bind)
- clear_staging_env()
-
class AutogenCrossSchemaTest(AutogenTest, TestBase):
__only_on__ = 'postgresql'
@@ -228,8 +63,6 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return m
def test_default_schema_omitted_upgrade(self):
- metadata = self.m2
- connection = self.context.bind
diffs = []
def include_object(obj, name, type_, reflected, compare_to):
@@ -237,17 +70,17 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return name == "t3"
else:
return True
- autogenerate._produce_net_changes(connection, metadata, diffs,
- self.autogen_context,
- object_filters=[include_object],
- include_schemas=True
- )
+ self.autogen_context.update({
+ 'object_filters': [include_object],
+ 'include_schemas': True,
+ 'metadata': self.m2
+ })
+ autogenerate._produce_net_changes(self.autogen_context, diffs)
+
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].schema, None)
def test_alt_schema_included_upgrade(self):
- metadata = self.m2
- connection = self.context.bind
diffs = []
def include_object(obj, name, type_, reflected, compare_to):
@@ -255,17 +88,18 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return name == "t4"
else:
return True
- autogenerate._produce_net_changes(connection, metadata, diffs,
- self.autogen_context,
- object_filters=[include_object],
- include_schemas=True
- )
+
+ self.autogen_context.update({
+ 'object_filters': [include_object],
+ 'include_schemas': True,
+ 'metadata': self.m2
+ })
+ autogenerate._produce_net_changes(self.autogen_context, diffs)
+
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].schema, config.test_schema)
def test_default_schema_omitted_downgrade(self):
- metadata = self.m2
- connection = self.context.bind
diffs = []
def include_object(obj, name, type_, reflected, compare_to):
@@ -273,17 +107,17 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return name == "t1"
else:
return True
- autogenerate._produce_net_changes(connection, metadata, diffs,
- self.autogen_context,
- object_filters=[include_object],
- include_schemas=True
- )
+ self.autogen_context.update({
+ 'object_filters': [include_object],
+ 'include_schemas': True,
+ 'metadata': self.m2
+ })
+ autogenerate._produce_net_changes(self.autogen_context, diffs)
+
eq_(diffs[0][0], "remove_table")
eq_(diffs[0][1].schema, None)
def test_alt_schema_included_downgrade(self):
- metadata = self.m2
- connection = self.context.bind
diffs = []
def include_object(obj, name, type_, reflected, compare_to):
@@ -291,11 +125,12 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return name == "t2"
else:
return True
- autogenerate._produce_net_changes(connection, metadata, diffs,
- self.autogen_context,
- object_filters=[include_object],
- include_schemas=True
- )
+ self.autogen_context.update({
+ 'object_filters': [include_object],
+ 'include_schemas': True,
+ 'metadata': self.m2
+ })
+ autogenerate._produce_net_changes(self.autogen_context, diffs)
eq_(diffs[0][0], "remove_table")
eq_(diffs[0][1].schema, config.test_schema)
@@ -646,14 +481,14 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
"""test generation of diff rules"""
metadata = self.m2
- connection = self.context.bind
diffs = []
- autogenerate._produce_net_changes(
- connection, metadata, diffs,
- self.autogen_context,
- object_filters=_default_object_filters,
- include_schemas=True
- )
+
+ self.autogen_context.update({
+ 'object_filters': _default_object_filters,
+ 'include_schemas': True,
+ 'metadata': self.m2
+ })
+ autogenerate._produce_net_changes(self.autogen_context, diffs)
eq_(
diffs[0],
@@ -707,7 +542,6 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
eq_(diffs[10][3].name, 'pw')
-
class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
__only_on__ = 'sqlite'
diff --git a/tests/test_autogen_fks.py b/tests/test_autogen_fks.py
index 90d25c4..525bed5 100644
--- a/tests/test_autogen_fks.py
+++ b/tests/test_autogen_fks.py
@@ -1,5 +1,5 @@
import sys
-from alembic.testing import TestBase, config
+from alembic.testing import TestBase
from sqlalchemy import MetaData, Column, Table, Integer, String, \
ForeignKeyConstraint
@@ -7,7 +7,7 @@ from alembic.testing import eq_
py3k = sys.version_info >= (3, )
-from .test_autogenerate import AutogenFixtureTest
+from ._autogen_fixtures import AutogenFixtureTest
class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py
index 1f92649..8ee33bc 100644
--- a/tests/test_autogen_indexes.py
+++ b/tests/test_autogen_indexes.py
@@ -12,7 +12,7 @@ from alembic.testing.env import staging_env
py3k = sys.version_info >= (3, )
-from .test_autogenerate import AutogenFixtureTest
+from ._autogen_fixtures import AutogenFixtureTest
class NoUqReflection(object):