diff options
-rw-r--r-- | alembic/operations/ops.py | 36 | ||||
-rw-r--r-- | alembic/testing/__init__.py | 2 | ||||
-rw-r--r-- | alembic/testing/assertions.py | 6 | ||||
-rw-r--r-- | tests/test_autogen_diffs.py | 135 |
4 files changed, 169 insertions, 10 deletions
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 314b49b..8735931 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -179,11 +179,12 @@ class CreatePrimaryKeyOp(AddConstraintOp): def __init__( self, constraint_name, table_name, columns, - schema=None, **kw): + schema=None, _orig_constraint=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns self.schema = schema + self._orig_constraint = _orig_constraint self.kw = kw @classmethod @@ -193,11 +194,15 @@ class CreatePrimaryKeyOp(AddConstraintOp): return cls( constraint.name, constraint_table.name, + constraint.columns, schema=constraint_table.schema, - *constraint.columns + _orig_constraint=constraint ) def to_constraint(self, migration_context=None): + if self._orig_constraint is not None: + return self._orig_constraint + schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.primary_key_constraint( self.constraint_name, self.table_name, @@ -289,11 +294,12 @@ class CreateUniqueConstraintOp(AddConstraintOp): def __init__( self, constraint_name, table_name, - columns, schema=None, **kw): + columns, schema=None, _orig_constraint=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns self.schema = schema + self._orig_constraint = _orig_constraint self.kw = kw @classmethod @@ -311,10 +317,14 @@ class CreateUniqueConstraintOp(AddConstraintOp): constraint_table.name, [c.name for c in constraint.columns], schema=constraint_table.schema, + _orig_constraint=constraint, **kw ) def to_constraint(self, migration_context=None): + if self._orig_constraint is not None: + return self._orig_constraint + schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.unique_constraint( self.constraint_name, self.table_name, self.columns, @@ -421,12 +431,13 @@ class CreateForeignKeyOp(AddConstraintOp): def __init__( self, constraint_name, source_table, referent_table, local_cols, - remote_cols, **kw): + remote_cols, _orig_constraint=None, **kw): self.constraint_name = constraint_name self.source_table = source_table self.referent_table = referent_table self.local_cols = local_cols self.remote_cols = remote_cols + self._orig_constraint = _orig_constraint self.kw = kw def to_diff_tuple(self): @@ -459,10 +470,13 @@ class CreateForeignKeyOp(AddConstraintOp): target_table, source_columns, target_columns, + _orig_constraint=constraint, **kw ) def to_constraint(self, migration_context=None): + if self._orig_constraint is not None: + return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.foreign_key_constraint( self.constraint_name, @@ -606,11 +620,13 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint_type = "check" def __init__( - self, constraint_name, table_name, condition, schema=None, **kw): + self, constraint_name, table_name, + condition, schema=None, _orig_constraint=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.condition = condition self.schema = schema + self._orig_constraint = _orig_constraint self.kw = kw @classmethod @@ -620,11 +636,14 @@ class CreateCheckConstraintOp(AddConstraintOp): return cls( constraint.name, constraint_table.name, - constraint.condition, + constraint.sqltext, schema=constraint_table.schema, + _orig_constraint=constraint ) def to_constraint(self, migration_context=None): + if self._orig_constraint is not None: + return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.check_constraint( self.constraint_name, self.table_name, @@ -1444,6 +1463,9 @@ class AddColumnOp(AlterTableOp): def to_diff_tuple(self): return ("add_column", self.schema, self.table_name, self.column) + def to_column(self): + return self.column + @classmethod def from_column(cls, col): return cls(col.table.name, col, schema=col.table.schema) @@ -1558,6 +1580,8 @@ class DropColumnOp(AlterTableOp): return cls(tname, col.name, schema=schema, _orig_column=col) def to_column(self, migration_context=None): + if self._orig_column is not None: + return self._orig_column schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.column(self.column_name, NULLTYPE) diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index b14fb88..553f501 100644 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -1,5 +1,5 @@ from .fixtures import TestBase -from .assertions import eq_, ne_, is_, assert_raises_message, \ +from .assertions import eq_, ne_, is_, is_not_, assert_raises_message, \ eq_ignore_whitespace, assert_raises from .util import provide_metadata diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py index 6acca21..b64725c 100644 --- a/alembic/testing/assertions.py +++ b/alembic/testing/assertions.py @@ -25,6 +25,10 @@ if not util.sqla_094: """Assert a is b, with repr messaging on failure.""" assert a is b, msg or "%r is not %r" % (a, b) + def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + def assert_raises(except_cls, callable_, *args, **kw): try: callable_(*args, **kw) @@ -45,7 +49,7 @@ if not util.sqla_094: print(text_type(e).encode('utf-8')) else: - from sqlalchemy.testing.assertions import eq_, ne_, is_, \ + from sqlalchemy.testing.assertions import eq_, ne_, is_, is_not_, \ assert_raises_message, assert_raises diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py index d176b91..ada6210 100644 --- a/tests/test_autogen_diffs.py +++ b/tests/test_autogen_diffs.py @@ -2,7 +2,8 @@ import sys from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ Numeric, CHAR, ForeignKey, INTEGER, Index, UniqueConstraint, \ - TypeDecorator, CheckConstraint, text, PrimaryKeyConstraint + TypeDecorator, CheckConstraint, text, PrimaryKeyConstraint, \ + ForeignKeyConstraint from sqlalchemy.types import NULLTYPE from sqlalchemy.engine.reflection import Inspector @@ -13,7 +14,7 @@ from alembic.testing import TestBase from alembic.testing import config from alembic.testing import assert_raises_message from alembic.testing.mock import Mock -from alembic.testing import eq_ +from alembic.testing import eq_, is_, is_not_ from alembic.util import CommandError from ._autogen_fixtures import AutogenTest, AutogenFixtureTest @@ -324,6 +325,10 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): eq_(diffs[10][0], 'remove_column') eq_(diffs[10][3].name, 'pw') + eq_(diffs[10][3].table.name, 'user') + assert isinstance( + diffs[10][3].type, String + ) def test_include_symbol(self): @@ -963,3 +968,129 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase): eq_(diffs[5][0][0], 'modify_nullable') eq_(diffs[5][0][5], False) eq_(diffs[5][0][6], True) + + +class OrigObjectTest(TestBase): + def setUp(self): + self.metadata = m = MetaData() + t = Table( + 't', m, + Column('id', Integer(), primary_key=True), + Column('x', Integer()) + ) + self.ix = Index('ix1', t.c.id) + fk = ForeignKeyConstraint(['t_id'], ['t.id']) + q = Table( + 'q', m, + Column('t_id', Integer()), + fk + ) + self.table = t + self.fk = fk + self.ck = CheckConstraint(t.c.x > 5) + self.uq = UniqueConstraint(q.c.t_id) + self.pk = t.primary_key + + def test_drop_fk(self): + fk = self.fk + op = ops.DropConstraintOp.from_constraint(fk) + is_(op.to_constraint(), fk) + is_(op.reverse().to_constraint(), fk) + + def test_add_fk(self): + fk = self.fk + op = ops.AddConstraintOp.from_constraint(fk) + is_(op.to_constraint(), fk) + is_(op.reverse().to_constraint(), fk) + is_not_(None, op.to_constraint().table) + + def test_add_check(self): + ck = self.ck + op = ops.AddConstraintOp.from_constraint(ck) + is_(op.to_constraint(), ck) + is_(op.reverse().to_constraint(), ck) + is_not_(None, op.to_constraint().table) + + def test_drop_check(self): + ck = self.ck + op = ops.DropConstraintOp.from_constraint(ck) + is_(op.to_constraint(), ck) + is_(op.reverse().to_constraint(), ck) + is_not_(None, op.to_constraint().table) + + def test_add_unique(self): + uq = self.uq + op = ops.AddConstraintOp.from_constraint(uq) + is_(op.to_constraint(), uq) + is_(op.reverse().to_constraint(), uq) + is_not_(None, op.to_constraint().table) + + def test_drop_unique(self): + uq = self.uq + op = ops.DropConstraintOp.from_constraint(uq) + is_(op.to_constraint(), uq) + is_(op.reverse().to_constraint(), uq) + is_not_(None, op.to_constraint().table) + + def test_add_pk_no_orig(self): + op = ops.CreatePrimaryKeyOp('pk1', 't', ['x', 'y']) + pk = op.to_constraint() + eq_(pk.name, 'pk1') + eq_(pk.table.name, 't') + + def test_add_pk(self): + pk = self.pk + op = ops.AddConstraintOp.from_constraint(pk) + is_(op.to_constraint(), pk) + is_(op.reverse().to_constraint(), pk) + is_not_(None, op.to_constraint().table) + + def test_drop_pk(self): + pk = self.pk + op = ops.DropConstraintOp.from_constraint(pk) + is_(op.to_constraint(), pk) + is_(op.reverse().to_constraint(), pk) + is_not_(None, op.to_constraint().table) + + def test_drop_column(self): + t = self.table + + op = ops.DropColumnOp.from_column_and_tablename(None, 't', t.c.x) + is_(op.to_column(), t.c.x) + is_(op.reverse().to_column(), t.c.x) + is_not_(None, op.to_column().table) + + def test_add_column(self): + t = self.table + + op = ops.AddColumnOp.from_column_and_tablename(None, 't', t.c.x) + is_(op.to_column(), t.c.x) + is_(op.reverse().to_column(), t.c.x) + is_not_(None, op.to_column().table) + + def test_drop_table(self): + t = self.table + + op = ops.DropTableOp.from_table(t) + is_(op.to_table(), t) + is_(op.reverse().to_table(), t) + is_(self.metadata, op.to_table().metadata) + + def test_add_table(self): + t = self.table + + op = ops.CreateTableOp.from_table(t) + is_(op.to_table(), t) + is_(op.reverse().to_table(), t) + is_(self.metadata, op.to_table().metadata) + + def test_drop_index(self): + op = ops.DropIndexOp.from_index(self.ix) + is_(op.to_index(), self.ix) + is_(op.reverse().to_index(), self.ix) + + def test_create_index(self): + op = ops.CreateIndexOp.from_index(self.ix) + is_(op.to_index(), self.ix) + is_(op.reverse().to_index(), self.ix) + |