summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-07-06 21:20:55 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-07-06 21:20:55 -0400
commit012d4e384440676ba050416581d58eecb5c30ac5 (patch)
tree801accd4531ca8f74c31d3603a818765280e93be
parent9eb8f3a085193067b15f89069ff1b40013a9f1d8 (diff)
downloadalembic-012d4e384440676ba050416581d58eecb5c30ac5.tar.gz
- all tests passing again. next step is do the compare API
-rw-r--r--alembic/autogenerate/api.py9
-rw-r--r--alembic/autogenerate/compare.py14
-rw-r--r--alembic/autogenerate/generate.py2
-rw-r--r--alembic/autogenerate/render.py4
-rw-r--r--alembic/operations/ops.py55
-rw-r--r--tests/test_postgresql.py18
6 files changed, 64 insertions, 38 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 1277c20..b76146d 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -121,15 +121,14 @@ def produce_migrations(context, metadata):
autogen_context = _autogen_context(context, metadata=metadata)
- upgrade_ops = ops.UpgradeOps([])
- compare._produce_net_changes(autogen_context, upgrade_ops)
-
migration_script = ops.MigrationScript(
rev_id=None,
- upgrade_ops=upgrade_ops,
- downgrade_ops=upgrade_ops.reverse(),
+ upgrade_ops=ops.UpgradeOps([]),
+ downgrade_ops=ops.DowngradeOps([]),
)
+ compare._populate_migration_script(autogen_context, migration_script)
+
return migration_script
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 06f442c..5bb2964 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -14,6 +14,11 @@ from alembic.ddl.base import _fk_spec
log = logging.getLogger(__name__)
+def _populate_migration_script(autogen_context, migration_script):
+ _produce_net_changes(autogen_context, migration_script.upgrade_ops)
+ migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops)
+
+
def _produce_net_changes(autogen_context, upgrade_ops):
metadata = autogen_context['metadata']
@@ -636,7 +641,7 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
rendered_conn_default = conn_col.server_default.arg.text \
if conn_col.server_default else None
- alter_column_op.existing_server_default = rendered_conn_default
+ alter_column_op.existing_server_default = conn_col_default
isdiff = autogen_context['context']._compare_server_default(
conn_col, metadata_col,
@@ -645,10 +650,9 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
)
if isdiff:
alter_column_op.modify_server_default = metadata_default
- log.info("Detected server default on column '%s.%s'",
- tname,
- cname
- )
+ log.info(
+ "Detected server default on column '%s.%s'",
+ tname, cname)
def _compare_foreign_keys(schema, tname, object_filters, conn_table,
diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py
index 1c8bef4..041509e 100644
--- a/alembic/autogenerate/generate.py
+++ b/alembic/autogenerate/generate.py
@@ -51,7 +51,7 @@ class RevisionContext(object):
migration_script = self.generated_revisions[0]
- compare._produce_net_changes(autogen_context, migration_script)
+ compare._populate_migration_script(autogen_context, migration_script)
hook = context.opts.get('process_revision_directives', None)
if hook:
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index c3f3df1..39817ab 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -343,10 +343,10 @@ def _alter_column(autogen_context, op):
if nullable is not None:
text += ",\n%snullable=%r" % (
indent, nullable,)
- if existing_nullable is not None:
+ if nullable is None and existing_nullable is not None:
text += ",\n%sexisting_nullable=%r" % (
indent, existing_nullable)
- if existing_server_default:
+ if server_default is False and existing_server_default:
rendered = _render_server_default(
existing_server_default,
autogen_context)
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index e9c2497..3981fc2 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -36,7 +36,7 @@ class AddConstraintOp(MigrateOperation):
"""Represent an add constraint operation."""
@property
- def type_(self):
+ def constraint_type(self):
raise NotImplementedError()
@classmethod
@@ -80,10 +80,10 @@ class DropConstraintOp(MigrateOperation):
return AddConstraintOp.from_constraint(self._orig_constraint)
def to_diff_tuple(self):
- if self.type_ == "foreign_key_constraint":
+ if self.constraint_type == "foreignkey":
return ("remove_fk", self.to_constraint())
else:
- return ("drop_constraint", self.to_constraint())
+ return ("remove_constraint", self.to_constraint())
@classmethod
def from_constraint(cls, constraint):
@@ -104,6 +104,14 @@ class DropConstraintOp(MigrateOperation):
_orig_constraint=constraint
)
+ def to_constraint(self):
+ if self._orig_constraint is not None:
+ return self._orig_constraint
+ else:
+ raise ValueError(
+ "constraint cannot be produced; "
+ "original constraint is not present")
+
@classmethod
@util._with_legacy_names([("type", "type_")])
def drop_constraint(
@@ -153,7 +161,7 @@ class DropConstraintOp(MigrateOperation):
class CreatePrimaryKeyOp(AddConstraintOp):
"""Represent a create primary key operation."""
- type_ = "primary_key_constraint"
+ constraint_type = "primarykey"
def __init__(
self, constraint_name, table_name, columns,
@@ -254,7 +262,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
class CreateUniqueConstraintOp(AddConstraintOp):
"""Represent a create unique constraint operation."""
- type_ = "unique_constraint"
+ constraint_type = "unique"
def __init__(
self, constraint_name, table_name,
@@ -373,7 +381,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
class CreateForeignKeyOp(AddConstraintOp):
"""Represent a create foreign key constraint operation."""
- type_ = "foreign_key_constraint"
+ constraint_type = "foreignkey"
def __init__(
self, constraint_name, source_table, referent_table, local_cols,
@@ -539,7 +547,7 @@ class CreateForeignKeyOp(AddConstraintOp):
class CreateCheckConstraintOp(AddConstraintOp):
"""Represent a create check constraint operation."""
- type_ = "check_constraint"
+ constraint_type = "check"
def __init__(
self, constraint_name, table_name, condition, schema=None, **kw):
@@ -662,7 +670,7 @@ class CreateIndexOp(MigrateOperation):
return DropIndexOp.from_index(self.to_index())
def to_diff_tuple(self):
- return ("add_index", self.to_constraint())
+ return ("add_index", self.to_index())
@classmethod
def from_index(cls, index):
@@ -1002,6 +1010,8 @@ class DropTableOp(MigrateOperation):
return cls(table.name, schema=table.schema, _orig_table=table)
def to_table(self, migration_context=None):
+ if self._orig_table is not None:
+ return self._orig_table
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
self.table_name,
@@ -1157,19 +1167,24 @@ class AlterColumnOp(AlterTableOp):
kw['existing_type'] = self.existing_type
kw['existing_nullable'] = self.existing_nullable
kw['existing_server_default'] = self.existing_server_default
- kw['modify_type'] = self.modify_type
- kw['modify_nullable'] = self.modify_nullable
- kw['modify_server_default'] = self.modify_server_default
+ if self.modify_type is not None:
+ kw['modify_type'] = self.modify_type
+ if self.modify_nullable is not None:
+ kw['modify_nullable'] = self.modify_nullable
+ if self.modify_server_default is not False:
+ kw['modify_server_default'] = self.modify_server_default
+ # TODO: make this a little simpler
all_keys = set(m.group(1) for m in [
re.match(r'^(?:existing_|modify_)(.+)$', k)
for k in kw
] if m)
for k in all_keys:
- swap = kw['existing_%s' % k]
- kw['existing_%s' % k] = kw['modify_%s' % k]
- kw['modify_%s' % k] = swap
+ if 'modify_%s' % k in kw:
+ swap = kw['existing_%s' % k]
+ kw['existing_%s' % k] = kw['modify_%s' % k]
+ kw['modify_%s' % k] = swap
return self.__class__(
self.table_name, self.column_name, schema=self.schema,
@@ -1736,12 +1751,14 @@ class UpgradeOps(OpContainer):
"""
+ def reverse_into(self, downgrade_ops):
+ downgrade_ops.ops[:] = list(reversed(
+ [op.reverse() for op in self.ops]
+ ))
+ return downgrade_ops
+
def reverse(self):
- return DowngradeOps(
- ops=list(reversed(
- [op.reverse() for op in self.ops]
- ))
- )
+ return self.reverse_into(DowngradeOps(ops=[]))
class DowngradeOps(OpContainer):
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index e70d05a..0503ef2 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -8,6 +8,7 @@ from sqlalchemy.sql import table, column
from alembic.autogenerate.compare import \
_compare_server_default, _compare_tables, _render_server_default_for_compare
+from alembic.operations import ops
from alembic import command, util
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
@@ -212,9 +213,10 @@ class PostgresqlDefaultCompareTest(TestBase):
cols = insp.get_columns(t1.name)
insp_col = Column("somecol", cols[0]['type'],
server_default=text(cols[0]['default']))
- diffs = []
+ op = ops.AlterColumnOp("test", "somecol")
_compare_server_default(None, "test", "somecol", insp_col,
- t2.c.somecol, diffs, self.autogen_context)
+ t2.c.somecol, op, self.autogen_context)
+ diffs = op.to_diff_tuple()
eq_(bool(diffs), diff_expected)
def _compare_default(
@@ -420,24 +422,28 @@ class PostgresqlDetectSerialTest(TestBase):
self.metadata.create_all(config.db)
insp = Inspector.from_engine(config.db)
- diffs = []
+
+ uo = ops.UpgradeOps(ops=[])
_compare_tables(
set([(None, 't')]), set([]),
[],
- insp, self.metadata, diffs, self.autogen_context)
+ insp, self.metadata, uo, self.autogen_context)
+ diffs = uo.as_diffs()
tab = diffs[0][1]
+
eq_(_render_server_default_for_compare(
tab.c.x.server_default, tab.c.x, self.autogen_context),
c_expected)
insp = Inspector.from_engine(config.db)
- diffs = []
+ uo = ops.UpgradeOps(ops=[])
m2 = MetaData()
Table('t', m2, Column('x', BigInteger()))
_compare_tables(
set([(None, 't')]), set([(None, 't')]),
[],
- insp, m2, diffs, self.autogen_context)
+ insp, m2, uo, self.autogen_context)
+ diffs = uo.as_diffs()
server_default = diffs[0][0][4]['existing_server_default']
eq_(_render_server_default_for_compare(
server_default, tab.c.x, self.autogen_context),