summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-07-16 19:00:55 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-07-16 19:00:55 -0400
commitc5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7 (patch)
treef158a8b62d7679176f3f60d1b1bf188a6383e03c
parent96214629cdb13f1694831f36c48a7ec86dd8c7f6 (diff)
downloadalembic-c5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7.tar.gz
- rework all of autogenerate to build directly on alembic.operations.ops
objects; the "diffs" is now a legacy system that is exported from the ops. A new model of comparison/rendering/ upgrade/downgrade composition that is cleaner and much more extensible is introduced. - autogenerate is now extensible as far as database objects compared and rendered into scripts; any new operation directive can also be registered into a series of hooks that allow custom database/model comparison functions to run as well as to render new operation directives into autogenerate scripts. - write all new docs for the new system fixes #306
-rw-r--r--alembic/autogenerate/__init__.py6
-rw-r--r--alembic/autogenerate/api.py306
-rw-r--r--alembic/autogenerate/compare.py336
-rw-r--r--alembic/autogenerate/compose.py144
-rw-r--r--alembic/autogenerate/generate.py92
-rw-r--r--alembic/autogenerate/render.py70
-rw-r--r--alembic/operations/ops.py256
-rw-r--r--alembic/runtime/migration.py3
-rw-r--r--alembic/util/langhelpers.py43
-rw-r--r--docs/build/api/autogenerate.rst255
-rw-r--r--docs/build/api/operations.rst66
-rw-r--r--docs/build/changelog.rst11
-rw-r--r--tests/_autogen_fixtures.py41
-rw-r--r--tests/test_autogen_composition.py21
-rw-r--r--tests/test_autogen_diffs.py203
-rw-r--r--tests/test_autogen_fks.py6
-rw-r--r--tests/test_autogen_indexes.py12
-rw-r--r--tests/test_autogen_render.py177
-rw-r--r--tests/test_postgresql.py67
19 files changed, 1263 insertions, 852 deletions
diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py
index 4272a7e..78520a8 100644
--- a/alembic/autogenerate/__init__.py
+++ b/alembic/autogenerate/__init__.py
@@ -1,7 +1,7 @@
from .api import ( # noqa
compare_metadata, _render_migration_diffs,
- produce_migrations, render_python_code
+ produce_migrations, render_python_code,
+ RevisionContext
)
-from .compare import _produce_net_changes # noqa
-from .generate import RevisionContext # noqa
+from .compare import _produce_net_changes, comparators # noqa
from .render import render_op_text, renderers # noqa \ No newline at end of file
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index cff977b..e9af4cf 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -4,8 +4,9 @@ automatically."""
from ..operations import ops
from . import render
from . import compare
-from . import compose
from .. import util
+from sqlalchemy.engine.reflection import Inspector
+import contextlib
def compare_metadata(context, metadata):
@@ -98,20 +99,8 @@ def compare_metadata(context, metadata):
"""
- autogen_context = _autogen_context(context, metadata=metadata)
-
- # as_sql=True is nonsensical here. autogenerate requires a connection
- # it can use to run queries against to get the database schema.
- if context.as_sql:
- raise util.CommandError(
- "autogenerate can't use as_sql=True as it prevents querying "
- "the database for schema information")
-
- diffs = []
-
- compare._produce_net_changes(autogen_context, diffs)
-
- return diffs
+ migration_script = produce_migrations(context, metadata)
+ return migration_script.upgrade_ops.as_diffs()
def produce_migrations(context, metadata):
@@ -132,10 +121,7 @@ def produce_migrations(context, metadata):
"""
- autogen_context = _autogen_context(context, metadata=metadata)
- diffs = []
-
- compare._produce_net_changes(autogen_context, diffs)
+ autogen_context = AutogenContext(context, metadata=metadata)
migration_script = ops.MigrationScript(
rev_id=None,
@@ -143,7 +129,7 @@ def produce_migrations(context, metadata):
downgrade_ops=ops.DowngradeOps([]),
)
- compose._to_migration_script(autogen_context, migration_script, diffs)
+ compare._populate_migration_script(autogen_context, migration_script)
return migration_script
@@ -152,6 +138,7 @@ def render_python_code(
up_or_down_op,
sqlalchemy_module_prefix='sa.',
alembic_module_prefix='op.',
+ render_as_batch=False,
imports=(),
render_item=None,
):
@@ -162,84 +149,239 @@ def render_python_code(
autogenerate output of a user-defined :class:`.MigrationScript` structure.
"""
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
- 'alembic_module_prefix': alembic_module_prefix,
- 'render_item': render_item,
- },
- 'imports': set(imports)
+ opts = {
+ 'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
+ 'alembic_module_prefix': alembic_module_prefix,
+ 'render_item': render_item,
+ 'render_as_batch': render_as_batch,
}
+
+ autogen_context = AutogenContext(None, opts=opts)
+ autogen_context.imports = set(imports)
return render._indent(render._render_cmd_body(
up_or_down_op, autogen_context))
-
-
-def _render_migration_diffs(context, template_args, imports):
+def _render_migration_diffs(context, template_args):
"""legacy, used by test_autogen_composition at the moment"""
- migration_script = produce_migrations(context, None)
-
- autogen_context = _autogen_context(context, imports=imports)
- diffs = []
+ autogen_context = AutogenContext(context)
- compare._produce_net_changes(autogen_context, diffs)
+ upgrade_ops = ops.UpgradeOps([])
+ compare._produce_net_changes(autogen_context, upgrade_ops)
migration_script = ops.MigrationScript(
rev_id=None,
- imports=imports,
- upgrade_ops=ops.UpgradeOps([]),
- downgrade_ops=ops.DowngradeOps([]),
+ upgrade_ops=upgrade_ops,
+ downgrade_ops=upgrade_ops.reverse(),
)
- compose._to_migration_script(autogen_context, migration_script, diffs)
-
render._render_migration_script(
autogen_context, migration_script, template_args
)
-def _autogen_context(
- context, imports=None, metadata=None, include_symbol=None,
- include_object=None, include_schemas=False):
-
- opts = context.opts
- metadata = opts['target_metadata'] if metadata is None else metadata
- include_schemas = opts.get('include_schemas', include_schemas)
-
- include_symbol = opts.get('include_symbol', include_symbol)
- include_object = opts.get('include_object', include_object)
-
- object_filters = []
- if include_symbol:
- def include_symbol_filter(object, name, type_, reflected, compare_to):
- if type_ == "table":
- return include_symbol(name, object.schema)
- else:
- return True
- object_filters.append(include_symbol_filter)
- if include_object:
- object_filters.append(include_object)
-
- if metadata is None:
- raise util.CommandError(
- "Can't proceed with --autogenerate option; environment "
- "script %s does not provide "
- "a MetaData object to the context." % (
- context.script.env_py_location
- ))
-
- opts = context.opts
- connection = context.bind
- return {
- 'imports': imports if imports is not None else set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context,
- 'opts': opts,
- 'metadata': metadata,
- 'object_filters': object_filters,
- 'include_schemas': include_schemas
- }
+class AutogenContext(object):
+ """Maintains configuration and state that's specific to an
+ autogenerate operation."""
+
+ metadata = None
+ """The :class:`~sqlalchemy.schema.MetaData` object
+ representing the destination.
+
+ This object is the one that is passed within ``env.py``
+ to the :paramref:`.EnvironmentContext.configure.target_metadata`
+ parameter. It represents the structure of :class:`.Table` and other
+ objects as stated in the current database model, and represents the
+ destination structure for the database being examined.
+
+ While the :class:`~sqlalchemy.schema.MetaData` object is primarily
+ known as a collection of :class:`~sqlalchemy.schema.Table` objects,
+ it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
+ that may be used by end-user schemes to store additional schema-level
+ objects that are to be compared in custom autogeneration schemes.
+
+ """
+
+ connection = None
+ """The :class:`~sqlalchemy.engine.base.Connection` object currently
+ connected to the database backend being compared.
+
+ This is obtained from the :attr:`.MigrationContext.bind` and is
+ utimately set up in the ``env.py`` script.
+
+ """
+
+ dialect = None
+ """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
+
+ This is normally obtained from the
+ :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
+
+ """
+
+ migration_context = None
+ """The :class:`.MigrationContext` established by the ``env.py`` script."""
+
+ def __init__(self, migration_context, metadata=None, opts=None):
+
+ if migration_context is not None and migration_context.as_sql:
+ raise util.CommandError(
+ "autogenerate can't use as_sql=True as it prevents querying "
+ "the database for schema information")
+
+ if opts is None:
+ opts = migration_context.opts
+ self.metadata = metadata = opts.get('target_metadata', None) \
+ if metadata is None else metadata
+
+ if metadata is None and \
+ migration_context is not None and \
+ migration_context.script is not None:
+ raise util.CommandError(
+ "Can't proceed with --autogenerate option; environment "
+ "script %s does not provide "
+ "a MetaData object to the context." % (
+ migration_context.script.env_py_location
+ ))
+
+ include_symbol = opts.get('include_symbol', None)
+ include_object = opts.get('include_object', None)
+
+ object_filters = []
+ if include_symbol:
+ def include_symbol_filter(
+ object, name, type_, reflected, compare_to):
+ if type_ == "table":
+ return include_symbol(name, object.schema)
+ else:
+ return True
+ object_filters.append(include_symbol_filter)
+ if include_object:
+ object_filters.append(include_object)
+
+ self._object_filters = object_filters
+
+ self.migration_context = migration_context
+ if self.migration_context is not None:
+ self.connection = self.migration_context.bind
+ self.dialect = self.migration_context.dialect
+
+ self._imports = set()
+ self.opts = opts
+ self._has_batch = False
+
+ @util.memoized_property
+ def inspector(self):
+ return Inspector.from_engine(self.connection)
+
+ @contextlib.contextmanager
+ def _within_batch(self):
+ self._has_batch = True
+ yield
+ self._has_batch = False
+
+ def run_filters(self, object_, name, type_, reflected, compare_to):
+ """Run the context's object filters and return True if the targets
+ should be part of the autogenerate operation.
+
+ This method should be run for every kind of object encountered within
+ an autogenerate operation, giving the environment the chance
+ to filter what objects should be included in the comparison.
+ The filters here are produced directly via the
+ :paramref:`.EnvironmentContext.configure.include_object`
+ and :paramref:`.EnvironmentContext.configure.include_symbol`
+ functions, if present.
+
+ """
+ for fn in self._object_filters:
+ if not fn(object_, name, type_, reflected, compare_to):
+ return False
+ else:
+ return True
+
+
+class RevisionContext(object):
+ """Maintains configuration and state that's specific to a revision
+ file generation operation."""
+
+ def __init__(self, config, script_directory, command_args):
+ self.config = config
+ self.script_directory = script_directory
+ self.command_args = command_args
+ self.template_args = {
+ 'config': config # Let templates use config for
+ # e.g. multiple databases
+ }
+ self.generated_revisions = [
+ self._default_revision()
+ ]
+
+ def _to_script(self, migration_script):
+ template_args = {}
+ for k, v in self.template_args.items():
+ template_args.setdefault(k, v)
+
+ if migration_script._autogen_context is not None:
+ render._render_migration_script(
+ migration_script._autogen_context, migration_script,
+ template_args
+ )
+
+ return self.script_directory.generate_revision(
+ migration_script.rev_id,
+ migration_script.message,
+ refresh=True,
+ head=migration_script.head,
+ splice=migration_script.splice,
+ branch_labels=migration_script.branch_label,
+ version_path=migration_script.version_path,
+ **template_args)
+
+ def run_autogenerate(self, rev, context):
+ if self.command_args['sql']:
+ raise util.CommandError(
+ "Using --sql with --autogenerate does not make any sense")
+ if set(self.script_directory.get_revisions(rev)) != \
+ set(self.script_directory.get_revisions("heads")):
+ raise util.CommandError("Target database is not up to date.")
+
+ autogen_context = AutogenContext(context)
+
+ migration_script = self.generated_revisions[0]
+
+ compare._populate_migration_script(autogen_context, migration_script)
+
+ hook = context.opts.get('process_revision_directives', None)
+ if hook:
+ hook(context, rev, self.generated_revisions)
+
+ for migration_script in self.generated_revisions:
+ migration_script._autogen_context = autogen_context
+
+ def run_no_autogenerate(self, rev, context):
+ hook = context.opts.get('process_revision_directives', None)
+ if hook:
+ hook(context, rev, self.generated_revisions)
+
+ for migration_script in self.generated_revisions:
+ migration_script._autogen_context = None
+
+ def _default_revision(self):
+ op = ops.MigrationScript(
+ rev_id=self.command_args['rev_id'] or util.rev_id(),
+ message=self.command_args['message'],
+ imports=set(),
+ upgrade_ops=ops.UpgradeOps([]),
+ downgrade_ops=ops.DowngradeOps([]),
+ head=self.command_args['head'],
+ splice=self.command_args['splice'],
+ branch_label=self.command_args['branch_label'],
+ version_path=self.command_args['version_path']
+ )
+ op._autogen_context = None
+ return op
+ def generate_scripts(self):
+ for generated_revision in self.generated_revisions:
+ yield self._to_script(generated_revision)
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index cd6b696..fdc3cae 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -1,7 +1,9 @@
from sqlalchemy import schema as sa_schema, types as sqltypes
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import event
+from ..operations import ops
import logging
+from .. import util
from ..util import compat
from ..util import sqla_compat
from sqlalchemy.util import OrderedSet
@@ -13,15 +15,20 @@ from alembic.ddl.base import _fk_spec
log = logging.getLogger(__name__)
-def _produce_net_changes(autogen_context, diffs):
+def _populate_migration_script(autogen_context, migration_script):
+ _produce_net_changes(autogen_context, migration_script.upgrade_ops)
+ migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops)
- metadata = autogen_context['metadata']
- connection = autogen_context['connection']
- object_filters = autogen_context.get('object_filters', ())
- include_schemas = autogen_context.get('include_schemas', False)
+
+comparators = util.Dispatcher(uselist=True)
+
+
+def _produce_net_changes(autogen_context, upgrade_ops):
+
+ connection = autogen_context.connection
+ include_schemas = autogen_context.opts.get('include_schemas', False)
inspector = Inspector.from_engine(connection)
- conn_table_names = set()
default_schema = connection.dialect.default_schema_name
if include_schemas:
@@ -34,14 +41,28 @@ def _produce_net_changes(autogen_context, diffs):
else:
schemas = [None]
- version_table_schema = autogen_context['context'].version_table_schema
- version_table = autogen_context['context'].version_table
+ comparators.dispatch("schema", autogen_context.dialect.name)(
+ autogen_context, upgrade_ops, schemas
+ )
+
+
+@comparators.dispatch_for("schema")
+def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
+ inspector = autogen_context.inspector
+
+ metadata = autogen_context.metadata
+
+ conn_table_names = set()
+
+ version_table_schema = \
+ autogen_context.migration_context.version_table_schema
+ version_table = autogen_context.migration_context.version_table
for s in schemas:
tables = set(inspector.get_table_names(schema=s))
if s == version_table_schema:
tables = tables.difference(
- [autogen_context['context'].version_table]
+ [autogen_context.migration_context.version_table]
)
conn_table_names.update(zip([s] * len(tables), tables))
@@ -50,21 +71,11 @@ def _produce_net_changes(autogen_context, diffs):
).difference([(version_table_schema, version_table)])
_compare_tables(conn_table_names, metadata_table_names,
- object_filters,
- inspector, metadata, diffs, autogen_context)
-
-
-def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
- for fn in object_filters:
- if not fn(object_, name, type_, reflected, compare_to):
- return False
- else:
- return True
+ inspector, metadata, upgrade_ops, autogen_context)
def _compare_tables(conn_table_names, metadata_table_names,
- object_filters,
- inspector, metadata, diffs, autogen_context):
+ inspector, metadata, upgrade_ops, autogen_context):
default_schema = inspector.bind.dialect.default_schema_name
@@ -95,14 +106,19 @@ def _compare_tables(conn_table_names, metadata_table_names,
for s, tname in metadata_table_names.difference(conn_table_names):
name = '%s.%s' % (s, tname) if s else tname
metadata_table = tname_to_table[(s, tname)]
- if _run_filters(
- metadata_table, tname, "table", False, None, object_filters):
- diffs.append(("add_table", metadata_table))
+ if autogen_context.run_filters(
+ metadata_table, tname, "table", False, None):
+ upgrade_ops.ops.append(
+ ops.CreateTableOp.from_table(metadata_table))
log.info("Detected added table %r", name)
- _compare_indexes_and_uniques(s, tname, object_filters,
- None,
- metadata_table,
- diffs, autogen_context, inspector)
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
+
+ comparators.dispatch("table")(
+ autogen_context, modify_table_ops,
+ s, tname, None, metadata_table
+ )
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
removal_metadata = sa_schema.MetaData()
for s, tname in conn_table_names.difference(metadata_table_names):
@@ -114,11 +130,13 @@ def _compare_tables(conn_table_names, metadata_table_names,
event.listen(
t,
"column_reflect",
- autogen_context['context'].impl.
+ autogen_context.migration_context.impl.
_compat_autogen_column_reflect(inspector))
inspector.reflecttable(t, None)
- if _run_filters(t, tname, "table", True, None, object_filters):
- diffs.append(("remove_table", t))
+ if autogen_context.run_filters(t, tname, "table", True, None):
+ upgrade_ops.ops.append(
+ ops.DropTableOp.from_table(t)
+ )
log.info("Detected removed table %r", name)
existing_tables = conn_table_names.intersection(metadata_table_names)
@@ -133,7 +151,7 @@ def _compare_tables(conn_table_names, metadata_table_names,
event.listen(
t,
"column_reflect",
- autogen_context['context'].impl.
+ autogen_context.migration_context.impl.
_compat_autogen_column_reflect(inspector))
inspector.reflecttable(t, None)
conn_column_info[(s, tname)] = t
@@ -144,25 +162,24 @@ def _compare_tables(conn_table_names, metadata_table_names,
metadata_table = tname_to_table[(s, tname)]
conn_table = existing_metadata.tables[name]
- if _run_filters(
+ if autogen_context.run_filters(
metadata_table, tname, "table", False,
- conn_table, object_filters):
+ conn_table):
+
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
with _compare_columns(
- s, tname, object_filters,
+ s, tname,
conn_table,
metadata_table,
- diffs, autogen_context, inspector):
- _compare_indexes_and_uniques(s, tname, object_filters,
- conn_table,
- metadata_table,
- diffs, autogen_context, inspector)
- _compare_foreign_keys(s, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context,
- inspector)
+ modify_table_ops, autogen_context, inspector):
- # TODO:
- # table constraints
- # sequences
+ comparators.dispatch("table")(
+ autogen_context, modify_table_ops,
+ s, tname, conn_table, metadata_table
+ )
+
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
def _make_index(params, conn_table):
@@ -202,56 +219,51 @@ def _make_foreign_key(params, conn_table):
@contextlib.contextmanager
-def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
- diffs, autogen_context, inspector):
+def _compare_columns(schema, tname, conn_table, metadata_table,
+ modify_table_ops, autogen_context, inspector):
name = '%s.%s' % (schema, tname) if schema else tname
metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
conn_col_names = dict((c.name, c) for c in conn_table.c)
metadata_col_names = OrderedSet(sorted(metadata_cols_by_name))
for cname in metadata_col_names.difference(conn_col_names):
- if _run_filters(metadata_cols_by_name[cname], cname,
- "column", False, None, object_filters):
- diffs.append(
- ("add_column", schema, tname, metadata_cols_by_name[cname])
+ if autogen_context.run_filters(
+ metadata_cols_by_name[cname], cname,
+ "column", False, None):
+ modify_table_ops.ops.append(
+ ops.AddColumnOp.from_column_and_tablename(
+ schema, tname, metadata_cols_by_name[cname])
)
log.info("Detected added column '%s.%s'", name, cname)
for colname in metadata_col_names.intersection(conn_col_names):
metadata_col = metadata_cols_by_name[colname]
conn_col = conn_table.c[colname]
- if not _run_filters(
+ if not autogen_context.run_filters(
metadata_col, colname, "column", False,
- conn_col, object_filters):
+ conn_col):
continue
- col_diff = []
- _compare_type(schema, tname, colname,
- conn_col,
- metadata_col,
- col_diff, autogen_context
- )
- # work around SQLAlchemy issue #3023
- if not metadata_col.primary_key:
- _compare_nullable(schema, tname, colname,
- conn_col,
- metadata_col.nullable,
- col_diff, autogen_context
- )
- _compare_server_default(schema, tname, colname,
- conn_col,
- metadata_col,
- col_diff, autogen_context
- )
- if col_diff:
- diffs.append(col_diff)
+ alter_column_op = ops.AlterColumnOp(
+ tname, colname, schema=schema)
+
+ comparators.dispatch("column")(
+ autogen_context, alter_column_op,
+ schema, tname, colname, conn_col, metadata_col
+ )
+
+ if alter_column_op.has_changes():
+ modify_table_ops.ops.append(alter_column_op)
yield
for cname in set(conn_col_names).difference(metadata_col_names):
- if _run_filters(conn_table.c[cname], cname,
- "column", True, None, object_filters):
- diffs.append(
- ("remove_column", schema, tname, conn_table.c[cname])
+ if autogen_context.run_filters(
+ conn_table.c[cname], cname,
+ "column", True, None):
+ modify_table_ops.ops.append(
+ ops.DropColumnOp.from_column_and_tablename(
+ schema, tname, conn_table.c[cname]
+ )
)
log.info("Detected removed column '%s.%s'", name, cname)
@@ -310,10 +322,12 @@ class _fk_constraint_sig(_constraint_sig):
)
-def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
- metadata_table, diffs,
- autogen_context, inspector):
+@comparators.dispatch_for("table")
+def _compare_indexes_and_uniques(
+ autogen_context, modify_ops, schema, tname, conn_table,
+ metadata_table):
+ inspector = autogen_context.inspector
is_create_table = conn_table is None
# 1a. get raw indexes and unique constraints from metadata ...
@@ -350,7 +364,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# 3. give the dialect a chance to omit indexes and constraints that
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
- autogen_context['context'].impl.\
+ autogen_context.migration_context.impl.\
correct_for_autogen_constraints(
conn_uniques, conn_indexes,
metadata_unique_constraints,
@@ -411,9 +425,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
def obj_added(obj):
if obj.is_index:
- if _run_filters(
- obj.const, obj.name, "index", False, None, object_filters):
- diffs.append(("add_index", obj.const))
+ if autogen_context.run_filters(
+ obj.const, obj.name, "index", False, None):
+ modify_ops.ops.append(
+ ops.CreateIndexOp.from_index(obj.const)
+ )
log.info("Detected added index '%s' on %s",
obj.name, ', '.join([
"'%s'" % obj.column_names
@@ -426,10 +442,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if is_create_table:
# unique constraints are created inline with table defs
return
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name,
- "unique_constraint", False, None, object_filters):
- diffs.append(("add_constraint", obj.const))
+ "unique_constraint", False, None):
+ modify_ops.ops.append(
+ ops.AddConstraintOp.from_constraint(obj.const)
+ )
log.info("Detected added unique constraint '%s' on %s",
obj.name, ', '.join([
"'%s'" % obj.column_names
@@ -443,39 +461,51 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# be sure what we're doing here
return
- if _run_filters(
- obj.const, obj.name, "index", True, None, object_filters):
- diffs.append(("remove_index", obj.const))
+ if autogen_context.run_filters(
+ obj.const, obj.name, "index", True, None):
+ modify_ops.ops.append(
+ ops.DropIndexOp.from_index(obj.const)
+ )
log.info(
"Detected removed index '%s' on '%s'", obj.name, tname)
else:
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name,
- "unique_constraint", True, None, object_filters):
- diffs.append(("remove_constraint", obj.const))
+ "unique_constraint", True, None):
+ modify_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(obj.const)
+ )
log.info("Detected removed unique constraint '%s' on '%s'",
obj.name, tname
)
def obj_changed(old, new, msg):
if old.is_index:
- if _run_filters(
+ if autogen_context.run_filters(
new.const, new.name, "index",
- False, old.const, object_filters):
+ False, old.const):
log.info("Detected changed index '%s' on '%s':%s",
old.name, tname, ', '.join(msg)
)
- diffs.append(("remove_index", old.const))
- diffs.append(("add_index", new.const))
+ modify_ops.ops.append(
+ ops.DropIndexOp.from_index(old.const)
+ )
+ modify_ops.ops.append(
+ ops.CreateIndexOp.from_index(new.const)
+ )
else:
- if _run_filters(
+ if autogen_context.run_filters(
new.const, new.name,
- "unique_constraint", False, old.const, object_filters):
+ "unique_constraint", False, old.const):
log.info("Detected changed unique constraint '%s' on '%s':%s",
old.name, tname, ', '.join(msg)
)
- diffs.append(("remove_constraint", old.const))
- diffs.append(("add_constraint", new.const))
+ modify_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(old.const)
+ )
+ modify_ops.ops.append(
+ ops.AddConstraintOp.from_constraint(new.const)
+ )
for added_name in sorted(set(metadata_names).difference(conn_names)):
obj = metadata_names[added_name]
@@ -528,20 +558,21 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
obj_added(unnamed_metadata_uniques[uq_sig])
-def _compare_nullable(schema, tname, cname, conn_col,
- metadata_col_nullable, diffs,
- autogen_context):
+@comparators.dispatch_for("column")
+def _compare_nullable(
+ autogen_context, alter_column_op, schema, tname, cname, conn_col,
+ metadata_col):
+
+ # work around SQLAlchemy issue #3023
+ if metadata_col.primary_key:
+ return
+
+ metadata_col_nullable = metadata_col.nullable
conn_col_nullable = conn_col.nullable
+ alter_column_op.existing_nullable = conn_col_nullable
+
if conn_col_nullable is not metadata_col_nullable:
- diffs.append(
- ("modify_nullable", schema, tname, cname,
- {
- "existing_type": conn_col.type,
- "existing_server_default": conn_col.server_default,
- },
- conn_col_nullable,
- metadata_col_nullable),
- )
+ alter_column_op.modify_nullable = metadata_col_nullable
log.info("Detected %s on column '%s.%s'",
"NULL" if metadata_col_nullable else "NOT NULL",
tname,
@@ -549,11 +580,13 @@ def _compare_nullable(schema, tname, cname, conn_col,
)
-def _compare_type(schema, tname, cname, conn_col,
- metadata_col, diffs,
- autogen_context):
+@comparators.dispatch_for("column")
+def _compare_type(
+ autogen_context, alter_column_op, schema, tname, cname, conn_col,
+ metadata_col):
conn_type = conn_col.type
+ alter_column_op.existing_type = conn_type
metadata_type = metadata_col.type
if conn_type._type_affinity is sqltypes.NullType:
log.info("Couldn't determine database type "
@@ -564,19 +597,11 @@ def _compare_type(schema, tname, cname, conn_col,
"the model; can't compare", tname, cname)
return
- isdiff = autogen_context['context']._compare_type(conn_col, metadata_col)
+ isdiff = autogen_context.migration_context._compare_type(
+ conn_col, metadata_col)
if isdiff:
-
- diffs.append(
- ("modify_type", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_server_default": conn_col.server_default,
- },
- conn_type,
- metadata_type),
- )
+ alter_column_op.modify_type = metadata_type
log.info("Detected type change from %r to %r on '%s.%s'",
conn_type, metadata_type, tname, cname
)
@@ -594,7 +619,7 @@ def _render_server_default_for_compare(metadata_default,
metadata_default = metadata_default.arg
else:
metadata_default = str(metadata_default.arg.compile(
- dialect=autogen_context['dialect']))
+ dialect=autogen_context.dialect))
if isinstance(metadata_default, compat.string_types):
if metadata_col.type._type_affinity is sqltypes.String:
metadata_default = re.sub(r"^'|'$", "", metadata_default)
@@ -605,8 +630,10 @@ def _render_server_default_for_compare(metadata_default,
return None
-def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
- diffs, autogen_context):
+@comparators.dispatch_for("column")
+def _compare_server_default(
+ autogen_context, alter_column_op, schema, tname, cname,
+ conn_col, metadata_col):
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
@@ -618,36 +645,31 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
rendered_conn_default = conn_col.server_default.arg.text \
if conn_col.server_default else None
- isdiff = autogen_context['context']._compare_server_default(
+ alter_column_op.existing_server_default = conn_col_default
+
+ isdiff = autogen_context.migration_context._compare_server_default(
conn_col, metadata_col,
rendered_metadata_default,
rendered_conn_default
)
if isdiff:
- conn_col_default = rendered_conn_default
- diffs.append(
- ("modify_default", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_type": conn_col.type,
- },
- conn_col_default,
- metadata_default),
- )
- log.info("Detected server default on column '%s.%s'",
- tname,
- cname
- )
+ alter_column_op.modify_server_default = metadata_default
+ log.info(
+ "Detected server default on column '%s.%s'",
+ tname, cname)
-def _compare_foreign_keys(schema, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context, inspector):
+@comparators.dispatch_for("table")
+def _compare_foreign_keys(
+ autogen_context, modify_table_ops, schema, tname, conn_table,
+ metadata_table):
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
if conn_table is None:
return
+ inspector = autogen_context.inspector
metadata_fks = set(
fk for fk in metadata_table.constraints
if isinstance(fk, sa_schema.ForeignKeyConstraint)
@@ -673,10 +695,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
)
def _add_fk(obj, compare_to):
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name, "foreign_key_constraint", False,
- compare_to, object_filters):
- diffs.append(('add_fk', const.const))
+ compare_to):
+ modify_table_ops.ops.append(
+ ops.CreateForeignKeyOp.from_constraint(const.const)
+ )
log.info(
"Detected added foreign key (%s)(%s) on table %s%s",
@@ -686,10 +710,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
obj.source_table)
def _remove_fk(obj, compare_to):
- if _run_filters(
+ if autogen_context.run_filters(
obj.const, obj.name, "foreign_key_constraint", True,
- compare_to, object_filters):
- diffs.append(('remove_fk', obj.const))
+ compare_to):
+ modify_table_ops.ops.append(
+ ops.DropConstraintOp.from_constraint(obj.const)
+ )
log.info(
"Detected removed foreign key (%s)(%s) on table %s%s",
", ".join(obj.source_columns),
@@ -713,5 +739,3 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table,
compare_to = conn_fks_by_name[const.name].const \
if const.name in conn_fks_by_name else None
_add_fk(const, compare_to)
-
- return diffs
diff --git a/alembic/autogenerate/compose.py b/alembic/autogenerate/compose.py
deleted file mode 100644
index b42b505..0000000
--- a/alembic/autogenerate/compose.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import itertools
-from ..operations import ops
-
-
-def _to_migration_script(autogen_context, migration_script, diffs):
- _to_upgrade_op(
- autogen_context,
- diffs,
- migration_script.upgrade_ops,
- )
-
- _to_downgrade_op(
- autogen_context,
- diffs,
- migration_script.downgrade_ops,
- )
-
-
-def _to_upgrade_op(autogen_context, diffs, upgrade_ops):
- return _to_updown_op(autogen_context, diffs, upgrade_ops, "upgrade")
-
-
-def _to_downgrade_op(autogen_context, diffs, downgrade_ops):
- return _to_updown_op(autogen_context, diffs, downgrade_ops, "downgrade")
-
-
-def _to_updown_op(autogen_context, diffs, op_container, type_):
- if not diffs:
- return
-
- if type_ == 'downgrade':
- diffs = reversed(diffs)
-
- dest = [op_container.ops]
-
- for (schema, tablename), subdiffs in _group_diffs_by_table(diffs):
- subdiffs = list(subdiffs)
- if tablename is not None:
- table_ops = []
- op = ops.ModifyTableOps(tablename, table_ops, schema=schema)
- dest[-1].append(op)
- dest.append(table_ops)
- for diff in subdiffs:
- _produce_command(autogen_context, diff, dest[-1], type_)
- if tablename is not None:
- dest.pop(-1)
-
-
-def _produce_command(autogen_context, diff, op_list, updown):
- if isinstance(diff, tuple):
- _produce_adddrop_command(updown, diff, op_list, autogen_context)
- else:
- _produce_modify_command(updown, diff, op_list, autogen_context)
-
-
-def _produce_adddrop_command(updown, diff, op_list, autogen_context):
- cmd_type = diff[0]
- adddrop, cmd_type = cmd_type.split("_")
-
- cmd_args = diff[1:]
-
- _commands = {
- "table": (ops.DropTableOp.from_table, ops.CreateTableOp.from_table),
- "column": (
- ops.DropColumnOp.from_column_and_tablename,
- ops.AddColumnOp.from_column_and_tablename),
- "index": (ops.DropIndexOp.from_index, ops.CreateIndexOp.from_index),
- "constraint": (
- ops.DropConstraintOp.from_constraint,
- ops.AddConstraintOp.from_constraint),
- "fk": (
- ops.DropConstraintOp.from_constraint,
- ops.CreateForeignKeyOp.from_constraint)
- }
-
- cmd_callables = _commands[cmd_type]
-
- if (
- updown == "upgrade" and adddrop == "add"
- ) or (
- updown == "downgrade" and adddrop == "remove"
- ):
- op_list.append(cmd_callables[1](*cmd_args))
- else:
- op_list.append(cmd_callables[0](*cmd_args))
-
-
-def _produce_modify_command(updown, diffs, op_list, autogen_context):
- sname, tname, cname = diffs[0][1:4]
- kw = {}
-
- _arg_struct = {
- "modify_type": ("existing_type", "modify_type"),
- "modify_nullable": ("existing_nullable", "modify_nullable"),
- "modify_default": ("existing_server_default", "modify_server_default"),
- }
- for diff in diffs:
- diff_kw = diff[4]
- for arg in ("existing_type",
- "existing_nullable",
- "existing_server_default"):
- if arg in diff_kw:
- kw.setdefault(arg, diff_kw[arg])
- old_kw, new_kw = _arg_struct[diff[0]]
- if updown == "upgrade":
- kw[new_kw] = diff[-1]
- kw[old_kw] = diff[-2]
- else:
- kw[new_kw] = diff[-2]
- kw[old_kw] = diff[-1]
-
- if "modify_nullable" in kw:
- kw.pop("existing_nullable", None)
- if "modify_server_default" in kw:
- kw.pop("existing_server_default", None)
-
- op_list.append(
- ops.AlterColumnOp(
- tname, cname, schema=sname,
- **kw
- )
- )
-
-
-def _group_diffs_by_table(diffs):
- _adddrop = {
- "table": lambda diff: (None, None),
- "column": lambda diff: (diff[0], diff[1]),
- "index": lambda diff: (diff[0].table.schema, diff[0].table.name),
- "constraint": lambda diff: (diff[0].table.schema, diff[0].table.name),
- "fk": lambda diff: (diff[0].parent.schema, diff[0].parent.name)
- }
-
- def _derive_table(diff):
- if isinstance(diff, tuple):
- cmd_type = diff[0]
- adddrop, cmd_type = cmd_type.split("_")
- return _adddrop[cmd_type](diff[1:])
- else:
- sname, tname = diff[0][1:3]
- return sname, tname
-
- return itertools.groupby(diffs, _derive_table)
-
diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py
deleted file mode 100644
index c686156..0000000
--- a/alembic/autogenerate/generate.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from .. import util
-from . import api
-from . import compose
-from . import compare
-from . import render
-from ..operations import ops
-
-
-class RevisionContext(object):
- def __init__(self, config, script_directory, command_args):
- self.config = config
- self.script_directory = script_directory
- self.command_args = command_args
- self.template_args = {
- 'config': config # Let templates use config for
- # e.g. multiple databases
- }
- self.generated_revisions = [
- self._default_revision()
- ]
-
- def _to_script(self, migration_script):
- template_args = {}
- for k, v in self.template_args.items():
- template_args.setdefault(k, v)
-
- if migration_script._autogen_context is not None:
- render._render_migration_script(
- migration_script._autogen_context, migration_script,
- template_args
- )
-
- return self.script_directory.generate_revision(
- migration_script.rev_id,
- migration_script.message,
- refresh=True,
- head=migration_script.head,
- splice=migration_script.splice,
- branch_labels=migration_script.branch_label,
- version_path=migration_script.version_path,
- **template_args)
-
- def run_autogenerate(self, rev, context):
- if self.command_args['sql']:
- raise util.CommandError(
- "Using --sql with --autogenerate does not make any sense")
- if set(self.script_directory.get_revisions(rev)) != \
- set(self.script_directory.get_revisions("heads")):
- raise util.CommandError("Target database is not up to date.")
-
- autogen_context = api._autogen_context(context)
-
- diffs = []
- compare._produce_net_changes(autogen_context, diffs)
-
- migration_script = self.generated_revisions[0]
-
- compose._to_migration_script(autogen_context, migration_script, diffs)
-
- hook = context.opts.get('process_revision_directives', None)
- if hook:
- hook(context, rev, self.generated_revisions)
-
- for migration_script in self.generated_revisions:
- migration_script._autogen_context = autogen_context
-
- def run_no_autogenerate(self, rev, context):
- hook = context.opts.get('process_revision_directives', None)
- if hook:
- hook(context, rev, self.generated_revisions)
-
- for migration_script in self.generated_revisions:
- migration_script._autogen_context = None
-
- def _default_revision(self):
- op = ops.MigrationScript(
- rev_id=self.command_args['rev_id'] or util.rev_id(),
- message=self.command_args['message'],
- imports=set(),
- upgrade_ops=ops.UpgradeOps([]),
- downgrade_ops=ops.DowngradeOps([]),
- head=self.command_args['head'],
- splice=self.command_args['splice'],
- branch_label=self.command_args['branch_label'],
- version_path=self.command_args['version_path']
- )
- op._autogen_context = None
- return op
-
- def generate_scripts(self):
- for generated_revision in self.generated_revisions:
- yield self._to_script(generated_revision)
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index c3f3df1..6f5f96c 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -30,8 +30,8 @@ def _indent(text):
def _render_migration_script(autogen_context, migration_script, template_args):
- opts = autogen_context['opts']
- imports = autogen_context['imports']
+ opts = autogen_context.opts
+ imports = autogen_context._imports
template_args[opts['upgrade_token']] = _indent(_render_cmd_body(
migration_script.upgrade_ops, autogen_context))
template_args[opts['downgrade_token']] = _indent(_render_cmd_body(
@@ -78,23 +78,26 @@ def render_op_text(autogen_context, op):
@renderers.dispatch_for(ops.ModifyTableOps)
def _render_modify_table(autogen_context, op):
- opts = autogen_context['opts']
+ opts = autogen_context.opts
render_as_batch = opts.get('render_as_batch', False)
if op.ops:
lines = []
if render_as_batch:
- lines.append(
- "with op.batch_alter_table(%r, schema=%r) as batch_op:"
- % (op.table_name, op.schema)
- )
- autogen_context['batch_prefix'] = 'batch_op.'
- for t_op in op.ops:
- t_lines = render_op(autogen_context, t_op)
- lines.extend(t_lines)
- if render_as_batch:
- del autogen_context['batch_prefix']
- lines.append("")
+ with autogen_context._within_batch():
+ lines.append(
+ "with op.batch_alter_table(%r, schema=%r) as batch_op:"
+ % (op.table_name, op.schema)
+ )
+ for t_op in op.ops:
+ t_lines = render_op(autogen_context, t_op)
+ lines.extend(t_lines)
+ lines.append("")
+ else:
+ for t_op in op.ops:
+ t_lines = render_op(autogen_context, t_op)
+ lines.extend(t_lines)
+
return lines
else:
return [
@@ -149,7 +152,7 @@ def _drop_table(autogen_context, op):
def _add_index(autogen_context, op):
index = op.to_index()
- has_batch = 'batch_prefix' in autogen_context
+ has_batch = autogen_context._has_batch
if has_batch:
tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\
@@ -180,7 +183,7 @@ def _add_index(autogen_context, op):
@renderers.dispatch_for(ops.DropIndexOp)
def _drop_index(autogen_context, op):
- has_batch = 'batch_prefix' in autogen_context
+ has_batch = autogen_context._has_batch
if has_batch:
tmpl = "%(prefix)sdrop_index(%(name)r)"
@@ -243,7 +246,7 @@ def _add_check_constraint(constraint, autogen_context):
@renderers.dispatch_for(ops.DropConstraintOp)
def _drop_constraint(autogen_context, op):
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)sdrop_constraint"\
"(%(name)r, type_=%(type)r)"
else:
@@ -266,7 +269,7 @@ def _drop_constraint(autogen_context, op):
def _add_column(autogen_context, op):
schema, tname, column = op.schema, op.table_name, op.column
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)sadd_column(%(column)s)"
else:
template = "%(prefix)sadd_column(%(tname)r, %(column)s"
@@ -287,7 +290,7 @@ def _drop_column(autogen_context, op):
schema, tname, column_name = op.schema, op.table_name, op.column_name
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)sdrop_column(%(cname)r)"
else:
template = "%(prefix)sdrop_column(%(tname)r, %(cname)r"
@@ -319,7 +322,7 @@ def _alter_column(autogen_context, op):
indent = " " * 11
- if 'batch_prefix' in autogen_context:
+ if autogen_context._has_batch:
template = "%(prefix)salter_column(%(cname)r"
else:
template = "%(prefix)salter_column(%(tname)r, %(cname)r"
@@ -343,16 +346,16 @@ def _alter_column(autogen_context, op):
if nullable is not None:
text += ",\n%snullable=%r" % (
indent, nullable,)
- if existing_nullable is not None:
+ if nullable is None and existing_nullable is not None:
text += ",\n%sexisting_nullable=%r" % (
indent, existing_nullable)
- if existing_server_default:
+ if server_default is False and existing_server_default:
rendered = _render_server_default(
existing_server_default,
autogen_context)
text += ",\n%sexisting_server_default=%s" % (
indent, rendered)
- if schema and "batch_prefix" not in autogen_context:
+ if schema and not autogen_context._has_batch:
text += ",\n%sschema=%r" % (indent, schema)
text += ")"
return text
@@ -409,7 +412,7 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
return template % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"sql": compat.text_type(
- value.compile(dialect=autogen_context['dialect'],
+ value.compile(dialect=autogen_context.dialect,
**compile_kw)
)
}
@@ -432,7 +435,7 @@ def _get_index_rendered_expressions(idx, autogen_context):
def _uq_constraint(constraint, autogen_context, alter):
opts = []
- has_batch = 'batch_prefix' in autogen_context
+ has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", str(constraint.deferrable)))
@@ -467,7 +470,7 @@ def _uq_constraint(constraint, autogen_context, alter):
def _user_autogenerate_prefix(autogen_context, target):
- prefix = autogen_context['opts']['user_module_prefix']
+ prefix = autogen_context.opts['user_module_prefix']
if prefix is None:
return "%s." % target.__module__
else:
@@ -475,20 +478,19 @@ def _user_autogenerate_prefix(autogen_context, target):
def _sqlalchemy_autogenerate_prefix(autogen_context):
- return autogen_context['opts']['sqlalchemy_module_prefix'] or ''
+ return autogen_context.opts['sqlalchemy_module_prefix'] or ''
def _alembic_autogenerate_prefix(autogen_context):
- if 'batch_prefix' in autogen_context:
- return autogen_context['batch_prefix']
+ if autogen_context._has_batch:
+ return 'batch_op.'
else:
- return autogen_context['opts']['alembic_module_prefix'] or ''
+ return autogen_context.opts['alembic_module_prefix'] or ''
def _user_defined_render(type_, object_, autogen_context):
- if 'opts' in autogen_context and \
- 'render_item' in autogen_context['opts']:
- render = autogen_context['opts']['render_item']
+ if 'render_item' in autogen_context.opts:
+ render = autogen_context.opts['render_item']
if render:
rendered = render(type_, object_, autogen_context)
if rendered is not False:
@@ -547,7 +549,7 @@ def _repr_type(type_, autogen_context):
return rendered
mod = type(type_).__module__
- imports = autogen_context.get('imports', None)
+ imports = autogen_context._imports
if mod.startswith("sqlalchemy.dialects"):
dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
if imports is not None:
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index 08a0551..71e8515 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -3,6 +3,7 @@ from ..util import sqla_compat
from . import schemaobj
from sqlalchemy.types import NULLTYPE
from .base import Operations, BatchOperations
+import re
class MigrateOperation(object):
@@ -34,6 +35,10 @@ class MigrateOperation(object):
class AddConstraintOp(MigrateOperation):
"""Represent an add constraint operation."""
+ @property
+ def constraint_type(self):
+ raise NotImplementedError()
+
@classmethod
def from_constraint(cls, constraint):
funcs = {
@@ -45,17 +50,40 @@ class AddConstraintOp(MigrateOperation):
}
return funcs[constraint.__visit_name__](constraint)
+ def reverse(self):
+ return DropConstraintOp.from_constraint(self.to_constraint())
+
+ def to_diff_tuple(self):
+ return ("add_constraint", self.to_constraint())
+
@Operations.register_operation("drop_constraint")
@BatchOperations.register_operation("drop_constraint", "batch_drop_constraint")
class DropConstraintOp(MigrateOperation):
"""Represent a drop constraint operation."""
- def __init__(self, constraint_name, table_name, type_=None, schema=None):
+ def __init__(
+ self,
+ constraint_name, table_name, type_=None, schema=None,
+ _orig_constraint=None):
self.constraint_name = constraint_name
self.table_name = table_name
self.constraint_type = type_
self.schema = schema
+ self._orig_constraint = _orig_constraint
+
+ def reverse(self):
+ if self._orig_constraint is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original constraint is not present")
+ return AddConstraintOp.from_constraint(self._orig_constraint)
+
+ def to_diff_tuple(self):
+ if self.constraint_type == "foreignkey":
+ return ("remove_fk", self.to_constraint())
+ else:
+ return ("remove_constraint", self.to_constraint())
@classmethod
def from_constraint(cls, constraint):
@@ -72,9 +100,18 @@ class DropConstraintOp(MigrateOperation):
constraint.name,
constraint_table.name,
schema=constraint_table.schema,
- type_=types[constraint.__visit_name__]
+ type_=types[constraint.__visit_name__],
+ _orig_constraint=constraint
)
+ def to_constraint(self):
+ if self._orig_constraint is not None:
+ return self._orig_constraint
+ else:
+ raise ValueError(
+ "constraint cannot be produced; "
+ "original constraint is not present")
+
@classmethod
@util._with_legacy_names([("type", "type_")])
def drop_constraint(
@@ -124,8 +161,11 @@ class DropConstraintOp(MigrateOperation):
class CreatePrimaryKeyOp(AddConstraintOp):
"""Represent a create primary key operation."""
+ constraint_type = "primarykey"
+
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw):
+ self, constraint_name, table_name, columns,
+ schema=None, **kw):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -225,8 +265,11 @@ class CreatePrimaryKeyOp(AddConstraintOp):
class CreateUniqueConstraintOp(AddConstraintOp):
"""Represent a create unique constraint operation."""
+ constraint_type = "unique"
+
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw):
+ self, constraint_name, table_name,
+ columns, schema=None, **kw):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -342,6 +385,8 @@ class CreateUniqueConstraintOp(AddConstraintOp):
class CreateForeignKeyOp(AddConstraintOp):
"""Represent a create foreign key constraint operation."""
+ constraint_type = "foreignkey"
+
def __init__(
self, constraint_name, source_table, referent_table, local_cols,
remote_cols, **kw):
@@ -352,6 +397,9 @@ class CreateForeignKeyOp(AddConstraintOp):
self.remote_cols = remote_cols
self.kw = kw
+ def to_diff_tuple(self):
+ return ("add_fk", self.to_constraint())
+
@classmethod
def from_constraint(cls, constraint):
kw = {}
@@ -507,6 +555,8 @@ class CreateForeignKeyOp(AddConstraintOp):
class CreateCheckConstraintOp(AddConstraintOp):
"""Represent a create check constraint operation."""
+ constraint_type = "check"
+
def __init__(
self, constraint_name, table_name, condition, schema=None, **kw):
self.constraint_name = constraint_name
@@ -523,7 +573,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
constraint.name,
constraint_table.name,
constraint.condition,
- schema=constraint_table.schema
+ schema=constraint_table.schema,
)
def to_constraint(self, migration_context=None):
@@ -624,6 +674,12 @@ class CreateIndexOp(MigrateOperation):
self.kw = kw
self._orig_index = _orig_index
+ def reverse(self):
+ return DropIndexOp.from_index(self.to_index())
+
+ def to_diff_tuple(self):
+ return ("add_index", self.to_index())
+
@classmethod
def from_index(cls, index):
return cls(
@@ -729,10 +785,22 @@ class CreateIndexOp(MigrateOperation):
class DropIndexOp(MigrateOperation):
"""Represent a drop index operation."""
- def __init__(self, index_name, table_name=None, schema=None):
+ def __init__(
+ self, index_name, table_name=None, schema=None, _orig_index=None):
self.index_name = index_name
self.table_name = table_name
self.schema = schema
+ self._orig_index = _orig_index
+
+ def to_diff_tuple(self):
+ return ("remove_index", self.to_index())
+
+ def reverse(self):
+ if self._orig_index is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original index is not present")
+ return CreateIndexOp.from_index(self._orig_index)
@classmethod
def from_index(cls, index):
@@ -740,6 +808,7 @@ class DropIndexOp(MigrateOperation):
index.name,
index.table.name,
schema=index.table.schema,
+ _orig_index=index
)
def to_index(self, migration_context=None):
@@ -807,6 +876,12 @@ class CreateTableOp(MigrateOperation):
self.kw = kw
self._orig_table = _orig_table
+ def reverse(self):
+ return DropTableOp.from_table(self.to_table())
+
+ def to_diff_tuple(self):
+ return ("add_table", self.to_table())
+
@classmethod
def from_table(cls, table):
return cls(
@@ -921,16 +996,30 @@ class CreateTableOp(MigrateOperation):
class DropTableOp(MigrateOperation):
"""Represent a drop table operation."""
- def __init__(self, table_name, schema=None, table_kw=None):
+ def __init__(
+ self, table_name, schema=None, table_kw=None, _orig_table=None):
self.table_name = table_name
self.schema = schema
self.table_kw = table_kw or {}
+ self._orig_table = _orig_table
+
+ def to_diff_tuple(self):
+ return ("remove_table", self.to_table())
+
+ def reverse(self):
+ if self._orig_table is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original table is not present")
+ return CreateTableOp.from_table(self._orig_table)
@classmethod
def from_table(cls, table):
- return cls(table.name, schema=table.schema)
+ return cls(table.name, schema=table.schema, _orig_table=table)
- def to_table(self, migration_context):
+ def to_table(self, migration_context=None):
+ if self._orig_table is not None:
+ return self._orig_table
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
self.table_name,
@@ -1029,6 +1118,87 @@ class AlterColumnOp(AlterTableOp):
self.modify_type = modify_type
self.kw = kw
+ def to_diff_tuple(self):
+ col_diff = []
+ schema, tname, cname = self.schema, self.table_name, self.column_name
+
+ if self.modify_type is not None:
+ col_diff.append(
+ ("modify_type", schema, tname, cname,
+ {
+ "existing_nullable": self.existing_nullable,
+ "existing_server_default": self.existing_server_default,
+ },
+ self.existing_type,
+ self.modify_type)
+ )
+
+ if self.modify_nullable is not None:
+ col_diff.append(
+ ("modify_nullable", schema, tname, cname,
+ {
+ "existing_type": self.existing_type,
+ "existing_server_default": self.existing_server_default
+ },
+ self.existing_nullable,
+ self.modify_nullable)
+ )
+
+ if self.modify_server_default is not False:
+ col_diff.append(
+ ("modify_default", schema, tname, cname,
+ {
+ "existing_nullable": self.existing_nullable,
+ "existing_type": self.existing_type
+ },
+ self.existing_server_default,
+ self.modify_server_default)
+ )
+
+ return col_diff
+
+ def has_changes(self):
+ hc1 = self.modify_nullable is not None or \
+ self.modify_server_default is not False or \
+ self.modify_type is not None
+ if hc1:
+ return True
+ for kw in self.kw:
+ if kw.startswith('modify_'):
+ return True
+ else:
+ return False
+
+ def reverse(self):
+
+ kw = self.kw.copy()
+ kw['existing_type'] = self.existing_type
+ kw['existing_nullable'] = self.existing_nullable
+ kw['existing_server_default'] = self.existing_server_default
+ if self.modify_type is not None:
+ kw['modify_type'] = self.modify_type
+ if self.modify_nullable is not None:
+ kw['modify_nullable'] = self.modify_nullable
+ if self.modify_server_default is not False:
+ kw['modify_server_default'] = self.modify_server_default
+
+ # TODO: make this a little simpler
+ all_keys = set(m.group(1) for m in [
+ re.match(r'^(?:existing_|modify_)(.+)$', k)
+ for k in kw
+ ] if m)
+
+ for k in all_keys:
+ if 'modify_%s' % k in kw:
+ swap = kw['existing_%s' % k]
+ kw['existing_%s' % k] = kw['modify_%s' % k]
+ kw['modify_%s' % k] = swap
+
+ return self.__class__(
+ self.table_name, self.column_name, schema=self.schema,
+ **kw
+ )
+
@classmethod
@util._with_legacy_names([('name', 'new_column_name')])
def alter_column(
@@ -1177,6 +1347,13 @@ class AddColumnOp(AlterTableOp):
super(AddColumnOp, self).__init__(table_name, schema=schema)
self.column = column
+ def reverse(self):
+ return DropColumnOp.from_column_and_tablename(
+ self.schema, self.table_name, self.column)
+
+ def to_diff_tuple(self):
+ return ("add_column", self.schema, self.table_name, self.column)
+
@classmethod
def from_column(cls, col):
return cls(col.table.name, col, schema=col.table.schema)
@@ -1265,14 +1442,30 @@ class AddColumnOp(AlterTableOp):
class DropColumnOp(AlterTableOp):
"""Represent a drop column operation."""
- def __init__(self, table_name, column_name, schema=None, **kw):
+ def __init__(
+ self, table_name, column_name, schema=None,
+ _orig_column=None, **kw):
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.kw = kw
+ self._orig_column = _orig_column
+
+ def to_diff_tuple(self):
+ return (
+ "remove_column", self.schema, self.table_name, self.to_column())
+
+ def reverse(self):
+ if self._orig_column is None:
+ raise ValueError(
+ "operation is not reversible; "
+ "original column is not present")
+
+ return AddColumnOp.from_column_and_tablename(
+ self.schema, self.table_name, self._orig_column)
@classmethod
def from_column_and_tablename(cls, schema, tname, col):
- return cls(tname, col.name, schema=schema)
+ return cls(tname, col.name, schema=schema, _orig_column=col)
def to_column(self, migration_context=None):
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1522,6 +1715,21 @@ class OpContainer(MigrateOperation):
def __init__(self, ops=()):
self.ops = ops
+ def is_empty(self):
+ return not self.ops
+
+ def as_diffs(self):
+ return list(OpContainer._ops_as_diffs(self))
+
+ @classmethod
+ def _ops_as_diffs(cls, migrations):
+ for op in migrations.ops:
+ if hasattr(op, 'ops'):
+ for sub_op in cls._ops_as_diffs(op):
+ yield sub_op
+ else:
+ yield op.to_diff_tuple()
+
class ModifyTableOps(OpContainer):
"""Contains a sequence of operations that all apply to a single Table."""
@@ -1531,6 +1739,15 @@ class ModifyTableOps(OpContainer):
self.table_name = table_name
self.schema = schema
+ def reverse(self):
+ return ModifyTableOps(
+ self.table_name,
+ ops=list(reversed(
+ [op.reverse() for op in self.ops]
+ )),
+ schema=self.schema
+ )
+
class UpgradeOps(OpContainer):
"""contains a sequence of operations that would apply to the
@@ -1542,6 +1759,15 @@ class UpgradeOps(OpContainer):
"""
+ def reverse_into(self, downgrade_ops):
+ downgrade_ops.ops[:] = list(reversed(
+ [op.reverse() for op in self.ops]
+ ))
+ return downgrade_ops
+
+ def reverse(self):
+ return self.reverse_into(DowngradeOps(ops=[]))
+
class DowngradeOps(OpContainer):
"""contains a sequence of operations that would apply to the
@@ -1553,6 +1779,13 @@ class DowngradeOps(OpContainer):
"""
+ def reverse(self):
+ return UpgradeOps(
+ ops=list(reversed(
+ [op.reverse() for op in self.ops]
+ ))
+ )
+
class MigrationScript(MigrateOperation):
"""represents a migration script.
@@ -1583,4 +1816,3 @@ class MigrationScript(MigrateOperation):
self.version_path = version_path
self.upgrade_ops = upgrade_ops
self.downgrade_ops = downgrade_ops
-
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index 84a3c7f..e811a36 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -118,6 +118,7 @@ class MigrationContext(object):
connection=None,
url=None,
dialect_name=None,
+ dialect=None,
environment_context=None,
opts=None,
):
@@ -152,7 +153,7 @@ class MigrationContext(object):
elif dialect_name:
url = sqla_url.make_url("%s://" % dialect_name)
dialect = url.get_dialect()()
- else:
+ elif not dialect:
raise Exception("Connection, url, or dialect_name is required.")
return MigrationContext(dialect, connection, opts, environment_context)
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index 1fb0942..6c92e3c 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -257,30 +257,57 @@ class immutabledict(dict):
class Dispatcher(object):
- def __init__(self):
+ def __init__(self, uselist=False):
self._registry = {}
+ self.uselist = uselist
def dispatch_for(self, target, qualifier='default'):
def decorate(fn):
- assert isinstance(target, type)
- assert target not in self._registry
- self._registry[(target, qualifier)] = fn
+ if self.uselist:
+ assert target not in self._registry
+ self._registry.setdefault((target, qualifier), []).append(fn)
+ else:
+ assert target not in self._registry
+ self._registry[(target, qualifier)] = fn
return fn
return decorate
def dispatch(self, obj, qualifier='default'):
- for spcls in type(obj).__mro__:
+
+ if isinstance(obj, string_types):
+ targets = [obj]
+ elif isinstance(obj, type):
+ targets = obj.__mro__
+ else:
+ targets = type(obj).__mro__
+
+ for spcls in targets:
if qualifier != 'default' and (spcls, qualifier) in self._registry:
- return self._registry[(spcls, qualifier)]
+ return self._fn_or_list(self._registry[(spcls, qualifier)])
elif (spcls, 'default') in self._registry:
- return self._registry[(spcls, 'default')]
+ return self._fn_or_list(self._registry[(spcls, 'default')])
else:
raise ValueError("no dispatch function for object: %s" % obj)
+ def _fn_or_list(self, fn_or_list):
+ if self.uselist:
+ def go(*arg, **kw):
+ for fn in fn_or_list:
+ fn(*arg, **kw)
+ return go
+ else:
+ return fn_or_list
+
def branch(self):
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
- d._registry.update(self._registry)
+ if self.uselist:
+ d._registry.update(
+ (k, [fn for fn in self._registry[k]])
+ for k in self._registry
+ )
+ else:
+ d._registry.update(self._registry)
return d
diff --git a/docs/build/api/autogenerate.rst b/docs/build/api/autogenerate.rst
index b024ab1..8b026e8 100644
--- a/docs/build/api/autogenerate.rst
+++ b/docs/build/api/autogenerate.rst
@@ -4,7 +4,8 @@
Autogeneration
==============
-The autogenerate system has two areas of API that are public:
+The autogeneration system has a wide degree of public API, including
+the following areas:
1. The ability to do a "diff" of a :class:`~sqlalchemy.schema.MetaData` object against
a database, and receive a data structure back. This structure
@@ -15,9 +16,22 @@ The autogenerate system has two areas of API that are public:
revision scripts, including support for multiple revision scripts
generated in one pass.
+3. The ability to add new operation directives to autogeneration, including
+ custom schema/model comparison functions and revision script rendering.
+
Getting Diffs
==============
+The simplest API autogenerate provides is the "schema comparison" API;
+these are simple functions that will run all registered "comparison" functions
+between a :class:`~sqlalchemy.schema.MetaData` object and a database
+backend to produce a structure showing how they differ. The two
+functions provided are :func:`.compare_metadata`, which is more of the
+"legacy" function that produces diff tuples, and :func:`.produce_migrations`,
+which produces a structure consisting of operation directives detailed in
+:ref:`alembic.operations.toplevel`.
+
+
.. autofunction:: alembic.autogenerate.compare_metadata
.. autofunction:: alembic.autogenerate.produce_migrations
@@ -184,6 +198,8 @@ to whatever is in this list.
.. autofunction:: alembic.autogenerate.render_python_code
+.. _autogen_custom_ops:
+
Autogenerating Custom Operation Directives
==========================================
@@ -192,16 +208,180 @@ subclasses of :class:`.MigrateOperation` in order to add new ``op.``
directives. In the preceding section :ref:`customizing_revision`, we
also learned that these same :class:`.MigrateOperation` structures are at
the base of how the autogenerate system knows what Python code to render.
-How to connect these two systems, so that our own custom operation
-directives can be used? First off, we'd probably be implementing
-a :paramref:`.EnvironmentContext.configure.process_revision_directives`
-plugin as described previously, so that we can add our own directives
-to the autogenerate stream. What if we wanted to add our ``CreateSequenceOp``
-to the autogenerate structure? We basically need to define an autogenerate
-renderer for it, as follows::
+Using this knowledge, we can create additional functions that plug into
+the autogenerate system so that our new operations can be generated
+into migration scripts when ``alembic revision --autogenerate`` is run.
+
+The following sections will detail an example of this using the
+the ``CreateSequenceOp`` and ``DropSequenceOp`` directives
+we created in :ref:`operation_plugins`, which correspond to the
+SQLAlchemy :class:`~sqlalchemy.schema.Sequence` construct.
+
+.. versionadded:: 0.8.0 - custom operations can be added to the
+ autogenerate system to support new kinds of database objects.
+
+Tracking our Object with the Model
+----------------------------------
+
+The basic job of an autogenerate comparison function is to inspect
+a series of objects in the database and compare them against a series
+of objects defined in our model. By "in our model", we mean anything
+defined in Python code that we want to track, however most commonly
+we're talking about a series of :class:`~sqlalchemy.schema.Table`
+objects present in a :class:`~sqlalchemy.schema.MetaData` collection.
+
+Let's propose a simple way of seeing what :class:`~sqlalchemy.schema.Sequence`
+objects we want to ensure exist in the database when autogenerate
+runs. While these objects do have some integrations with
+:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.MetaData`
+already, let's assume they don't, as the example here intends to illustrate
+how we would do this for most any kind of custom construct. We
+associate the object with the :attr:`~sqlalchemy.schema.MetaData.info`
+collection of :class:`~sqlalchemy.schema.MetaData`, which is a dictionary
+we can use for anything, which we also know will be passed to the autogenerate
+process::
+
+ from sqlalchemy.schema import Sequence
+
+ def add_sequence_to_model(sequence, metadata):
+ metadata.info.setdefault("sequences", set()).add(
+ (sequence.schema, sequence.name)
+ )
+
+ my_seq = Sequence("my_sequence")
+ add_sequence_to_model(my_seq, model_metadata)
+
+The :attr:`~sqlalchemy.schema.MetaData.info`
+dictionary is a good place to put things that we want our autogeneration
+routines to be able to locate, which can include any object such as
+custom DDL objects representing views, triggers, special constraints,
+or anything else we want to support.
+
- # note: this is a continuation of the example from the
- # "Operation Plugins" section
+Registering a Comparison Function
+---------------------------------
+
+We now need to register a comparison hook, which will be used
+to compare the database to our model and produce ``CreateSequenceOp``
+and ``DropSequenceOp`` directives to be included in our migration
+script. Note that we are assuming a
+Postgresql backend::
+
+ from alembic.autogenerate import comparators
+
+ @comparators.dispatch_for("schema")
+ def compare_sequences(autogen_context, upgrade_ops, schemas):
+ all_conn_sequences = set()
+
+ for sch in schemas:
+
+ all_conn_sequences.update([
+ (sch, row[0]) for row in
+ autogen_context.connection.execute(
+ "SELECT relname FROM pg_class c join "
+ "pg_namespace n on n.oid=c.relnamespace where "
+ "relkind='S' and n.nspname=%(nspname)s",
+
+ # note that we consider a schema of 'None' in our
+ # model to be the "default" name in the PG database;
+ # this usually is the name 'public'
+ nspname=autogen_context.dialect.default_schema_name
+ if sch is None else sch
+ )
+ ])
+
+ # get the collection of Sequence objects we're storing with
+ # our MetaData
+ metadata_sequences = autogen_context.metadata.info.setdefault(
+ "sequences", set())
+
+ # for new names, produce CreateSequenceOp directives
+ for sch, name in metadata_sequences.difference(all_conn_sequences):
+ upgrade_ops.ops.append(
+ CreateSequenceOp(name, schema=sch)
+ )
+
+ # for names that are going away, produce DropSequenceOp
+ # directives
+ for sch, name in all_conn_sequences.difference(metadata_sequences):
+ upgrade_ops.ops.append(
+ DropSequenceOp(name, schema=sch)
+ )
+
+Above, we've built a new function ``compare_sequences()`` and registered
+it as a "schema" level comparison function with autogenerate. The
+job that it performs is that it compares the list of sequence names
+present in each database schema with that of a list of sequence names
+that we are maintaining in our :class:`~sqlalchemy.schema.MetaData` object.
+
+When autogenerate completes, it will have a series of
+``CreateSequenceOp`` and ``DropSequenceOp`` directives in the list of
+"upgrade" operations; the list of "downgrade" operations is generated
+directly from these using the
+``CreateSequenceOp.reverse()`` and ``DropSequenceOp.reverse()`` methods
+that we've implemented on these objects.
+
+The registration of our function at the scope of "schema" means our
+autogenerate comparison function is called outside of the context
+of any specific table or column. The three available scopes
+are "schema", "table", and "column", summarized as follows:
+
+* **Schema level** - these hooks are passed a :class:`.AutogenContext`,
+ an :class:`.UpgradeOps` collection, and a collection of string schema
+ names to be operated upon. If the
+ :class:`.UpgradeOps` collection contains changes after all
+ hooks are run, it is included in the migration script:
+
+ ::
+
+ @comparators.dispatch_for("schema")
+ def compare_schema_level(autogen_context, upgrade_ops, schemas):
+ pass
+
+* **Table level** - these hooks are passed a :class:`.AutogenContext`,
+ a :class:`.ModifyTableOps` collection, a schema name, table name,
+ a :class:`~sqlalchemy.schema.Table` reflected from the database if any
+ or ``None``, and a :class:`~sqlalchemy.schema.Table` present in the
+ local :class:`~sqlalchemy.schema.MetaData`. If the
+ :class:`.ModifyTableOps` collection contains changes after all
+ hooks are run, it is included in the migration script:
+
+ ::
+
+ @comparators.dispatch_for("table")
+ def compare_table_level(autogen_context, modify_ops,
+ schemaname, tablename, conn_table, metadata_table):
+ pass
+
+* **Column level** - these hooks are passed a :class:`.AutogenContext`,
+ an :class:`.AlterColumnOp` object, a schema name, table name,
+ column name, a :class:`~sqlalchemy.schema.Column` reflected from the
+ database and a :class:`~sqlalchemy.schema.Column` present in the
+ local table. If the :class:`.AlterColumnOp` contains changes after
+ all hooks are run, it is included in the migration script;
+ a "change" is considered to be present if any of the ``modify_`` attributes
+ are set to a non-default value, or there are any keys
+ in the ``.kw`` collection with the prefix ``"modify_"``:
+
+ ::
+
+ @comparators.dispatch_for("column")
+ def compare_column_level(autogen_context, alter_column_op,
+ schemaname, tname, cname, conn_col, metadata_col):
+ pass
+
+The :class:`.AutogenContext` passed to these hooks is documented below.
+
+.. autoclass:: alembic.autogenerate.api.AutogenContext
+ :members:
+
+Creating a Render Function
+--------------------------
+
+The second autogenerate integration hook is to provide a "render" function;
+since the autogenerate
+system renders Python code, we need to build a function that renders
+the correct "op" instructions for our directive::
from alembic.autogenerate import renderers
@@ -209,27 +389,52 @@ renderer for it, as follows::
def render_create_sequence(autogen_context, op):
return "op.create_sequence(%r, **%r)" % (
op.sequence_name,
- op.kw
+ {"schema": op.schema}
)
-With our render function established, we can our ``CreateSequenceOp``
-generated in an autogenerate context using the :func:`.render_python_code`
-debugging function in conjunction with an :class:`.UpgradeOps` structure::
- from alembic.operations import ops
- from alembic.autogenerate import render_python_code
+ @renderers.dispatch_for(DropSequenceOp)
+ def render_drop_sequence(autogen_context, op):
+ return "op.drop_sequence(%r, **%r)" % (
+ op.sequence_name,
+ {"schema": op.schema}
+ )
- upgrade_ops = ops.UpgradeOps(
- ops=[
- CreateSequenceOp("my_seq")
- ]
- )
+The above functions will render Python code corresponding to the
+presence of ``CreateSequenceOp`` and ``DropSequenceOp`` instructions
+in the list that our comparison function generates.
- print(render_python_code(upgrade_ops))
+Running It
+----------
-Which produces::
+All the above code can be organized however the developer sees fit;
+the only thing that needs to make it work is that when the
+Alembic environment ``env.py`` is invoked, it either imports modules
+which contain all the above routines, or they are locally present,
+or some combination thereof.
- ### commands auto generated by Alembic - please adjust! ###
- op.create_sequence('my_seq', **{})
+If we then have code in our model (which of course also needs to be invoked
+when ``env.py`` runs!) like this::
+
+ from sqlalchemy.schema import Sequence
+
+ my_seq_1 = Sequence("my_sequence_1")
+ add_sequence_to_model(my_seq_1, target_metadata)
+
+When we first run ``alembic revision --autogenerate``, we'll see this
+in our migration file::
+
+ def upgrade():
+ ### commands auto generated by Alembic - please adjust! ###
+ op.create_sequence('my_sequence_1', **{'schema': None})
### end Alembic commands ###
+
+ def downgrade():
+ ### commands auto generated by Alembic - please adjust! ###
+ op.drop_sequence('my_sequence_1', **{'schema': None})
+ ### end Alembic commands ###
+
+These are our custom directives that will invoke when ``alembic upgrade``
+or ``alembic downgrade`` is run.
+
diff --git a/docs/build/api/operations.rst b/docs/build/api/operations.rst
index d9ff238..2eb8358 100644
--- a/docs/build/api/operations.rst
+++ b/docs/build/api/operations.rst
@@ -1,7 +1,7 @@
.. _alembic.operations.toplevel:
=====================
-The Operations Object
+Operation Directives
=====================
Within migration scripts, actual database migration operations are handled
@@ -48,9 +48,9 @@ migration scripts::
class CreateSequenceOp(MigrateOperation):
"""Create a SEQUENCE."""
- def __init__(self, sequence_name, **kw):
+ def __init__(self, sequence_name, schema=None):
self.sequence_name = sequence_name
- self.kw = kw
+ self.schema = schema
@classmethod
def create_sequence(cls, operations, sequence_name, **kw):
@@ -59,20 +59,58 @@ migration scripts::
op = CreateSequenceOp(sequence_name, **kw)
return operations.invoke(op)
-Above, the ``CreateSequenceOp`` class represents a new operation that will
-be available as ``op.create_sequence()``. The reason the operation
-is represented as a stateful class is so that an operation and a specific
+ def reverse(self):
+ # only needed to support autogenerate
+ return DropSequenceOp(self.sequence_name, schema=self.schema)
+
+ @Operations.register_operation("drop_sequence")
+ class DropSequenceOp(MigrateOperation):
+ """Drop a SEQUENCE."""
+
+ def __init__(self, sequence_name, schema=None):
+ self.sequence_name = sequence_name
+ self.schema = schema
+
+ @classmethod
+ def drop_sequence(cls, operations, sequence_name, **kw):
+ """Issue a "DROP SEQUENCE" instruction."""
+
+ op = DropSequenceOp(sequence_name, **kw)
+ return operations.invoke(op)
+
+ def reverse(self):
+ # only needed to support autogenerate
+ return CreateSequenceOp(self.sequence_name, schema=self.schema)
+
+Above, the ``CreateSequenceOp`` and ``DropSequenceOp`` classes represent
+new operations that will
+be available as ``op.create_sequence()`` and ``op.drop_sequence()``.
+The reason the operations
+are represented as stateful classes is so that an operation and a specific
set of arguments can be represented generically; the state can then correspond
to different kinds of operations, such as invoking the instruction against
a database, or autogenerating Python code for the operation into a
script.
-In order to establish the migrate-script behavior of the new operation,
+In order to establish the migrate-script behavior of the new operations,
we use the :meth:`.Operations.implementation_for` decorator::
@Operations.implementation_for(CreateSequenceOp)
def create_sequence(operations, operation):
- operations.execute("CREATE SEQUENCE %s" % operation.sequence_name)
+ if operation.schema is not None:
+ name = "%s.%s" % (operation.schema, operation.sequence_name)
+ else:
+ name = operation.sequence_name
+ operations.execute("CREATE SEQUENCE %s" % name)
+
+
+ @Operations.implementation_for(DropSequenceOp)
+ def drop_sequence(operations, operation):
+ if operation.schema is not None:
+ name = "%s.%s" % (operation.schema, operation.sequence_name)
+ else:
+ name = operation.sequence_name
+ operations.execute("DROP SEQUENCE %s" % name)
Above, we use the simplest possible technique of invoking our DDL, which
is just to call :meth:`.Operations.execute` with literal SQL. If this is
@@ -80,16 +118,24 @@ all a custom operation needs, then this is fine. However, options for
more comprehensive support include building out a custom SQL construct,
as documented at :ref:`sqlalchemy.ext.compiler_toplevel`.
-With the above two steps, a migration script can now use a new method
-``op.create_sequence()`` that will proxy to our object as a classmethod::
+With the above two steps, a migration script can now use new methods
+``op.create_sequence()`` and ``op.drop_sequence()`` that will proxy to
+our object as a classmethod::
def upgrade():
op.create_sequence("my_sequence")
+ def downgrade():
+ op.drop_sequence("my_sequence")
+
The registration of new operations only needs to occur in time for the
``env.py`` script to invoke :meth:`.MigrationContext.run_migrations`;
within the module level of the ``env.py`` script is sufficient.
+.. seealso::
+
+ :ref:`autogen_custom_ops` - how to add autogenerate support to
+ custom operations.
.. versionadded:: 0.8 - the migration operations available via the
:class:`.Operations` class as well as the ``alembic.op`` namespace
diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst
index 691402d..424bf8f 100644
--- a/docs/build/changelog.rst
+++ b/docs/build/changelog.rst
@@ -27,7 +27,7 @@ Changelog
.. change::
:tags: feature, autogenerate
- :tickets: 301
+ :tickets: 301, 306
The internal system for autogenerate been reworked to build upon
the extensible system of operation objects present in
@@ -38,9 +38,12 @@ Changelog
:paramref:`.EnvironmentContext.configure.process_revision_directives`
allows end-user code to fully customize what autogenerate will do,
including not just full manipulation of the Python steps to take
- but also what file or files will be written and where. It is also
- possible to write a system that reads an autogenerate stream and
- invokes it directly against a database without writing any files.
+ but also what file or files will be written and where. Additionally,
+ autogenerate is now extensible as far as database objects compared
+ and rendered into scripts; any new operation directive can also be
+ registered into a series of hooks that allow custom database/model
+ comparison functions to run as well as to render new operation
+ directives into autogenerate scripts.
.. seealso::
diff --git a/tests/_autogen_fixtures.py b/tests/_autogen_fixtures.py
index 7ef6cbf..e668885 100644
--- a/tests/_autogen_fixtures.py
+++ b/tests/_autogen_fixtures.py
@@ -2,12 +2,14 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text
from sqlalchemy.engine.reflection import Inspector
+from alembic.operations import ops
from alembic import autogenerate
from alembic.migration import MigrationContext
from alembic.testing import config
from alembic.testing.env import staging_env, clear_staging_env
from alembic.testing import eq_
from alembic.ddl.base import _fk_spec
+from alembic.autogenerate import api
names_in_this_test = set()
@@ -25,9 +27,7 @@ def _default_include_object(obj, name, type_, reflected, compare_to):
else:
return True
-_default_object_filters = [
- _default_include_object
-]
+_default_object_filters = _default_include_object
class ModelOne(object):
@@ -177,6 +177,7 @@ class AutogenTest(_ComparesFKs):
'downgrade_token': "downgrades",
'alembic_module_prefix': 'op.',
'sqlalchemy_module_prefix': 'sa.',
+ 'include_object': _default_object_filters
}
if self.configure_opts:
ctx_opts.update(self.configure_opts)
@@ -185,17 +186,18 @@ class AutogenTest(_ComparesFKs):
opts=ctx_opts
)
- connection = context.bind
- self.autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context
- }
+ self.autogen_context = api.AutogenContext(context, self.m2)
def tearDown(self):
self.conn.close()
+ def _update_context(self, object_filters=None, include_schemas=None):
+ if include_schemas is not None:
+ self.autogen_context.opts['include_schemas'] = include_schemas
+ if object_filters is not None:
+ self.autogen_context._object_filters = [object_filters]
+ return self.autogen_context
+
class AutogenFixtureTest(_ComparesFKs):
@@ -214,6 +216,8 @@ class AutogenFixtureTest(_ComparesFKs):
'downgrade_token': "downgrades",
'alembic_module_prefix': 'op.',
'sqlalchemy_module_prefix': 'sa.',
+ 'include_object': object_filters,
+ 'include_schemas': include_schemas
}
if opts:
ctx_opts.update(opts)
@@ -222,21 +226,12 @@ class AutogenFixtureTest(_ComparesFKs):
opts=ctx_opts
)
- connection = context.bind
- autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context,
- 'metadata': model_metadata,
- 'object_filters': object_filters,
- 'include_schemas': include_schemas
- }
- diffs = []
+ autogen_context = api.AutogenContext(context, model_metadata)
+ uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(
- autogen_context, diffs
+ autogen_context, uo
)
- return diffs
+ return uo.as_diffs()
reports_unnamed_constraints = False
diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py
index b1717ab..6d1f55b 100644
--- a/tests/test_autogen_composition.py
+++ b/tests/test_autogen_composition.py
@@ -23,7 +23,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._render_migration_diffs(context, template_args, set())
+ autogenerate._render_migration_diffs(context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -50,10 +50,8 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._render_migration_diffs(
- context, template_args, set(),
+ autogenerate._render_migration_diffs(context, template_args)
- )
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
pass
@@ -67,8 +65,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
"""test a full render including indentation"""
template_args = {}
- autogenerate._render_migration_diffs(
- self.context, template_args, set())
+ autogenerate._render_migration_diffs(self.context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
@@ -135,8 +132,7 @@ nullable=True))
template_args = {}
self.context.opts['render_as_batch'] = True
- autogenerate._render_migration_diffs(
- self.context, template_args, set())
+ autogenerate._render_migration_diffs(self.context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -229,10 +225,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._render_migration_diffs(
- context, template_args, set(),
+ autogenerate._render_migration_diffs(context, template_args)
- )
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
pass
@@ -250,9 +244,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
'include_object': _default_include_object,
'include_schemas': True
})
- autogenerate._render_migration_diffs(
- self.context, template_args, set()
- )
+ autogenerate._render_migration_diffs(self.context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -326,3 +318,4 @@ name='extra_uid_fkey'),
)
op.drop_table('item', schema='%(schema)s')
### end Alembic commands ###""" % {"schema": self.schema})
+
diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py
index f32fd84..d176b91 100644
--- a/tests/test_autogen_diffs.py
+++ b/tests/test_autogen_diffs.py
@@ -6,6 +6,7 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
from sqlalchemy.types import NULLTYPE
from sqlalchemy.engine.reflection import Inspector
+from alembic.operations import ops
from alembic import autogenerate
from alembic.migration import MigrationContext
from alembic.testing import TestBase
@@ -14,8 +15,7 @@ from alembic.testing import assert_raises_message
from alembic.testing.mock import Mock
from alembic.testing import eq_
from alembic.util import CommandError
-from ._autogen_fixtures import \
- AutogenTest, AutogenFixtureTest, _default_object_filters
+from ._autogen_fixtures import AutogenTest, AutogenFixtureTest
py3k = sys.version_info >= (3, )
@@ -63,25 +63,24 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return m
def test_default_schema_omitted_upgrade(self):
- diffs = []
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t3"
else:
return True
- self.autogen_context.update({
- 'object_filters': [include_object],
- 'include_schemas': True,
- 'metadata': self.m2
- })
- autogenerate._produce_net_changes(self.autogen_context, diffs)
+ self._update_context(
+ object_filters=include_object,
+ include_schemas=True,
+ )
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(self.autogen_context, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].schema, None)
def test_alt_schema_included_upgrade(self):
- diffs = []
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
@@ -89,48 +88,48 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
else:
return True
- self.autogen_context.update({
- 'object_filters': [include_object],
- 'include_schemas': True,
- 'metadata': self.m2
- })
- autogenerate._produce_net_changes(self.autogen_context, diffs)
+ self._update_context(
+ object_filters=include_object,
+ include_schemas=True,
+ )
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(self.autogen_context, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].schema, config.test_schema)
def test_default_schema_omitted_downgrade(self):
- diffs = []
-
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t1"
else:
return True
- self.autogen_context.update({
- 'object_filters': [include_object],
- 'include_schemas': True,
- 'metadata': self.m2
- })
- autogenerate._produce_net_changes(self.autogen_context, diffs)
+ self._update_context(
+ object_filters=include_object,
+ include_schemas=True,
+ )
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(self.autogen_context, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0], "remove_table")
eq_(diffs[0][1].schema, None)
def test_alt_schema_included_downgrade(self):
- diffs = []
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t2"
else:
return True
- self.autogen_context.update({
- 'object_filters': [include_object],
- 'include_schemas': True,
- 'metadata': self.m2
- })
- autogenerate._produce_net_changes(self.autogen_context, diffs)
+ self._update_context(
+ object_filters=include_object,
+ include_schemas=True,
+ )
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(self.autogen_context, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0], "remove_table")
eq_(diffs[0][1].schema, config.test_schema)
@@ -268,14 +267,14 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
"""test generation of diff rules"""
metadata = self.m2
- diffs = []
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
- ctx['object_filters'] = _default_object_filters
+ uo = ops.UpgradeOps(ops=[])
+ ctx = self.autogen_context
+
autogenerate._produce_net_changes(
- ctx, diffs
+ ctx, uo
)
+ diffs = uo.as_diffs()
eq_(
diffs[0],
('add_table', metadata.tables['item'])
@@ -396,21 +395,25 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
eq_(alter_cols, set(['user_id', 'order', 'user']))
def test_skip_null_type_comparison_reflected(self):
- diff = []
- autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", NULLTYPE),
- Column("somecol", Integer()),
- diff, self.autogen_context
- )
+ ac = ops.AlterColumnOp("sometable", "somecol")
+ autogenerate.compare._compare_type(
+ self.autogen_context, ac,
+ None, "sometable", "somecol",
+ Column("somecol", NULLTYPE),
+ Column("somecol", Integer()),
+ )
+ diff = ac.to_diff_tuple()
assert not diff
def test_skip_null_type_comparison_local(self):
- diff = []
- autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", Integer()),
- Column("somecol", NULLTYPE),
- diff, self.autogen_context
- )
+ ac = ops.AlterColumnOp("sometable", "somecol")
+ autogenerate.compare._compare_type(
+ self.autogen_context, ac,
+ None, "sometable", "somecol",
+ Column("somecol", Integer()),
+ Column("somecol", NULLTYPE),
+ )
+ diff = ac.to_diff_tuple()
assert not diff
def test_custom_type_compare(self):
@@ -420,20 +423,24 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
def compare_against_backend(self, dialect, conn_type):
return isinstance(conn_type, Integer)
- diff = []
- autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", INTEGER()),
- Column("somecol", MyType()),
- diff, self.autogen_context
- )
- assert not diff
+ ac = ops.AlterColumnOp("sometable", "somecol")
+ autogenerate.compare._compare_type(
+ self.autogen_context, ac,
+ None, "sometable", "somecol",
+ Column("somecol", INTEGER()),
+ Column("somecol", MyType()),
+ )
+
+ assert not ac.has_changes()
- diff = []
- autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", String()),
- Column("somecol", MyType()),
- diff, self.autogen_context
- )
+ ac = ops.AlterColumnOp("sometable", "somecol")
+ autogenerate.compare._compare_type(
+ self.autogen_context, ac,
+ None, "sometable", "somecol",
+ Column("somecol", String()),
+ Column("somecol", MyType()),
+ )
+ diff = ac.to_diff_tuple()
eq_(
diff[0][0:4],
('modify_type', None, 'sometable', 'somecol')
@@ -449,26 +456,26 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
else:
return dialect.type_descriptor(CHAR(32))
- diff = []
+ uo = ops.AlterColumnOp('sometable', 'somecol')
autogenerate.compare._compare_type(
+ self.autogen_context, uo,
None, "sometable", "somecol",
Column("somecol", Integer, nullable=True),
- Column("somecol", MyType()),
- diff, self.autogen_context
+ Column("somecol", MyType())
)
- assert not diff
+ assert not uo.has_changes()
def test_dont_barf_on_already_reflected(self):
- diffs = []
from sqlalchemy.util import OrderedSet
inspector = Inspector.from_engine(self.bind)
+ uo = ops.UpgradeOps(ops=[])
autogenerate.compare._compare_tables(
OrderedSet([(None, 'extra'), (None, 'user')]),
- OrderedSet(), [], inspector,
- MetaData(), diffs, self.autogen_context
+ OrderedSet(), inspector,
+ MetaData(), uo, self.autogen_context
)
eq_(
- [(rec[0], rec[1].name) for rec in diffs],
+ [(rec[0], rec[1].name) for rec in uo.as_diffs()],
[('remove_table', 'extra'), ('remove_table', 'user')]
)
@@ -481,14 +488,14 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
"""test generation of diff rules"""
metadata = self.m2
- diffs = []
- self.autogen_context.update({
- 'object_filters': _default_object_filters,
- 'include_schemas': True,
- 'metadata': self.m2
- })
- autogenerate._produce_net_changes(self.autogen_context, diffs)
+ self._update_context(
+ include_schemas=True,
+ )
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(self.autogen_context, uo)
+
+ diffs = uo.as_diffs()
eq_(
diffs[0],
@@ -567,10 +574,10 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
my_compare_type = Mock()
self.context._user_compare_type = my_compare_type
- diffs = []
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
- autogenerate._produce_net_changes(ctx, diffs)
+ uo = ops.UpgradeOps(ops=[])
+
+ ctx = self.autogen_context
+ autogenerate._produce_net_changes(ctx, uo)
first_table = self.m2.tables['sometable']
first_column = first_table.columns['id']
@@ -593,8 +600,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
self.context._user_compare_type = my_compare_type
diffs = []
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
+ ctx = self.autogen_context
diffs = []
autogenerate._produce_net_changes(ctx, diffs)
@@ -605,10 +611,10 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
my_compare_type.return_value = True
self.context._user_compare_type = my_compare_type
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
- diffs = []
- autogenerate._produce_net_changes(ctx, diffs)
+ ctx = self.autogen_context
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(ctx, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0][0], 'modify_type')
eq_(diffs[1][0][0], 'modify_type')
@@ -636,8 +642,7 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
def test_no_change(self):
diffs = []
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
+ ctx = self.autogen_context
autogenerate._produce_net_changes(ctx, diffs)
eq_(diffs, [])
@@ -674,11 +679,11 @@ class AutogenKeyTest(AutogenTest, TestBase):
def test_autogen(self):
- diffs = []
+ uo = ops.UpgradeOps(ops=[])
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
- autogenerate._produce_net_changes(ctx, diffs)
+ ctx = self.autogen_context
+ autogenerate._produce_net_changes(ctx, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].name, "sometable")
eq_(diffs[1][0], "add_column")
@@ -705,8 +710,7 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
def test_no_version_table(self):
diffs = []
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
+ ctx = self.autogen_context
autogenerate._produce_net_changes(ctx, diffs)
eq_(diffs, [])
@@ -717,8 +721,7 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
self.version_table_name,
self.m2, Column('x', Integer), schema=self.version_table_schema)
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
+ ctx = self.autogen_context
autogenerate._produce_net_changes(ctx, diffs)
eq_(diffs, [])
@@ -769,10 +772,10 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase):
before their parent tables
"""
- ctx = self.autogen_context.copy()
- ctx['metadata'] = self.m2
- diffs = []
- autogenerate._produce_net_changes(ctx, diffs)
+ ctx = self.autogen_context
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(ctx, uo)
+ diffs = uo.as_diffs()
eq_(diffs[0][0], 'add_table')
eq_(diffs[0][1].name, "parent")
diff --git a/tests/test_autogen_fks.py b/tests/test_autogen_fks.py
index 525bed5..174a538 100644
--- a/tests/test_autogen_fks.py
+++ b/tests/test_autogen_fks.py
@@ -351,7 +351,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
type_ == 'foreign_key_constraint'
and reflected and name == 'fk1')
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
self._assert_fk_diff(
diffs[0], "remove_fk",
@@ -390,7 +390,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
type_ == 'foreign_key_constraint'
and not reflected and name == 'fk1')
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
self._assert_fk_diff(
diffs[0], "add_fk",
@@ -456,7 +456,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
and name == 'fk1'
)
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
self._assert_fk_diff(
diffs[0], "remove_fk",
diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py
index 8ee33bc..9b6cd44 100644
--- a/tests/test_autogen_indexes.py
+++ b/tests/test_autogen_indexes.py
@@ -798,7 +798,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
isinstance(object_, Index) and
type_ == 'index' and reflected and name == 'ix1')
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
eq_(diffs[0][0], 'remove_index')
eq_(diffs[0][1].name, 'ix2')
@@ -825,7 +825,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
isinstance(object_, UniqueConstraint) and
type_ == 'unique_constraint' and reflected and name == 'uq1')
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
eq_(diffs[0][0], 'remove_constraint')
eq_(diffs[0][1].name, 'uq2')
@@ -846,7 +846,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
isinstance(object_, Index) and
type_ == 'index' and not reflected and name == 'ix1')
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
eq_(diffs[0][0], 'add_index')
eq_(diffs[0][1].name, 'ix2')
@@ -871,7 +871,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
type_ == 'unique_constraint' and
not reflected and name == 'uq1')
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
eq_(diffs[0][0], 'add_constraint')
eq_(diffs[0][1].name, 'uq2')
@@ -899,7 +899,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
type_ == 'index' and not reflected and name == 'ix1'
and isinstance(compare_to, Index))
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
eq_(diffs[0][0], 'remove_index')
eq_(diffs[0][1].name, 'ix2')
@@ -935,7 +935,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
not reflected and name == 'uq1'
and isinstance(compare_to, UniqueConstraint))
- diffs = self._fixture(m1, m2, object_filters=[include_object])
+ diffs = self._fixture(m1, m2, object_filters=include_object)
eq_(diffs[0][0], 'remove_constraint')
eq_(diffs[0][1].name, 'uq2')
diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py
index 4a49d5c..a73cff5 100644
--- a/tests/test_autogen_render.py
+++ b/tests/test_autogen_render.py
@@ -14,6 +14,8 @@ from sqlalchemy.types import UserDefinedType
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.sql import and_, column, literal_column, false
+from alembic.migration import MigrationContext
+from alembic.autogenerate import api
from alembic.testing.mock import patch
@@ -32,22 +34,30 @@ class AutogenRenderTest(TestBase):
"""test individual directives"""
- @classmethod
- def setup_class(cls):
- cls.autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- },
- 'dialect': mysql.dialect()
- }
- cls.pg_autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- },
- 'dialect': postgresql.dialect()
+ def setUp(self):
+ ctx_opts = {
+ 'sqlalchemy_module_prefix': 'sa.',
+ 'alembic_module_prefix': 'op.',
+ 'target_metadata': MetaData()
}
+ context = MigrationContext.configure(
+ dialect_name="mysql",
+ opts=ctx_opts
+ )
+
+ self.autogen_context = api.AutogenContext(context)
+
+ context = MigrationContext.configure(
+ dialect_name="postgresql",
+ opts=ctx_opts
+ )
+ self.pg_autogen_context = api.AutogenContext(context)
+
+ context = MigrationContext.configure(
+ dialect=DefaultDialect(),
+ opts=ctx_opts
+ )
+ self.default_autogen_context = api.AutogenContext(context)
def test_render_add_index(self):
"""
@@ -812,10 +822,10 @@ unique=False, """
return "col(%s)" % obj.name
return "render:%s" % type_
- autogen_context = {"opts": {
- 'render_item': render,
- 'alembic_module_prefix': 'sa.'
- }}
+ self.autogen_context.opts.update(
+ render_item=render,
+ alembic_module_prefix='sa.'
+ )
t = Table('t', MetaData(),
Column('x', Integer),
@@ -824,7 +834,7 @@ unique=False, """
ForeignKeyConstraint(['x'], ['y'])
)
op_obj = ops.CreateTableOp.from_table(t)
- result = autogenerate.render_op_text(autogen_context, op_obj)
+ result = autogenerate.render_op_text(self.autogen_context, op_obj)
eq_ignore_whitespace(
result,
"sa.create_table('t',"
@@ -1087,28 +1097,13 @@ unique=False, """
def test_repr_plain_sqla_type(self):
type_ = Integer()
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- },
- 'dialect': mysql.dialect()
- }
-
eq_ignore_whitespace(
- autogenerate.render._repr_type(type_, autogen_context),
+ autogenerate.render._repr_type(type_, self.autogen_context),
"sa.Integer()"
)
def test_repr_custom_type_w_sqla_prefix(self):
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'user_module_prefix': None
- },
- 'dialect': mysql.dialect()
- }
+ self.autogen_context.opts['user_module_prefix'] = None
class MyType(UserDefinedType):
pass
@@ -1118,7 +1113,7 @@ unique=False, """
type_ = MyType()
eq_ignore_whitespace(
- autogenerate.render._repr_type(type_, autogen_context),
+ autogenerate.render._repr_type(type_, self.autogen_context),
"sqlalchemy_util.types.MyType()"
)
@@ -1129,17 +1124,10 @@ unique=False, """
return "MYTYPE"
type_ = MyType()
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'user_module_prefix': None
- },
- 'dialect': mysql.dialect()
- }
+ self.autogen_context.opts['user_module_prefix'] = None
eq_ignore_whitespace(
- autogenerate.render._repr_type(type_, autogen_context),
+ autogenerate.render._repr_type(type_, self.autogen_context),
"tests.test_autogen_render.MyType()"
)
@@ -1152,17 +1140,11 @@ unique=False, """
return "MYTYPE"
type_ = MyType()
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'user_module_prefix': 'user.',
- },
- 'dialect': mysql.dialect()
- }
+
+ self.autogen_context.opts['user_module_prefix'] = 'user.'
eq_ignore_whitespace(
- autogenerate.render._repr_type(type_, autogen_context),
+ autogenerate.render._repr_type(type_, self.autogen_context),
"user.MyType()"
)
@@ -1171,20 +1153,14 @@ unique=False, """
from sqlalchemy.dialects.mysql import VARCHAR
type_ = VARCHAR(20, charset='utf8', national=True)
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'user_module_prefix': None,
- },
- 'imports': set(),
- 'dialect': mysql.dialect()
- }
+
+ self.autogen_context.opts['user_module_prefix'] = None
+
eq_ignore_whitespace(
- autogenerate.render._repr_type(type_, autogen_context),
+ autogenerate.render._repr_type(type_, self.autogen_context),
"mysql.VARCHAR(charset='utf8', national=True, length=20)"
)
- eq_(autogen_context['imports'],
+ eq_(self.autogen_context._imports,
set(['from sqlalchemy.dialects import mysql'])
)
@@ -1204,19 +1180,12 @@ unique=False, """
)
def test_render_server_default_native_boolean(self):
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- },
- 'dialect': postgresql.dialect()
- }
c = Column(
'updated_at', Boolean(),
server_default=false(),
nullable=False)
result = autogenerate.render._render_column(
- c, autogen_context,
+ c, self.autogen_context,
)
eq_ignore_whitespace(
result,
@@ -1231,17 +1200,10 @@ unique=False, """
'updated_at', Boolean(),
server_default=false(),
nullable=False)
- dialect = DefaultDialect()
- autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- },
- 'dialect': dialect
- }
+# MARKMARK
result = autogenerate.render._render_column(
- c, autogen_context
+ c, self.default_autogen_context
)
eq_ignore_whitespace(
result,
@@ -1296,16 +1258,6 @@ unique=False, """
class RenderNamingConventionTest(TestBase):
__requires__ = ('sqlalchemy_094',)
- @classmethod
- def setup_class(cls):
- cls.autogen_context = {
- 'opts': {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- },
- 'dialect': postgresql.dialect()
- }
-
def setUp(self):
convention = {
@@ -1322,6 +1274,17 @@ class RenderNamingConventionTest(TestBase):
naming_convention=convention
)
+ ctx_opts = {
+ 'sqlalchemy_module_prefix': 'sa.',
+ 'alembic_module_prefix': 'op.',
+ 'target_metadata': MetaData()
+ }
+ context = MigrationContext.configure(
+ dialect_name="postgresql",
+ opts=ctx_opts
+ )
+ self.autogen_context = api.AutogenContext(context)
+
def test_schema_type_boolean(self):
t = Table('t', self.metadata, Column('c', Boolean(name='xyz')))
op_obj = ops.AddColumnOp.from_column(t.c.c)
@@ -1457,3 +1420,29 @@ class RenderNamingConventionTest(TestBase):
"sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))"
)
+ def test_create_table_plus_add_index_in_modify(self):
+ uo = ops.UpgradeOps(ops=[
+ ops.CreateTableOp(
+ "sometable",
+ [Column('x', Integer), Column('y', Integer)]
+ ),
+ ops.ModifyTableOps(
+ "sometable", ops=[
+ ops.CreateIndexOp('ix1', 'sometable', ['x', 'y'])
+ ]
+ )
+ ])
+
+ eq_(
+ autogenerate.render_python_code(uo, render_as_batch=True),
+ "### commands auto generated by Alembic - please adjust! ###\n"
+ " op.create_table('sometable',\n"
+ " sa.Column('x', sa.Integer(), nullable=True),\n"
+ " sa.Column('y', sa.Integer(), nullable=True)\n"
+ " )\n"
+ " with op.batch_alter_table('sometable', schema=None) "
+ "as batch_op:\n"
+ " batch_op.create_index("
+ "'ix1', ['x', 'y'], unique=False)\n\n"
+ " ### end Alembic commands ###"
+ )
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index e70d05a..576d957 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -8,9 +8,11 @@ from sqlalchemy.sql import table, column
from alembic.autogenerate.compare import \
_compare_server_default, _compare_tables, _render_server_default_for_compare
+from alembic.operations import ops
from alembic import command, util
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
+from alembic.autogenerate import api
from alembic.testing import eq_, provide_metadata
from alembic.testing.env import staging_env, clear_staging_env, \
@@ -162,34 +164,22 @@ class PostgresqlDefaultCompareTest(TestBase):
def setup_class(cls):
cls.bind = config.db
staging_env()
- context = MigrationContext.configure(
+ cls.migration_context = MigrationContext.configure(
connection=cls.bind.connect(),
opts={
'compare_type': True,
'compare_server_default': True
}
)
- connection = context.bind
- cls.autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context,
- 'opts': {
- 'compare_type': True,
- 'compare_server_default': True,
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- }
- }
+
+ def setUp(self):
+ self.metadata = MetaData(self.bind)
+ self.autogen_context = api.AutogenContext(self.migration_context)
@classmethod
def teardown_class(cls):
clear_staging_env()
- def setUp(self):
- self.metadata = MetaData(self.bind)
-
def tearDown(self):
self.metadata.drop_all()
@@ -212,9 +202,12 @@ class PostgresqlDefaultCompareTest(TestBase):
cols = insp.get_columns(t1.name)
insp_col = Column("somecol", cols[0]['type'],
server_default=text(cols[0]['default']))
- diffs = []
- _compare_server_default(None, "test", "somecol", insp_col,
- t2.c.somecol, diffs, self.autogen_context)
+ op = ops.AlterColumnOp("test", "somecol")
+ _compare_server_default(
+ self.autogen_context, op,
+ None, "test", "somecol", insp_col, t2.c.somecol)
+
+ diffs = op.to_diff_tuple()
eq_(bool(diffs), diff_expected)
def _compare_default(
@@ -225,7 +218,7 @@ class PostgresqlDefaultCompareTest(TestBase):
t1.create(self.bind, checkfirst=True)
insp = Inspector.from_engine(self.bind)
cols = insp.get_columns(t1.name)
- ctx = self.autogen_context['context']
+ ctx = self.autogen_context.migration_context
return ctx.impl.compare_server_default(
None,
@@ -385,26 +378,16 @@ class PostgresqlDetectSerialTest(TestBase):
cls.bind = config.db
cls.conn = cls.bind.connect()
staging_env()
- context = MigrationContext.configure(
+ cls.migration_context = MigrationContext.configure(
connection=cls.conn,
opts={
'compare_type': True,
'compare_server_default': True
}
)
- connection = context.bind
- cls.autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context,
- 'opts': {
- 'compare_type': True,
- 'compare_server_default': True,
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- }
- }
+
+ def setUp(self):
+ self.autogen_context = api.AutogenContext(self.migration_context)
@classmethod
def teardown_class(cls):
@@ -420,24 +403,26 @@ class PostgresqlDetectSerialTest(TestBase):
self.metadata.create_all(config.db)
insp = Inspector.from_engine(config.db)
- diffs = []
+
+ uo = ops.UpgradeOps(ops=[])
_compare_tables(
set([(None, 't')]), set([]),
- [],
- insp, self.metadata, diffs, self.autogen_context)
+ insp, self.metadata, uo, self.autogen_context)
+ diffs = uo.as_diffs()
tab = diffs[0][1]
+
eq_(_render_server_default_for_compare(
tab.c.x.server_default, tab.c.x, self.autogen_context),
c_expected)
insp = Inspector.from_engine(config.db)
- diffs = []
+ uo = ops.UpgradeOps(ops=[])
m2 = MetaData()
Table('t', m2, Column('x', BigInteger()))
_compare_tables(
set([(None, 't')]), set([(None, 't')]),
- [],
- insp, m2, diffs, self.autogen_context)
+ insp, m2, uo, self.autogen_context)
+ diffs = uo.as_diffs()
server_default = diffs[0][0][4]['existing_server_default']
eq_(_render_server_default_for_compare(
server_default, tab.c.x, self.autogen_context),