diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-06 21:20:55 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-06 21:20:55 -0400 |
commit | 012d4e384440676ba050416581d58eecb5c30ac5 (patch) | |
tree | 801accd4531ca8f74c31d3603a818765280e93be | |
parent | 9eb8f3a085193067b15f89069ff1b40013a9f1d8 (diff) | |
download | alembic-012d4e384440676ba050416581d58eecb5c30ac5.tar.gz |
- all tests passing again. next step is do the compare API
-rw-r--r-- | alembic/autogenerate/api.py | 9 | ||||
-rw-r--r-- | alembic/autogenerate/compare.py | 14 | ||||
-rw-r--r-- | alembic/autogenerate/generate.py | 2 | ||||
-rw-r--r-- | alembic/autogenerate/render.py | 4 | ||||
-rw-r--r-- | alembic/operations/ops.py | 55 | ||||
-rw-r--r-- | tests/test_postgresql.py | 18 |
6 files changed, 64 insertions, 38 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 1277c20..b76146d 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -121,15 +121,14 @@ def produce_migrations(context, metadata): autogen_context = _autogen_context(context, metadata=metadata) - upgrade_ops = ops.UpgradeOps([]) - compare._produce_net_changes(autogen_context, upgrade_ops) - migration_script = ops.MigrationScript( rev_id=None, - upgrade_ops=upgrade_ops, - downgrade_ops=upgrade_ops.reverse(), + upgrade_ops=ops.UpgradeOps([]), + downgrade_ops=ops.DowngradeOps([]), ) + compare._populate_migration_script(autogen_context, migration_script) + return migration_script diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 06f442c..5bb2964 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -14,6 +14,11 @@ from alembic.ddl.base import _fk_spec log = logging.getLogger(__name__) +def _populate_migration_script(autogen_context, migration_script): + _produce_net_changes(autogen_context, migration_script.upgrade_ops) + migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops) + + def _produce_net_changes(autogen_context, upgrade_ops): metadata = autogen_context['metadata'] @@ -636,7 +641,7 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, rendered_conn_default = conn_col.server_default.arg.text \ if conn_col.server_default else None - alter_column_op.existing_server_default = rendered_conn_default + alter_column_op.existing_server_default = conn_col_default isdiff = autogen_context['context']._compare_server_default( conn_col, metadata_col, @@ -645,10 +650,9 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, ) if isdiff: alter_column_op.modify_server_default = metadata_default - log.info("Detected server default on column '%s.%s'", - tname, - cname - ) + log.info( + "Detected server default on column '%s.%s'", + tname, cname) def _compare_foreign_keys(schema, tname, object_filters, conn_table, diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py index 1c8bef4..041509e 100644 --- a/alembic/autogenerate/generate.py +++ b/alembic/autogenerate/generate.py @@ -51,7 +51,7 @@ class RevisionContext(object): migration_script = self.generated_revisions[0] - compare._produce_net_changes(autogen_context, migration_script) + compare._populate_migration_script(autogen_context, migration_script) hook = context.opts.get('process_revision_directives', None) if hook: diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index c3f3df1..39817ab 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -343,10 +343,10 @@ def _alter_column(autogen_context, op): if nullable is not None: text += ",\n%snullable=%r" % ( indent, nullable,) - if existing_nullable is not None: + if nullable is None and existing_nullable is not None: text += ",\n%sexisting_nullable=%r" % ( indent, existing_nullable) - if existing_server_default: + if server_default is False and existing_server_default: rendered = _render_server_default( existing_server_default, autogen_context) diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index e9c2497..3981fc2 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -36,7 +36,7 @@ class AddConstraintOp(MigrateOperation): """Represent an add constraint operation.""" @property - def type_(self): + def constraint_type(self): raise NotImplementedError() @classmethod @@ -80,10 +80,10 @@ class DropConstraintOp(MigrateOperation): return AddConstraintOp.from_constraint(self._orig_constraint) def to_diff_tuple(self): - if self.type_ == "foreign_key_constraint": + if self.constraint_type == "foreignkey": return ("remove_fk", self.to_constraint()) else: - return ("drop_constraint", self.to_constraint()) + return ("remove_constraint", self.to_constraint()) @classmethod def from_constraint(cls, constraint): @@ -104,6 +104,14 @@ class DropConstraintOp(MigrateOperation): _orig_constraint=constraint ) + def to_constraint(self): + if self._orig_constraint is not None: + return self._orig_constraint + else: + raise ValueError( + "constraint cannot be produced; " + "original constraint is not present") + @classmethod @util._with_legacy_names([("type", "type_")]) def drop_constraint( @@ -153,7 +161,7 @@ class DropConstraintOp(MigrateOperation): class CreatePrimaryKeyOp(AddConstraintOp): """Represent a create primary key operation.""" - type_ = "primary_key_constraint" + constraint_type = "primarykey" def __init__( self, constraint_name, table_name, columns, @@ -254,7 +262,7 @@ class CreatePrimaryKeyOp(AddConstraintOp): class CreateUniqueConstraintOp(AddConstraintOp): """Represent a create unique constraint operation.""" - type_ = "unique_constraint" + constraint_type = "unique" def __init__( self, constraint_name, table_name, @@ -373,7 +381,7 @@ class CreateUniqueConstraintOp(AddConstraintOp): class CreateForeignKeyOp(AddConstraintOp): """Represent a create foreign key constraint operation.""" - type_ = "foreign_key_constraint" + constraint_type = "foreignkey" def __init__( self, constraint_name, source_table, referent_table, local_cols, @@ -539,7 +547,7 @@ class CreateForeignKeyOp(AddConstraintOp): class CreateCheckConstraintOp(AddConstraintOp): """Represent a create check constraint operation.""" - type_ = "check_constraint" + constraint_type = "check" def __init__( self, constraint_name, table_name, condition, schema=None, **kw): @@ -662,7 +670,7 @@ class CreateIndexOp(MigrateOperation): return DropIndexOp.from_index(self.to_index()) def to_diff_tuple(self): - return ("add_index", self.to_constraint()) + return ("add_index", self.to_index()) @classmethod def from_index(cls, index): @@ -1002,6 +1010,8 @@ class DropTableOp(MigrateOperation): return cls(table.name, schema=table.schema, _orig_table=table) def to_table(self, migration_context=None): + if self._orig_table is not None: + return self._orig_table schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( self.table_name, @@ -1157,19 +1167,24 @@ class AlterColumnOp(AlterTableOp): kw['existing_type'] = self.existing_type kw['existing_nullable'] = self.existing_nullable kw['existing_server_default'] = self.existing_server_default - kw['modify_type'] = self.modify_type - kw['modify_nullable'] = self.modify_nullable - kw['modify_server_default'] = self.modify_server_default + if self.modify_type is not None: + kw['modify_type'] = self.modify_type + if self.modify_nullable is not None: + kw['modify_nullable'] = self.modify_nullable + if self.modify_server_default is not False: + kw['modify_server_default'] = self.modify_server_default + # TODO: make this a little simpler all_keys = set(m.group(1) for m in [ re.match(r'^(?:existing_|modify_)(.+)$', k) for k in kw ] if m) for k in all_keys: - swap = kw['existing_%s' % k] - kw['existing_%s' % k] = kw['modify_%s' % k] - kw['modify_%s' % k] = swap + if 'modify_%s' % k in kw: + swap = kw['existing_%s' % k] + kw['existing_%s' % k] = kw['modify_%s' % k] + kw['modify_%s' % k] = swap return self.__class__( self.table_name, self.column_name, schema=self.schema, @@ -1736,12 +1751,14 @@ class UpgradeOps(OpContainer): """ + def reverse_into(self, downgrade_ops): + downgrade_ops.ops[:] = list(reversed( + [op.reverse() for op in self.ops] + )) + return downgrade_ops + def reverse(self): - return DowngradeOps( - ops=list(reversed( - [op.reverse() for op in self.ops] - )) - ) + return self.reverse_into(DowngradeOps(ops=[])) class DowngradeOps(OpContainer): diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index e70d05a..0503ef2 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -8,6 +8,7 @@ from sqlalchemy.sql import table, column from alembic.autogenerate.compare import \ _compare_server_default, _compare_tables, _render_server_default_for_compare +from alembic.operations import ops from alembic import command, util from alembic.migration import MigrationContext from alembic.script import ScriptDirectory @@ -212,9 +213,10 @@ class PostgresqlDefaultCompareTest(TestBase): cols = insp.get_columns(t1.name) insp_col = Column("somecol", cols[0]['type'], server_default=text(cols[0]['default'])) - diffs = [] + op = ops.AlterColumnOp("test", "somecol") _compare_server_default(None, "test", "somecol", insp_col, - t2.c.somecol, diffs, self.autogen_context) + t2.c.somecol, op, self.autogen_context) + diffs = op.to_diff_tuple() eq_(bool(diffs), diff_expected) def _compare_default( @@ -420,24 +422,28 @@ class PostgresqlDetectSerialTest(TestBase): self.metadata.create_all(config.db) insp = Inspector.from_engine(config.db) - diffs = [] + + uo = ops.UpgradeOps(ops=[]) _compare_tables( set([(None, 't')]), set([]), [], - insp, self.metadata, diffs, self.autogen_context) + insp, self.metadata, uo, self.autogen_context) + diffs = uo.as_diffs() tab = diffs[0][1] + eq_(_render_server_default_for_compare( tab.c.x.server_default, tab.c.x, self.autogen_context), c_expected) insp = Inspector.from_engine(config.db) - diffs = [] + uo = ops.UpgradeOps(ops=[]) m2 = MetaData() Table('t', m2, Column('x', BigInteger())) _compare_tables( set([(None, 't')]), set([(None, 't')]), [], - insp, m2, diffs, self.autogen_context) + insp, m2, uo, self.autogen_context) + diffs = uo.as_diffs() server_default = diffs[0][0][4]['existing_server_default'] eq_(_render_server_default_for_compare( server_default, tab.c.x, self.autogen_context), |