diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-16 19:00:55 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-16 19:00:55 -0400 |
commit | c5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7 (patch) | |
tree | f158a8b62d7679176f3f60d1b1bf188a6383e03c /alembic | |
parent | 96214629cdb13f1694831f36c48a7ec86dd8c7f6 (diff) | |
download | alembic-c5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7.tar.gz |
- rework all of autogenerate to build directly on alembic.operations.ops
objects; the "diffs" is now a legacy system that is exported from
the ops. A new model of comparison/rendering/ upgrade/downgrade
composition that is cleaner and much more extensible is introduced.
- autogenerate is now extensible as far as database objects compared
and rendered into scripts; any new operation directive can also be
registered into a series of hooks that allow custom database/model
comparison functions to run as well as to render new operation
directives into autogenerate scripts.
- write all new docs for the new system
fixes #306
Diffstat (limited to 'alembic')
-rw-r--r-- | alembic/autogenerate/__init__.py | 6 | ||||
-rw-r--r-- | alembic/autogenerate/api.py | 306 | ||||
-rw-r--r-- | alembic/autogenerate/compare.py | 336 | ||||
-rw-r--r-- | alembic/autogenerate/compose.py | 144 | ||||
-rw-r--r-- | alembic/autogenerate/generate.py | 92 | ||||
-rw-r--r-- | alembic/autogenerate/render.py | 70 | ||||
-rw-r--r-- | alembic/operations/ops.py | 256 | ||||
-rw-r--r-- | alembic/runtime/migration.py | 3 | ||||
-rw-r--r-- | alembic/util/langhelpers.py | 43 |
9 files changed, 724 insertions, 532 deletions
diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py index 4272a7e..78520a8 100644 --- a/alembic/autogenerate/__init__.py +++ b/alembic/autogenerate/__init__.py @@ -1,7 +1,7 @@ from .api import ( # noqa compare_metadata, _render_migration_diffs, - produce_migrations, render_python_code + produce_migrations, render_python_code, + RevisionContext ) -from .compare import _produce_net_changes # noqa -from .generate import RevisionContext # noqa +from .compare import _produce_net_changes, comparators # noqa from .render import render_op_text, renderers # noqa
\ No newline at end of file diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index cff977b..e9af4cf 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -4,8 +4,9 @@ automatically.""" from ..operations import ops from . import render from . import compare -from . import compose from .. import util +from sqlalchemy.engine.reflection import Inspector +import contextlib def compare_metadata(context, metadata): @@ -98,20 +99,8 @@ def compare_metadata(context, metadata): """ - autogen_context = _autogen_context(context, metadata=metadata) - - # as_sql=True is nonsensical here. autogenerate requires a connection - # it can use to run queries against to get the database schema. - if context.as_sql: - raise util.CommandError( - "autogenerate can't use as_sql=True as it prevents querying " - "the database for schema information") - - diffs = [] - - compare._produce_net_changes(autogen_context, diffs) - - return diffs + migration_script = produce_migrations(context, metadata) + return migration_script.upgrade_ops.as_diffs() def produce_migrations(context, metadata): @@ -132,10 +121,7 @@ def produce_migrations(context, metadata): """ - autogen_context = _autogen_context(context, metadata=metadata) - diffs = [] - - compare._produce_net_changes(autogen_context, diffs) + autogen_context = AutogenContext(context, metadata=metadata) migration_script = ops.MigrationScript( rev_id=None, @@ -143,7 +129,7 @@ def produce_migrations(context, metadata): downgrade_ops=ops.DowngradeOps([]), ) - compose._to_migration_script(autogen_context, migration_script, diffs) + compare._populate_migration_script(autogen_context, migration_script) return migration_script @@ -152,6 +138,7 @@ def render_python_code( up_or_down_op, sqlalchemy_module_prefix='sa.', alembic_module_prefix='op.', + render_as_batch=False, imports=(), render_item=None, ): @@ -162,84 +149,239 @@ def render_python_code( autogenerate output of a user-defined :class:`.MigrationScript` structure. """ - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': sqlalchemy_module_prefix, - 'alembic_module_prefix': alembic_module_prefix, - 'render_item': render_item, - }, - 'imports': set(imports) + opts = { + 'sqlalchemy_module_prefix': sqlalchemy_module_prefix, + 'alembic_module_prefix': alembic_module_prefix, + 'render_item': render_item, + 'render_as_batch': render_as_batch, } + + autogen_context = AutogenContext(None, opts=opts) + autogen_context.imports = set(imports) return render._indent(render._render_cmd_body( up_or_down_op, autogen_context)) - - -def _render_migration_diffs(context, template_args, imports): +def _render_migration_diffs(context, template_args): """legacy, used by test_autogen_composition at the moment""" - migration_script = produce_migrations(context, None) - - autogen_context = _autogen_context(context, imports=imports) - diffs = [] + autogen_context = AutogenContext(context) - compare._produce_net_changes(autogen_context, diffs) + upgrade_ops = ops.UpgradeOps([]) + compare._produce_net_changes(autogen_context, upgrade_ops) migration_script = ops.MigrationScript( rev_id=None, - imports=imports, - upgrade_ops=ops.UpgradeOps([]), - downgrade_ops=ops.DowngradeOps([]), + upgrade_ops=upgrade_ops, + downgrade_ops=upgrade_ops.reverse(), ) - compose._to_migration_script(autogen_context, migration_script, diffs) - render._render_migration_script( autogen_context, migration_script, template_args ) -def _autogen_context( - context, imports=None, metadata=None, include_symbol=None, - include_object=None, include_schemas=False): - - opts = context.opts - metadata = opts['target_metadata'] if metadata is None else metadata - include_schemas = opts.get('include_schemas', include_schemas) - - include_symbol = opts.get('include_symbol', include_symbol) - include_object = opts.get('include_object', include_object) - - object_filters = [] - if include_symbol: - def include_symbol_filter(object, name, type_, reflected, compare_to): - if type_ == "table": - return include_symbol(name, object.schema) - else: - return True - object_filters.append(include_symbol_filter) - if include_object: - object_filters.append(include_object) - - if metadata is None: - raise util.CommandError( - "Can't proceed with --autogenerate option; environment " - "script %s does not provide " - "a MetaData object to the context." % ( - context.script.env_py_location - )) - - opts = context.opts - connection = context.bind - return { - 'imports': imports if imports is not None else set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context, - 'opts': opts, - 'metadata': metadata, - 'object_filters': object_filters, - 'include_schemas': include_schemas - } +class AutogenContext(object): + """Maintains configuration and state that's specific to an + autogenerate operation.""" + + metadata = None + """The :class:`~sqlalchemy.schema.MetaData` object + representing the destination. + + This object is the one that is passed within ``env.py`` + to the :paramref:`.EnvironmentContext.configure.target_metadata` + parameter. It represents the structure of :class:`.Table` and other + objects as stated in the current database model, and represents the + destination structure for the database being examined. + + While the :class:`~sqlalchemy.schema.MetaData` object is primarily + known as a collection of :class:`~sqlalchemy.schema.Table` objects, + it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary + that may be used by end-user schemes to store additional schema-level + objects that are to be compared in custom autogeneration schemes. + + """ + + connection = None + """The :class:`~sqlalchemy.engine.base.Connection` object currently + connected to the database backend being compared. + + This is obtained from the :attr:`.MigrationContext.bind` and is + utimately set up in the ``env.py`` script. + + """ + + dialect = None + """The :class:`~sqlalchemy.engine.Dialect` object currently in use. + + This is normally obtained from the + :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute. + + """ + + migration_context = None + """The :class:`.MigrationContext` established by the ``env.py`` script.""" + + def __init__(self, migration_context, metadata=None, opts=None): + + if migration_context is not None and migration_context.as_sql: + raise util.CommandError( + "autogenerate can't use as_sql=True as it prevents querying " + "the database for schema information") + + if opts is None: + opts = migration_context.opts + self.metadata = metadata = opts.get('target_metadata', None) \ + if metadata is None else metadata + + if metadata is None and \ + migration_context is not None and \ + migration_context.script is not None: + raise util.CommandError( + "Can't proceed with --autogenerate option; environment " + "script %s does not provide " + "a MetaData object to the context." % ( + migration_context.script.env_py_location + )) + + include_symbol = opts.get('include_symbol', None) + include_object = opts.get('include_object', None) + + object_filters = [] + if include_symbol: + def include_symbol_filter( + object, name, type_, reflected, compare_to): + if type_ == "table": + return include_symbol(name, object.schema) + else: + return True + object_filters.append(include_symbol_filter) + if include_object: + object_filters.append(include_object) + + self._object_filters = object_filters + + self.migration_context = migration_context + if self.migration_context is not None: + self.connection = self.migration_context.bind + self.dialect = self.migration_context.dialect + + self._imports = set() + self.opts = opts + self._has_batch = False + + @util.memoized_property + def inspector(self): + return Inspector.from_engine(self.connection) + + @contextlib.contextmanager + def _within_batch(self): + self._has_batch = True + yield + self._has_batch = False + + def run_filters(self, object_, name, type_, reflected, compare_to): + """Run the context's object filters and return True if the targets + should be part of the autogenerate operation. + + This method should be run for every kind of object encountered within + an autogenerate operation, giving the environment the chance + to filter what objects should be included in the comparison. + The filters here are produced directly via the + :paramref:`.EnvironmentContext.configure.include_object` + and :paramref:`.EnvironmentContext.configure.include_symbol` + functions, if present. + + """ + for fn in self._object_filters: + if not fn(object_, name, type_, reflected, compare_to): + return False + else: + return True + + +class RevisionContext(object): + """Maintains configuration and state that's specific to a revision + file generation operation.""" + + def __init__(self, config, script_directory, command_args): + self.config = config + self.script_directory = script_directory + self.command_args = command_args + self.template_args = { + 'config': config # Let templates use config for + # e.g. multiple databases + } + self.generated_revisions = [ + self._default_revision() + ] + + def _to_script(self, migration_script): + template_args = {} + for k, v in self.template_args.items(): + template_args.setdefault(k, v) + + if migration_script._autogen_context is not None: + render._render_migration_script( + migration_script._autogen_context, migration_script, + template_args + ) + + return self.script_directory.generate_revision( + migration_script.rev_id, + migration_script.message, + refresh=True, + head=migration_script.head, + splice=migration_script.splice, + branch_labels=migration_script.branch_label, + version_path=migration_script.version_path, + **template_args) + + def run_autogenerate(self, rev, context): + if self.command_args['sql']: + raise util.CommandError( + "Using --sql with --autogenerate does not make any sense") + if set(self.script_directory.get_revisions(rev)) != \ + set(self.script_directory.get_revisions("heads")): + raise util.CommandError("Target database is not up to date.") + + autogen_context = AutogenContext(context) + + migration_script = self.generated_revisions[0] + + compare._populate_migration_script(autogen_context, migration_script) + + hook = context.opts.get('process_revision_directives', None) + if hook: + hook(context, rev, self.generated_revisions) + + for migration_script in self.generated_revisions: + migration_script._autogen_context = autogen_context + + def run_no_autogenerate(self, rev, context): + hook = context.opts.get('process_revision_directives', None) + if hook: + hook(context, rev, self.generated_revisions) + + for migration_script in self.generated_revisions: + migration_script._autogen_context = None + + def _default_revision(self): + op = ops.MigrationScript( + rev_id=self.command_args['rev_id'] or util.rev_id(), + message=self.command_args['message'], + imports=set(), + upgrade_ops=ops.UpgradeOps([]), + downgrade_ops=ops.DowngradeOps([]), + head=self.command_args['head'], + splice=self.command_args['splice'], + branch_label=self.command_args['branch_label'], + version_path=self.command_args['version_path'] + ) + op._autogen_context = None + return op + def generate_scripts(self): + for generated_revision in self.generated_revisions: + yield self._to_script(generated_revision) diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index cd6b696..fdc3cae 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1,7 +1,9 @@ from sqlalchemy import schema as sa_schema, types as sqltypes from sqlalchemy.engine.reflection import Inspector from sqlalchemy import event +from ..operations import ops import logging +from .. import util from ..util import compat from ..util import sqla_compat from sqlalchemy.util import OrderedSet @@ -13,15 +15,20 @@ from alembic.ddl.base import _fk_spec log = logging.getLogger(__name__) -def _produce_net_changes(autogen_context, diffs): +def _populate_migration_script(autogen_context, migration_script): + _produce_net_changes(autogen_context, migration_script.upgrade_ops) + migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops) - metadata = autogen_context['metadata'] - connection = autogen_context['connection'] - object_filters = autogen_context.get('object_filters', ()) - include_schemas = autogen_context.get('include_schemas', False) + +comparators = util.Dispatcher(uselist=True) + + +def _produce_net_changes(autogen_context, upgrade_ops): + + connection = autogen_context.connection + include_schemas = autogen_context.opts.get('include_schemas', False) inspector = Inspector.from_engine(connection) - conn_table_names = set() default_schema = connection.dialect.default_schema_name if include_schemas: @@ -34,14 +41,28 @@ def _produce_net_changes(autogen_context, diffs): else: schemas = [None] - version_table_schema = autogen_context['context'].version_table_schema - version_table = autogen_context['context'].version_table + comparators.dispatch("schema", autogen_context.dialect.name)( + autogen_context, upgrade_ops, schemas + ) + + +@comparators.dispatch_for("schema") +def _autogen_for_tables(autogen_context, upgrade_ops, schemas): + inspector = autogen_context.inspector + + metadata = autogen_context.metadata + + conn_table_names = set() + + version_table_schema = \ + autogen_context.migration_context.version_table_schema + version_table = autogen_context.migration_context.version_table for s in schemas: tables = set(inspector.get_table_names(schema=s)) if s == version_table_schema: tables = tables.difference( - [autogen_context['context'].version_table] + [autogen_context.migration_context.version_table] ) conn_table_names.update(zip([s] * len(tables), tables)) @@ -50,21 +71,11 @@ def _produce_net_changes(autogen_context, diffs): ).difference([(version_table_schema, version_table)]) _compare_tables(conn_table_names, metadata_table_names, - object_filters, - inspector, metadata, diffs, autogen_context) - - -def _run_filters(object_, name, type_, reflected, compare_to, object_filters): - for fn in object_filters: - if not fn(object_, name, type_, reflected, compare_to): - return False - else: - return True + inspector, metadata, upgrade_ops, autogen_context) def _compare_tables(conn_table_names, metadata_table_names, - object_filters, - inspector, metadata, diffs, autogen_context): + inspector, metadata, upgrade_ops, autogen_context): default_schema = inspector.bind.dialect.default_schema_name @@ -95,14 +106,19 @@ def _compare_tables(conn_table_names, metadata_table_names, for s, tname in metadata_table_names.difference(conn_table_names): name = '%s.%s' % (s, tname) if s else tname metadata_table = tname_to_table[(s, tname)] - if _run_filters( - metadata_table, tname, "table", False, None, object_filters): - diffs.append(("add_table", metadata_table)) + if autogen_context.run_filters( + metadata_table, tname, "table", False, None): + upgrade_ops.ops.append( + ops.CreateTableOp.from_table(metadata_table)) log.info("Detected added table %r", name) - _compare_indexes_and_uniques(s, tname, object_filters, - None, - metadata_table, - diffs, autogen_context, inspector) + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + + comparators.dispatch("table")( + autogen_context, modify_table_ops, + s, tname, None, metadata_table + ) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) removal_metadata = sa_schema.MetaData() for s, tname in conn_table_names.difference(metadata_table_names): @@ -114,11 +130,13 @@ def _compare_tables(conn_table_names, metadata_table_names, event.listen( t, "column_reflect", - autogen_context['context'].impl. + autogen_context.migration_context.impl. _compat_autogen_column_reflect(inspector)) inspector.reflecttable(t, None) - if _run_filters(t, tname, "table", True, None, object_filters): - diffs.append(("remove_table", t)) + if autogen_context.run_filters(t, tname, "table", True, None): + upgrade_ops.ops.append( + ops.DropTableOp.from_table(t) + ) log.info("Detected removed table %r", name) existing_tables = conn_table_names.intersection(metadata_table_names) @@ -133,7 +151,7 @@ def _compare_tables(conn_table_names, metadata_table_names, event.listen( t, "column_reflect", - autogen_context['context'].impl. + autogen_context.migration_context.impl. _compat_autogen_column_reflect(inspector)) inspector.reflecttable(t, None) conn_column_info[(s, tname)] = t @@ -144,25 +162,24 @@ def _compare_tables(conn_table_names, metadata_table_names, metadata_table = tname_to_table[(s, tname)] conn_table = existing_metadata.tables[name] - if _run_filters( + if autogen_context.run_filters( metadata_table, tname, "table", False, - conn_table, object_filters): + conn_table): + + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) with _compare_columns( - s, tname, object_filters, + s, tname, conn_table, metadata_table, - diffs, autogen_context, inspector): - _compare_indexes_and_uniques(s, tname, object_filters, - conn_table, - metadata_table, - diffs, autogen_context, inspector) - _compare_foreign_keys(s, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, - inspector) + modify_table_ops, autogen_context, inspector): - # TODO: - # table constraints - # sequences + comparators.dispatch("table")( + autogen_context, modify_table_ops, + s, tname, conn_table, metadata_table + ) + + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) def _make_index(params, conn_table): @@ -202,56 +219,51 @@ def _make_foreign_key(params, conn_table): @contextlib.contextmanager -def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, - diffs, autogen_context, inspector): +def _compare_columns(schema, tname, conn_table, metadata_table, + modify_table_ops, autogen_context, inspector): name = '%s.%s' % (schema, tname) if schema else tname metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c) conn_col_names = dict((c.name, c) for c in conn_table.c) metadata_col_names = OrderedSet(sorted(metadata_cols_by_name)) for cname in metadata_col_names.difference(conn_col_names): - if _run_filters(metadata_cols_by_name[cname], cname, - "column", False, None, object_filters): - diffs.append( - ("add_column", schema, tname, metadata_cols_by_name[cname]) + if autogen_context.run_filters( + metadata_cols_by_name[cname], cname, + "column", False, None): + modify_table_ops.ops.append( + ops.AddColumnOp.from_column_and_tablename( + schema, tname, metadata_cols_by_name[cname]) ) log.info("Detected added column '%s.%s'", name, cname) for colname in metadata_col_names.intersection(conn_col_names): metadata_col = metadata_cols_by_name[colname] conn_col = conn_table.c[colname] - if not _run_filters( + if not autogen_context.run_filters( metadata_col, colname, "column", False, - conn_col, object_filters): + conn_col): continue - col_diff = [] - _compare_type(schema, tname, colname, - conn_col, - metadata_col, - col_diff, autogen_context - ) - # work around SQLAlchemy issue #3023 - if not metadata_col.primary_key: - _compare_nullable(schema, tname, colname, - conn_col, - metadata_col.nullable, - col_diff, autogen_context - ) - _compare_server_default(schema, tname, colname, - conn_col, - metadata_col, - col_diff, autogen_context - ) - if col_diff: - diffs.append(col_diff) + alter_column_op = ops.AlterColumnOp( + tname, colname, schema=schema) + + comparators.dispatch("column")( + autogen_context, alter_column_op, + schema, tname, colname, conn_col, metadata_col + ) + + if alter_column_op.has_changes(): + modify_table_ops.ops.append(alter_column_op) yield for cname in set(conn_col_names).difference(metadata_col_names): - if _run_filters(conn_table.c[cname], cname, - "column", True, None, object_filters): - diffs.append( - ("remove_column", schema, tname, conn_table.c[cname]) + if autogen_context.run_filters( + conn_table.c[cname], cname, + "column", True, None): + modify_table_ops.ops.append( + ops.DropColumnOp.from_column_and_tablename( + schema, tname, conn_table.c[cname] + ) ) log.info("Detected removed column '%s.%s'", name, cname) @@ -310,10 +322,12 @@ class _fk_constraint_sig(_constraint_sig): ) -def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, - metadata_table, diffs, - autogen_context, inspector): +@comparators.dispatch_for("table") +def _compare_indexes_and_uniques( + autogen_context, modify_ops, schema, tname, conn_table, + metadata_table): + inspector = autogen_context.inspector is_create_table = conn_table is None # 1a. get raw indexes and unique constraints from metadata ... @@ -350,7 +364,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # 3. give the dialect a chance to omit indexes and constraints that # we know are either added implicitly by the DB or that the DB # can't accurately report on - autogen_context['context'].impl.\ + autogen_context.migration_context.impl.\ correct_for_autogen_constraints( conn_uniques, conn_indexes, metadata_unique_constraints, @@ -411,9 +425,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, def obj_added(obj): if obj.is_index: - if _run_filters( - obj.const, obj.name, "index", False, None, object_filters): - diffs.append(("add_index", obj.const)) + if autogen_context.run_filters( + obj.const, obj.name, "index", False, None): + modify_ops.ops.append( + ops.CreateIndexOp.from_index(obj.const) + ) log.info("Detected added index '%s' on %s", obj.name, ', '.join([ "'%s'" % obj.column_names @@ -426,10 +442,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if is_create_table: # unique constraints are created inline with table defs return - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, - "unique_constraint", False, None, object_filters): - diffs.append(("add_constraint", obj.const)) + "unique_constraint", False, None): + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(obj.const) + ) log.info("Detected added unique constraint '%s' on %s", obj.name, ', '.join([ "'%s'" % obj.column_names @@ -443,39 +461,51 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # be sure what we're doing here return - if _run_filters( - obj.const, obj.name, "index", True, None, object_filters): - diffs.append(("remove_index", obj.const)) + if autogen_context.run_filters( + obj.const, obj.name, "index", True, None): + modify_ops.ops.append( + ops.DropIndexOp.from_index(obj.const) + ) log.info( "Detected removed index '%s' on '%s'", obj.name, tname) else: - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, - "unique_constraint", True, None, object_filters): - diffs.append(("remove_constraint", obj.const)) + "unique_constraint", True, None): + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) log.info("Detected removed unique constraint '%s' on '%s'", obj.name, tname ) def obj_changed(old, new, msg): if old.is_index: - if _run_filters( + if autogen_context.run_filters( new.const, new.name, "index", - False, old.const, object_filters): + False, old.const): log.info("Detected changed index '%s' on '%s':%s", old.name, tname, ', '.join(msg) ) - diffs.append(("remove_index", old.const)) - diffs.append(("add_index", new.const)) + modify_ops.ops.append( + ops.DropIndexOp.from_index(old.const) + ) + modify_ops.ops.append( + ops.CreateIndexOp.from_index(new.const) + ) else: - if _run_filters( + if autogen_context.run_filters( new.const, new.name, - "unique_constraint", False, old.const, object_filters): + "unique_constraint", False, old.const): log.info("Detected changed unique constraint '%s' on '%s':%s", old.name, tname, ', '.join(msg) ) - diffs.append(("remove_constraint", old.const)) - diffs.append(("add_constraint", new.const)) + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(old.const) + ) + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(new.const) + ) for added_name in sorted(set(metadata_names).difference(conn_names)): obj = metadata_names[added_name] @@ -528,20 +558,21 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, obj_added(unnamed_metadata_uniques[uq_sig]) -def _compare_nullable(schema, tname, cname, conn_col, - metadata_col_nullable, diffs, - autogen_context): +@comparators.dispatch_for("column") +def _compare_nullable( + autogen_context, alter_column_op, schema, tname, cname, conn_col, + metadata_col): + + # work around SQLAlchemy issue #3023 + if metadata_col.primary_key: + return + + metadata_col_nullable = metadata_col.nullable conn_col_nullable = conn_col.nullable + alter_column_op.existing_nullable = conn_col_nullable + if conn_col_nullable is not metadata_col_nullable: - diffs.append( - ("modify_nullable", schema, tname, cname, - { - "existing_type": conn_col.type, - "existing_server_default": conn_col.server_default, - }, - conn_col_nullable, - metadata_col_nullable), - ) + alter_column_op.modify_nullable = metadata_col_nullable log.info("Detected %s on column '%s.%s'", "NULL" if metadata_col_nullable else "NOT NULL", tname, @@ -549,11 +580,13 @@ def _compare_nullable(schema, tname, cname, conn_col, ) -def _compare_type(schema, tname, cname, conn_col, - metadata_col, diffs, - autogen_context): +@comparators.dispatch_for("column") +def _compare_type( + autogen_context, alter_column_op, schema, tname, cname, conn_col, + metadata_col): conn_type = conn_col.type + alter_column_op.existing_type = conn_type metadata_type = metadata_col.type if conn_type._type_affinity is sqltypes.NullType: log.info("Couldn't determine database type " @@ -564,19 +597,11 @@ def _compare_type(schema, tname, cname, conn_col, "the model; can't compare", tname, cname) return - isdiff = autogen_context['context']._compare_type(conn_col, metadata_col) + isdiff = autogen_context.migration_context._compare_type( + conn_col, metadata_col) if isdiff: - - diffs.append( - ("modify_type", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_server_default": conn_col.server_default, - }, - conn_type, - metadata_type), - ) + alter_column_op.modify_type = metadata_type log.info("Detected type change from %r to %r on '%s.%s'", conn_type, metadata_type, tname, cname ) @@ -594,7 +619,7 @@ def _render_server_default_for_compare(metadata_default, metadata_default = metadata_default.arg else: metadata_default = str(metadata_default.arg.compile( - dialect=autogen_context['dialect'])) + dialect=autogen_context.dialect)) if isinstance(metadata_default, compat.string_types): if metadata_col.type._type_affinity is sqltypes.String: metadata_default = re.sub(r"^'|'$", "", metadata_default) @@ -605,8 +630,10 @@ def _render_server_default_for_compare(metadata_default, return None -def _compare_server_default(schema, tname, cname, conn_col, metadata_col, - diffs, autogen_context): +@comparators.dispatch_for("column") +def _compare_server_default( + autogen_context, alter_column_op, schema, tname, cname, + conn_col, metadata_col): metadata_default = metadata_col.server_default conn_col_default = conn_col.server_default @@ -618,36 +645,31 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, rendered_conn_default = conn_col.server_default.arg.text \ if conn_col.server_default else None - isdiff = autogen_context['context']._compare_server_default( + alter_column_op.existing_server_default = conn_col_default + + isdiff = autogen_context.migration_context._compare_server_default( conn_col, metadata_col, rendered_metadata_default, rendered_conn_default ) if isdiff: - conn_col_default = rendered_conn_default - diffs.append( - ("modify_default", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_type": conn_col.type, - }, - conn_col_default, - metadata_default), - ) - log.info("Detected server default on column '%s.%s'", - tname, - cname - ) + alter_column_op.modify_server_default = metadata_default + log.info( + "Detected server default on column '%s.%s'", + tname, cname) -def _compare_foreign_keys(schema, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, inspector): +@comparators.dispatch_for("table") +def _compare_foreign_keys( + autogen_context, modify_table_ops, schema, tname, conn_table, + metadata_table): # if we're doing CREATE TABLE, all FKs are created # inline within the table def if conn_table is None: return + inspector = autogen_context.inspector metadata_fks = set( fk for fk in metadata_table.constraints if isinstance(fk, sa_schema.ForeignKeyConstraint) @@ -673,10 +695,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table, ) def _add_fk(obj, compare_to): - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, "foreign_key_constraint", False, - compare_to, object_filters): - diffs.append(('add_fk', const.const)) + compare_to): + modify_table_ops.ops.append( + ops.CreateForeignKeyOp.from_constraint(const.const) + ) log.info( "Detected added foreign key (%s)(%s) on table %s%s", @@ -686,10 +710,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table, obj.source_table) def _remove_fk(obj, compare_to): - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, "foreign_key_constraint", True, - compare_to, object_filters): - diffs.append(('remove_fk', obj.const)) + compare_to): + modify_table_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) log.info( "Detected removed foreign key (%s)(%s) on table %s%s", ", ".join(obj.source_columns), @@ -713,5 +739,3 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table, compare_to = conn_fks_by_name[const.name].const \ if const.name in conn_fks_by_name else None _add_fk(const, compare_to) - - return diffs diff --git a/alembic/autogenerate/compose.py b/alembic/autogenerate/compose.py deleted file mode 100644 index b42b505..0000000 --- a/alembic/autogenerate/compose.py +++ /dev/null @@ -1,144 +0,0 @@ -import itertools -from ..operations import ops - - -def _to_migration_script(autogen_context, migration_script, diffs): - _to_upgrade_op( - autogen_context, - diffs, - migration_script.upgrade_ops, - ) - - _to_downgrade_op( - autogen_context, - diffs, - migration_script.downgrade_ops, - ) - - -def _to_upgrade_op(autogen_context, diffs, upgrade_ops): - return _to_updown_op(autogen_context, diffs, upgrade_ops, "upgrade") - - -def _to_downgrade_op(autogen_context, diffs, downgrade_ops): - return _to_updown_op(autogen_context, diffs, downgrade_ops, "downgrade") - - -def _to_updown_op(autogen_context, diffs, op_container, type_): - if not diffs: - return - - if type_ == 'downgrade': - diffs = reversed(diffs) - - dest = [op_container.ops] - - for (schema, tablename), subdiffs in _group_diffs_by_table(diffs): - subdiffs = list(subdiffs) - if tablename is not None: - table_ops = [] - op = ops.ModifyTableOps(tablename, table_ops, schema=schema) - dest[-1].append(op) - dest.append(table_ops) - for diff in subdiffs: - _produce_command(autogen_context, diff, dest[-1], type_) - if tablename is not None: - dest.pop(-1) - - -def _produce_command(autogen_context, diff, op_list, updown): - if isinstance(diff, tuple): - _produce_adddrop_command(updown, diff, op_list, autogen_context) - else: - _produce_modify_command(updown, diff, op_list, autogen_context) - - -def _produce_adddrop_command(updown, diff, op_list, autogen_context): - cmd_type = diff[0] - adddrop, cmd_type = cmd_type.split("_") - - cmd_args = diff[1:] - - _commands = { - "table": (ops.DropTableOp.from_table, ops.CreateTableOp.from_table), - "column": ( - ops.DropColumnOp.from_column_and_tablename, - ops.AddColumnOp.from_column_and_tablename), - "index": (ops.DropIndexOp.from_index, ops.CreateIndexOp.from_index), - "constraint": ( - ops.DropConstraintOp.from_constraint, - ops.AddConstraintOp.from_constraint), - "fk": ( - ops.DropConstraintOp.from_constraint, - ops.CreateForeignKeyOp.from_constraint) - } - - cmd_callables = _commands[cmd_type] - - if ( - updown == "upgrade" and adddrop == "add" - ) or ( - updown == "downgrade" and adddrop == "remove" - ): - op_list.append(cmd_callables[1](*cmd_args)) - else: - op_list.append(cmd_callables[0](*cmd_args)) - - -def _produce_modify_command(updown, diffs, op_list, autogen_context): - sname, tname, cname = diffs[0][1:4] - kw = {} - - _arg_struct = { - "modify_type": ("existing_type", "modify_type"), - "modify_nullable": ("existing_nullable", "modify_nullable"), - "modify_default": ("existing_server_default", "modify_server_default"), - } - for diff in diffs: - diff_kw = diff[4] - for arg in ("existing_type", - "existing_nullable", - "existing_server_default"): - if arg in diff_kw: - kw.setdefault(arg, diff_kw[arg]) - old_kw, new_kw = _arg_struct[diff[0]] - if updown == "upgrade": - kw[new_kw] = diff[-1] - kw[old_kw] = diff[-2] - else: - kw[new_kw] = diff[-2] - kw[old_kw] = diff[-1] - - if "modify_nullable" in kw: - kw.pop("existing_nullable", None) - if "modify_server_default" in kw: - kw.pop("existing_server_default", None) - - op_list.append( - ops.AlterColumnOp( - tname, cname, schema=sname, - **kw - ) - ) - - -def _group_diffs_by_table(diffs): - _adddrop = { - "table": lambda diff: (None, None), - "column": lambda diff: (diff[0], diff[1]), - "index": lambda diff: (diff[0].table.schema, diff[0].table.name), - "constraint": lambda diff: (diff[0].table.schema, diff[0].table.name), - "fk": lambda diff: (diff[0].parent.schema, diff[0].parent.name) - } - - def _derive_table(diff): - if isinstance(diff, tuple): - cmd_type = diff[0] - adddrop, cmd_type = cmd_type.split("_") - return _adddrop[cmd_type](diff[1:]) - else: - sname, tname = diff[0][1:3] - return sname, tname - - return itertools.groupby(diffs, _derive_table) - diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py deleted file mode 100644 index c686156..0000000 --- a/alembic/autogenerate/generate.py +++ /dev/null @@ -1,92 +0,0 @@ -from .. import util -from . import api -from . import compose -from . import compare -from . import render -from ..operations import ops - - -class RevisionContext(object): - def __init__(self, config, script_directory, command_args): - self.config = config - self.script_directory = script_directory - self.command_args = command_args - self.template_args = { - 'config': config # Let templates use config for - # e.g. multiple databases - } - self.generated_revisions = [ - self._default_revision() - ] - - def _to_script(self, migration_script): - template_args = {} - for k, v in self.template_args.items(): - template_args.setdefault(k, v) - - if migration_script._autogen_context is not None: - render._render_migration_script( - migration_script._autogen_context, migration_script, - template_args - ) - - return self.script_directory.generate_revision( - migration_script.rev_id, - migration_script.message, - refresh=True, - head=migration_script.head, - splice=migration_script.splice, - branch_labels=migration_script.branch_label, - version_path=migration_script.version_path, - **template_args) - - def run_autogenerate(self, rev, context): - if self.command_args['sql']: - raise util.CommandError( - "Using --sql with --autogenerate does not make any sense") - if set(self.script_directory.get_revisions(rev)) != \ - set(self.script_directory.get_revisions("heads")): - raise util.CommandError("Target database is not up to date.") - - autogen_context = api._autogen_context(context) - - diffs = [] - compare._produce_net_changes(autogen_context, diffs) - - migration_script = self.generated_revisions[0] - - compose._to_migration_script(autogen_context, migration_script, diffs) - - hook = context.opts.get('process_revision_directives', None) - if hook: - hook(context, rev, self.generated_revisions) - - for migration_script in self.generated_revisions: - migration_script._autogen_context = autogen_context - - def run_no_autogenerate(self, rev, context): - hook = context.opts.get('process_revision_directives', None) - if hook: - hook(context, rev, self.generated_revisions) - - for migration_script in self.generated_revisions: - migration_script._autogen_context = None - - def _default_revision(self): - op = ops.MigrationScript( - rev_id=self.command_args['rev_id'] or util.rev_id(), - message=self.command_args['message'], - imports=set(), - upgrade_ops=ops.UpgradeOps([]), - downgrade_ops=ops.DowngradeOps([]), - head=self.command_args['head'], - splice=self.command_args['splice'], - branch_label=self.command_args['branch_label'], - version_path=self.command_args['version_path'] - ) - op._autogen_context = None - return op - - def generate_scripts(self): - for generated_revision in self.generated_revisions: - yield self._to_script(generated_revision) diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index c3f3df1..6f5f96c 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -30,8 +30,8 @@ def _indent(text): def _render_migration_script(autogen_context, migration_script, template_args): - opts = autogen_context['opts'] - imports = autogen_context['imports'] + opts = autogen_context.opts + imports = autogen_context._imports template_args[opts['upgrade_token']] = _indent(_render_cmd_body( migration_script.upgrade_ops, autogen_context)) template_args[opts['downgrade_token']] = _indent(_render_cmd_body( @@ -78,23 +78,26 @@ def render_op_text(autogen_context, op): @renderers.dispatch_for(ops.ModifyTableOps) def _render_modify_table(autogen_context, op): - opts = autogen_context['opts'] + opts = autogen_context.opts render_as_batch = opts.get('render_as_batch', False) if op.ops: lines = [] if render_as_batch: - lines.append( - "with op.batch_alter_table(%r, schema=%r) as batch_op:" - % (op.table_name, op.schema) - ) - autogen_context['batch_prefix'] = 'batch_op.' - for t_op in op.ops: - t_lines = render_op(autogen_context, t_op) - lines.extend(t_lines) - if render_as_batch: - del autogen_context['batch_prefix'] - lines.append("") + with autogen_context._within_batch(): + lines.append( + "with op.batch_alter_table(%r, schema=%r) as batch_op:" + % (op.table_name, op.schema) + ) + for t_op in op.ops: + t_lines = render_op(autogen_context, t_op) + lines.extend(t_lines) + lines.append("") + else: + for t_op in op.ops: + t_lines = render_op(autogen_context, t_op) + lines.extend(t_lines) + return lines else: return [ @@ -149,7 +152,7 @@ def _drop_table(autogen_context, op): def _add_index(autogen_context, op): index = op.to_index() - has_batch = 'batch_prefix' in autogen_context + has_batch = autogen_context._has_batch if has_batch: tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\ @@ -180,7 +183,7 @@ def _add_index(autogen_context, op): @renderers.dispatch_for(ops.DropIndexOp) def _drop_index(autogen_context, op): - has_batch = 'batch_prefix' in autogen_context + has_batch = autogen_context._has_batch if has_batch: tmpl = "%(prefix)sdrop_index(%(name)r)" @@ -243,7 +246,7 @@ def _add_check_constraint(constraint, autogen_context): @renderers.dispatch_for(ops.DropConstraintOp) def _drop_constraint(autogen_context, op): - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)sdrop_constraint"\ "(%(name)r, type_=%(type)r)" else: @@ -266,7 +269,7 @@ def _drop_constraint(autogen_context, op): def _add_column(autogen_context, op): schema, tname, column = op.schema, op.table_name, op.column - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)sadd_column(%(column)s)" else: template = "%(prefix)sadd_column(%(tname)r, %(column)s" @@ -287,7 +290,7 @@ def _drop_column(autogen_context, op): schema, tname, column_name = op.schema, op.table_name, op.column_name - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)sdrop_column(%(cname)r)" else: template = "%(prefix)sdrop_column(%(tname)r, %(cname)r" @@ -319,7 +322,7 @@ def _alter_column(autogen_context, op): indent = " " * 11 - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)salter_column(%(cname)r" else: template = "%(prefix)salter_column(%(tname)r, %(cname)r" @@ -343,16 +346,16 @@ def _alter_column(autogen_context, op): if nullable is not None: text += ",\n%snullable=%r" % ( indent, nullable,) - if existing_nullable is not None: + if nullable is None and existing_nullable is not None: text += ",\n%sexisting_nullable=%r" % ( indent, existing_nullable) - if existing_server_default: + if server_default is False and existing_server_default: rendered = _render_server_default( existing_server_default, autogen_context) text += ",\n%sexisting_server_default=%s" % ( indent, rendered) - if schema and "batch_prefix" not in autogen_context: + if schema and not autogen_context._has_batch: text += ",\n%sschema=%r" % (indent, schema) text += ")" return text @@ -409,7 +412,7 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True): return template % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "sql": compat.text_type( - value.compile(dialect=autogen_context['dialect'], + value.compile(dialect=autogen_context.dialect, **compile_kw) ) } @@ -432,7 +435,7 @@ def _get_index_rendered_expressions(idx, autogen_context): def _uq_constraint(constraint, autogen_context, alter): opts = [] - has_batch = 'batch_prefix' in autogen_context + has_batch = autogen_context._has_batch if constraint.deferrable: opts.append(("deferrable", str(constraint.deferrable))) @@ -467,7 +470,7 @@ def _uq_constraint(constraint, autogen_context, alter): def _user_autogenerate_prefix(autogen_context, target): - prefix = autogen_context['opts']['user_module_prefix'] + prefix = autogen_context.opts['user_module_prefix'] if prefix is None: return "%s." % target.__module__ else: @@ -475,20 +478,19 @@ def _user_autogenerate_prefix(autogen_context, target): def _sqlalchemy_autogenerate_prefix(autogen_context): - return autogen_context['opts']['sqlalchemy_module_prefix'] or '' + return autogen_context.opts['sqlalchemy_module_prefix'] or '' def _alembic_autogenerate_prefix(autogen_context): - if 'batch_prefix' in autogen_context: - return autogen_context['batch_prefix'] + if autogen_context._has_batch: + return 'batch_op.' else: - return autogen_context['opts']['alembic_module_prefix'] or '' + return autogen_context.opts['alembic_module_prefix'] or '' def _user_defined_render(type_, object_, autogen_context): - if 'opts' in autogen_context and \ - 'render_item' in autogen_context['opts']: - render = autogen_context['opts']['render_item'] + if 'render_item' in autogen_context.opts: + render = autogen_context.opts['render_item'] if render: rendered = render(type_, object_, autogen_context) if rendered is not False: @@ -547,7 +549,7 @@ def _repr_type(type_, autogen_context): return rendered mod = type(type_).__module__ - imports = autogen_context.get('imports', None) + imports = autogen_context._imports if mod.startswith("sqlalchemy.dialects"): dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1) if imports is not None: diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 08a0551..71e8515 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -3,6 +3,7 @@ from ..util import sqla_compat from . import schemaobj from sqlalchemy.types import NULLTYPE from .base import Operations, BatchOperations +import re class MigrateOperation(object): @@ -34,6 +35,10 @@ class MigrateOperation(object): class AddConstraintOp(MigrateOperation): """Represent an add constraint operation.""" + @property + def constraint_type(self): + raise NotImplementedError() + @classmethod def from_constraint(cls, constraint): funcs = { @@ -45,17 +50,40 @@ class AddConstraintOp(MigrateOperation): } return funcs[constraint.__visit_name__](constraint) + def reverse(self): + return DropConstraintOp.from_constraint(self.to_constraint()) + + def to_diff_tuple(self): + return ("add_constraint", self.to_constraint()) + @Operations.register_operation("drop_constraint") @BatchOperations.register_operation("drop_constraint", "batch_drop_constraint") class DropConstraintOp(MigrateOperation): """Represent a drop constraint operation.""" - def __init__(self, constraint_name, table_name, type_=None, schema=None): + def __init__( + self, + constraint_name, table_name, type_=None, schema=None, + _orig_constraint=None): self.constraint_name = constraint_name self.table_name = table_name self.constraint_type = type_ self.schema = schema + self._orig_constraint = _orig_constraint + + def reverse(self): + if self._orig_constraint is None: + raise ValueError( + "operation is not reversible; " + "original constraint is not present") + return AddConstraintOp.from_constraint(self._orig_constraint) + + def to_diff_tuple(self): + if self.constraint_type == "foreignkey": + return ("remove_fk", self.to_constraint()) + else: + return ("remove_constraint", self.to_constraint()) @classmethod def from_constraint(cls, constraint): @@ -72,9 +100,18 @@ class DropConstraintOp(MigrateOperation): constraint.name, constraint_table.name, schema=constraint_table.schema, - type_=types[constraint.__visit_name__] + type_=types[constraint.__visit_name__], + _orig_constraint=constraint ) + def to_constraint(self): + if self._orig_constraint is not None: + return self._orig_constraint + else: + raise ValueError( + "constraint cannot be produced; " + "original constraint is not present") + @classmethod @util._with_legacy_names([("type", "type_")]) def drop_constraint( @@ -124,8 +161,11 @@ class DropConstraintOp(MigrateOperation): class CreatePrimaryKeyOp(AddConstraintOp): """Represent a create primary key operation.""" + constraint_type = "primarykey" + def __init__( - self, constraint_name, table_name, columns, schema=None, **kw): + self, constraint_name, table_name, columns, + schema=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -225,8 +265,11 @@ class CreatePrimaryKeyOp(AddConstraintOp): class CreateUniqueConstraintOp(AddConstraintOp): """Represent a create unique constraint operation.""" + constraint_type = "unique" + def __init__( - self, constraint_name, table_name, columns, schema=None, **kw): + self, constraint_name, table_name, + columns, schema=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -342,6 +385,8 @@ class CreateUniqueConstraintOp(AddConstraintOp): class CreateForeignKeyOp(AddConstraintOp): """Represent a create foreign key constraint operation.""" + constraint_type = "foreignkey" + def __init__( self, constraint_name, source_table, referent_table, local_cols, remote_cols, **kw): @@ -352,6 +397,9 @@ class CreateForeignKeyOp(AddConstraintOp): self.remote_cols = remote_cols self.kw = kw + def to_diff_tuple(self): + return ("add_fk", self.to_constraint()) + @classmethod def from_constraint(cls, constraint): kw = {} @@ -507,6 +555,8 @@ class CreateForeignKeyOp(AddConstraintOp): class CreateCheckConstraintOp(AddConstraintOp): """Represent a create check constraint operation.""" + constraint_type = "check" + def __init__( self, constraint_name, table_name, condition, schema=None, **kw): self.constraint_name = constraint_name @@ -523,7 +573,7 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint.name, constraint_table.name, constraint.condition, - schema=constraint_table.schema + schema=constraint_table.schema, ) def to_constraint(self, migration_context=None): @@ -624,6 +674,12 @@ class CreateIndexOp(MigrateOperation): self.kw = kw self._orig_index = _orig_index + def reverse(self): + return DropIndexOp.from_index(self.to_index()) + + def to_diff_tuple(self): + return ("add_index", self.to_index()) + @classmethod def from_index(cls, index): return cls( @@ -729,10 +785,22 @@ class CreateIndexOp(MigrateOperation): class DropIndexOp(MigrateOperation): """Represent a drop index operation.""" - def __init__(self, index_name, table_name=None, schema=None): + def __init__( + self, index_name, table_name=None, schema=None, _orig_index=None): self.index_name = index_name self.table_name = table_name self.schema = schema + self._orig_index = _orig_index + + def to_diff_tuple(self): + return ("remove_index", self.to_index()) + + def reverse(self): + if self._orig_index is None: + raise ValueError( + "operation is not reversible; " + "original index is not present") + return CreateIndexOp.from_index(self._orig_index) @classmethod def from_index(cls, index): @@ -740,6 +808,7 @@ class DropIndexOp(MigrateOperation): index.name, index.table.name, schema=index.table.schema, + _orig_index=index ) def to_index(self, migration_context=None): @@ -807,6 +876,12 @@ class CreateTableOp(MigrateOperation): self.kw = kw self._orig_table = _orig_table + def reverse(self): + return DropTableOp.from_table(self.to_table()) + + def to_diff_tuple(self): + return ("add_table", self.to_table()) + @classmethod def from_table(cls, table): return cls( @@ -921,16 +996,30 @@ class CreateTableOp(MigrateOperation): class DropTableOp(MigrateOperation): """Represent a drop table operation.""" - def __init__(self, table_name, schema=None, table_kw=None): + def __init__( + self, table_name, schema=None, table_kw=None, _orig_table=None): self.table_name = table_name self.schema = schema self.table_kw = table_kw or {} + self._orig_table = _orig_table + + def to_diff_tuple(self): + return ("remove_table", self.to_table()) + + def reverse(self): + if self._orig_table is None: + raise ValueError( + "operation is not reversible; " + "original table is not present") + return CreateTableOp.from_table(self._orig_table) @classmethod def from_table(cls, table): - return cls(table.name, schema=table.schema) + return cls(table.name, schema=table.schema, _orig_table=table) - def to_table(self, migration_context): + def to_table(self, migration_context=None): + if self._orig_table is not None: + return self._orig_table schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( self.table_name, @@ -1029,6 +1118,87 @@ class AlterColumnOp(AlterTableOp): self.modify_type = modify_type self.kw = kw + def to_diff_tuple(self): + col_diff = [] + schema, tname, cname = self.schema, self.table_name, self.column_name + + if self.modify_type is not None: + col_diff.append( + ("modify_type", schema, tname, cname, + { + "existing_nullable": self.existing_nullable, + "existing_server_default": self.existing_server_default, + }, + self.existing_type, + self.modify_type) + ) + + if self.modify_nullable is not None: + col_diff.append( + ("modify_nullable", schema, tname, cname, + { + "existing_type": self.existing_type, + "existing_server_default": self.existing_server_default + }, + self.existing_nullable, + self.modify_nullable) + ) + + if self.modify_server_default is not False: + col_diff.append( + ("modify_default", schema, tname, cname, + { + "existing_nullable": self.existing_nullable, + "existing_type": self.existing_type + }, + self.existing_server_default, + self.modify_server_default) + ) + + return col_diff + + def has_changes(self): + hc1 = self.modify_nullable is not None or \ + self.modify_server_default is not False or \ + self.modify_type is not None + if hc1: + return True + for kw in self.kw: + if kw.startswith('modify_'): + return True + else: + return False + + def reverse(self): + + kw = self.kw.copy() + kw['existing_type'] = self.existing_type + kw['existing_nullable'] = self.existing_nullable + kw['existing_server_default'] = self.existing_server_default + if self.modify_type is not None: + kw['modify_type'] = self.modify_type + if self.modify_nullable is not None: + kw['modify_nullable'] = self.modify_nullable + if self.modify_server_default is not False: + kw['modify_server_default'] = self.modify_server_default + + # TODO: make this a little simpler + all_keys = set(m.group(1) for m in [ + re.match(r'^(?:existing_|modify_)(.+)$', k) + for k in kw + ] if m) + + for k in all_keys: + if 'modify_%s' % k in kw: + swap = kw['existing_%s' % k] + kw['existing_%s' % k] = kw['modify_%s' % k] + kw['modify_%s' % k] = swap + + return self.__class__( + self.table_name, self.column_name, schema=self.schema, + **kw + ) + @classmethod @util._with_legacy_names([('name', 'new_column_name')]) def alter_column( @@ -1177,6 +1347,13 @@ class AddColumnOp(AlterTableOp): super(AddColumnOp, self).__init__(table_name, schema=schema) self.column = column + def reverse(self): + return DropColumnOp.from_column_and_tablename( + self.schema, self.table_name, self.column) + + def to_diff_tuple(self): + return ("add_column", self.schema, self.table_name, self.column) + @classmethod def from_column(cls, col): return cls(col.table.name, col, schema=col.table.schema) @@ -1265,14 +1442,30 @@ class AddColumnOp(AlterTableOp): class DropColumnOp(AlterTableOp): """Represent a drop column operation.""" - def __init__(self, table_name, column_name, schema=None, **kw): + def __init__( + self, table_name, column_name, schema=None, + _orig_column=None, **kw): super(DropColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw + self._orig_column = _orig_column + + def to_diff_tuple(self): + return ( + "remove_column", self.schema, self.table_name, self.to_column()) + + def reverse(self): + if self._orig_column is None: + raise ValueError( + "operation is not reversible; " + "original column is not present") + + return AddColumnOp.from_column_and_tablename( + self.schema, self.table_name, self._orig_column) @classmethod def from_column_and_tablename(cls, schema, tname, col): - return cls(tname, col.name, schema=schema) + return cls(tname, col.name, schema=schema, _orig_column=col) def to_column(self, migration_context=None): schema_obj = schemaobj.SchemaObjects(migration_context) @@ -1522,6 +1715,21 @@ class OpContainer(MigrateOperation): def __init__(self, ops=()): self.ops = ops + def is_empty(self): + return not self.ops + + def as_diffs(self): + return list(OpContainer._ops_as_diffs(self)) + + @classmethod + def _ops_as_diffs(cls, migrations): + for op in migrations.ops: + if hasattr(op, 'ops'): + for sub_op in cls._ops_as_diffs(op): + yield sub_op + else: + yield op.to_diff_tuple() + class ModifyTableOps(OpContainer): """Contains a sequence of operations that all apply to a single Table.""" @@ -1531,6 +1739,15 @@ class ModifyTableOps(OpContainer): self.table_name = table_name self.schema = schema + def reverse(self): + return ModifyTableOps( + self.table_name, + ops=list(reversed( + [op.reverse() for op in self.ops] + )), + schema=self.schema + ) + class UpgradeOps(OpContainer): """contains a sequence of operations that would apply to the @@ -1542,6 +1759,15 @@ class UpgradeOps(OpContainer): """ + def reverse_into(self, downgrade_ops): + downgrade_ops.ops[:] = list(reversed( + [op.reverse() for op in self.ops] + )) + return downgrade_ops + + def reverse(self): + return self.reverse_into(DowngradeOps(ops=[])) + class DowngradeOps(OpContainer): """contains a sequence of operations that would apply to the @@ -1553,6 +1779,13 @@ class DowngradeOps(OpContainer): """ + def reverse(self): + return UpgradeOps( + ops=list(reversed( + [op.reverse() for op in self.ops] + )) + ) + class MigrationScript(MigrateOperation): """represents a migration script. @@ -1583,4 +1816,3 @@ class MigrationScript(MigrateOperation): self.version_path = version_path self.upgrade_ops = upgrade_ops self.downgrade_ops = downgrade_ops - diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 84a3c7f..e811a36 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -118,6 +118,7 @@ class MigrationContext(object): connection=None, url=None, dialect_name=None, + dialect=None, environment_context=None, opts=None, ): @@ -152,7 +153,7 @@ class MigrationContext(object): elif dialect_name: url = sqla_url.make_url("%s://" % dialect_name) dialect = url.get_dialect()() - else: + elif not dialect: raise Exception("Connection, url, or dialect_name is required.") return MigrationContext(dialect, connection, opts, environment_context) diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index 1fb0942..6c92e3c 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -257,30 +257,57 @@ class immutabledict(dict): class Dispatcher(object): - def __init__(self): + def __init__(self, uselist=False): self._registry = {} + self.uselist = uselist def dispatch_for(self, target, qualifier='default'): def decorate(fn): - assert isinstance(target, type) - assert target not in self._registry - self._registry[(target, qualifier)] = fn + if self.uselist: + assert target not in self._registry + self._registry.setdefault((target, qualifier), []).append(fn) + else: + assert target not in self._registry + self._registry[(target, qualifier)] = fn return fn return decorate def dispatch(self, obj, qualifier='default'): - for spcls in type(obj).__mro__: + + if isinstance(obj, string_types): + targets = [obj] + elif isinstance(obj, type): + targets = obj.__mro__ + else: + targets = type(obj).__mro__ + + for spcls in targets: if qualifier != 'default' and (spcls, qualifier) in self._registry: - return self._registry[(spcls, qualifier)] + return self._fn_or_list(self._registry[(spcls, qualifier)]) elif (spcls, 'default') in self._registry: - return self._registry[(spcls, 'default')] + return self._fn_or_list(self._registry[(spcls, 'default')]) else: raise ValueError("no dispatch function for object: %s" % obj) + def _fn_or_list(self, fn_or_list): + if self.uselist: + def go(*arg, **kw): + for fn in fn_or_list: + fn(*arg, **kw) + return go + else: + return fn_or_list + def branch(self): """Return a copy of this dispatcher that is independently writable.""" d = Dispatcher() - d._registry.update(self._registry) + if self.uselist: + d._registry.update( + (k, [fn for fn in self._registry[k]]) + for k in self._registry + ) + else: + d._registry.update(self._registry) return d |