summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-07-06 16:50:33 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-07-06 16:50:33 -0400
commit8ecaabf64c87d7d4aee8db4d7182bfcdaa90c0b3 (patch)
tree84a438f4d9c2cd7f9da64e93fb0c336e515de7cd
parenta294f8cc3f2e5fc2cad048bc4ce27c57554e2688 (diff)
downloadalembic-8ecaabf64c87d7d4aee8db4d7182bfcdaa90c0b3.tar.gz
- most reverse() methods rough draft
- replace most use of diffs in compare
-rw-r--r--alembic/autogenerate/compare.py136
-rw-r--r--alembic/operations/ops.py145
2 files changed, 209 insertions, 72 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index cd6b696..e02b785 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -1,6 +1,7 @@
from sqlalchemy import schema as sa_schema, types as sqltypes
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import event
+from ..operations import ops
import logging
from ..util import compat
from ..util import sqla_compat
@@ -13,7 +14,7 @@ from alembic.ddl.base import _fk_spec
log = logging.getLogger(__name__)
-def _produce_net_changes(autogen_context, diffs):
+def _produce_net_changes(autogen_context, upgrade_ops):
metadata = autogen_context['metadata']
connection = autogen_context['connection']
@@ -51,7 +52,7 @@ def _produce_net_changes(autogen_context, diffs):
_compare_tables(conn_table_names, metadata_table_names,
object_filters,
- inspector, metadata, diffs, autogen_context)
+ inspector, metadata, upgrade_ops, autogen_context)
def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
@@ -64,7 +65,7 @@ def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
def _compare_tables(conn_table_names, metadata_table_names,
object_filters,
- inspector, metadata, diffs, autogen_context):
+ inspector, metadata, upgrade_ops, autogen_context):
default_schema = inspector.bind.dialect.default_schema_name
@@ -97,12 +98,14 @@ def _compare_tables(conn_table_names, metadata_table_names,
metadata_table = tname_to_table[(s, tname)]
if _run_filters(
metadata_table, tname, "table", False, None, object_filters):
- diffs.append(("add_table", metadata_table))
+ upgrade_ops.ops.append(
+ ops.CreateTableOp.from_table(metadata_table))
log.info("Detected added table %r", name)
_compare_indexes_and_uniques(s, tname, object_filters,
None,
metadata_table,
- diffs, autogen_context, inspector)
+ upgrade_ops,
+ autogen_context, inspector)
removal_metadata = sa_schema.MetaData()
for s, tname in conn_table_names.difference(metadata_table_names):
@@ -118,7 +121,9 @@ def _compare_tables(conn_table_names, metadata_table_names,
_compat_autogen_column_reflect(inspector))
inspector.reflecttable(t, None)
if _run_filters(t, tname, "table", True, None, object_filters):
- diffs.append(("remove_table", t))
+ upgrade_ops.ops.append(
+ ops.DropTableOp.from_table(t)
+ )
log.info("Detected removed table %r", name)
existing_tables = conn_table_names.intersection(metadata_table_names)
@@ -147,19 +152,26 @@ def _compare_tables(conn_table_names, metadata_table_names,
if _run_filters(
metadata_table, tname, "table", False,
conn_table, object_filters):
+
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
with _compare_columns(
s, tname, object_filters,
conn_table,
metadata_table,
- diffs, autogen_context, inspector):
+ modify_table_ops, autogen_context, inspector):
_compare_indexes_and_uniques(s, tname, object_filters,
conn_table,
metadata_table,
- diffs, autogen_context, inspector)
+ modify_table_ops,
+ autogen_context, inspector)
_compare_foreign_keys(s, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context,
+ metadata_table,
+ modify_table_ops, autogen_context,
inspector)
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
+
# TODO:
# table constraints
# sequences
@@ -203,7 +215,7 @@ def _make_foreign_key(params, conn_table):
@contextlib.contextmanager
def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
- diffs, autogen_context, inspector):
+ modify_table_ops, autogen_context, inspector):
name = '%s.%s' % (schema, tname) if schema else tname
metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
conn_col_names = dict((c.name, c) for c in conn_table.c)
@@ -212,8 +224,9 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
for cname in metadata_col_names.difference(conn_col_names):
if _run_filters(metadata_cols_by_name[cname], cname,
"column", False, None, object_filters):
- diffs.append(
- ("add_column", schema, tname, metadata_cols_by_name[cname])
+ modify_table_ops.ops.append(
+ ops.AddColumnOp.from_column_and_tablename(
+ schema, tname, metadata_cols_by_name[cname])
)
log.info("Detected added column '%s.%s'", name, cname)
@@ -224,34 +237,37 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
metadata_col, colname, "column", False,
conn_col, object_filters):
continue
- col_diff = []
+ alter_column_op = ops.AlterColumnOp(
+ tname, cname, schema=schema)
_compare_type(schema, tname, colname,
conn_col,
metadata_col,
- col_diff, autogen_context
+ alter_column_op, autogen_context
)
# work around SQLAlchemy issue #3023
if not metadata_col.primary_key:
_compare_nullable(schema, tname, colname,
conn_col,
metadata_col.nullable,
- col_diff, autogen_context
+ alter_column_op, autogen_context
)
_compare_server_default(schema, tname, colname,
conn_col,
metadata_col,
- col_diff, autogen_context
+ alter_column_op, autogen_context
)
- if col_diff:
- diffs.append(col_diff)
+ if alter_column_op.has_changes():
+ modify_table_ops.ops.append(alter_column_op)
yield
for cname in set(conn_col_names).difference(metadata_col_names):
if _run_filters(conn_table.c[cname], cname,
"column", True, None, object_filters):
- diffs.append(
- ("remove_column", schema, tname, conn_table.c[cname])
+ modify_table_ops.ops.append(
+ ops.DropColumnOp.from_column_and_tablename(
+ schema, tname, conn_table.c[cname]
+ )
)
log.info("Detected removed column '%s.%s'", name, cname)
@@ -311,7 +327,7 @@ class _fk_constraint_sig(_constraint_sig):
def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
- metadata_table, diffs,
+ metadata_table, modify_ops,
autogen_context, inspector):
is_create_table = conn_table is None
@@ -413,7 +429,9 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if obj.is_index:
if _run_filters(
obj.const, obj.name, "index", False, None, object_filters):
- diffs.append(("add_index", obj.const))
+ modify_ops.ops.append(
+ ops.CreateIndexOp.from_index(obj.const)
+ )
log.info("Detected added index '%s' on %s",
obj.name, ', '.join([
"'%s'" % obj.column_names
@@ -429,7 +447,9 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if _run_filters(
obj.const, obj.name,
"unique_constraint", False, None, object_filters):
- diffs.append(("add_constraint", obj.const))
+ modify_ops.ops.append(
+ ops.AddConstraintOp.from_constraint(obj.const)
+ )
log.info("Detected added unique constraint '%s' on %s",
obj.name, ', '.join([
"'%s'" % obj.column_names
@@ -445,14 +465,18 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if _run_filters(
obj.const, obj.name, "index", True, None, object_filters):
- diffs.append(("remove_index", obj.const))
+ modify_ops.ops.append(
+ ops.DropIndexOp.from_index(obj.const)
+ )
log.info(
"Detected removed index '%s' on '%s'", obj.name, tname)
else:
if _run_filters(
obj.const, obj.name,
"unique_constraint", True, None, object_filters):
- diffs.append(("remove_constraint", obj.const))
+ modify_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(obj.const)
+ )
log.info("Detected removed unique constraint '%s' on '%s'",
obj.name, tname
)
@@ -465,8 +489,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
log.info("Detected changed index '%s' on '%s':%s",
old.name, tname, ', '.join(msg)
)
- diffs.append(("remove_index", old.const))
- diffs.append(("add_index", new.const))
+ modify_ops.ops.append(
+ ops.DropIndexOp.from_index(old.const)
+ )
+ modify_ops.ops.append(
+ ops.CreateIndexOp.from_index(new.const)
+ )
else:
if _run_filters(
new.const, new.name,
@@ -474,8 +502,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
log.info("Detected changed unique constraint '%s' on '%s':%s",
old.name, tname, ', '.join(msg)
)
- diffs.append(("remove_constraint", old.const))
- diffs.append(("add_constraint", new.const))
+ modify_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(old.const)
+ )
+ modify_ops.ops.append(
+ ops.AddConstraintOp.from_constraint(new.const)
+ )
for added_name in sorted(set(metadata_names).difference(conn_names)):
obj = metadata_names[added_name]
@@ -529,19 +561,13 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
def _compare_nullable(schema, tname, cname, conn_col,
- metadata_col_nullable, diffs,
+ metadata_col_nullable, alter_column_op,
autogen_context):
conn_col_nullable = conn_col.nullable
+ alter_column_op.existing_nullable = conn_col_nullable
+
if conn_col_nullable is not metadata_col_nullable:
- diffs.append(
- ("modify_nullable", schema, tname, cname,
- {
- "existing_type": conn_col.type,
- "existing_server_default": conn_col.server_default,
- },
- conn_col_nullable,
- metadata_col_nullable),
- )
+ alter_column_op.modify_nullable = metadata_col_nullable
log.info("Detected %s on column '%s.%s'",
"NULL" if metadata_col_nullable else "NOT NULL",
tname,
@@ -550,10 +576,11 @@ def _compare_nullable(schema, tname, cname, conn_col,
def _compare_type(schema, tname, cname, conn_col,
- metadata_col, diffs,
+ metadata_col, alter_column_op,
autogen_context):
conn_type = conn_col.type
+ alter_column_op.existing_type = conn_type
metadata_type = metadata_col.type
if conn_type._type_affinity is sqltypes.NullType:
log.info("Couldn't determine database type "
@@ -567,16 +594,7 @@ def _compare_type(schema, tname, cname, conn_col,
isdiff = autogen_context['context']._compare_type(conn_col, metadata_col)
if isdiff:
-
- diffs.append(
- ("modify_type", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_server_default": conn_col.server_default,
- },
- conn_type,
- metadata_type),
- )
+ alter_column_op.modify_type = metadata_type
log.info("Detected type change from %r to %r on '%s.%s'",
conn_type, metadata_type, tname, cname
)
@@ -606,7 +624,7 @@ def _render_server_default_for_compare(metadata_default,
def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
- diffs, autogen_context):
+ alter_column_op, autogen_context):
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
@@ -618,22 +636,15 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
rendered_conn_default = conn_col.server_default.arg.text \
if conn_col.server_default else None
+ alter_column_op.existing_server_default = rendered_conn_default
+
isdiff = autogen_context['context']._compare_server_default(
conn_col, metadata_col,
rendered_metadata_default,
rendered_conn_default
)
if isdiff:
- conn_col_default = rendered_conn_default
- diffs.append(
- ("modify_default", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_type": conn_col.type,
- },
- conn_col_default,
- metadata_default),
- )
+ alter_column_op.modify_server_default = metadata_default
log.info("Detected server default on column '%s.%s'",
tname,
cname
@@ -641,7 +652,8 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
def _compare_foreign_keys(schema, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context, inspector):
+ metadata_table, modify_table_ops,
+ autogen_context, inspector):
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index 16cccb6..21d4163 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -3,6 +3,7 @@ from ..util import sqla_compat
from . import schemaobj
from sqlalchemy.types import NULLTYPE
from .base import Operations, BatchOperations
+import re
class MigrateOperation(object):
@@ -34,6 +35,10 @@ class MigrateOperation(object):
class AddConstraintOp(MigrateOperation):
"""Represent an add constraint operation."""
+ @property
+ def type_(self):
+ raise NotImplementedError()
+
@classmethod
def from_constraint(cls, constraint):
funcs = {
@@ -45,17 +50,33 @@ class AddConstraintOp(MigrateOperation):
}
return funcs[constraint.__visit_name__](constraint)
+ def reverse(self):
+ return DropConstraintOp(
+ self.constraint_name, self.table_name, type_=self.type_,
+ schema=self.schema)
+
@Operations.register_operation("drop_constraint")
@BatchOperations.register_operation("drop_constraint", "batch_drop_constraint")
class DropConstraintOp(MigrateOperation):
"""Represent a drop constraint operation."""
- def __init__(self, constraint_name, table_name, type_=None, schema=None):
+ def __init__(
+ self,
+ constraint_name, table_name, type_=None, schema=None,
+ _orig_constraint=None):
self.constraint_name = constraint_name
self.table_name = table_name
self.constraint_type = type_
self.schema = schema
+ self._orig_constraint = _orig_constraint
+
+ def reverse(self):
+ if self._orig_constraint is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original constraint is not present")
+ return AddConstraintOp.from_constraint(self._orig_constraint)
@classmethod
def from_constraint(cls, constraint):
@@ -124,8 +145,11 @@ class DropConstraintOp(MigrateOperation):
class CreatePrimaryKeyOp(AddConstraintOp):
"""Represent a create primary key operation."""
+ type_ = "primary_key_constraint"
+
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw):
+ self, constraint_name, table_name, columns,
+ schema=None, **kw):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -222,8 +246,11 @@ class CreatePrimaryKeyOp(AddConstraintOp):
class CreateUniqueConstraintOp(AddConstraintOp):
"""Represent a create unique constraint operation."""
+ type_ = "unique_constraint"
+
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw):
+ self, constraint_name, table_name,
+ columns, schema=None, **kw):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -338,6 +365,8 @@ class CreateUniqueConstraintOp(AddConstraintOp):
class CreateForeignKeyOp(AddConstraintOp):
"""Represent a create foreign key constraint operation."""
+ type_ = "foreign_key_constraint"
+
def __init__(
self, constraint_name, source_table, referent_table, local_cols,
remote_cols, **kw):
@@ -499,6 +528,8 @@ class CreateForeignKeyOp(AddConstraintOp):
class CreateCheckConstraintOp(AddConstraintOp):
"""Represent a create check constraint operation."""
+ type_ = "check_constraint"
+
def __init__(
self, constraint_name, table_name, condition, schema=None, **kw):
self.constraint_name = constraint_name
@@ -515,7 +546,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
constraint.name,
constraint_table.name,
constraint.condition,
- schema=constraint_table.schema
+ schema=constraint_table.schema,
)
def to_constraint(self, migration_context=None):
@@ -616,6 +647,9 @@ class CreateIndexOp(MigrateOperation):
self.kw = kw
self._orig_index = _orig_index
+ def reverse(self):
+ return DropIndexOp.from_index(self.to_index())
+
@classmethod
def from_index(cls, index):
return cls(
@@ -721,10 +755,19 @@ class CreateIndexOp(MigrateOperation):
class DropIndexOp(MigrateOperation):
"""Represent a drop index operation."""
- def __init__(self, index_name, table_name=None, schema=None):
+ def __init__(
+ self, index_name, table_name=None, schema=None, _orig_index=None):
self.index_name = index_name
self.table_name = table_name
self.schema = schema
+ self._orig_index = _orig_index
+
+ def reverse(self):
+ if self._orig_index is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original index is not present")
+ return CreateIndexOp.from_index(self._orig_index)
@classmethod
def from_index(cls, index):
@@ -732,6 +775,7 @@ class DropIndexOp(MigrateOperation):
index.name,
index.table.name,
schema=index.table.schema,
+ _orig_index=index
)
def to_index(self, migration_context=None):
@@ -799,6 +843,9 @@ class CreateTableOp(MigrateOperation):
self.kw = kw
self._orig_table = _orig_table
+ def reverse(self):
+ return DropTableOp.from_table(self.to_table())
+
@classmethod
def from_table(cls, table):
return cls(
@@ -913,14 +960,23 @@ class CreateTableOp(MigrateOperation):
class DropTableOp(MigrateOperation):
"""Represent a drop table operation."""
- def __init__(self, table_name, schema=None, table_kw=None):
+ def __init__(
+ self, table_name, schema=None, table_kw=None, _orig_table=None):
self.table_name = table_name
self.schema = schema
self.table_kw = table_kw or {}
+ self._orig_table = _orig_table
+
+ def reverse(self):
+ if self._orig_table is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original table is not present")
+ return CreateTableOp.from_table(self._orig_table)
@classmethod
def from_table(cls, table):
- return cls(table.name, schema=table.schema)
+ return cls(table.name, schema=table.schema, _orig_table=table)
def to_table(self, migration_context):
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1021,6 +1077,43 @@ class AlterColumnOp(AlterTableOp):
self.modify_type = modify_type
self.kw = kw
+ def has_changes(self):
+ hc1 = self.modify_nullable is not None or \
+ self.modify_server_default is not False or \
+ self.modify_type is not None
+ if hc1:
+ return True
+ for kw in self.kw:
+ if kw.startswith('modify_'):
+ return True
+ else:
+ return False
+
+ def reverse(self):
+
+ kw = self.kw.copy()
+ kw['existing_type'] = self.existing_type
+ kw['existing_nullable'] = self.existing_nullable
+ kw['existing_server_default'] = self.existing_server_default
+ kw['modify_type'] = self.modify_type
+ kw['modify_nullable'] = self.modify_nullable
+ kw['modify_server_default'] = self.modify_server_default
+
+ all_keys = set(m.group(1) for m in [
+ re.match(r'(?:existing_|modify_)(.+))', k)
+ for k in kw
+ ] if m)
+
+ for k in all_keys:
+ swap = kw['existing_%s' % k]
+ kw['existing_%s' % k] = kw['modify_%s' % k]
+ kw['modify_%s' % k] = swap
+
+ return self.__class__(
+ self.table_name, self.column_name, schema=self.schema,
+ **kw
+ )
+
@classmethod
@util._with_legacy_names([('name', 'new_column_name')])
def alter_column(
@@ -1169,6 +1262,10 @@ class AddColumnOp(AlterTableOp):
super(AddColumnOp, self).__init__(table_name, schema=schema)
self.column = column
+ def reverse(self):
+ return DropColumnOp.from_column_and_tablename(
+ self.schema, self.table_name, self.column)
+
@classmethod
def from_column(cls, col):
return cls(col.table.name, col, schema=col.table.schema)
@@ -1257,14 +1354,26 @@ class AddColumnOp(AlterTableOp):
class DropColumnOp(AlterTableOp):
"""Represent a drop column operation."""
- def __init__(self, table_name, column_name, schema=None, **kw):
+ def __init__(
+ self, table_name, column_name, schema=None,
+ _orig_column=None, **kw):
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.kw = kw
+ self._orig_column = _orig_column
+
+ def reverse(self):
+ if self._orig_column is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original column is not present")
+
+ return AddColumnOp.from_column_and_tablename(
+ self.schema, self.table_name, self._orig_column)
@classmethod
def from_column_and_tablename(cls, schema, tname, col):
- return cls(tname, col.name, schema=schema)
+ return cls(tname, col.name, schema=schema, _orig_column=col)
def to_column(self, migration_context=None):
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1514,6 +1623,9 @@ class OpContainer(MigrateOperation):
def __init__(self, ops=()):
self.ops = ops
+ def is_empty(self):
+ return not self.ops
+
class ModifyTableOps(OpContainer):
"""Contains a sequence of operations that all apply to a single Table."""
@@ -1534,6 +1646,13 @@ class UpgradeOps(OpContainer):
"""
+ def reverse(self):
+ return DowngradeOps(
+ ops=list(reversed(
+ op.reverse() for op in self.ops
+ ))
+ )
+
class DowngradeOps(OpContainer):
"""contains a sequence of operations that would apply to the
@@ -1545,6 +1664,13 @@ class DowngradeOps(OpContainer):
"""
+ def reverse(self):
+ return UpgradeOps(
+ ops=list(reversed(
+ op.reverse() for op in self.ops
+ ))
+ )
+
class MigrationScript(MigrateOperation):
"""represents a migration script.
@@ -1575,4 +1701,3 @@ class MigrationScript(MigrateOperation):
self.version_path = version_path
self.upgrade_ops = upgrade_ops
self.downgrade_ops = downgrade_ops
-