diff options
-rw-r--r-- | alembic/autogenerate/compare.py | 136 | ||||
-rw-r--r-- | alembic/operations/ops.py | 145 |
2 files changed, 209 insertions, 72 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index cd6b696..e02b785 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1,6 +1,7 @@ from sqlalchemy import schema as sa_schema, types as sqltypes from sqlalchemy.engine.reflection import Inspector from sqlalchemy import event +from ..operations import ops import logging from ..util import compat from ..util import sqla_compat @@ -13,7 +14,7 @@ from alembic.ddl.base import _fk_spec log = logging.getLogger(__name__) -def _produce_net_changes(autogen_context, diffs): +def _produce_net_changes(autogen_context, upgrade_ops): metadata = autogen_context['metadata'] connection = autogen_context['connection'] @@ -51,7 +52,7 @@ def _produce_net_changes(autogen_context, diffs): _compare_tables(conn_table_names, metadata_table_names, object_filters, - inspector, metadata, diffs, autogen_context) + inspector, metadata, upgrade_ops, autogen_context) def _run_filters(object_, name, type_, reflected, compare_to, object_filters): @@ -64,7 +65,7 @@ def _run_filters(object_, name, type_, reflected, compare_to, object_filters): def _compare_tables(conn_table_names, metadata_table_names, object_filters, - inspector, metadata, diffs, autogen_context): + inspector, metadata, upgrade_ops, autogen_context): default_schema = inspector.bind.dialect.default_schema_name @@ -97,12 +98,14 @@ def _compare_tables(conn_table_names, metadata_table_names, metadata_table = tname_to_table[(s, tname)] if _run_filters( metadata_table, tname, "table", False, None, object_filters): - diffs.append(("add_table", metadata_table)) + upgrade_ops.ops.append( + ops.CreateTableOp.from_table(metadata_table)) log.info("Detected added table %r", name) _compare_indexes_and_uniques(s, tname, object_filters, None, metadata_table, - diffs, autogen_context, inspector) + upgrade_ops, + autogen_context, inspector) removal_metadata = sa_schema.MetaData() for s, tname in conn_table_names.difference(metadata_table_names): @@ -118,7 +121,9 @@ def _compare_tables(conn_table_names, metadata_table_names, _compat_autogen_column_reflect(inspector)) inspector.reflecttable(t, None) if _run_filters(t, tname, "table", True, None, object_filters): - diffs.append(("remove_table", t)) + upgrade_ops.ops.append( + ops.DropTableOp.from_table(t) + ) log.info("Detected removed table %r", name) existing_tables = conn_table_names.intersection(metadata_table_names) @@ -147,19 +152,26 @@ def _compare_tables(conn_table_names, metadata_table_names, if _run_filters( metadata_table, tname, "table", False, conn_table, object_filters): + + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) with _compare_columns( s, tname, object_filters, conn_table, metadata_table, - diffs, autogen_context, inspector): + modify_table_ops, autogen_context, inspector): _compare_indexes_and_uniques(s, tname, object_filters, conn_table, metadata_table, - diffs, autogen_context, inspector) + modify_table_ops, + autogen_context, inspector) _compare_foreign_keys(s, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, + metadata_table, + modify_table_ops, autogen_context, inspector) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + # TODO: # table constraints # sequences @@ -203,7 +215,7 @@ def _make_foreign_key(params, conn_table): @contextlib.contextmanager def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, - diffs, autogen_context, inspector): + modify_table_ops, autogen_context, inspector): name = '%s.%s' % (schema, tname) if schema else tname metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c) conn_col_names = dict((c.name, c) for c in conn_table.c) @@ -212,8 +224,9 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, for cname in metadata_col_names.difference(conn_col_names): if _run_filters(metadata_cols_by_name[cname], cname, "column", False, None, object_filters): - diffs.append( - ("add_column", schema, tname, metadata_cols_by_name[cname]) + modify_table_ops.ops.append( + ops.AddColumnOp.from_column_and_tablename( + schema, tname, metadata_cols_by_name[cname]) ) log.info("Detected added column '%s.%s'", name, cname) @@ -224,34 +237,37 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, metadata_col, colname, "column", False, conn_col, object_filters): continue - col_diff = [] + alter_column_op = ops.AlterColumnOp( + tname, cname, schema=schema) _compare_type(schema, tname, colname, conn_col, metadata_col, - col_diff, autogen_context + alter_column_op, autogen_context ) # work around SQLAlchemy issue #3023 if not metadata_col.primary_key: _compare_nullable(schema, tname, colname, conn_col, metadata_col.nullable, - col_diff, autogen_context + alter_column_op, autogen_context ) _compare_server_default(schema, tname, colname, conn_col, metadata_col, - col_diff, autogen_context + alter_column_op, autogen_context ) - if col_diff: - diffs.append(col_diff) + if alter_column_op.has_changes(): + modify_table_ops.ops.append(alter_column_op) yield for cname in set(conn_col_names).difference(metadata_col_names): if _run_filters(conn_table.c[cname], cname, "column", True, None, object_filters): - diffs.append( - ("remove_column", schema, tname, conn_table.c[cname]) + modify_table_ops.ops.append( + ops.DropColumnOp.from_column_and_tablename( + schema, tname, conn_table.c[cname] + ) ) log.info("Detected removed column '%s.%s'", name, cname) @@ -311,7 +327,7 @@ class _fk_constraint_sig(_constraint_sig): def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, - metadata_table, diffs, + metadata_table, modify_ops, autogen_context, inspector): is_create_table = conn_table is None @@ -413,7 +429,9 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if obj.is_index: if _run_filters( obj.const, obj.name, "index", False, None, object_filters): - diffs.append(("add_index", obj.const)) + modify_ops.ops.append( + ops.CreateIndexOp.from_index(obj.const) + ) log.info("Detected added index '%s' on %s", obj.name, ', '.join([ "'%s'" % obj.column_names @@ -429,7 +447,9 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if _run_filters( obj.const, obj.name, "unique_constraint", False, None, object_filters): - diffs.append(("add_constraint", obj.const)) + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(obj.const) + ) log.info("Detected added unique constraint '%s' on %s", obj.name, ', '.join([ "'%s'" % obj.column_names @@ -445,14 +465,18 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if _run_filters( obj.const, obj.name, "index", True, None, object_filters): - diffs.append(("remove_index", obj.const)) + modify_ops.ops.append( + ops.DropIndexOp.from_index(obj.const) + ) log.info( "Detected removed index '%s' on '%s'", obj.name, tname) else: if _run_filters( obj.const, obj.name, "unique_constraint", True, None, object_filters): - diffs.append(("remove_constraint", obj.const)) + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) log.info("Detected removed unique constraint '%s' on '%s'", obj.name, tname ) @@ -465,8 +489,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, log.info("Detected changed index '%s' on '%s':%s", old.name, tname, ', '.join(msg) ) - diffs.append(("remove_index", old.const)) - diffs.append(("add_index", new.const)) + modify_ops.ops.append( + ops.DropIndexOp.from_index(old.const) + ) + modify_ops.ops.append( + ops.CreateIndexOp.from_index(new.const) + ) else: if _run_filters( new.const, new.name, @@ -474,8 +502,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, log.info("Detected changed unique constraint '%s' on '%s':%s", old.name, tname, ', '.join(msg) ) - diffs.append(("remove_constraint", old.const)) - diffs.append(("add_constraint", new.const)) + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(old.const) + ) + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(new.const) + ) for added_name in sorted(set(metadata_names).difference(conn_names)): obj = metadata_names[added_name] @@ -529,19 +561,13 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, def _compare_nullable(schema, tname, cname, conn_col, - metadata_col_nullable, diffs, + metadata_col_nullable, alter_column_op, autogen_context): conn_col_nullable = conn_col.nullable + alter_column_op.existing_nullable = conn_col_nullable + if conn_col_nullable is not metadata_col_nullable: - diffs.append( - ("modify_nullable", schema, tname, cname, - { - "existing_type": conn_col.type, - "existing_server_default": conn_col.server_default, - }, - conn_col_nullable, - metadata_col_nullable), - ) + alter_column_op.modify_nullable = metadata_col_nullable log.info("Detected %s on column '%s.%s'", "NULL" if metadata_col_nullable else "NOT NULL", tname, @@ -550,10 +576,11 @@ def _compare_nullable(schema, tname, cname, conn_col, def _compare_type(schema, tname, cname, conn_col, - metadata_col, diffs, + metadata_col, alter_column_op, autogen_context): conn_type = conn_col.type + alter_column_op.existing_type = conn_type metadata_type = metadata_col.type if conn_type._type_affinity is sqltypes.NullType: log.info("Couldn't determine database type " @@ -567,16 +594,7 @@ def _compare_type(schema, tname, cname, conn_col, isdiff = autogen_context['context']._compare_type(conn_col, metadata_col) if isdiff: - - diffs.append( - ("modify_type", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_server_default": conn_col.server_default, - }, - conn_type, - metadata_type), - ) + alter_column_op.modify_type = metadata_type log.info("Detected type change from %r to %r on '%s.%s'", conn_type, metadata_type, tname, cname ) @@ -606,7 +624,7 @@ def _render_server_default_for_compare(metadata_default, def _compare_server_default(schema, tname, cname, conn_col, metadata_col, - diffs, autogen_context): + alter_column_op, autogen_context): metadata_default = metadata_col.server_default conn_col_default = conn_col.server_default @@ -618,22 +636,15 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, rendered_conn_default = conn_col.server_default.arg.text \ if conn_col.server_default else None + alter_column_op.existing_server_default = rendered_conn_default + isdiff = autogen_context['context']._compare_server_default( conn_col, metadata_col, rendered_metadata_default, rendered_conn_default ) if isdiff: - conn_col_default = rendered_conn_default - diffs.append( - ("modify_default", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_type": conn_col.type, - }, - conn_col_default, - metadata_default), - ) + alter_column_op.modify_server_default = metadata_default log.info("Detected server default on column '%s.%s'", tname, cname @@ -641,7 +652,8 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, def _compare_foreign_keys(schema, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, inspector): + metadata_table, modify_table_ops, + autogen_context, inspector): # if we're doing CREATE TABLE, all FKs are created # inline within the table def diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 16cccb6..21d4163 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -3,6 +3,7 @@ from ..util import sqla_compat from . import schemaobj from sqlalchemy.types import NULLTYPE from .base import Operations, BatchOperations +import re class MigrateOperation(object): @@ -34,6 +35,10 @@ class MigrateOperation(object): class AddConstraintOp(MigrateOperation): """Represent an add constraint operation.""" + @property + def type_(self): + raise NotImplementedError() + @classmethod def from_constraint(cls, constraint): funcs = { @@ -45,17 +50,33 @@ class AddConstraintOp(MigrateOperation): } return funcs[constraint.__visit_name__](constraint) + def reverse(self): + return DropConstraintOp( + self.constraint_name, self.table_name, type_=self.type_, + schema=self.schema) + @Operations.register_operation("drop_constraint") @BatchOperations.register_operation("drop_constraint", "batch_drop_constraint") class DropConstraintOp(MigrateOperation): """Represent a drop constraint operation.""" - def __init__(self, constraint_name, table_name, type_=None, schema=None): + def __init__( + self, + constraint_name, table_name, type_=None, schema=None, + _orig_constraint=None): self.constraint_name = constraint_name self.table_name = table_name self.constraint_type = type_ self.schema = schema + self._orig_constraint = _orig_constraint + + def reverse(self): + if self._orig_constraint is None: + raise ValueError( + "operation is not reversible; " + "original constraint is not present") + return AddConstraintOp.from_constraint(self._orig_constraint) @classmethod def from_constraint(cls, constraint): @@ -124,8 +145,11 @@ class DropConstraintOp(MigrateOperation): class CreatePrimaryKeyOp(AddConstraintOp): """Represent a create primary key operation.""" + type_ = "primary_key_constraint" + def __init__( - self, constraint_name, table_name, columns, schema=None, **kw): + self, constraint_name, table_name, columns, + schema=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -222,8 +246,11 @@ class CreatePrimaryKeyOp(AddConstraintOp): class CreateUniqueConstraintOp(AddConstraintOp): """Represent a create unique constraint operation.""" + type_ = "unique_constraint" + def __init__( - self, constraint_name, table_name, columns, schema=None, **kw): + self, constraint_name, table_name, + columns, schema=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -338,6 +365,8 @@ class CreateUniqueConstraintOp(AddConstraintOp): class CreateForeignKeyOp(AddConstraintOp): """Represent a create foreign key constraint operation.""" + type_ = "foreign_key_constraint" + def __init__( self, constraint_name, source_table, referent_table, local_cols, remote_cols, **kw): @@ -499,6 +528,8 @@ class CreateForeignKeyOp(AddConstraintOp): class CreateCheckConstraintOp(AddConstraintOp): """Represent a create check constraint operation.""" + type_ = "check_constraint" + def __init__( self, constraint_name, table_name, condition, schema=None, **kw): self.constraint_name = constraint_name @@ -515,7 +546,7 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint.name, constraint_table.name, constraint.condition, - schema=constraint_table.schema + schema=constraint_table.schema, ) def to_constraint(self, migration_context=None): @@ -616,6 +647,9 @@ class CreateIndexOp(MigrateOperation): self.kw = kw self._orig_index = _orig_index + def reverse(self): + return DropIndexOp.from_index(self.to_index()) + @classmethod def from_index(cls, index): return cls( @@ -721,10 +755,19 @@ class CreateIndexOp(MigrateOperation): class DropIndexOp(MigrateOperation): """Represent a drop index operation.""" - def __init__(self, index_name, table_name=None, schema=None): + def __init__( + self, index_name, table_name=None, schema=None, _orig_index=None): self.index_name = index_name self.table_name = table_name self.schema = schema + self._orig_index = _orig_index + + def reverse(self): + if self._orig_index is None: + raise ValueError( + "operation is not reversible; " + "original index is not present") + return CreateIndexOp.from_index(self._orig_index) @classmethod def from_index(cls, index): @@ -732,6 +775,7 @@ class DropIndexOp(MigrateOperation): index.name, index.table.name, schema=index.table.schema, + _orig_index=index ) def to_index(self, migration_context=None): @@ -799,6 +843,9 @@ class CreateTableOp(MigrateOperation): self.kw = kw self._orig_table = _orig_table + def reverse(self): + return DropTableOp.from_table(self.to_table()) + @classmethod def from_table(cls, table): return cls( @@ -913,14 +960,23 @@ class CreateTableOp(MigrateOperation): class DropTableOp(MigrateOperation): """Represent a drop table operation.""" - def __init__(self, table_name, schema=None, table_kw=None): + def __init__( + self, table_name, schema=None, table_kw=None, _orig_table=None): self.table_name = table_name self.schema = schema self.table_kw = table_kw or {} + self._orig_table = _orig_table + + def reverse(self): + if self._orig_table is None: + raise ValueError( + "operation is not reversible; " + "original table is not present") + return CreateTableOp.from_table(self._orig_table) @classmethod def from_table(cls, table): - return cls(table.name, schema=table.schema) + return cls(table.name, schema=table.schema, _orig_table=table) def to_table(self, migration_context): schema_obj = schemaobj.SchemaObjects(migration_context) @@ -1021,6 +1077,43 @@ class AlterColumnOp(AlterTableOp): self.modify_type = modify_type self.kw = kw + def has_changes(self): + hc1 = self.modify_nullable is not None or \ + self.modify_server_default is not False or \ + self.modify_type is not None + if hc1: + return True + for kw in self.kw: + if kw.startswith('modify_'): + return True + else: + return False + + def reverse(self): + + kw = self.kw.copy() + kw['existing_type'] = self.existing_type + kw['existing_nullable'] = self.existing_nullable + kw['existing_server_default'] = self.existing_server_default + kw['modify_type'] = self.modify_type + kw['modify_nullable'] = self.modify_nullable + kw['modify_server_default'] = self.modify_server_default + + all_keys = set(m.group(1) for m in [ + re.match(r'(?:existing_|modify_)(.+))', k) + for k in kw + ] if m) + + for k in all_keys: + swap = kw['existing_%s' % k] + kw['existing_%s' % k] = kw['modify_%s' % k] + kw['modify_%s' % k] = swap + + return self.__class__( + self.table_name, self.column_name, schema=self.schema, + **kw + ) + @classmethod @util._with_legacy_names([('name', 'new_column_name')]) def alter_column( @@ -1169,6 +1262,10 @@ class AddColumnOp(AlterTableOp): super(AddColumnOp, self).__init__(table_name, schema=schema) self.column = column + def reverse(self): + return DropColumnOp.from_column_and_tablename( + self.schema, self.table_name, self.column) + @classmethod def from_column(cls, col): return cls(col.table.name, col, schema=col.table.schema) @@ -1257,14 +1354,26 @@ class AddColumnOp(AlterTableOp): class DropColumnOp(AlterTableOp): """Represent a drop column operation.""" - def __init__(self, table_name, column_name, schema=None, **kw): + def __init__( + self, table_name, column_name, schema=None, + _orig_column=None, **kw): super(DropColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw + self._orig_column = _orig_column + + def reverse(self): + if self._orig_column is None: + raise ValueError( + "operation is not reversible; " + "original column is not present") + + return AddColumnOp.from_column_and_tablename( + self.schema, self.table_name, self._orig_column) @classmethod def from_column_and_tablename(cls, schema, tname, col): - return cls(tname, col.name, schema=schema) + return cls(tname, col.name, schema=schema, _orig_column=col) def to_column(self, migration_context=None): schema_obj = schemaobj.SchemaObjects(migration_context) @@ -1514,6 +1623,9 @@ class OpContainer(MigrateOperation): def __init__(self, ops=()): self.ops = ops + def is_empty(self): + return not self.ops + class ModifyTableOps(OpContainer): """Contains a sequence of operations that all apply to a single Table.""" @@ -1534,6 +1646,13 @@ class UpgradeOps(OpContainer): """ + def reverse(self): + return DowngradeOps( + ops=list(reversed( + op.reverse() for op in self.ops + )) + ) + class DowngradeOps(OpContainer): """contains a sequence of operations that would apply to the @@ -1545,6 +1664,13 @@ class DowngradeOps(OpContainer): """ + def reverse(self): + return UpgradeOps( + ops=list(reversed( + op.reverse() for op in self.ops + )) + ) + class MigrationScript(MigrateOperation): """represents a migration script. @@ -1575,4 +1701,3 @@ class MigrationScript(MigrateOperation): self.version_path = version_path self.upgrade_ops = upgrade_ops self.downgrade_ops = downgrade_ops - |