summaryrefslogtreecommitdiff
path: root/alembic
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-07-16 19:00:55 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-07-16 19:00:55 -0400
commitc5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7 (patch)
treef158a8b62d7679176f3f60d1b1bf188a6383e03c /alembic
parent96214629cdb13f1694831f36c48a7ec86dd8c7f6 (diff)
downloadalembic-c5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7.tar.gz
- rework all of autogenerate to build directly on alembic.operations.ops
objects; the "diffs" is now a legacy system that is exported from the ops. A new model of comparison/rendering/ upgrade/downgrade composition that is cleaner and much more extensible is introduced. - autogenerate is now extensible as far as database objects compared and rendered into scripts; any new operation directive can also be registered into a series of hooks that allow custom database/model comparison functions to run as well as to render new operation directives into autogenerate scripts. - write all new docs for the new system fixes #306
Diffstat (limited to 'alembic')
-rw-r--r--alembic/autogenerate/__init__.py6
-rw-r--r--alembic/autogenerate/api.py306
-rw-r--r--alembic/autogenerate/compare.py336
-rw-r--r--alembic/autogenerate/compose.py144
-rw-r--r--alembic/autogenerate/generate.py92
-rw-r--r--alembic/autogenerate/render.py70
-rw-r--r--alembic/operations/ops.py256
-rw-r--r--alembic/runtime/migration.py3
-rw-r--r--alembic/util/langhelpers.py43
9 files changed, 724 insertions, 532 deletions
diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py
index 4272a7e..78520a8 100644
--- a/alembic/autogenerate/__init__.py
+++ b/alembic/autogenerate/__init__.py
@@ -1,7 +1,7 @@
from .api import ( # noqa
compare_metadata, _render_migration_diffs,
- produce_migrations, render_python_code
+ produce_migrations, render_python_code,
+ RevisionContext
)
-from .compare import _produce_net_changes # noqa
-from .generate import RevisionContext # noqa
+from .compare import _produce_net_changes, comparators # noqa
from .render import render_op_text, renderers # noqa \ No newline at end of file
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index cff977b..e9af4cf 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -4,8 +4,9 @@ automatically."""
from ..operations import ops
from . import render
from . import compare
-from . import compose
from .. import util
+from sqlalchemy.engine.reflection import Inspector
+import contextlib
def compare_metadata(context, metadata):
@@ -98,20 +99,8 @@ def compare_metadata(context, metadata):
"""
- autogen_context = _autogen_context(context, metadata=metadata)
-
- # as_sql=True is nonsensical here. autogenerate requires a connection
- # it can use to run queries against to get the database schema.
- if context.as_sql:
- raise util.CommandError(
- "autogenerate can't use as_sql=True as it prevents querying "
- "the database for schema information")
-
- diffs = []
-
- compare._produce_net_changes(autogen_context, diffs)
-
- return diffs
+ migration_script = produce_migrations(context, metadata)
+ return migration_script.upgrade_ops.as_diffs()
def produce_migrations(context, metadata):
@@ -132,10 +121,7 @@ def produce_migrations(context, metadata):
"""
- autogen_context = _autogen_context(context, metadata=metadata)
- diffs = []
-
- compare._produce_net_changes(autogen_context, diffs)
+ autogen_context = AutogenContext(context, metadata=metadata)
migration_script = ops.MigrationScript(
rev_id=None,
@@ -143,7 +129,7 @@ def produce_migrations(context, metadata):
downgrade_ops=ops.DowngradeOps([]),
)
- compose._to_migration_script(autogen_context, migration_script, diffs)
+ compare._populate_migration_script(autogen_context, migration_script)
return migration_script
@@ -152,6 +138,7 @@ def render_python_code(
up_or_down_op,
sqlalchemy_module_prefix='sa.',
alembic_module_prefix='op.',
+ render_as_batch=False,
imports=(),
render_item=None,
):
@@ -162,84 +149,239 @@ def render_python_code(
autogenerate output of a user-defined :class:`.MigrationScript` structure.
"""
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
- 'alembic_module_prefix': alembic_module_prefix,
- 'render_item': render_item,
- },
- 'imports': set(imports)
+ opts = {
+ 'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
+ 'alembic_module_prefix': alembic_module_prefix,
+ 'render_item': render_item,
+ 'render_as_batch': render_as_batch,
}
+
+ autogen_context = AutogenContext(None, opts=opts)
+ autogen_context.imports = set(imports)
return render._indent(render._render_cmd_body(
up_or_down_op, autogen_context))
-
-
-def _render_migration_diffs(context, template_args, imports):
+def _render_migration_diffs(context, template_args):
"""legacy, used by test_autogen_composition at the moment"""
- migration_script = produce_migrations(context, None)
-
- autogen_context = _autogen_context(context, imports=imports)
- diffs = []
+ autogen_context = AutogenContext(context)
- compare._produce_net_changes(autogen_context, diffs)
+ upgrade_ops = ops.UpgradeOps([])
+ compare._produce_net_changes(autogen_context, upgrade_ops)
migration_script = ops.MigrationScript(
rev_id=None,
- imports=imports,
- upgrade_ops=ops.UpgradeOps([]),
- downgrade_ops=ops.DowngradeOps([]),
+ upgrade_ops=upgrade_ops,
+ downgrade_ops=upgrade_ops.reverse(),
)
- compose._to_migration_script(autogen_context, migration_script, diffs)
-
render._render_migration_script(
autogen_context, migration_script, template_args
)
-def _autogen_context(
- context, imports=None, metadata=None, include_symbol=None,
- include_object=None, include_schemas=False):
-
- opts = context.opts
- metadata = opts['target_metadata'] if metadata is None else metadata
- include_schemas = opts.get('include_schemas', include_schemas)
-
- include_symbol = opts.get('include_symbol', include_symbol)
- include_object = opts.get('include_object', include_object)
-
- object_filters = []
- if include_symbol:
- def include_symbol_filter(object, name, type_, reflected, compare_to):
- if type_ == "table":
- return include_symbol(name, object.schema)
- else:
- return True
- object_filters.append(include_symbol_filter)
- if include_object:
- object_filters.append(include_object)
-
- if metadata is None:
- raise util.CommandError(
- "Can't proceed with --autogenerate option; environment "
- "script %s does not provide "
- "a MetaData object to the context." % (
- context.script.env_py_location
- ))
-
- opts = context.opts
- connection = context.bind
- return {
- 'imports': imports if imports is not None else set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context,
- 'opts': opts,
- 'metadata': metadata,
- 'object_filters': object_filters,
- 'include_schemas': include_schemas
- }
+class AutogenContext(object):
+ """Maintains configuration and state that's specific to an
+ autogenerate operation."""
+
+ metadata = None
+ """The :class:`~sqlalchemy.schema.MetaData` object
+ representing the destination.
+
+ This object is the one that is passed within ``env.py``
+ to the :paramref:`.EnvironmentContext.configure.target_metadata`
+ parameter. It represents the structure of :class:`.Table` and other
+ objects as stated in the current database model, and represents the
+ destination structure for the database being examined.
+
+ While the :class:`~sqlalchemy.schema.MetaData` object is primarily
+ known as a collection of :class:`~sqlalchemy.schema.Table` objects,
+ it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
+ that may be used by end-user schemes to store additional schema-level
+ objects that are to be compared in custom autogeneration schemes.
+
+ """
+
+ connection = None
+ """The :class:`~sqlalchemy.engine.base.Connection` object currently
+ connected to the database backend being compared.
+
+ This is obtained from the :attr:`.MigrationContext.bind` and is
+ utimately set up in the ``env.py`` script.
+
+ """
+
+ dialect = None
+ """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
+
+ This is normally obtained from the
+ :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
+
+ """
+
+ migration_context = None
+ """The :class:`.MigrationContext` established by the ``env.py`` script."""
+
+ def __init__(self, migration_context, metadata=None, opts=None):
+
+ if migration_context is not None and migration_context.as_sql:
+ raise util.CommandError(
+ "autogenerate can't use as_sql=True as it prevents querying "
+ "the database for schema information")
+
+ if opts is None:
+ opts = migration_context.opts
+ self.metadata = metadata = opts.get('target_metadata', None) \
+ if metadata is None else metadata
+
+ if metadata is None and \
+ migration_context is not None and \
+ migration_context.script is not None:
+ raise util.CommandError(
+ "Can't proceed with --autogenerate option; environment "
+ "script %s does not provide "
+ "a MetaData object to the context." % (
+ migration_context.script.env_py_location
+ ))
+
+ include_symbol = opts.get('include_symbol', None)
+ include_object = opts.get('include_object', None)
+
+ object_filters = []
+ if include_symbol:
+ def include_symbol_filter(
+ object, name, type_, reflected, compare_to):
+ if type_ == "table":
+ return include_symbol(name, object.schema)
+ else:
+ return True
+ object_filters.append(include_symbol_filter)
+ if include_object:
+ object_filters.append(include_object)
+
+ self._object_filters = object_filters
+
+ self.migration_context = migration_context
+ if self.migration_context is not None:
+ self.connection = self.migration_context.bind
+ self.dialect = self.migration_context.dialect
+
+ self._imports = set()
+ self.opts = opts
+ self._has_batch = False
+
+ @util.memoized_property
+ def inspector(self):
+ return Inspector.from_engine(self.connection)
+
+ @contextlib.contextmanager
+ def _within_batch(self):
+ self._has_batch = True
+ yield
+ self._has_batch = False
+
+ def run_filters(self, object_, name, type_, reflected, compare_to):
+ """Run the context's object filters and return True if the targets
+ should be part of the autogenerate operation.
+
+ This method should be run for every kind of object encountered within
+ an autogenerate operation, giving the environment the chance
+ to filter what objects should be included in the comparison.
+ The filters here are produced directly via the
+ :paramref:`.EnvironmentContext.configure.include_object`
+ and :paramref:`.EnvironmentContext.configure.include_symbol`
+ functions, if present.
+
+ """
+ for fn in self._object_filters:
+ if not fn(object_, name, type_, reflected, compare_to):
+ return False
+ else:
+ return True
+
+
+class RevisionContext(object):
+ """Maintains configuration and state that's specific to a revision
+ file generation operation."""
+
+ def __init__(self, config, script_directory, command_args):
+ self.config = config
+ self.script_directory = script_directory
+ self.command_args = command_args
+ self.template_args = {
+ 'config': config # Let templates use config for
+ # e.g. multiple databases
+ }
+ self.generated_revisions = [
+ self._default_revision()
+ ]
+
+ def _to_script(self, migration_script):
+ template_args = {}
+ for k, v in self.template_args.items():
+ template_args.setdefault(k, v)
+
+ if migration_script._autogen_context is not None:
+ render._render_migration_script(
+ migration_script._autogen_context, migration_script,
+ template_args
+ )
+
+ return self.script_directory.generate_revision(
+ migration_script.rev_id,
+ migration_script.message,
+ refresh=True,
+ head=migration_script.head,
+ splice=migration_script.splice,
+ branch_labels=migration_script.branch_label,
+ version_path=migration_script.version_path,
+ **template_args)
+
+ def run_autogenerate(self, rev, context):
+ if self.command_args['sql']:
+ raise util.CommandError(
+ "Using --sql with --autogenerate does not make any sense")
+ if set(self.script_directory.get_revisions(rev)) != \
+ set(self.script_directory.get_revisions("heads")):
+ raise util.CommandError("Target database is not up to date.")
+
+ autogen_context = AutogenContext(context)
+
+ migration_script = self.generated_revisions[0]
+
+ compare._populate_migration_script(autogen_context, migration_script)
+
+ hook = context.opts.get('process_revision_directives', None)
+ if hook:
+ hook(context, rev, self.generated_revisions)
+
+ for migration_script in self.generated_revisions:
+ migration_script._autogen_context = autogen_context
+
+ def run_no_autogenerate(self, rev, context):
+ hook = context.opts.get('process_revision_directives', None)
+ if hook:
+ hook(context, rev, self.generated_revisions)
+
+ for migration_script in self.generated_revisions:
+ migration_script._autogen_context = None
+
+ def _default_revision(self):
+ op = ops.MigrationScript(
+ rev_id=self.command_args['rev_id'] or util.rev_id(),
+ message=self.command_args['message'],
+ imports=set(),
+ upgrade_ops=ops.UpgradeOps([]),
+ downgrade_ops=ops.DowngradeOps([]),
+ head=self.command_args['head'],
+ splice=self.command_args['splice'],
+ branch_label=self.command_args['branch_label'],
+ version_path=self.command_args['version_path']
+ )
+ op._autogen_context = None
+ return op
+ def generate_scripts(self):
+ for generated_revision in self.generated_revisions:
+ yield self._to_script(generated_revision)
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index cd6b696..fdc3cae 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -1,7 +1,9 @@
from sqlalchemy import schema as sa_schema, types as sqltypes
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import event
+from ..operations import ops
import logging
+from .. import util
from ..util import compat
from ..util import sqla_compat
from sqlalchemy.util import OrderedSet
@@ -13,15 +15,20 @@ from alembic.ddl.base import _fk_spec
log = logging.getLogger(__name__)
-def _produce_net_changes(autogen_context, diffs):
+def _populate_migration_script(autogen_context, migration_script):
+ _produce_net_changes(autogen_context, migration_script.upgrade_ops)
+ migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops)
- metadata = autogen_context['metadata']
- connection = autogen_context['connection']
- object_filters = autogen_context.get('object_filters', ())
- include_schemas = autogen_context.get('include_schemas', False)
+
+comparators = util.Dispatcher(uselist=True)
+
+
+def _produce_net_changes(autogen_context, upgrade_ops):
+
+ connection = autogen_context.connection
+ include_schemas = autogen_context.opts.get('include_schemas', False)
inspector = Inspector.from_engine(connection)
- conn_table_names = set()
default_schema = connection.dialect.default_schema_name
if include_schemas:
@@ -34,14 +41,28 @@ def _produce_net_changes(autogen_context, diffs):
else:
schemas = [None]
- version_table_schema = autogen_context['context'].version_table_schema
- version_table = autogen_context['context'].version_table
+ comparators.dispatch("schema", autogen_context.dialect.name)(
+ autogen_context, upgrade_ops, schemas
+ )
+
+
+@comparators.dispatch_for("schema")
+def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
+ inspector = autogen_context.inspector
+
+ metadata = autogen_context.metadata
+
+ conn_table_names = set()
+
+ version_table_schema = \
+ autogen_context.migration_context.version_table_schema
+ version_table = autogen_context.migration_context.version_table
for s in schemas:
tables = set(inspector.get_table_names(schema=s))
if s == version_table_schema:
tables = tables.difference(
- [autogen_context['context'].version_table]
+ [autogen_context.migration_context.version_table]
)
conn_table_names.update(zip([s] * len(tables), tables))
@@ -50,21 +71,11 @@ def _produce_net_changes(autogen_context, diffs):
).difference([(version_table_schema, version_table)])
_compare_tables(conn_table_names, metadata_table_names,
- object_filters,
- inspector, metadata, diffs, autogen_context)
-
-
-def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
- for fn in object_filters:
- if not fn(object_, name, type_, reflected, compare_to):
- return False
- else:
- return True
+ inspector, metadata, upgrade_ops, autogen_context)
def _compare_tables(conn_table_names, metadata_table_names,
- object_filters,
- inspector, metadata, diffs, autogen_context):
+ inspector, metadata, upgrade_ops, autogen_context):
default_schema = inspector.bind.dialect.default_schema_name
@@ -95,14 +106,19 @@ def _compare_tables(conn_table_names, metadata_table_names,
for s, tname in metadata_table_names.difference(conn_table_names):
name = '%s.%s' % (s, tname) if s else tname
metadata_table = tname_to_table[(s, tname)]
- if _run_filters(
- metadata_table, tname, "table", False, None, object_filters):
- diffs.append(("add_table", metadata_table))
+ if autogen_context.run_filters(
+ metadata_table, tname, "table", False, None):
+ upgrade_ops.ops.append(
+ ops.CreateTableOp.from_table(metadata_table))
log.info("Detected added table %r", name)
- _compare_indexes_and_uniques(s, tname, object_filters,
- None,
- metadata_table,
- diffs, autogen_context, inspector)
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
+
+ comparators.dispatch("table")(
+ autogen_context, modify_table_ops,
+ s, tname, None, metadata_table
+ )
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
removal_metadata = sa_schema.MetaData()
for s, tname in conn_table_names.difference(metadata_table_names):
@@ -114,11 +130,13 @@ def _compare_tables(conn_table_names, metadata_table_names,
event.listen(
t,
"column_reflect",
- autogen_context['context'].impl.
+ autogen_context.migration_context.impl.
_compat_autogen_column_reflect(inspector))
inspector.reflecttable(t, None)
- if _run_filters(t, tname, "table", True, None, object_filters):
- diffs.append(("remove_table", t))
+ if autogen_context.run_filters(t, tname, "table", True, None):
+ upgrade_ops.ops.append(
+ ops.DropTableOp.from_table(t)
+ )
log.info("Detected removed table %r", name)
existing_tables = conn_table_names.intersection(metadata_table_names)
@@ -133,7 +151,7 @@ def _compare_tables(conn_table_names, metadata_table_names,
event.listen(
t,
"column_reflect",
- autogen_context['context'].impl.
+ autogen_context.migration_context.impl.
_compat_autogen_column_reflect(inspector))
inspector.reflecttable(t, None)
conn_column_info[(s, tname)] = t
@@ -144,25 +162,24 @@ def _compare_tables(conn_table_names, metadata_table_names,
metadata_table = tname_to_table[(s, tname)]
conn_table = existing_metadata.tables[name]
- if _run_filters(
+ if autogen_context.run_filters(
metadata_table, tname, "table", False,
- conn_table, object_filters):
+ conn_table):
+
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
with _compare_columns(
- s, tname, object_filters,
+ s, tname,
conn_table,
metadata_table,
- diffs, autogen_context, inspector):
- _compare_indexes_and_uniques(s, tname, object_filters,
- conn_table,
- metadata_table,
- diffs, autogen_context, inspector)
- _compare_foreign_keys(s, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context,
- inspector)
+ modify_table_ops, autogen_context, inspector):
- # TODO:
- # table constraints
- # sequences
+ comparators.dispatch("table")(
+ autogen_context, modify_table_ops,
+ s, tname, conn_table, metadata_table
+ )
+
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
def _make_index(params, conn_table):
@@ -202,56 +219,51 @@ def _make_foreign_key(params, conn_table):
@contextlib.contextmanager
-def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
- diffs, autogen_context, inspector):
+def _compare_columns(schema, tname, conn_table, metadata_table,
+ modify_table_ops, autogen_context, inspector):
name = '%s.%s' % (schema, tname) if schema else tname
metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
conn_col_names = dict((c.name, c) for c in conn_table.c)
metadata_col_names = OrderedSet(sorted(metadata_cols_by_name))
for cname in metadata_col_names.difference(conn_col_names):
- if _run_filters(metadata_cols_by_name[cname], cname,
- "column", False, None, object_filters):
- diffs.append(
- ("add_column", schema, tname, metadata_cols_by_name[cname])
+ if autogen_context.run_filters(
+ metadata_cols_by_name[cname], cname,
+ "column", False, None):
+ modify_table_ops.ops.append(
+ ops.AddColumnOp.from_column_and_tablename(
+ schema, tname, metadata_cols_by_name[cname])
)
log.info("Detected added column '%s.%s'", name, cname)
for colname in metadata_col_names.intersection(conn_col_names):
metadata_col = metadata_cols_by_name[colname]
conn_col = conn_table.c[colname]
- if not _run_filters(
+ if not autogen_context.run_filters(
metadata_col, colname, "column", False,
- conn_col, object_filters):
+ conn_col):
continue
- col_diff = []
- _compare_type(schema, tname, colname,
- conn_col,
- metadata_col,
- col_diff, autogen_context
- )
- # work around SQLAlchemy issue #3023
- if not metadata_col.primary_key:
- _compare_nullable(schema, tname, colname,
- conn_col,
- metadata_col.nullable,
- col_diff, autogen_context
- )
- _compare_server_default(schema, tname, colname,
- conn_col,
- metadata_col,
- col_diff, autogen_context
- )
- if col_diff:
- diffs.append(col_diff)
+ alter_column_op = ops.AlterColumnOp(
+ tname, colname, schema=schema)
+
+ comparators.dispatch("column")(
+ autogen_context, alter_column_op,
+ schema, tname, colname, conn_col, metadata_col
+ )
+
+ if alter_column_op.has_changes():
+ modify_table_ops.ops.append(alter_column_op)
yield
for cname in set(conn_col_names).difference(metadata_col_names):
- if _run_filters(conn_table.c[cname], cname,
- "column", True, None, object_filters):
- diffs.append(
- ("remove_column", schema, tname, conn_table.c[cname])
+ if autogen_context.run_filters(
+ conn_table.c[cname], cname,
+ "column", True, None):
+ modify_table_ops.ops.append(
+ ops.DropColumnOp.from_column_and_tablename(
+ schema, tname, conn_table.c[cname]
+ )
)
log.info("Detected removed column '%s.%s'", name, cname)
@@ -310,10 +322,12 @@ class _fk_constraint_sig(_constraint_sig):
)
-def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
- metadata_table, diffs,
- autogen_context, inspector):
+@comparators.dispatch_for("table")
+def _compare_indexes_and_uniques(
+ autogen_context, modify_ops, schema, tname, conn_table,
+ metadata_table):
+ inspector = autogen_context.inspector
is_create_table = conn_table is None
# 1a. get raw indexes and unique constraints from metadata ...
@@ -350,7 +364,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# 3. give the dialect a chance to omit indexes and constraints that
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
- autogen_context['context'].impl.\
+ autogen_context.migration_context.impl.\
correct_for_autogen_constraints(
conn_uniques, conn_indexes,
metadata_unique_constraints,
@@ -411,9 +425,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
def obj_added(obj):
if obj.is_index:
- if _run_filters(
- obj.const, obj.name, "index", False, None, object_filters):
- diffs.append(("add_index", obj.const))
+ if autogen_context.run_filters(
+ obj.const, obj.name, "index", False, None):
+ modify_ops.ops.append(
+ ops.CreateIndexOp.from_index(obj.const)
+ )
log.info("Detected added index '%s' on %s",
obj.name, ', '.join([
"'%s'" % obj.column_names
@@ -426,10 +442,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if is_create_table:
# unique constraints are created inline with table defs
return
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name,
- "unique_constraint", False, None, object_filters):
- diffs.append(("add_constraint", obj.const))
+ "unique_constraint", False, None):
+ modify_ops.ops.append(
+ ops.AddConstraintOp.from_constraint(obj.const)
+ )
log.info("Detected added unique constraint '%s' on %s",
obj.name, ', '.join([
"'%s'" % obj.column_names
@@ -443,39 +461,51 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# be sure what we're doing here
return
- if _run_filters(
- obj.const, obj.name, "index", True, None, object_filters):
- diffs.append(("remove_index", obj.const))
+ if autogen_context.run_filters(
+ obj.const, obj.name, "index", True, None):
+ modify_ops.ops.append(
+ ops.DropIndexOp.from_index(obj.const)
+ )
log.info(
"Detected removed index '%s' on '%s'", obj.name, tname)
else:
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name,
- "unique_constraint", True, None, object_filters):
- diffs.append(("remove_constraint", obj.const))
+ "unique_constraint", True, None):
+ modify_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(obj.const)
+ )
log.info("Detected removed unique constraint '%s' on '%s'",
obj.name, tname
)
def obj_changed(old, new, msg):
if old.is_index:
- if _run_filters(
+ if autogen_context.run_filters(
new.const, new.name, "index",
- False, old.const, object_filters):
+ False, old.const):
log.info("Detected changed index '%s' on '%s':%s",
old.name, tname, ', '.join(msg)
)
- diffs.append(("remove_index", old.const))
- diffs.append(("add_index", new.const))
+ modify_ops.ops.append(
+ ops.DropIndexOp.from_index(old.const)
+ )
+ modify_ops.ops.append(
+ ops.CreateIndexOp.from_index(new.const)
+ )
else:
- if _run_filters(
+ if autogen_context.run_filters(
new.const, new.name,
- "unique_constraint", False, old.const, object_filters):
+ "unique_constraint", False, old.const):
log.info("Detected changed unique constraint '%s' on '%s':%s",
old.name, tname, ', '.join(msg)
)
- diffs.append(("remove_constraint", old.const))
- diffs.append(("add_constraint", new.const))
+ modify_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(old.const)
+ )
+ modify_ops.ops.append(
+ ops.AddConstraintOp.from_constraint(new.const)
+ )
for added_name in sorted(set(metadata_names).difference(conn_names)):
obj = metadata_names[added_name]
@@ -528,20 +558,21 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
obj_added(unnamed_metadata_uniques[uq_sig])
-def _compare_nullable(schema, tname, cname, conn_col,
- metadata_col_nullable, diffs,
- autogen_context):
+@comparators.dispatch_for("column")
+def _compare_nullable(
+ autogen_context, alter_column_op, schema, tname, cname, conn_col,
+ metadata_col):
+
+ # work around SQLAlchemy issue #3023
+ if metadata_col.primary_key:
+ return
+
+ metadata_col_nullable = metadata_col.nullable
conn_col_nullable = conn_col.nullable
+ alter_column_op.existing_nullable = conn_col_nullable
+
if conn_col_nullable is not metadata_col_nullable:
- diffs.append(
- ("modify_nullable", schema, tname, cname,
- {
- "existing_type": conn_col.type,
- "existing_server_default": conn_col.server_default,
- },
- conn_col_nullable,
- metadata_col_nullable),
- )
+ alter_column_op.modify_nullable = metadata_col_nullable
log.info("Detected %s on column '%s.%s'",
"NULL" if metadata_col_nullable else "NOT NULL",
tname,
@@ -549,11 +580,13 @@ def _compare_nullable(schema, tname, cname, conn_col,
)
-def _compare_type(schema, tname, cname, conn_col,
- metadata_col, diffs,
- autogen_context):
+@comparators.dispatch_for("column")
+def _compare_type(
+ autogen_context, alter_column_op, schema, tname, cname, conn_col,
+ metadata_col):
conn_type = conn_col.type
+ alter_column_op.existing_type = conn_type
metadata_type = metadata_col.type
if conn_type._type_affinity is sqltypes.NullType:
log.info("Couldn't determine database type "
@@ -564,19 +597,11 @@ def _compare_type(schema, tname, cname, conn_col,
"the model; can't compare", tname, cname)
return
- isdiff = autogen_context['context']._compare_type(conn_col, metadata_col)
+ isdiff = autogen_context.migration_context._compare_type(
+ conn_col, metadata_col)
if isdiff:
-
- diffs.append(
- ("modify_type", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_server_default": conn_col.server_default,
- },
- conn_type,
- metadata_type),
- )
+ alter_column_op.modify_type = metadata_type
log.info("Detected type change from %r to %r on '%s.%s'",
conn_type, metadata_type, tname, cname
)
@@ -594,7 +619,7 @@ def _render_server_default_for_compare(metadata_default,
metadata_default = metadata_default.arg
else:
metadata_default = str(metadata_default.arg.compile(
- dialect=autogen_context['dialect']))
+ dialect=autogen_context.dialect))
if isinstance(metadata_default, compat.string_types):
if metadata_col.type._type_affinity is sqltypes.String:
metadata_default = re.sub(r"^'|'$", "", metadata_default)
@@ -605,8 +630,10 @@ def _render_server_default_for_compare(metadata_default,
return None
-def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
- diffs, autogen_context):
+@comparators.dispatch_for("column")
+def _compare_server_default(
+ autogen_context, alter_column_op, schema, tname, cname,
+ conn_col, metadata_col):
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
@@ -618,36 +645,31 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
rendered_conn_default = conn_col.server_default.arg.text \
if conn_col.server_default else None
- isdiff = autogen_context['context']._compare_server_default(
+ alter_column_op.existing_server_default = conn_col_default
+
+ isdiff = autogen_context.migration_context._compare_server_default(
conn_col, metadata_col,
rendered_metadata_default,
rendered_conn_default
)
if isdiff:
- conn_col_default = rendered_conn_default
- diffs.append(
- ("modify_default", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_type": conn_col.type,
- },
- conn_col_default,
- metadata_default),
- )
- log.info("Detected server default on column '%s.%s'",
- tname,
- cname
- )
+ alter_column_op.modify_server_default = metadata_default
+ log.info(
+ "Detected server default on column '%s.%s'",
+ tname, cname)
-def _compare_foreign_keys(schema, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context, inspector):
+@comparators.dispatch_for("table")
+def _compare_foreign_keys(
+ autogen_context, modify_table_ops, schema, tname, conn_table,
+ metadata_table):
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
if conn_table is None:
return
+ inspector = autogen_context.inspector
metadata_fks = set(
fk for fk in metadata_table.constraints
if isinstance(fk, sa_schema.ForeignKeyConstraint)
@@ -673,10 +695,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
)
def _add_fk(obj, compare_to):
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name, "foreign_key_constraint", False,
- compare_to, object_filters):
- diffs.append(('add_fk', const.const))
+ compare_to):
+ modify_table_ops.ops.append(
+ ops.CreateForeignKeyOp.from_constraint(const.const)
+ )
log.info(
"Detected added foreign key (%s)(%s) on table %s%s",
@@ -686,10 +710,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
obj.source_table)
def _remove_fk(obj, compare_to):
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name, "foreign_key_constraint", True,
- compare_to, object_filters):
- diffs.append(('remove_fk', obj.const))
+ compare_to):
+ modify_table_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(obj.const)
+ )
log.info(
"Detected removed foreign key (%s)(%s) on table %s%s",
", ".join(obj.source_columns),
@@ -713,5 +739,3 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
compare_to = conn_fks_by_name[const.name].const \
if const.name in conn_fks_by_name else None
_add_fk(const, compare_to)
-
- return diffs
diff --git a/alembic/autogenerate/compose.py b/alembic/autogenerate/compose.py
deleted file mode 100644
index b42b505..0000000
--- a/alembic/autogenerate/compose.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import itertools
-from ..operations import ops
-
-
-def _to_migration_script(autogen_context, migration_script, diffs):
- _to_upgrade_op(
- autogen_context,
- diffs,
- migration_script.upgrade_ops,
- )
-
- _to_downgrade_op(
- autogen_context,
- diffs,
- migration_script.downgrade_ops,
- )
-
-
-def _to_upgrade_op(autogen_context, diffs, upgrade_ops):
- return _to_updown_op(autogen_context, diffs, upgrade_ops, "upgrade")
-
-
-def _to_downgrade_op(autogen_context, diffs, downgrade_ops):
- return _to_updown_op(autogen_context, diffs, downgrade_ops, "downgrade")
-
-
-def _to_updown_op(autogen_context, diffs, op_container, type_):
- if not diffs:
- return
-
- if type_ == 'downgrade':
- diffs = reversed(diffs)
-
- dest = [op_container.ops]
-
- for (schema, tablename), subdiffs in _group_diffs_by_table(diffs):
- subdiffs = list(subdiffs)
- if tablename is not None:
- table_ops = []
- op = ops.ModifyTableOps(tablename, table_ops, schema=schema)
- dest[-1].append(op)
- dest.append(table_ops)
- for diff in subdiffs:
- _produce_command(autogen_context, diff, dest[-1], type_)
- if tablename is not None:
- dest.pop(-1)
-
-
-def _produce_command(autogen_context, diff, op_list, updown):
- if isinstance(diff, tuple):
- _produce_adddrop_command(updown, diff, op_list, autogen_context)
- else:
- _produce_modify_command(updown, diff, op_list, autogen_context)
-
-
-def _produce_adddrop_command(updown, diff, op_list, autogen_context):
- cmd_type = diff[0]
- adddrop, cmd_type = cmd_type.split("_")
-
- cmd_args = diff[1:]
-
- _commands = {
- "table": (ops.DropTableOp.from_table, ops.CreateTableOp.from_table),
- "column": (
- ops.DropColumnOp.from_column_and_tablename,
- ops.AddColumnOp.from_column_and_tablename),
- "index": (ops.DropIndexOp.from_index, ops.CreateIndexOp.from_index),
- "constraint": (
- ops.DropConstraintOp.from_constraint,
- ops.AddConstraintOp.from_constraint),
- "fk": (
- ops.DropConstraintOp.from_constraint,
- ops.CreateForeignKeyOp.from_constraint)
- }
-
- cmd_callables = _commands[cmd_type]
-
- if (
- updown == "upgrade" and adddrop == "add"
- ) or (
- updown == "downgrade" and adddrop == "remove"
- ):
- op_list.append(cmd_callables[1](*cmd_args))
- else:
- op_list.append(cmd_callables[0](*cmd_args))
-
-
-def _produce_modify_command(updown, diffs, op_list, autogen_context):
- sname, tname, cname = diffs[0][1:4]
- kw = {}
-
- _arg_struct = {
- "modify_type": ("existing_type", "modify_type"),
- "modify_nullable": ("existing_nullable", "modify_nullable"),
- "modify_default": ("existing_server_default", "modify_server_default"),
- }
- for diff in diffs:
- diff_kw = diff[4]
- for arg in ("existing_type",
- "existing_nullable",
- "existing_server_default"):
- if arg in diff_kw:
- kw.setdefault(arg, diff_kw[arg])
- old_kw, new_kw = _arg_struct[diff[0]]
- if updown == "upgrade":
- kw[new_kw] = diff[-1]
- kw[old_kw] = diff[-2]
- else:
- kw[new_kw] = diff[-2]
- kw[old_kw] = diff[-1]
-
- if "modify_nullable" in kw:
- kw.pop("existing_nullable", None)
- if "modify_server_default" in kw:
- kw.pop("existing_server_default", None)
-
- op_list.append(
- ops.AlterColumnOp(
- tname, cname, schema=sname,
- **kw
- )
- )
-
-
-def _group_diffs_by_table(diffs):
- _adddrop = {
- "table": lambda diff: (None, None),
- "column": lambda diff: (diff[0], diff[1]),
- "index": lambda diff: (diff[0].table.schema, diff[0].table.name),
- "constraint": lambda diff: (diff[0].table.schema, diff[0].table.name),
- "fk": lambda diff: (diff[0].parent.schema, diff[0].parent.name)
- }
-
- def _derive_table(diff):
- if isinstance(diff, tuple):
- cmd_type = diff[0]
- adddrop, cmd_type = cmd_type.split("_")
- return _adddrop[cmd_type](diff[1:])
- else:
- sname, tname = diff[0][1:3]
- return sname, tname
-
- return itertools.groupby(diffs, _derive_table)
-
diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py
deleted file mode 100644
index c686156..0000000
--- a/alembic/autogenerate/generate.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from .. import util
-from . import api
-from . import compose
-from . import compare
-from . import render
-from ..operations import ops
-
-
-class RevisionContext(object):
- def __init__(self, config, script_directory, command_args):
- self.config = config
- self.script_directory = script_directory
- self.command_args = command_args
- self.template_args = {
- 'config': config # Let templates use config for
- # e.g. multiple databases
- }
- self.generated_revisions = [
- self._default_revision()
- ]
-
- def _to_script(self, migration_script):
- template_args = {}
- for k, v in self.template_args.items():
- template_args.setdefault(k, v)
-
- if migration_script._autogen_context is not None:
- render._render_migration_script(
- migration_script._autogen_context, migration_script,
- template_args
- )
-
- return self.script_directory.generate_revision(
- migration_script.rev_id,
- migration_script.message,
- refresh=True,
- head=migration_script.head,
- splice=migration_script.splice,
- branch_labels=migration_script.branch_label,
- version_path=migration_script.version_path,
- **template_args)
-
- def run_autogenerate(self, rev, context):
- if self.command_args['sql']:
- raise util.CommandError(
- "Using --sql with --autogenerate does not make any sense")
- if set(self.script_directory.get_revisions(rev)) != \
- set(self.script_directory.get_revisions("heads")):
- raise util.CommandError("Target database is not up to date.")
-
- autogen_context = api._autogen_context(context)
-
- diffs = []
- compare._produce_net_changes(autogen_context, diffs)
-
- migration_script = self.generated_revisions[0]
-
- compose._to_migration_script(autogen_context, migration_script, diffs)
-
- hook = context.opts.get('process_revision_directives', None)
- if hook:
- hook(context, rev, self.generated_revisions)
-
- for migration_script in self.generated_revisions:
- migration_script._autogen_context = autogen_context
-
- def run_no_autogenerate(self, rev, context):
- hook = context.opts.get('process_revision_directives', None)
- if hook:
- hook(context, rev, self.generated_revisions)
-
- for migration_script in self.generated_revisions:
- migration_script._autogen_context = None
-
- def _default_revision(self):
- op = ops.MigrationScript(
- rev_id=self.command_args['rev_id'] or util.rev_id(),
- message=self.command_args['message'],
- imports=set(),
- upgrade_ops=ops.UpgradeOps([]),
- downgrade_ops=ops.DowngradeOps([]),
- head=self.command_args['head'],
- splice=self.command_args['splice'],
- branch_label=self.command_args['branch_label'],
- version_path=self.command_args['version_path']
- )
- op._autogen_context = None
- return op
-
- def generate_scripts(self):
- for generated_revision in self.generated_revisions:
- yield self._to_script(generated_revision)
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index c3f3df1..6f5f96c 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -30,8 +30,8 @@ def _indent(text):
def _render_migration_script(autogen_context, migration_script, template_args):
- opts = autogen_context['opts']
- imports = autogen_context['imports']
+ opts = autogen_context.opts
+ imports = autogen_context._imports
template_args[opts['upgrade_token']] = _indent(_render_cmd_body(
migration_script.upgrade_ops, autogen_context))
template_args[opts['downgrade_token']] = _indent(_render_cmd_body(
@@ -78,23 +78,26 @@ def render_op_text(autogen_context, op):
@renderers.dispatch_for(ops.ModifyTableOps)
def _render_modify_table(autogen_context, op):
- opts = autogen_context['opts']
+ opts = autogen_context.opts
render_as_batch = opts.get('render_as_batch', False)
if op.ops:
lines = []
if render_as_batch:
- lines.append(
- "with op.batch_alter_table(%r, schema=%r) as batch_op:"
- % (op.table_name, op.schema)
- )
- autogen_context['batch_prefix'] = 'batch_op.'
- for t_op in op.ops:
- t_lines = render_op(autogen_context, t_op)
- lines.extend(t_lines)
- if render_as_batch:
- del autogen_context['batch_prefix']
- lines.append("")
+ with autogen_context._within_batch():
+ lines.append(
+ "with op.batch_alter_table(%r, schema=%r) as batch_op:"
+ % (op.table_name, op.schema)
+ )
+ for t_op in op.ops:
+ t_lines = render_op(autogen_context, t_op)
+ lines.extend(t_lines)
+ lines.append("")
+ else:
+ for t_op in op.ops:
+ t_lines = render_op(autogen_context, t_op)
+ lines.extend(t_lines)
+
return lines
else:
return [
@@ -149,7 +152,7 @@ def _drop_table(autogen_context, op):
def _add_index(autogen_context, op):
index = op.to_index()
- has_batch = 'batch_prefix' in autogen_context
+ has_batch = autogen_context._has_batch
if has_batch:
tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\
@@ -180,7 +183,7 @@ def _add_index(autogen_context, op):
@renderers.dispatch_for(ops.DropIndexOp)
def _drop_index(autogen_context, op):
- has_batch = 'batch_prefix' in autogen_context
+ has_batch = autogen_context._has_batch
if has_batch:
tmpl = "%(prefix)sdrop_index(%(name)r)"
@@ -243,7 +246,7 @@ def _add_check_constraint(constraint, autogen_context):
@renderers.dispatch_for(ops.DropConstraintOp)
def _drop_constraint(autogen_context, op):
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)sdrop_constraint"\
"(%(name)r, type_=%(type)r)"
else:
@@ -266,7 +269,7 @@ def _drop_constraint(autogen_context, op):
def _add_column(autogen_context, op):
schema, tname, column = op.schema, op.table_name, op.column
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)sadd_column(%(column)s)"
else:
template = "%(prefix)sadd_column(%(tname)r, %(column)s"
@@ -287,7 +290,7 @@ def _drop_column(autogen_context, op):
schema, tname, column_name = op.schema, op.table_name, op.column_name
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)sdrop_column(%(cname)r)"
else:
template = "%(prefix)sdrop_column(%(tname)r, %(cname)r"
@@ -319,7 +322,7 @@ def _alter_column(autogen_context, op):
indent = " " * 11
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)salter_column(%(cname)r"
else:
template = "%(prefix)salter_column(%(tname)r, %(cname)r"
@@ -343,16 +346,16 @@ def _alter_column(autogen_context, op):
if nullable is not None:
text += ",\n%snullable=%r" % (
indent, nullable,)
- if existing_nullable is not None:
+ if nullable is None and existing_nullable is not None:
text += ",\n%sexisting_nullable=%r" % (
indent, existing_nullable)
- if existing_server_default:
+ if server_default is False and existing_server_default:
rendered = _render_server_default(
existing_server_default,
autogen_context)
text += ",\n%sexisting_server_default=%s" % (
indent, rendered)
- if schema and "batch_prefix" not in autogen_context:
+ if schema and not autogen_context._has_batch:
text += ",\n%sschema=%r" % (indent, schema)
text += ")"
return text
@@ -409,7 +412,7 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
return template % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"sql": compat.text_type(
- value.compile(dialect=autogen_context['dialect'],
+ value.compile(dialect=autogen_context.dialect,
**compile_kw)
)
}
@@ -432,7 +435,7 @@ def _get_index_rendered_expressions(idx, autogen_context):
def _uq_constraint(constraint, autogen_context, alter):
opts = []
- has_batch = 'batch_prefix' in autogen_context
+ has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", str(constraint.deferrable)))
@@ -467,7 +470,7 @@ def _uq_constraint(constraint, autogen_context, alter):
def _user_autogenerate_prefix(autogen_context, target):
- prefix = autogen_context['opts']['user_module_prefix']
+ prefix = autogen_context.opts['user_module_prefix']
if prefix is None:
return "%s." % target.__module__
else:
@@ -475,20 +478,19 @@ def _user_autogenerate_prefix(autogen_context, target):
def _sqlalchemy_autogenerate_prefix(autogen_context):
- return autogen_context['opts']['sqlalchemy_module_prefix'] or ''
+ return autogen_context.opts['sqlalchemy_module_prefix'] or ''
def _alembic_autogenerate_prefix(autogen_context):
- if 'batch_prefix' in autogen_context:
- return autogen_context['batch_prefix']
+ if autogen_context._has_batch:
+ return 'batch_op.'
else:
- return autogen_context['opts']['alembic_module_prefix'] or ''
+ return autogen_context.opts['alembic_module_prefix'] or ''
def _user_defined_render(type_, object_, autogen_context):
- if 'opts' in autogen_context and \
- 'render_item' in autogen_context['opts']:
- render = autogen_context['opts']['render_item']
+ if 'render_item' in autogen_context.opts:
+ render = autogen_context.opts['render_item']
if render:
rendered = render(type_, object_, autogen_context)
if rendered is not False:
@@ -547,7 +549,7 @@ def _repr_type(type_, autogen_context):
return rendered
mod = type(type_).__module__
- imports = autogen_context.get('imports', None)
+ imports = autogen_context._imports
if mod.startswith("sqlalchemy.dialects"):
dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
if imports is not None:
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index 08a0551..71e8515 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -3,6 +3,7 @@ from ..util import sqla_compat
from . import schemaobj
from sqlalchemy.types import NULLTYPE
from .base import Operations, BatchOperations
+import re
class MigrateOperation(object):
@@ -34,6 +35,10 @@ class MigrateOperation(object):
class AddConstraintOp(MigrateOperation):
"""Represent an add constraint operation."""
+ @property
+ def constraint_type(self):
+ raise NotImplementedError()
+
@classmethod
def from_constraint(cls, constraint):
funcs = {
@@ -45,17 +50,40 @@ class AddConstraintOp(MigrateOperation):
}
return funcs[constraint.__visit_name__](constraint)
+ def reverse(self):
+ return DropConstraintOp.from_constraint(self.to_constraint())
+
+ def to_diff_tuple(self):
+ return ("add_constraint", self.to_constraint())
+
@Operations.register_operation("drop_constraint")
@BatchOperations.register_operation("drop_constraint", "batch_drop_constraint")
class DropConstraintOp(MigrateOperation):
"""Represent a drop constraint operation."""
- def __init__(self, constraint_name, table_name, type_=None, schema=None):
+ def __init__(
+ self,
+ constraint_name, table_name, type_=None, schema=None,
+ _orig_constraint=None):
self.constraint_name = constraint_name
self.table_name = table_name
self.constraint_type = type_
self.schema = schema
+ self._orig_constraint = _orig_constraint
+
+ def reverse(self):
+ if self._orig_constraint is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original constraint is not present")
+ return AddConstraintOp.from_constraint(self._orig_constraint)
+
+ def to_diff_tuple(self):
+ if self.constraint_type == "foreignkey":
+ return ("remove_fk", self.to_constraint())
+ else:
+ return ("remove_constraint", self.to_constraint())
@classmethod
def from_constraint(cls, constraint):
@@ -72,9 +100,18 @@ class DropConstraintOp(MigrateOperation):
constraint.name,
constraint_table.name,
schema=constraint_table.schema,
- type_=types[constraint.__visit_name__]
+ type_=types[constraint.__visit_name__],
+ _orig_constraint=constraint
)
+ def to_constraint(self):
+ if self._orig_constraint is not None:
+ return self._orig_constraint
+ else:
+ raise ValueError(
+ "constraint cannot be produced; "
+ "original constraint is not present")
+
@classmethod
@util._with_legacy_names([("type", "type_")])
def drop_constraint(
@@ -124,8 +161,11 @@ class DropConstraintOp(MigrateOperation):
class CreatePrimaryKeyOp(AddConstraintOp):
"""Represent a create primary key operation."""
+ constraint_type = "primarykey"
+
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw):
+ self, constraint_name, table_name, columns,
+ schema=None, **kw):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -225,8 +265,11 @@ class CreatePrimaryKeyOp(AddConstraintOp):
class CreateUniqueConstraintOp(AddConstraintOp):
"""Represent a create unique constraint operation."""
+ constraint_type = "unique"
+
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw):
+ self, constraint_name, table_name,
+ columns, schema=None, **kw):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -342,6 +385,8 @@ class CreateUniqueConstraintOp(AddConstraintOp):
class CreateForeignKeyOp(AddConstraintOp):
"""Represent a create foreign key constraint operation."""
+ constraint_type = "foreignkey"
+
def __init__(
self, constraint_name, source_table, referent_table, local_cols,
remote_cols, **kw):
@@ -352,6 +397,9 @@ class CreateForeignKeyOp(AddConstraintOp):
self.remote_cols = remote_cols
self.kw = kw
+ def to_diff_tuple(self):
+ return ("add_fk", self.to_constraint())
+
@classmethod
def from_constraint(cls, constraint):
kw = {}
@@ -507,6 +555,8 @@ class CreateForeignKeyOp(AddConstraintOp):
class CreateCheckConstraintOp(AddConstraintOp):
"""Represent a create check constraint operation."""
+ constraint_type = "check"
+
def __init__(
self, constraint_name, table_name, condition, schema=None, **kw):
self.constraint_name = constraint_name
@@ -523,7 +573,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
constraint.name,
constraint_table.name,
constraint.condition,
- schema=constraint_table.schema
+ schema=constraint_table.schema,
)
def to_constraint(self, migration_context=None):
@@ -624,6 +674,12 @@ class CreateIndexOp(MigrateOperation):
self.kw = kw
self._orig_index = _orig_index
+ def reverse(self):
+ return DropIndexOp.from_index(self.to_index())
+
+ def to_diff_tuple(self):
+ return ("add_index", self.to_index())
+
@classmethod
def from_index(cls, index):
return cls(
@@ -729,10 +785,22 @@ class CreateIndexOp(MigrateOperation):
class DropIndexOp(MigrateOperation):
"""Represent a drop index operation."""
- def __init__(self, index_name, table_name=None, schema=None):
+ def __init__(
+ self, index_name, table_name=None, schema=None, _orig_index=None):
self.index_name = index_name
self.table_name = table_name
self.schema = schema
+ self._orig_index = _orig_index
+
+ def to_diff_tuple(self):
+ return ("remove_index", self.to_index())
+
+ def reverse(self):
+ if self._orig_index is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original index is not present")
+ return CreateIndexOp.from_index(self._orig_index)
@classmethod
def from_index(cls, index):
@@ -740,6 +808,7 @@ class DropIndexOp(MigrateOperation):
index.name,
index.table.name,
schema=index.table.schema,
+ _orig_index=index
)
def to_index(self, migration_context=None):
@@ -807,6 +876,12 @@ class CreateTableOp(MigrateOperation):
self.kw = kw
self._orig_table = _orig_table
+ def reverse(self):
+ return DropTableOp.from_table(self.to_table())
+
+ def to_diff_tuple(self):
+ return ("add_table", self.to_table())
+
@classmethod
def from_table(cls, table):
return cls(
@@ -921,16 +996,30 @@ class CreateTableOp(MigrateOperation):
class DropTableOp(MigrateOperation):
"""Represent a drop table operation."""
- def __init__(self, table_name, schema=None, table_kw=None):
+ def __init__(
+ self, table_name, schema=None, table_kw=None, _orig_table=None):
self.table_name = table_name
self.schema = schema
self.table_kw = table_kw or {}
+ self._orig_table = _orig_table
+
+ def to_diff_tuple(self):
+ return ("remove_table", self.to_table())
+
+ def reverse(self):
+ if self._orig_table is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original table is not present")
+ return CreateTableOp.from_table(self._orig_table)
@classmethod
def from_table(cls, table):
- return cls(table.name, schema=table.schema)
+ return cls(table.name, schema=table.schema, _orig_table=table)
- def to_table(self, migration_context):
+ def to_table(self, migration_context=None):
+ if self._orig_table is not None:
+ return self._orig_table
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
self.table_name,
@@ -1029,6 +1118,87 @@ class AlterColumnOp(AlterTableOp):
self.modify_type = modify_type
self.kw = kw
+ def to_diff_tuple(self):
+ col_diff = []
+ schema, tname, cname = self.schema, self.table_name, self.column_name
+
+ if self.modify_type is not None:
+ col_diff.append(
+ ("modify_type", schema, tname, cname,
+ {
+ "existing_nullable": self.existing_nullable,
+ "existing_server_default": self.existing_server_default,
+ },
+ self.existing_type,
+ self.modify_type)
+ )
+
+ if self.modify_nullable is not None:
+ col_diff.append(
+ ("modify_nullable", schema, tname, cname,
+ {
+ "existing_type": self.existing_type,
+ "existing_server_default": self.existing_server_default
+ },
+ self.existing_nullable,
+ self.modify_nullable)
+ )
+
+ if self.modify_server_default is not False:
+ col_diff.append(
+ ("modify_default", schema, tname, cname,
+ {
+ "existing_nullable": self.existing_nullable,
+ "existing_type": self.existing_type
+ },
+ self.existing_server_default,
+ self.modify_server_default)
+ )
+
+ return col_diff
+
+ def has_changes(self):
+ hc1 = self.modify_nullable is not None or \
+ self.modify_server_default is not False or \
+ self.modify_type is not None
+ if hc1:
+ return True
+ for kw in self.kw:
+ if kw.startswith('modify_'):
+ return True
+ else:
+ return False
+
+ def reverse(self):
+
+ kw = self.kw.copy()
+ kw['existing_type'] = self.existing_type
+ kw['existing_nullable'] = self.existing_nullable
+ kw['existing_server_default'] = self.existing_server_default
+ if self.modify_type is not None:
+ kw['modify_type'] = self.modify_type
+ if self.modify_nullable is not None:
+ kw['modify_nullable'] = self.modify_nullable
+ if self.modify_server_default is not False:
+ kw['modify_server_default'] = self.modify_server_default
+
+ # TODO: make this a little simpler
+ all_keys = set(m.group(1) for m in [
+ re.match(r'^(?:existing_|modify_)(.+)$', k)
+ for k in kw
+ ] if m)
+
+ for k in all_keys:
+ if 'modify_%s' % k in kw:
+ swap = kw['existing_%s' % k]
+ kw['existing_%s' % k] = kw['modify_%s' % k]
+ kw['modify_%s' % k] = swap
+
+ return self.__class__(
+ self.table_name, self.column_name, schema=self.schema,
+ **kw
+ )
+
@classmethod
@util._with_legacy_names([('name', 'new_column_name')])
def alter_column(
@@ -1177,6 +1347,13 @@ class AddColumnOp(AlterTableOp):
super(AddColumnOp, self).__init__(table_name, schema=schema)
self.column = column
+ def reverse(self):
+ return DropColumnOp.from_column_and_tablename(
+ self.schema, self.table_name, self.column)
+
+ def to_diff_tuple(self):
+ return ("add_column", self.schema, self.table_name, self.column)
+
@classmethod
def from_column(cls, col):
return cls(col.table.name, col, schema=col.table.schema)
@@ -1265,14 +1442,30 @@ class AddColumnOp(AlterTableOp):
class DropColumnOp(AlterTableOp):
"""Represent a drop column operation."""
- def __init__(self, table_name, column_name, schema=None, **kw):
+ def __init__(
+ self, table_name, column_name, schema=None,
+ _orig_column=None, **kw):
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.kw = kw
+ self._orig_column = _orig_column
+
+ def to_diff_tuple(self):
+ return (
+ "remove_column", self.schema, self.table_name, self.to_column())
+
+ def reverse(self):
+ if self._orig_column is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original column is not present")
+
+ return AddColumnOp.from_column_and_tablename(
+ self.schema, self.table_name, self._orig_column)
@classmethod
def from_column_and_tablename(cls, schema, tname, col):
- return cls(tname, col.name, schema=schema)
+ return cls(tname, col.name, schema=schema, _orig_column=col)
def to_column(self, migration_context=None):
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1522,6 +1715,21 @@ class OpContainer(MigrateOperation):
def __init__(self, ops=()):
self.ops = ops
+ def is_empty(self):
+ return not self.ops
+
+ def as_diffs(self):
+ return list(OpContainer._ops_as_diffs(self))
+
+ @classmethod
+ def _ops_as_diffs(cls, migrations):
+ for op in migrations.ops:
+ if hasattr(op, 'ops'):
+ for sub_op in cls._ops_as_diffs(op):
+ yield sub_op
+ else:
+ yield op.to_diff_tuple()
+
class ModifyTableOps(OpContainer):
"""Contains a sequence of operations that all apply to a single Table."""
@@ -1531,6 +1739,15 @@ class ModifyTableOps(OpContainer):
self.table_name = table_name
self.schema = schema
+ def reverse(self):
+ return ModifyTableOps(
+ self.table_name,
+ ops=list(reversed(
+ [op.reverse() for op in self.ops]
+ )),
+ schema=self.schema
+ )
+
class UpgradeOps(OpContainer):
"""contains a sequence of operations that would apply to the
@@ -1542,6 +1759,15 @@ class UpgradeOps(OpContainer):
"""
+ def reverse_into(self, downgrade_ops):
+ downgrade_ops.ops[:] = list(reversed(
+ [op.reverse() for op in self.ops]
+ ))
+ return downgrade_ops
+
+ def reverse(self):
+ return self.reverse_into(DowngradeOps(ops=[]))
+
class DowngradeOps(OpContainer):
"""contains a sequence of operations that would apply to the
@@ -1553,6 +1779,13 @@ class DowngradeOps(OpContainer):
"""
+ def reverse(self):
+ return UpgradeOps(
+ ops=list(reversed(
+ [op.reverse() for op in self.ops]
+ ))
+ )
+
class MigrationScript(MigrateOperation):
"""represents a migration script.
@@ -1583,4 +1816,3 @@ class MigrationScript(MigrateOperation):
self.version_path = version_path
self.upgrade_ops = upgrade_ops
self.downgrade_ops = downgrade_ops
-
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index 84a3c7f..e811a36 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -118,6 +118,7 @@ class MigrationContext(object):
connection=None,
url=None,
dialect_name=None,
+ dialect=None,
environment_context=None,
opts=None,
):
@@ -152,7 +153,7 @@ class MigrationContext(object):
elif dialect_name:
url = sqla_url.make_url("%s://" % dialect_name)
dialect = url.get_dialect()()
- else:
+ elif not dialect:
raise Exception("Connection, url, or dialect_name is required.")
return MigrationContext(dialect, connection, opts, environment_context)
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index 1fb0942..6c92e3c 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -257,30 +257,57 @@ class immutabledict(dict):
class Dispatcher(object):
- def __init__(self):
+ def __init__(self, uselist=False):
self._registry = {}
+ self.uselist = uselist
def dispatch_for(self, target, qualifier='default'):
def decorate(fn):
- assert isinstance(target, type)
- assert target not in self._registry
- self._registry[(target, qualifier)] = fn
+ if self.uselist:
+ assert target not in self._registry
+ self._registry.setdefault((target, qualifier), []).append(fn)
+ else:
+ assert target not in self._registry
+ self._registry[(target, qualifier)] = fn
return fn
return decorate
def dispatch(self, obj, qualifier='default'):
- for spcls in type(obj).__mro__:
+
+ if isinstance(obj, string_types):
+ targets = [obj]
+ elif isinstance(obj, type):
+ targets = obj.__mro__
+ else:
+ targets = type(obj).__mro__
+
+ for spcls in targets:
if qualifier != 'default' and (spcls, qualifier) in self._registry:
- return self._registry[(spcls, qualifier)]
+ return self._fn_or_list(self._registry[(spcls, qualifier)])
elif (spcls, 'default') in self._registry:
- return self._registry[(spcls, 'default')]
+ return self._fn_or_list(self._registry[(spcls, 'default')])
else:
raise ValueError("no dispatch function for object: %s" % obj)
+ def _fn_or_list(self, fn_or_list):
+ if self.uselist:
+ def go(*arg, **kw):
+ for fn in fn_or_list:
+ fn(*arg, **kw)
+ return go
+ else:
+ return fn_or_list
+
def branch(self):
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
- d._registry.update(self._registry)
+ if self.uselist:
+ d._registry.update(
+ (k, [fn for fn in self._registry[k]])
+ for k in self._registry
+ )
+ else:
+ d._registry.update(self._registry)
return d