diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-16 19:00:55 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-16 19:00:55 -0400 |
commit | c5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7 (patch) | |
tree | f158a8b62d7679176f3f60d1b1bf188a6383e03c | |
parent | 96214629cdb13f1694831f36c48a7ec86dd8c7f6 (diff) | |
download | alembic-c5be31760e4bc57b898d1d69d4bb0dd7c2dc7eb7.tar.gz |
- rework all of autogenerate to build directly on alembic.operations.ops
objects; the "diffs" is now a legacy system that is exported from
the ops. A new model of comparison/rendering/ upgrade/downgrade
composition that is cleaner and much more extensible is introduced.
- autogenerate is now extensible as far as database objects compared
and rendered into scripts; any new operation directive can also be
registered into a series of hooks that allow custom database/model
comparison functions to run as well as to render new operation
directives into autogenerate scripts.
- write all new docs for the new system
fixes #306
-rw-r--r-- | alembic/autogenerate/__init__.py | 6 | ||||
-rw-r--r-- | alembic/autogenerate/api.py | 306 | ||||
-rw-r--r-- | alembic/autogenerate/compare.py | 336 | ||||
-rw-r--r-- | alembic/autogenerate/compose.py | 144 | ||||
-rw-r--r-- | alembic/autogenerate/generate.py | 92 | ||||
-rw-r--r-- | alembic/autogenerate/render.py | 70 | ||||
-rw-r--r-- | alembic/operations/ops.py | 256 | ||||
-rw-r--r-- | alembic/runtime/migration.py | 3 | ||||
-rw-r--r-- | alembic/util/langhelpers.py | 43 | ||||
-rw-r--r-- | docs/build/api/autogenerate.rst | 255 | ||||
-rw-r--r-- | docs/build/api/operations.rst | 66 | ||||
-rw-r--r-- | docs/build/changelog.rst | 11 | ||||
-rw-r--r-- | tests/_autogen_fixtures.py | 41 | ||||
-rw-r--r-- | tests/test_autogen_composition.py | 21 | ||||
-rw-r--r-- | tests/test_autogen_diffs.py | 203 | ||||
-rw-r--r-- | tests/test_autogen_fks.py | 6 | ||||
-rw-r--r-- | tests/test_autogen_indexes.py | 12 | ||||
-rw-r--r-- | tests/test_autogen_render.py | 177 | ||||
-rw-r--r-- | tests/test_postgresql.py | 67 |
19 files changed, 1263 insertions, 852 deletions
diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py index 4272a7e..78520a8 100644 --- a/alembic/autogenerate/__init__.py +++ b/alembic/autogenerate/__init__.py @@ -1,7 +1,7 @@ from .api import ( # noqa compare_metadata, _render_migration_diffs, - produce_migrations, render_python_code + produce_migrations, render_python_code, + RevisionContext ) -from .compare import _produce_net_changes # noqa -from .generate import RevisionContext # noqa +from .compare import _produce_net_changes, comparators # noqa from .render import render_op_text, renderers # noqa
\ No newline at end of file diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index cff977b..e9af4cf 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -4,8 +4,9 @@ automatically.""" from ..operations import ops from . import render from . import compare -from . import compose from .. import util +from sqlalchemy.engine.reflection import Inspector +import contextlib def compare_metadata(context, metadata): @@ -98,20 +99,8 @@ def compare_metadata(context, metadata): """ - autogen_context = _autogen_context(context, metadata=metadata) - - # as_sql=True is nonsensical here. autogenerate requires a connection - # it can use to run queries against to get the database schema. - if context.as_sql: - raise util.CommandError( - "autogenerate can't use as_sql=True as it prevents querying " - "the database for schema information") - - diffs = [] - - compare._produce_net_changes(autogen_context, diffs) - - return diffs + migration_script = produce_migrations(context, metadata) + return migration_script.upgrade_ops.as_diffs() def produce_migrations(context, metadata): @@ -132,10 +121,7 @@ def produce_migrations(context, metadata): """ - autogen_context = _autogen_context(context, metadata=metadata) - diffs = [] - - compare._produce_net_changes(autogen_context, diffs) + autogen_context = AutogenContext(context, metadata=metadata) migration_script = ops.MigrationScript( rev_id=None, @@ -143,7 +129,7 @@ def produce_migrations(context, metadata): downgrade_ops=ops.DowngradeOps([]), ) - compose._to_migration_script(autogen_context, migration_script, diffs) + compare._populate_migration_script(autogen_context, migration_script) return migration_script @@ -152,6 +138,7 @@ def render_python_code( up_or_down_op, sqlalchemy_module_prefix='sa.', alembic_module_prefix='op.', + render_as_batch=False, imports=(), render_item=None, ): @@ -162,84 +149,239 @@ def render_python_code( autogenerate output of a user-defined :class:`.MigrationScript` structure. """ - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': sqlalchemy_module_prefix, - 'alembic_module_prefix': alembic_module_prefix, - 'render_item': render_item, - }, - 'imports': set(imports) + opts = { + 'sqlalchemy_module_prefix': sqlalchemy_module_prefix, + 'alembic_module_prefix': alembic_module_prefix, + 'render_item': render_item, + 'render_as_batch': render_as_batch, } + + autogen_context = AutogenContext(None, opts=opts) + autogen_context.imports = set(imports) return render._indent(render._render_cmd_body( up_or_down_op, autogen_context)) - - -def _render_migration_diffs(context, template_args, imports): +def _render_migration_diffs(context, template_args): """legacy, used by test_autogen_composition at the moment""" - migration_script = produce_migrations(context, None) - - autogen_context = _autogen_context(context, imports=imports) - diffs = [] + autogen_context = AutogenContext(context) - compare._produce_net_changes(autogen_context, diffs) + upgrade_ops = ops.UpgradeOps([]) + compare._produce_net_changes(autogen_context, upgrade_ops) migration_script = ops.MigrationScript( rev_id=None, - imports=imports, - upgrade_ops=ops.UpgradeOps([]), - downgrade_ops=ops.DowngradeOps([]), + upgrade_ops=upgrade_ops, + downgrade_ops=upgrade_ops.reverse(), ) - compose._to_migration_script(autogen_context, migration_script, diffs) - render._render_migration_script( autogen_context, migration_script, template_args ) -def _autogen_context( - context, imports=None, metadata=None, include_symbol=None, - include_object=None, include_schemas=False): - - opts = context.opts - metadata = opts['target_metadata'] if metadata is None else metadata - include_schemas = opts.get('include_schemas', include_schemas) - - include_symbol = opts.get('include_symbol', include_symbol) - include_object = opts.get('include_object', include_object) - - object_filters = [] - if include_symbol: - def include_symbol_filter(object, name, type_, reflected, compare_to): - if type_ == "table": - return include_symbol(name, object.schema) - else: - return True - object_filters.append(include_symbol_filter) - if include_object: - object_filters.append(include_object) - - if metadata is None: - raise util.CommandError( - "Can't proceed with --autogenerate option; environment " - "script %s does not provide " - "a MetaData object to the context." % ( - context.script.env_py_location - )) - - opts = context.opts - connection = context.bind - return { - 'imports': imports if imports is not None else set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context, - 'opts': opts, - 'metadata': metadata, - 'object_filters': object_filters, - 'include_schemas': include_schemas - } +class AutogenContext(object): + """Maintains configuration and state that's specific to an + autogenerate operation.""" + + metadata = None + """The :class:`~sqlalchemy.schema.MetaData` object + representing the destination. + + This object is the one that is passed within ``env.py`` + to the :paramref:`.EnvironmentContext.configure.target_metadata` + parameter. It represents the structure of :class:`.Table` and other + objects as stated in the current database model, and represents the + destination structure for the database being examined. + + While the :class:`~sqlalchemy.schema.MetaData` object is primarily + known as a collection of :class:`~sqlalchemy.schema.Table` objects, + it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary + that may be used by end-user schemes to store additional schema-level + objects that are to be compared in custom autogeneration schemes. + + """ + + connection = None + """The :class:`~sqlalchemy.engine.base.Connection` object currently + connected to the database backend being compared. + + This is obtained from the :attr:`.MigrationContext.bind` and is + utimately set up in the ``env.py`` script. + + """ + + dialect = None + """The :class:`~sqlalchemy.engine.Dialect` object currently in use. + + This is normally obtained from the + :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute. + + """ + + migration_context = None + """The :class:`.MigrationContext` established by the ``env.py`` script.""" + + def __init__(self, migration_context, metadata=None, opts=None): + + if migration_context is not None and migration_context.as_sql: + raise util.CommandError( + "autogenerate can't use as_sql=True as it prevents querying " + "the database for schema information") + + if opts is None: + opts = migration_context.opts + self.metadata = metadata = opts.get('target_metadata', None) \ + if metadata is None else metadata + + if metadata is None and \ + migration_context is not None and \ + migration_context.script is not None: + raise util.CommandError( + "Can't proceed with --autogenerate option; environment " + "script %s does not provide " + "a MetaData object to the context." % ( + migration_context.script.env_py_location + )) + + include_symbol = opts.get('include_symbol', None) + include_object = opts.get('include_object', None) + + object_filters = [] + if include_symbol: + def include_symbol_filter( + object, name, type_, reflected, compare_to): + if type_ == "table": + return include_symbol(name, object.schema) + else: + return True + object_filters.append(include_symbol_filter) + if include_object: + object_filters.append(include_object) + + self._object_filters = object_filters + + self.migration_context = migration_context + if self.migration_context is not None: + self.connection = self.migration_context.bind + self.dialect = self.migration_context.dialect + + self._imports = set() + self.opts = opts + self._has_batch = False + + @util.memoized_property + def inspector(self): + return Inspector.from_engine(self.connection) + + @contextlib.contextmanager + def _within_batch(self): + self._has_batch = True + yield + self._has_batch = False + + def run_filters(self, object_, name, type_, reflected, compare_to): + """Run the context's object filters and return True if the targets + should be part of the autogenerate operation. + + This method should be run for every kind of object encountered within + an autogenerate operation, giving the environment the chance + to filter what objects should be included in the comparison. + The filters here are produced directly via the + :paramref:`.EnvironmentContext.configure.include_object` + and :paramref:`.EnvironmentContext.configure.include_symbol` + functions, if present. + + """ + for fn in self._object_filters: + if not fn(object_, name, type_, reflected, compare_to): + return False + else: + return True + + +class RevisionContext(object): + """Maintains configuration and state that's specific to a revision + file generation operation.""" + + def __init__(self, config, script_directory, command_args): + self.config = config + self.script_directory = script_directory + self.command_args = command_args + self.template_args = { + 'config': config # Let templates use config for + # e.g. multiple databases + } + self.generated_revisions = [ + self._default_revision() + ] + + def _to_script(self, migration_script): + template_args = {} + for k, v in self.template_args.items(): + template_args.setdefault(k, v) + + if migration_script._autogen_context is not None: + render._render_migration_script( + migration_script._autogen_context, migration_script, + template_args + ) + + return self.script_directory.generate_revision( + migration_script.rev_id, + migration_script.message, + refresh=True, + head=migration_script.head, + splice=migration_script.splice, + branch_labels=migration_script.branch_label, + version_path=migration_script.version_path, + **template_args) + + def run_autogenerate(self, rev, context): + if self.command_args['sql']: + raise util.CommandError( + "Using --sql with --autogenerate does not make any sense") + if set(self.script_directory.get_revisions(rev)) != \ + set(self.script_directory.get_revisions("heads")): + raise util.CommandError("Target database is not up to date.") + + autogen_context = AutogenContext(context) + + migration_script = self.generated_revisions[0] + + compare._populate_migration_script(autogen_context, migration_script) + + hook = context.opts.get('process_revision_directives', None) + if hook: + hook(context, rev, self.generated_revisions) + + for migration_script in self.generated_revisions: + migration_script._autogen_context = autogen_context + + def run_no_autogenerate(self, rev, context): + hook = context.opts.get('process_revision_directives', None) + if hook: + hook(context, rev, self.generated_revisions) + + for migration_script in self.generated_revisions: + migration_script._autogen_context = None + + def _default_revision(self): + op = ops.MigrationScript( + rev_id=self.command_args['rev_id'] or util.rev_id(), + message=self.command_args['message'], + imports=set(), + upgrade_ops=ops.UpgradeOps([]), + downgrade_ops=ops.DowngradeOps([]), + head=self.command_args['head'], + splice=self.command_args['splice'], + branch_label=self.command_args['branch_label'], + version_path=self.command_args['version_path'] + ) + op._autogen_context = None + return op + def generate_scripts(self): + for generated_revision in self.generated_revisions: + yield self._to_script(generated_revision) diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index cd6b696..fdc3cae 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1,7 +1,9 @@ from sqlalchemy import schema as sa_schema, types as sqltypes from sqlalchemy.engine.reflection import Inspector from sqlalchemy import event +from ..operations import ops import logging +from .. import util from ..util import compat from ..util import sqla_compat from sqlalchemy.util import OrderedSet @@ -13,15 +15,20 @@ from alembic.ddl.base import _fk_spec log = logging.getLogger(__name__) -def _produce_net_changes(autogen_context, diffs): +def _populate_migration_script(autogen_context, migration_script): + _produce_net_changes(autogen_context, migration_script.upgrade_ops) + migration_script.upgrade_ops.reverse_into(migration_script.downgrade_ops) - metadata = autogen_context['metadata'] - connection = autogen_context['connection'] - object_filters = autogen_context.get('object_filters', ()) - include_schemas = autogen_context.get('include_schemas', False) + +comparators = util.Dispatcher(uselist=True) + + +def _produce_net_changes(autogen_context, upgrade_ops): + + connection = autogen_context.connection + include_schemas = autogen_context.opts.get('include_schemas', False) inspector = Inspector.from_engine(connection) - conn_table_names = set() default_schema = connection.dialect.default_schema_name if include_schemas: @@ -34,14 +41,28 @@ def _produce_net_changes(autogen_context, diffs): else: schemas = [None] - version_table_schema = autogen_context['context'].version_table_schema - version_table = autogen_context['context'].version_table + comparators.dispatch("schema", autogen_context.dialect.name)( + autogen_context, upgrade_ops, schemas + ) + + +@comparators.dispatch_for("schema") +def _autogen_for_tables(autogen_context, upgrade_ops, schemas): + inspector = autogen_context.inspector + + metadata = autogen_context.metadata + + conn_table_names = set() + + version_table_schema = \ + autogen_context.migration_context.version_table_schema + version_table = autogen_context.migration_context.version_table for s in schemas: tables = set(inspector.get_table_names(schema=s)) if s == version_table_schema: tables = tables.difference( - [autogen_context['context'].version_table] + [autogen_context.migration_context.version_table] ) conn_table_names.update(zip([s] * len(tables), tables)) @@ -50,21 +71,11 @@ def _produce_net_changes(autogen_context, diffs): ).difference([(version_table_schema, version_table)]) _compare_tables(conn_table_names, metadata_table_names, - object_filters, - inspector, metadata, diffs, autogen_context) - - -def _run_filters(object_, name, type_, reflected, compare_to, object_filters): - for fn in object_filters: - if not fn(object_, name, type_, reflected, compare_to): - return False - else: - return True + inspector, metadata, upgrade_ops, autogen_context) def _compare_tables(conn_table_names, metadata_table_names, - object_filters, - inspector, metadata, diffs, autogen_context): + inspector, metadata, upgrade_ops, autogen_context): default_schema = inspector.bind.dialect.default_schema_name @@ -95,14 +106,19 @@ def _compare_tables(conn_table_names, metadata_table_names, for s, tname in metadata_table_names.difference(conn_table_names): name = '%s.%s' % (s, tname) if s else tname metadata_table = tname_to_table[(s, tname)] - if _run_filters( - metadata_table, tname, "table", False, None, object_filters): - diffs.append(("add_table", metadata_table)) + if autogen_context.run_filters( + metadata_table, tname, "table", False, None): + upgrade_ops.ops.append( + ops.CreateTableOp.from_table(metadata_table)) log.info("Detected added table %r", name) - _compare_indexes_and_uniques(s, tname, object_filters, - None, - metadata_table, - diffs, autogen_context, inspector) + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + + comparators.dispatch("table")( + autogen_context, modify_table_ops, + s, tname, None, metadata_table + ) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) removal_metadata = sa_schema.MetaData() for s, tname in conn_table_names.difference(metadata_table_names): @@ -114,11 +130,13 @@ def _compare_tables(conn_table_names, metadata_table_names, event.listen( t, "column_reflect", - autogen_context['context'].impl. + autogen_context.migration_context.impl. _compat_autogen_column_reflect(inspector)) inspector.reflecttable(t, None) - if _run_filters(t, tname, "table", True, None, object_filters): - diffs.append(("remove_table", t)) + if autogen_context.run_filters(t, tname, "table", True, None): + upgrade_ops.ops.append( + ops.DropTableOp.from_table(t) + ) log.info("Detected removed table %r", name) existing_tables = conn_table_names.intersection(metadata_table_names) @@ -133,7 +151,7 @@ def _compare_tables(conn_table_names, metadata_table_names, event.listen( t, "column_reflect", - autogen_context['context'].impl. + autogen_context.migration_context.impl. _compat_autogen_column_reflect(inspector)) inspector.reflecttable(t, None) conn_column_info[(s, tname)] = t @@ -144,25 +162,24 @@ def _compare_tables(conn_table_names, metadata_table_names, metadata_table = tname_to_table[(s, tname)] conn_table = existing_metadata.tables[name] - if _run_filters( + if autogen_context.run_filters( metadata_table, tname, "table", False, - conn_table, object_filters): + conn_table): + + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) with _compare_columns( - s, tname, object_filters, + s, tname, conn_table, metadata_table, - diffs, autogen_context, inspector): - _compare_indexes_and_uniques(s, tname, object_filters, - conn_table, - metadata_table, - diffs, autogen_context, inspector) - _compare_foreign_keys(s, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, - inspector) + modify_table_ops, autogen_context, inspector): - # TODO: - # table constraints - # sequences + comparators.dispatch("table")( + autogen_context, modify_table_ops, + s, tname, conn_table, metadata_table + ) + + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) def _make_index(params, conn_table): @@ -202,56 +219,51 @@ def _make_foreign_key(params, conn_table): @contextlib.contextmanager -def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, - diffs, autogen_context, inspector): +def _compare_columns(schema, tname, conn_table, metadata_table, + modify_table_ops, autogen_context, inspector): name = '%s.%s' % (schema, tname) if schema else tname metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c) conn_col_names = dict((c.name, c) for c in conn_table.c) metadata_col_names = OrderedSet(sorted(metadata_cols_by_name)) for cname in metadata_col_names.difference(conn_col_names): - if _run_filters(metadata_cols_by_name[cname], cname, - "column", False, None, object_filters): - diffs.append( - ("add_column", schema, tname, metadata_cols_by_name[cname]) + if autogen_context.run_filters( + metadata_cols_by_name[cname], cname, + "column", False, None): + modify_table_ops.ops.append( + ops.AddColumnOp.from_column_and_tablename( + schema, tname, metadata_cols_by_name[cname]) ) log.info("Detected added column '%s.%s'", name, cname) for colname in metadata_col_names.intersection(conn_col_names): metadata_col = metadata_cols_by_name[colname] conn_col = conn_table.c[colname] - if not _run_filters( + if not autogen_context.run_filters( metadata_col, colname, "column", False, - conn_col, object_filters): + conn_col): continue - col_diff = [] - _compare_type(schema, tname, colname, - conn_col, - metadata_col, - col_diff, autogen_context - ) - # work around SQLAlchemy issue #3023 - if not metadata_col.primary_key: - _compare_nullable(schema, tname, colname, - conn_col, - metadata_col.nullable, - col_diff, autogen_context - ) - _compare_server_default(schema, tname, colname, - conn_col, - metadata_col, - col_diff, autogen_context - ) - if col_diff: - diffs.append(col_diff) + alter_column_op = ops.AlterColumnOp( + tname, colname, schema=schema) + + comparators.dispatch("column")( + autogen_context, alter_column_op, + schema, tname, colname, conn_col, metadata_col + ) + + if alter_column_op.has_changes(): + modify_table_ops.ops.append(alter_column_op) yield for cname in set(conn_col_names).difference(metadata_col_names): - if _run_filters(conn_table.c[cname], cname, - "column", True, None, object_filters): - diffs.append( - ("remove_column", schema, tname, conn_table.c[cname]) + if autogen_context.run_filters( + conn_table.c[cname], cname, + "column", True, None): + modify_table_ops.ops.append( + ops.DropColumnOp.from_column_and_tablename( + schema, tname, conn_table.c[cname] + ) ) log.info("Detected removed column '%s.%s'", name, cname) @@ -310,10 +322,12 @@ class _fk_constraint_sig(_constraint_sig): ) -def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, - metadata_table, diffs, - autogen_context, inspector): +@comparators.dispatch_for("table") +def _compare_indexes_and_uniques( + autogen_context, modify_ops, schema, tname, conn_table, + metadata_table): + inspector = autogen_context.inspector is_create_table = conn_table is None # 1a. get raw indexes and unique constraints from metadata ... @@ -350,7 +364,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # 3. give the dialect a chance to omit indexes and constraints that # we know are either added implicitly by the DB or that the DB # can't accurately report on - autogen_context['context'].impl.\ + autogen_context.migration_context.impl.\ correct_for_autogen_constraints( conn_uniques, conn_indexes, metadata_unique_constraints, @@ -411,9 +425,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, def obj_added(obj): if obj.is_index: - if _run_filters( - obj.const, obj.name, "index", False, None, object_filters): - diffs.append(("add_index", obj.const)) + if autogen_context.run_filters( + obj.const, obj.name, "index", False, None): + modify_ops.ops.append( + ops.CreateIndexOp.from_index(obj.const) + ) log.info("Detected added index '%s' on %s", obj.name, ', '.join([ "'%s'" % obj.column_names @@ -426,10 +442,12 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if is_create_table: # unique constraints are created inline with table defs return - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, - "unique_constraint", False, None, object_filters): - diffs.append(("add_constraint", obj.const)) + "unique_constraint", False, None): + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(obj.const) + ) log.info("Detected added unique constraint '%s' on %s", obj.name, ', '.join([ "'%s'" % obj.column_names @@ -443,39 +461,51 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # be sure what we're doing here return - if _run_filters( - obj.const, obj.name, "index", True, None, object_filters): - diffs.append(("remove_index", obj.const)) + if autogen_context.run_filters( + obj.const, obj.name, "index", True, None): + modify_ops.ops.append( + ops.DropIndexOp.from_index(obj.const) + ) log.info( "Detected removed index '%s' on '%s'", obj.name, tname) else: - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, - "unique_constraint", True, None, object_filters): - diffs.append(("remove_constraint", obj.const)) + "unique_constraint", True, None): + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) log.info("Detected removed unique constraint '%s' on '%s'", obj.name, tname ) def obj_changed(old, new, msg): if old.is_index: - if _run_filters( + if autogen_context.run_filters( new.const, new.name, "index", - False, old.const, object_filters): + False, old.const): log.info("Detected changed index '%s' on '%s':%s", old.name, tname, ', '.join(msg) ) - diffs.append(("remove_index", old.const)) - diffs.append(("add_index", new.const)) + modify_ops.ops.append( + ops.DropIndexOp.from_index(old.const) + ) + modify_ops.ops.append( + ops.CreateIndexOp.from_index(new.const) + ) else: - if _run_filters( + if autogen_context.run_filters( new.const, new.name, - "unique_constraint", False, old.const, object_filters): + "unique_constraint", False, old.const): log.info("Detected changed unique constraint '%s' on '%s':%s", old.name, tname, ', '.join(msg) ) - diffs.append(("remove_constraint", old.const)) - diffs.append(("add_constraint", new.const)) + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(old.const) + ) + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(new.const) + ) for added_name in sorted(set(metadata_names).difference(conn_names)): obj = metadata_names[added_name] @@ -528,20 +558,21 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, obj_added(unnamed_metadata_uniques[uq_sig]) -def _compare_nullable(schema, tname, cname, conn_col, - metadata_col_nullable, diffs, - autogen_context): +@comparators.dispatch_for("column") +def _compare_nullable( + autogen_context, alter_column_op, schema, tname, cname, conn_col, + metadata_col): + + # work around SQLAlchemy issue #3023 + if metadata_col.primary_key: + return + + metadata_col_nullable = metadata_col.nullable conn_col_nullable = conn_col.nullable + alter_column_op.existing_nullable = conn_col_nullable + if conn_col_nullable is not metadata_col_nullable: - diffs.append( - ("modify_nullable", schema, tname, cname, - { - "existing_type": conn_col.type, - "existing_server_default": conn_col.server_default, - }, - conn_col_nullable, - metadata_col_nullable), - ) + alter_column_op.modify_nullable = metadata_col_nullable log.info("Detected %s on column '%s.%s'", "NULL" if metadata_col_nullable else "NOT NULL", tname, @@ -549,11 +580,13 @@ def _compare_nullable(schema, tname, cname, conn_col, ) -def _compare_type(schema, tname, cname, conn_col, - metadata_col, diffs, - autogen_context): +@comparators.dispatch_for("column") +def _compare_type( + autogen_context, alter_column_op, schema, tname, cname, conn_col, + metadata_col): conn_type = conn_col.type + alter_column_op.existing_type = conn_type metadata_type = metadata_col.type if conn_type._type_affinity is sqltypes.NullType: log.info("Couldn't determine database type " @@ -564,19 +597,11 @@ def _compare_type(schema, tname, cname, conn_col, "the model; can't compare", tname, cname) return - isdiff = autogen_context['context']._compare_type(conn_col, metadata_col) + isdiff = autogen_context.migration_context._compare_type( + conn_col, metadata_col) if isdiff: - - diffs.append( - ("modify_type", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_server_default": conn_col.server_default, - }, - conn_type, - metadata_type), - ) + alter_column_op.modify_type = metadata_type log.info("Detected type change from %r to %r on '%s.%s'", conn_type, metadata_type, tname, cname ) @@ -594,7 +619,7 @@ def _render_server_default_for_compare(metadata_default, metadata_default = metadata_default.arg else: metadata_default = str(metadata_default.arg.compile( - dialect=autogen_context['dialect'])) + dialect=autogen_context.dialect)) if isinstance(metadata_default, compat.string_types): if metadata_col.type._type_affinity is sqltypes.String: metadata_default = re.sub(r"^'|'$", "", metadata_default) @@ -605,8 +630,10 @@ def _render_server_default_for_compare(metadata_default, return None -def _compare_server_default(schema, tname, cname, conn_col, metadata_col, - diffs, autogen_context): +@comparators.dispatch_for("column") +def _compare_server_default( + autogen_context, alter_column_op, schema, tname, cname, + conn_col, metadata_col): metadata_default = metadata_col.server_default conn_col_default = conn_col.server_default @@ -618,36 +645,31 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, rendered_conn_default = conn_col.server_default.arg.text \ if conn_col.server_default else None - isdiff = autogen_context['context']._compare_server_default( + alter_column_op.existing_server_default = conn_col_default + + isdiff = autogen_context.migration_context._compare_server_default( conn_col, metadata_col, rendered_metadata_default, rendered_conn_default ) if isdiff: - conn_col_default = rendered_conn_default - diffs.append( - ("modify_default", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_type": conn_col.type, - }, - conn_col_default, - metadata_default), - ) - log.info("Detected server default on column '%s.%s'", - tname, - cname - ) + alter_column_op.modify_server_default = metadata_default + log.info( + "Detected server default on column '%s.%s'", + tname, cname) -def _compare_foreign_keys(schema, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, inspector): +@comparators.dispatch_for("table") +def _compare_foreign_keys( + autogen_context, modify_table_ops, schema, tname, conn_table, + metadata_table): # if we're doing CREATE TABLE, all FKs are created # inline within the table def if conn_table is None: return + inspector = autogen_context.inspector metadata_fks = set( fk for fk in metadata_table.constraints if isinstance(fk, sa_schema.ForeignKeyConstraint) @@ -673,10 +695,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table, ) def _add_fk(obj, compare_to): - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, "foreign_key_constraint", False, - compare_to, object_filters): - diffs.append(('add_fk', const.const)) + compare_to): + modify_table_ops.ops.append( + ops.CreateForeignKeyOp.from_constraint(const.const) + ) log.info( "Detected added foreign key (%s)(%s) on table %s%s", @@ -686,10 +710,12 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table, obj.source_table) def _remove_fk(obj, compare_to): - if _run_filters( + if autogen_context.run_filters( obj.const, obj.name, "foreign_key_constraint", True, - compare_to, object_filters): - diffs.append(('remove_fk', obj.const)) + compare_to): + modify_table_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) log.info( "Detected removed foreign key (%s)(%s) on table %s%s", ", ".join(obj.source_columns), @@ -713,5 +739,3 @@ def _compare_foreign_keys(schema, tname, object_filters, conn_table, compare_to = conn_fks_by_name[const.name].const \ if const.name in conn_fks_by_name else None _add_fk(const, compare_to) - - return diffs diff --git a/alembic/autogenerate/compose.py b/alembic/autogenerate/compose.py deleted file mode 100644 index b42b505..0000000 --- a/alembic/autogenerate/compose.py +++ /dev/null @@ -1,144 +0,0 @@ -import itertools -from ..operations import ops - - -def _to_migration_script(autogen_context, migration_script, diffs): - _to_upgrade_op( - autogen_context, - diffs, - migration_script.upgrade_ops, - ) - - _to_downgrade_op( - autogen_context, - diffs, - migration_script.downgrade_ops, - ) - - -def _to_upgrade_op(autogen_context, diffs, upgrade_ops): - return _to_updown_op(autogen_context, diffs, upgrade_ops, "upgrade") - - -def _to_downgrade_op(autogen_context, diffs, downgrade_ops): - return _to_updown_op(autogen_context, diffs, downgrade_ops, "downgrade") - - -def _to_updown_op(autogen_context, diffs, op_container, type_): - if not diffs: - return - - if type_ == 'downgrade': - diffs = reversed(diffs) - - dest = [op_container.ops] - - for (schema, tablename), subdiffs in _group_diffs_by_table(diffs): - subdiffs = list(subdiffs) - if tablename is not None: - table_ops = [] - op = ops.ModifyTableOps(tablename, table_ops, schema=schema) - dest[-1].append(op) - dest.append(table_ops) - for diff in subdiffs: - _produce_command(autogen_context, diff, dest[-1], type_) - if tablename is not None: - dest.pop(-1) - - -def _produce_command(autogen_context, diff, op_list, updown): - if isinstance(diff, tuple): - _produce_adddrop_command(updown, diff, op_list, autogen_context) - else: - _produce_modify_command(updown, diff, op_list, autogen_context) - - -def _produce_adddrop_command(updown, diff, op_list, autogen_context): - cmd_type = diff[0] - adddrop, cmd_type = cmd_type.split("_") - - cmd_args = diff[1:] - - _commands = { - "table": (ops.DropTableOp.from_table, ops.CreateTableOp.from_table), - "column": ( - ops.DropColumnOp.from_column_and_tablename, - ops.AddColumnOp.from_column_and_tablename), - "index": (ops.DropIndexOp.from_index, ops.CreateIndexOp.from_index), - "constraint": ( - ops.DropConstraintOp.from_constraint, - ops.AddConstraintOp.from_constraint), - "fk": ( - ops.DropConstraintOp.from_constraint, - ops.CreateForeignKeyOp.from_constraint) - } - - cmd_callables = _commands[cmd_type] - - if ( - updown == "upgrade" and adddrop == "add" - ) or ( - updown == "downgrade" and adddrop == "remove" - ): - op_list.append(cmd_callables[1](*cmd_args)) - else: - op_list.append(cmd_callables[0](*cmd_args)) - - -def _produce_modify_command(updown, diffs, op_list, autogen_context): - sname, tname, cname = diffs[0][1:4] - kw = {} - - _arg_struct = { - "modify_type": ("existing_type", "modify_type"), - "modify_nullable": ("existing_nullable", "modify_nullable"), - "modify_default": ("existing_server_default", "modify_server_default"), - } - for diff in diffs: - diff_kw = diff[4] - for arg in ("existing_type", - "existing_nullable", - "existing_server_default"): - if arg in diff_kw: - kw.setdefault(arg, diff_kw[arg]) - old_kw, new_kw = _arg_struct[diff[0]] - if updown == "upgrade": - kw[new_kw] = diff[-1] - kw[old_kw] = diff[-2] - else: - kw[new_kw] = diff[-2] - kw[old_kw] = diff[-1] - - if "modify_nullable" in kw: - kw.pop("existing_nullable", None) - if "modify_server_default" in kw: - kw.pop("existing_server_default", None) - - op_list.append( - ops.AlterColumnOp( - tname, cname, schema=sname, - **kw - ) - ) - - -def _group_diffs_by_table(diffs): - _adddrop = { - "table": lambda diff: (None, None), - "column": lambda diff: (diff[0], diff[1]), - "index": lambda diff: (diff[0].table.schema, diff[0].table.name), - "constraint": lambda diff: (diff[0].table.schema, diff[0].table.name), - "fk": lambda diff: (diff[0].parent.schema, diff[0].parent.name) - } - - def _derive_table(diff): - if isinstance(diff, tuple): - cmd_type = diff[0] - adddrop, cmd_type = cmd_type.split("_") - return _adddrop[cmd_type](diff[1:]) - else: - sname, tname = diff[0][1:3] - return sname, tname - - return itertools.groupby(diffs, _derive_table) - diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py deleted file mode 100644 index c686156..0000000 --- a/alembic/autogenerate/generate.py +++ /dev/null @@ -1,92 +0,0 @@ -from .. import util -from . import api -from . import compose -from . import compare -from . import render -from ..operations import ops - - -class RevisionContext(object): - def __init__(self, config, script_directory, command_args): - self.config = config - self.script_directory = script_directory - self.command_args = command_args - self.template_args = { - 'config': config # Let templates use config for - # e.g. multiple databases - } - self.generated_revisions = [ - self._default_revision() - ] - - def _to_script(self, migration_script): - template_args = {} - for k, v in self.template_args.items(): - template_args.setdefault(k, v) - - if migration_script._autogen_context is not None: - render._render_migration_script( - migration_script._autogen_context, migration_script, - template_args - ) - - return self.script_directory.generate_revision( - migration_script.rev_id, - migration_script.message, - refresh=True, - head=migration_script.head, - splice=migration_script.splice, - branch_labels=migration_script.branch_label, - version_path=migration_script.version_path, - **template_args) - - def run_autogenerate(self, rev, context): - if self.command_args['sql']: - raise util.CommandError( - "Using --sql with --autogenerate does not make any sense") - if set(self.script_directory.get_revisions(rev)) != \ - set(self.script_directory.get_revisions("heads")): - raise util.CommandError("Target database is not up to date.") - - autogen_context = api._autogen_context(context) - - diffs = [] - compare._produce_net_changes(autogen_context, diffs) - - migration_script = self.generated_revisions[0] - - compose._to_migration_script(autogen_context, migration_script, diffs) - - hook = context.opts.get('process_revision_directives', None) - if hook: - hook(context, rev, self.generated_revisions) - - for migration_script in self.generated_revisions: - migration_script._autogen_context = autogen_context - - def run_no_autogenerate(self, rev, context): - hook = context.opts.get('process_revision_directives', None) - if hook: - hook(context, rev, self.generated_revisions) - - for migration_script in self.generated_revisions: - migration_script._autogen_context = None - - def _default_revision(self): - op = ops.MigrationScript( - rev_id=self.command_args['rev_id'] or util.rev_id(), - message=self.command_args['message'], - imports=set(), - upgrade_ops=ops.UpgradeOps([]), - downgrade_ops=ops.DowngradeOps([]), - head=self.command_args['head'], - splice=self.command_args['splice'], - branch_label=self.command_args['branch_label'], - version_path=self.command_args['version_path'] - ) - op._autogen_context = None - return op - - def generate_scripts(self): - for generated_revision in self.generated_revisions: - yield self._to_script(generated_revision) diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index c3f3df1..6f5f96c 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -30,8 +30,8 @@ def _indent(text): def _render_migration_script(autogen_context, migration_script, template_args): - opts = autogen_context['opts'] - imports = autogen_context['imports'] + opts = autogen_context.opts + imports = autogen_context._imports template_args[opts['upgrade_token']] = _indent(_render_cmd_body( migration_script.upgrade_ops, autogen_context)) template_args[opts['downgrade_token']] = _indent(_render_cmd_body( @@ -78,23 +78,26 @@ def render_op_text(autogen_context, op): @renderers.dispatch_for(ops.ModifyTableOps) def _render_modify_table(autogen_context, op): - opts = autogen_context['opts'] + opts = autogen_context.opts render_as_batch = opts.get('render_as_batch', False) if op.ops: lines = [] if render_as_batch: - lines.append( - "with op.batch_alter_table(%r, schema=%r) as batch_op:" - % (op.table_name, op.schema) - ) - autogen_context['batch_prefix'] = 'batch_op.' - for t_op in op.ops: - t_lines = render_op(autogen_context, t_op) - lines.extend(t_lines) - if render_as_batch: - del autogen_context['batch_prefix'] - lines.append("") + with autogen_context._within_batch(): + lines.append( + "with op.batch_alter_table(%r, schema=%r) as batch_op:" + % (op.table_name, op.schema) + ) + for t_op in op.ops: + t_lines = render_op(autogen_context, t_op) + lines.extend(t_lines) + lines.append("") + else: + for t_op in op.ops: + t_lines = render_op(autogen_context, t_op) + lines.extend(t_lines) + return lines else: return [ @@ -149,7 +152,7 @@ def _drop_table(autogen_context, op): def _add_index(autogen_context, op): index = op.to_index() - has_batch = 'batch_prefix' in autogen_context + has_batch = autogen_context._has_batch if has_batch: tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\ @@ -180,7 +183,7 @@ def _add_index(autogen_context, op): @renderers.dispatch_for(ops.DropIndexOp) def _drop_index(autogen_context, op): - has_batch = 'batch_prefix' in autogen_context + has_batch = autogen_context._has_batch if has_batch: tmpl = "%(prefix)sdrop_index(%(name)r)" @@ -243,7 +246,7 @@ def _add_check_constraint(constraint, autogen_context): @renderers.dispatch_for(ops.DropConstraintOp) def _drop_constraint(autogen_context, op): - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)sdrop_constraint"\ "(%(name)r, type_=%(type)r)" else: @@ -266,7 +269,7 @@ def _drop_constraint(autogen_context, op): def _add_column(autogen_context, op): schema, tname, column = op.schema, op.table_name, op.column - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)sadd_column(%(column)s)" else: template = "%(prefix)sadd_column(%(tname)r, %(column)s" @@ -287,7 +290,7 @@ def _drop_column(autogen_context, op): schema, tname, column_name = op.schema, op.table_name, op.column_name - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)sdrop_column(%(cname)r)" else: template = "%(prefix)sdrop_column(%(tname)r, %(cname)r" @@ -319,7 +322,7 @@ def _alter_column(autogen_context, op): indent = " " * 11 - if 'batch_prefix' in autogen_context: + if autogen_context._has_batch: template = "%(prefix)salter_column(%(cname)r" else: template = "%(prefix)salter_column(%(tname)r, %(cname)r" @@ -343,16 +346,16 @@ def _alter_column(autogen_context, op): if nullable is not None: text += ",\n%snullable=%r" % ( indent, nullable,) - if existing_nullable is not None: + if nullable is None and existing_nullable is not None: text += ",\n%sexisting_nullable=%r" % ( indent, existing_nullable) - if existing_server_default: + if server_default is False and existing_server_default: rendered = _render_server_default( existing_server_default, autogen_context) text += ",\n%sexisting_server_default=%s" % ( indent, rendered) - if schema and "batch_prefix" not in autogen_context: + if schema and not autogen_context._has_batch: text += ",\n%sschema=%r" % (indent, schema) text += ")" return text @@ -409,7 +412,7 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True): return template % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "sql": compat.text_type( - value.compile(dialect=autogen_context['dialect'], + value.compile(dialect=autogen_context.dialect, **compile_kw) ) } @@ -432,7 +435,7 @@ def _get_index_rendered_expressions(idx, autogen_context): def _uq_constraint(constraint, autogen_context, alter): opts = [] - has_batch = 'batch_prefix' in autogen_context + has_batch = autogen_context._has_batch if constraint.deferrable: opts.append(("deferrable", str(constraint.deferrable))) @@ -467,7 +470,7 @@ def _uq_constraint(constraint, autogen_context, alter): def _user_autogenerate_prefix(autogen_context, target): - prefix = autogen_context['opts']['user_module_prefix'] + prefix = autogen_context.opts['user_module_prefix'] if prefix is None: return "%s." % target.__module__ else: @@ -475,20 +478,19 @@ def _user_autogenerate_prefix(autogen_context, target): def _sqlalchemy_autogenerate_prefix(autogen_context): - return autogen_context['opts']['sqlalchemy_module_prefix'] or '' + return autogen_context.opts['sqlalchemy_module_prefix'] or '' def _alembic_autogenerate_prefix(autogen_context): - if 'batch_prefix' in autogen_context: - return autogen_context['batch_prefix'] + if autogen_context._has_batch: + return 'batch_op.' else: - return autogen_context['opts']['alembic_module_prefix'] or '' + return autogen_context.opts['alembic_module_prefix'] or '' def _user_defined_render(type_, object_, autogen_context): - if 'opts' in autogen_context and \ - 'render_item' in autogen_context['opts']: - render = autogen_context['opts']['render_item'] + if 'render_item' in autogen_context.opts: + render = autogen_context.opts['render_item'] if render: rendered = render(type_, object_, autogen_context) if rendered is not False: @@ -547,7 +549,7 @@ def _repr_type(type_, autogen_context): return rendered mod = type(type_).__module__ - imports = autogen_context.get('imports', None) + imports = autogen_context._imports if mod.startswith("sqlalchemy.dialects"): dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1) if imports is not None: diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 08a0551..71e8515 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -3,6 +3,7 @@ from ..util import sqla_compat from . import schemaobj from sqlalchemy.types import NULLTYPE from .base import Operations, BatchOperations +import re class MigrateOperation(object): @@ -34,6 +35,10 @@ class MigrateOperation(object): class AddConstraintOp(MigrateOperation): """Represent an add constraint operation.""" + @property + def constraint_type(self): + raise NotImplementedError() + @classmethod def from_constraint(cls, constraint): funcs = { @@ -45,17 +50,40 @@ class AddConstraintOp(MigrateOperation): } return funcs[constraint.__visit_name__](constraint) + def reverse(self): + return DropConstraintOp.from_constraint(self.to_constraint()) + + def to_diff_tuple(self): + return ("add_constraint", self.to_constraint()) + @Operations.register_operation("drop_constraint") @BatchOperations.register_operation("drop_constraint", "batch_drop_constraint") class DropConstraintOp(MigrateOperation): """Represent a drop constraint operation.""" - def __init__(self, constraint_name, table_name, type_=None, schema=None): + def __init__( + self, + constraint_name, table_name, type_=None, schema=None, + _orig_constraint=None): self.constraint_name = constraint_name self.table_name = table_name self.constraint_type = type_ self.schema = schema + self._orig_constraint = _orig_constraint + + def reverse(self): + if self._orig_constraint is None: + raise ValueError( + "operation is not reversible; " + "original constraint is not present") + return AddConstraintOp.from_constraint(self._orig_constraint) + + def to_diff_tuple(self): + if self.constraint_type == "foreignkey": + return ("remove_fk", self.to_constraint()) + else: + return ("remove_constraint", self.to_constraint()) @classmethod def from_constraint(cls, constraint): @@ -72,9 +100,18 @@ class DropConstraintOp(MigrateOperation): constraint.name, constraint_table.name, schema=constraint_table.schema, - type_=types[constraint.__visit_name__] + type_=types[constraint.__visit_name__], + _orig_constraint=constraint ) + def to_constraint(self): + if self._orig_constraint is not None: + return self._orig_constraint + else: + raise ValueError( + "constraint cannot be produced; " + "original constraint is not present") + @classmethod @util._with_legacy_names([("type", "type_")]) def drop_constraint( @@ -124,8 +161,11 @@ class DropConstraintOp(MigrateOperation): class CreatePrimaryKeyOp(AddConstraintOp): """Represent a create primary key operation.""" + constraint_type = "primarykey" + def __init__( - self, constraint_name, table_name, columns, schema=None, **kw): + self, constraint_name, table_name, columns, + schema=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -225,8 +265,11 @@ class CreatePrimaryKeyOp(AddConstraintOp): class CreateUniqueConstraintOp(AddConstraintOp): """Represent a create unique constraint operation.""" + constraint_type = "unique" + def __init__( - self, constraint_name, table_name, columns, schema=None, **kw): + self, constraint_name, table_name, + columns, schema=None, **kw): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -342,6 +385,8 @@ class CreateUniqueConstraintOp(AddConstraintOp): class CreateForeignKeyOp(AddConstraintOp): """Represent a create foreign key constraint operation.""" + constraint_type = "foreignkey" + def __init__( self, constraint_name, source_table, referent_table, local_cols, remote_cols, **kw): @@ -352,6 +397,9 @@ class CreateForeignKeyOp(AddConstraintOp): self.remote_cols = remote_cols self.kw = kw + def to_diff_tuple(self): + return ("add_fk", self.to_constraint()) + @classmethod def from_constraint(cls, constraint): kw = {} @@ -507,6 +555,8 @@ class CreateForeignKeyOp(AddConstraintOp): class CreateCheckConstraintOp(AddConstraintOp): """Represent a create check constraint operation.""" + constraint_type = "check" + def __init__( self, constraint_name, table_name, condition, schema=None, **kw): self.constraint_name = constraint_name @@ -523,7 +573,7 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint.name, constraint_table.name, constraint.condition, - schema=constraint_table.schema + schema=constraint_table.schema, ) def to_constraint(self, migration_context=None): @@ -624,6 +674,12 @@ class CreateIndexOp(MigrateOperation): self.kw = kw self._orig_index = _orig_index + def reverse(self): + return DropIndexOp.from_index(self.to_index()) + + def to_diff_tuple(self): + return ("add_index", self.to_index()) + @classmethod def from_index(cls, index): return cls( @@ -729,10 +785,22 @@ class CreateIndexOp(MigrateOperation): class DropIndexOp(MigrateOperation): """Represent a drop index operation.""" - def __init__(self, index_name, table_name=None, schema=None): + def __init__( + self, index_name, table_name=None, schema=None, _orig_index=None): self.index_name = index_name self.table_name = table_name self.schema = schema + self._orig_index = _orig_index + + def to_diff_tuple(self): + return ("remove_index", self.to_index()) + + def reverse(self): + if self._orig_index is None: + raise ValueError( + "operation is not reversible; " + "original index is not present") + return CreateIndexOp.from_index(self._orig_index) @classmethod def from_index(cls, index): @@ -740,6 +808,7 @@ class DropIndexOp(MigrateOperation): index.name, index.table.name, schema=index.table.schema, + _orig_index=index ) def to_index(self, migration_context=None): @@ -807,6 +876,12 @@ class CreateTableOp(MigrateOperation): self.kw = kw self._orig_table = _orig_table + def reverse(self): + return DropTableOp.from_table(self.to_table()) + + def to_diff_tuple(self): + return ("add_table", self.to_table()) + @classmethod def from_table(cls, table): return cls( @@ -921,16 +996,30 @@ class CreateTableOp(MigrateOperation): class DropTableOp(MigrateOperation): """Represent a drop table operation.""" - def __init__(self, table_name, schema=None, table_kw=None): + def __init__( + self, table_name, schema=None, table_kw=None, _orig_table=None): self.table_name = table_name self.schema = schema self.table_kw = table_kw or {} + self._orig_table = _orig_table + + def to_diff_tuple(self): + return ("remove_table", self.to_table()) + + def reverse(self): + if self._orig_table is None: + raise ValueError( + "operation is not reversible; " + "original table is not present") + return CreateTableOp.from_table(self._orig_table) @classmethod def from_table(cls, table): - return cls(table.name, schema=table.schema) + return cls(table.name, schema=table.schema, _orig_table=table) - def to_table(self, migration_context): + def to_table(self, migration_context=None): + if self._orig_table is not None: + return self._orig_table schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( self.table_name, @@ -1029,6 +1118,87 @@ class AlterColumnOp(AlterTableOp): self.modify_type = modify_type self.kw = kw + def to_diff_tuple(self): + col_diff = [] + schema, tname, cname = self.schema, self.table_name, self.column_name + + if self.modify_type is not None: + col_diff.append( + ("modify_type", schema, tname, cname, + { + "existing_nullable": self.existing_nullable, + "existing_server_default": self.existing_server_default, + }, + self.existing_type, + self.modify_type) + ) + + if self.modify_nullable is not None: + col_diff.append( + ("modify_nullable", schema, tname, cname, + { + "existing_type": self.existing_type, + "existing_server_default": self.existing_server_default + }, + self.existing_nullable, + self.modify_nullable) + ) + + if self.modify_server_default is not False: + col_diff.append( + ("modify_default", schema, tname, cname, + { + "existing_nullable": self.existing_nullable, + "existing_type": self.existing_type + }, + self.existing_server_default, + self.modify_server_default) + ) + + return col_diff + + def has_changes(self): + hc1 = self.modify_nullable is not None or \ + self.modify_server_default is not False or \ + self.modify_type is not None + if hc1: + return True + for kw in self.kw: + if kw.startswith('modify_'): + return True + else: + return False + + def reverse(self): + + kw = self.kw.copy() + kw['existing_type'] = self.existing_type + kw['existing_nullable'] = self.existing_nullable + kw['existing_server_default'] = self.existing_server_default + if self.modify_type is not None: + kw['modify_type'] = self.modify_type + if self.modify_nullable is not None: + kw['modify_nullable'] = self.modify_nullable + if self.modify_server_default is not False: + kw['modify_server_default'] = self.modify_server_default + + # TODO: make this a little simpler + all_keys = set(m.group(1) for m in [ + re.match(r'^(?:existing_|modify_)(.+)$', k) + for k in kw + ] if m) + + for k in all_keys: + if 'modify_%s' % k in kw: + swap = kw['existing_%s' % k] + kw['existing_%s' % k] = kw['modify_%s' % k] + kw['modify_%s' % k] = swap + + return self.__class__( + self.table_name, self.column_name, schema=self.schema, + **kw + ) + @classmethod @util._with_legacy_names([('name', 'new_column_name')]) def alter_column( @@ -1177,6 +1347,13 @@ class AddColumnOp(AlterTableOp): super(AddColumnOp, self).__init__(table_name, schema=schema) self.column = column + def reverse(self): + return DropColumnOp.from_column_and_tablename( + self.schema, self.table_name, self.column) + + def to_diff_tuple(self): + return ("add_column", self.schema, self.table_name, self.column) + @classmethod def from_column(cls, col): return cls(col.table.name, col, schema=col.table.schema) @@ -1265,14 +1442,30 @@ class AddColumnOp(AlterTableOp): class DropColumnOp(AlterTableOp): """Represent a drop column operation.""" - def __init__(self, table_name, column_name, schema=None, **kw): + def __init__( + self, table_name, column_name, schema=None, + _orig_column=None, **kw): super(DropColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw + self._orig_column = _orig_column + + def to_diff_tuple(self): + return ( + "remove_column", self.schema, self.table_name, self.to_column()) + + def reverse(self): + if self._orig_column is None: + raise ValueError( + "operation is not reversible; " + "original column is not present") + + return AddColumnOp.from_column_and_tablename( + self.schema, self.table_name, self._orig_column) @classmethod def from_column_and_tablename(cls, schema, tname, col): - return cls(tname, col.name, schema=schema) + return cls(tname, col.name, schema=schema, _orig_column=col) def to_column(self, migration_context=None): schema_obj = schemaobj.SchemaObjects(migration_context) @@ -1522,6 +1715,21 @@ class OpContainer(MigrateOperation): def __init__(self, ops=()): self.ops = ops + def is_empty(self): + return not self.ops + + def as_diffs(self): + return list(OpContainer._ops_as_diffs(self)) + + @classmethod + def _ops_as_diffs(cls, migrations): + for op in migrations.ops: + if hasattr(op, 'ops'): + for sub_op in cls._ops_as_diffs(op): + yield sub_op + else: + yield op.to_diff_tuple() + class ModifyTableOps(OpContainer): """Contains a sequence of operations that all apply to a single Table.""" @@ -1531,6 +1739,15 @@ class ModifyTableOps(OpContainer): self.table_name = table_name self.schema = schema + def reverse(self): + return ModifyTableOps( + self.table_name, + ops=list(reversed( + [op.reverse() for op in self.ops] + )), + schema=self.schema + ) + class UpgradeOps(OpContainer): """contains a sequence of operations that would apply to the @@ -1542,6 +1759,15 @@ class UpgradeOps(OpContainer): """ + def reverse_into(self, downgrade_ops): + downgrade_ops.ops[:] = list(reversed( + [op.reverse() for op in self.ops] + )) + return downgrade_ops + + def reverse(self): + return self.reverse_into(DowngradeOps(ops=[])) + class DowngradeOps(OpContainer): """contains a sequence of operations that would apply to the @@ -1553,6 +1779,13 @@ class DowngradeOps(OpContainer): """ + def reverse(self): + return UpgradeOps( + ops=list(reversed( + [op.reverse() for op in self.ops] + )) + ) + class MigrationScript(MigrateOperation): """represents a migration script. @@ -1583,4 +1816,3 @@ class MigrationScript(MigrateOperation): self.version_path = version_path self.upgrade_ops = upgrade_ops self.downgrade_ops = downgrade_ops - diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 84a3c7f..e811a36 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -118,6 +118,7 @@ class MigrationContext(object): connection=None, url=None, dialect_name=None, + dialect=None, environment_context=None, opts=None, ): @@ -152,7 +153,7 @@ class MigrationContext(object): elif dialect_name: url = sqla_url.make_url("%s://" % dialect_name) dialect = url.get_dialect()() - else: + elif not dialect: raise Exception("Connection, url, or dialect_name is required.") return MigrationContext(dialect, connection, opts, environment_context) diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index 1fb0942..6c92e3c 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -257,30 +257,57 @@ class immutabledict(dict): class Dispatcher(object): - def __init__(self): + def __init__(self, uselist=False): self._registry = {} + self.uselist = uselist def dispatch_for(self, target, qualifier='default'): def decorate(fn): - assert isinstance(target, type) - assert target not in self._registry - self._registry[(target, qualifier)] = fn + if self.uselist: + assert target not in self._registry + self._registry.setdefault((target, qualifier), []).append(fn) + else: + assert target not in self._registry + self._registry[(target, qualifier)] = fn return fn return decorate def dispatch(self, obj, qualifier='default'): - for spcls in type(obj).__mro__: + + if isinstance(obj, string_types): + targets = [obj] + elif isinstance(obj, type): + targets = obj.__mro__ + else: + targets = type(obj).__mro__ + + for spcls in targets: if qualifier != 'default' and (spcls, qualifier) in self._registry: - return self._registry[(spcls, qualifier)] + return self._fn_or_list(self._registry[(spcls, qualifier)]) elif (spcls, 'default') in self._registry: - return self._registry[(spcls, 'default')] + return self._fn_or_list(self._registry[(spcls, 'default')]) else: raise ValueError("no dispatch function for object: %s" % obj) + def _fn_or_list(self, fn_or_list): + if self.uselist: + def go(*arg, **kw): + for fn in fn_or_list: + fn(*arg, **kw) + return go + else: + return fn_or_list + def branch(self): """Return a copy of this dispatcher that is independently writable.""" d = Dispatcher() - d._registry.update(self._registry) + if self.uselist: + d._registry.update( + (k, [fn for fn in self._registry[k]]) + for k in self._registry + ) + else: + d._registry.update(self._registry) return d diff --git a/docs/build/api/autogenerate.rst b/docs/build/api/autogenerate.rst index b024ab1..8b026e8 100644 --- a/docs/build/api/autogenerate.rst +++ b/docs/build/api/autogenerate.rst @@ -4,7 +4,8 @@ Autogeneration ============== -The autogenerate system has two areas of API that are public: +The autogeneration system has a wide degree of public API, including +the following areas: 1. The ability to do a "diff" of a :class:`~sqlalchemy.schema.MetaData` object against a database, and receive a data structure back. This structure @@ -15,9 +16,22 @@ The autogenerate system has two areas of API that are public: revision scripts, including support for multiple revision scripts generated in one pass. +3. The ability to add new operation directives to autogeneration, including + custom schema/model comparison functions and revision script rendering. + Getting Diffs ============== +The simplest API autogenerate provides is the "schema comparison" API; +these are simple functions that will run all registered "comparison" functions +between a :class:`~sqlalchemy.schema.MetaData` object and a database +backend to produce a structure showing how they differ. The two +functions provided are :func:`.compare_metadata`, which is more of the +"legacy" function that produces diff tuples, and :func:`.produce_migrations`, +which produces a structure consisting of operation directives detailed in +:ref:`alembic.operations.toplevel`. + + .. autofunction:: alembic.autogenerate.compare_metadata .. autofunction:: alembic.autogenerate.produce_migrations @@ -184,6 +198,8 @@ to whatever is in this list. .. autofunction:: alembic.autogenerate.render_python_code +.. _autogen_custom_ops: + Autogenerating Custom Operation Directives ========================================== @@ -192,16 +208,180 @@ subclasses of :class:`.MigrateOperation` in order to add new ``op.`` directives. In the preceding section :ref:`customizing_revision`, we also learned that these same :class:`.MigrateOperation` structures are at the base of how the autogenerate system knows what Python code to render. -How to connect these two systems, so that our own custom operation -directives can be used? First off, we'd probably be implementing -a :paramref:`.EnvironmentContext.configure.process_revision_directives` -plugin as described previously, so that we can add our own directives -to the autogenerate stream. What if we wanted to add our ``CreateSequenceOp`` -to the autogenerate structure? We basically need to define an autogenerate -renderer for it, as follows:: +Using this knowledge, we can create additional functions that plug into +the autogenerate system so that our new operations can be generated +into migration scripts when ``alembic revision --autogenerate`` is run. + +The following sections will detail an example of this using the +the ``CreateSequenceOp`` and ``DropSequenceOp`` directives +we created in :ref:`operation_plugins`, which correspond to the +SQLAlchemy :class:`~sqlalchemy.schema.Sequence` construct. + +.. versionadded:: 0.8.0 - custom operations can be added to the + autogenerate system to support new kinds of database objects. + +Tracking our Object with the Model +---------------------------------- + +The basic job of an autogenerate comparison function is to inspect +a series of objects in the database and compare them against a series +of objects defined in our model. By "in our model", we mean anything +defined in Python code that we want to track, however most commonly +we're talking about a series of :class:`~sqlalchemy.schema.Table` +objects present in a :class:`~sqlalchemy.schema.MetaData` collection. + +Let's propose a simple way of seeing what :class:`~sqlalchemy.schema.Sequence` +objects we want to ensure exist in the database when autogenerate +runs. While these objects do have some integrations with +:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.MetaData` +already, let's assume they don't, as the example here intends to illustrate +how we would do this for most any kind of custom construct. We +associate the object with the :attr:`~sqlalchemy.schema.MetaData.info` +collection of :class:`~sqlalchemy.schema.MetaData`, which is a dictionary +we can use for anything, which we also know will be passed to the autogenerate +process:: + + from sqlalchemy.schema import Sequence + + def add_sequence_to_model(sequence, metadata): + metadata.info.setdefault("sequences", set()).add( + (sequence.schema, sequence.name) + ) + + my_seq = Sequence("my_sequence") + add_sequence_to_model(my_seq, model_metadata) + +The :attr:`~sqlalchemy.schema.MetaData.info` +dictionary is a good place to put things that we want our autogeneration +routines to be able to locate, which can include any object such as +custom DDL objects representing views, triggers, special constraints, +or anything else we want to support. + - # note: this is a continuation of the example from the - # "Operation Plugins" section +Registering a Comparison Function +--------------------------------- + +We now need to register a comparison hook, which will be used +to compare the database to our model and produce ``CreateSequenceOp`` +and ``DropSequenceOp`` directives to be included in our migration +script. Note that we are assuming a +Postgresql backend:: + + from alembic.autogenerate import comparators + + @comparators.dispatch_for("schema") + def compare_sequences(autogen_context, upgrade_ops, schemas): + all_conn_sequences = set() + + for sch in schemas: + + all_conn_sequences.update([ + (sch, row[0]) for row in + autogen_context.connection.execute( + "SELECT relname FROM pg_class c join " + "pg_namespace n on n.oid=c.relnamespace where " + "relkind='S' and n.nspname=%(nspname)s", + + # note that we consider a schema of 'None' in our + # model to be the "default" name in the PG database; + # this usually is the name 'public' + nspname=autogen_context.dialect.default_schema_name + if sch is None else sch + ) + ]) + + # get the collection of Sequence objects we're storing with + # our MetaData + metadata_sequences = autogen_context.metadata.info.setdefault( + "sequences", set()) + + # for new names, produce CreateSequenceOp directives + for sch, name in metadata_sequences.difference(all_conn_sequences): + upgrade_ops.ops.append( + CreateSequenceOp(name, schema=sch) + ) + + # for names that are going away, produce DropSequenceOp + # directives + for sch, name in all_conn_sequences.difference(metadata_sequences): + upgrade_ops.ops.append( + DropSequenceOp(name, schema=sch) + ) + +Above, we've built a new function ``compare_sequences()`` and registered +it as a "schema" level comparison function with autogenerate. The +job that it performs is that it compares the list of sequence names +present in each database schema with that of a list of sequence names +that we are maintaining in our :class:`~sqlalchemy.schema.MetaData` object. + +When autogenerate completes, it will have a series of +``CreateSequenceOp`` and ``DropSequenceOp`` directives in the list of +"upgrade" operations; the list of "downgrade" operations is generated +directly from these using the +``CreateSequenceOp.reverse()`` and ``DropSequenceOp.reverse()`` methods +that we've implemented on these objects. + +The registration of our function at the scope of "schema" means our +autogenerate comparison function is called outside of the context +of any specific table or column. The three available scopes +are "schema", "table", and "column", summarized as follows: + +* **Schema level** - these hooks are passed a :class:`.AutogenContext`, + an :class:`.UpgradeOps` collection, and a collection of string schema + names to be operated upon. If the + :class:`.UpgradeOps` collection contains changes after all + hooks are run, it is included in the migration script: + + :: + + @comparators.dispatch_for("schema") + def compare_schema_level(autogen_context, upgrade_ops, schemas): + pass + +* **Table level** - these hooks are passed a :class:`.AutogenContext`, + a :class:`.ModifyTableOps` collection, a schema name, table name, + a :class:`~sqlalchemy.schema.Table` reflected from the database if any + or ``None``, and a :class:`~sqlalchemy.schema.Table` present in the + local :class:`~sqlalchemy.schema.MetaData`. If the + :class:`.ModifyTableOps` collection contains changes after all + hooks are run, it is included in the migration script: + + :: + + @comparators.dispatch_for("table") + def compare_table_level(autogen_context, modify_ops, + schemaname, tablename, conn_table, metadata_table): + pass + +* **Column level** - these hooks are passed a :class:`.AutogenContext`, + an :class:`.AlterColumnOp` object, a schema name, table name, + column name, a :class:`~sqlalchemy.schema.Column` reflected from the + database and a :class:`~sqlalchemy.schema.Column` present in the + local table. If the :class:`.AlterColumnOp` contains changes after + all hooks are run, it is included in the migration script; + a "change" is considered to be present if any of the ``modify_`` attributes + are set to a non-default value, or there are any keys + in the ``.kw`` collection with the prefix ``"modify_"``: + + :: + + @comparators.dispatch_for("column") + def compare_column_level(autogen_context, alter_column_op, + schemaname, tname, cname, conn_col, metadata_col): + pass + +The :class:`.AutogenContext` passed to these hooks is documented below. + +.. autoclass:: alembic.autogenerate.api.AutogenContext + :members: + +Creating a Render Function +-------------------------- + +The second autogenerate integration hook is to provide a "render" function; +since the autogenerate +system renders Python code, we need to build a function that renders +the correct "op" instructions for our directive:: from alembic.autogenerate import renderers @@ -209,27 +389,52 @@ renderer for it, as follows:: def render_create_sequence(autogen_context, op): return "op.create_sequence(%r, **%r)" % ( op.sequence_name, - op.kw + {"schema": op.schema} ) -With our render function established, we can our ``CreateSequenceOp`` -generated in an autogenerate context using the :func:`.render_python_code` -debugging function in conjunction with an :class:`.UpgradeOps` structure:: - from alembic.operations import ops - from alembic.autogenerate import render_python_code + @renderers.dispatch_for(DropSequenceOp) + def render_drop_sequence(autogen_context, op): + return "op.drop_sequence(%r, **%r)" % ( + op.sequence_name, + {"schema": op.schema} + ) - upgrade_ops = ops.UpgradeOps( - ops=[ - CreateSequenceOp("my_seq") - ] - ) +The above functions will render Python code corresponding to the +presence of ``CreateSequenceOp`` and ``DropSequenceOp`` instructions +in the list that our comparison function generates. - print(render_python_code(upgrade_ops)) +Running It +---------- -Which produces:: +All the above code can be organized however the developer sees fit; +the only thing that needs to make it work is that when the +Alembic environment ``env.py`` is invoked, it either imports modules +which contain all the above routines, or they are locally present, +or some combination thereof. - ### commands auto generated by Alembic - please adjust! ### - op.create_sequence('my_seq', **{}) +If we then have code in our model (which of course also needs to be invoked +when ``env.py`` runs!) like this:: + + from sqlalchemy.schema import Sequence + + my_seq_1 = Sequence("my_sequence_1") + add_sequence_to_model(my_seq_1, target_metadata) + +When we first run ``alembic revision --autogenerate``, we'll see this +in our migration file:: + + def upgrade(): + ### commands auto generated by Alembic - please adjust! ### + op.create_sequence('my_sequence_1', **{'schema': None}) ### end Alembic commands ### + + def downgrade(): + ### commands auto generated by Alembic - please adjust! ### + op.drop_sequence('my_sequence_1', **{'schema': None}) + ### end Alembic commands ### + +These are our custom directives that will invoke when ``alembic upgrade`` +or ``alembic downgrade`` is run. + diff --git a/docs/build/api/operations.rst b/docs/build/api/operations.rst index d9ff238..2eb8358 100644 --- a/docs/build/api/operations.rst +++ b/docs/build/api/operations.rst @@ -1,7 +1,7 @@ .. _alembic.operations.toplevel: ===================== -The Operations Object +Operation Directives ===================== Within migration scripts, actual database migration operations are handled @@ -48,9 +48,9 @@ migration scripts:: class CreateSequenceOp(MigrateOperation): """Create a SEQUENCE.""" - def __init__(self, sequence_name, **kw): + def __init__(self, sequence_name, schema=None): self.sequence_name = sequence_name - self.kw = kw + self.schema = schema @classmethod def create_sequence(cls, operations, sequence_name, **kw): @@ -59,20 +59,58 @@ migration scripts:: op = CreateSequenceOp(sequence_name, **kw) return operations.invoke(op) -Above, the ``CreateSequenceOp`` class represents a new operation that will -be available as ``op.create_sequence()``. The reason the operation -is represented as a stateful class is so that an operation and a specific + def reverse(self): + # only needed to support autogenerate + return DropSequenceOp(self.sequence_name, schema=self.schema) + + @Operations.register_operation("drop_sequence") + class DropSequenceOp(MigrateOperation): + """Drop a SEQUENCE.""" + + def __init__(self, sequence_name, schema=None): + self.sequence_name = sequence_name + self.schema = schema + + @classmethod + def drop_sequence(cls, operations, sequence_name, **kw): + """Issue a "DROP SEQUENCE" instruction.""" + + op = DropSequenceOp(sequence_name, **kw) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return CreateSequenceOp(self.sequence_name, schema=self.schema) + +Above, the ``CreateSequenceOp`` and ``DropSequenceOp`` classes represent +new operations that will +be available as ``op.create_sequence()`` and ``op.drop_sequence()``. +The reason the operations +are represented as stateful classes is so that an operation and a specific set of arguments can be represented generically; the state can then correspond to different kinds of operations, such as invoking the instruction against a database, or autogenerating Python code for the operation into a script. -In order to establish the migrate-script behavior of the new operation, +In order to establish the migrate-script behavior of the new operations, we use the :meth:`.Operations.implementation_for` decorator:: @Operations.implementation_for(CreateSequenceOp) def create_sequence(operations, operation): - operations.execute("CREATE SEQUENCE %s" % operation.sequence_name) + if operation.schema is not None: + name = "%s.%s" % (operation.schema, operation.sequence_name) + else: + name = operation.sequence_name + operations.execute("CREATE SEQUENCE %s" % name) + + + @Operations.implementation_for(DropSequenceOp) + def drop_sequence(operations, operation): + if operation.schema is not None: + name = "%s.%s" % (operation.schema, operation.sequence_name) + else: + name = operation.sequence_name + operations.execute("DROP SEQUENCE %s" % name) Above, we use the simplest possible technique of invoking our DDL, which is just to call :meth:`.Operations.execute` with literal SQL. If this is @@ -80,16 +118,24 @@ all a custom operation needs, then this is fine. However, options for more comprehensive support include building out a custom SQL construct, as documented at :ref:`sqlalchemy.ext.compiler_toplevel`. -With the above two steps, a migration script can now use a new method -``op.create_sequence()`` that will proxy to our object as a classmethod:: +With the above two steps, a migration script can now use new methods +``op.create_sequence()`` and ``op.drop_sequence()`` that will proxy to +our object as a classmethod:: def upgrade(): op.create_sequence("my_sequence") + def downgrade(): + op.drop_sequence("my_sequence") + The registration of new operations only needs to occur in time for the ``env.py`` script to invoke :meth:`.MigrationContext.run_migrations`; within the module level of the ``env.py`` script is sufficient. +.. seealso:: + + :ref:`autogen_custom_ops` - how to add autogenerate support to + custom operations. .. versionadded:: 0.8 - the migration operations available via the :class:`.Operations` class as well as the ``alembic.op`` namespace diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst index 691402d..424bf8f 100644 --- a/docs/build/changelog.rst +++ b/docs/build/changelog.rst @@ -27,7 +27,7 @@ Changelog .. change:: :tags: feature, autogenerate - :tickets: 301 + :tickets: 301, 306 The internal system for autogenerate been reworked to build upon the extensible system of operation objects present in @@ -38,9 +38,12 @@ Changelog :paramref:`.EnvironmentContext.configure.process_revision_directives` allows end-user code to fully customize what autogenerate will do, including not just full manipulation of the Python steps to take - but also what file or files will be written and where. It is also - possible to write a system that reads an autogenerate stream and - invokes it directly against a database without writing any files. + but also what file or files will be written and where. Additionally, + autogenerate is now extensible as far as database objects compared + and rendered into scripts; any new operation directive can also be + registered into a series of hooks that allow custom database/model + comparison functions to run as well as to render new operation + directives into autogenerate scripts. .. seealso:: diff --git a/tests/_autogen_fixtures.py b/tests/_autogen_fixtures.py index 7ef6cbf..e668885 100644 --- a/tests/_autogen_fixtures.py +++ b/tests/_autogen_fixtures.py @@ -2,12 +2,14 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text from sqlalchemy.engine.reflection import Inspector +from alembic.operations import ops from alembic import autogenerate from alembic.migration import MigrationContext from alembic.testing import config from alembic.testing.env import staging_env, clear_staging_env from alembic.testing import eq_ from alembic.ddl.base import _fk_spec +from alembic.autogenerate import api names_in_this_test = set() @@ -25,9 +27,7 @@ def _default_include_object(obj, name, type_, reflected, compare_to): else: return True -_default_object_filters = [ - _default_include_object -] +_default_object_filters = _default_include_object class ModelOne(object): @@ -177,6 +177,7 @@ class AutogenTest(_ComparesFKs): 'downgrade_token': "downgrades", 'alembic_module_prefix': 'op.', 'sqlalchemy_module_prefix': 'sa.', + 'include_object': _default_object_filters } if self.configure_opts: ctx_opts.update(self.configure_opts) @@ -185,17 +186,18 @@ class AutogenTest(_ComparesFKs): opts=ctx_opts ) - connection = context.bind - self.autogen_context = { - 'imports': set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context - } + self.autogen_context = api.AutogenContext(context, self.m2) def tearDown(self): self.conn.close() + def _update_context(self, object_filters=None, include_schemas=None): + if include_schemas is not None: + self.autogen_context.opts['include_schemas'] = include_schemas + if object_filters is not None: + self.autogen_context._object_filters = [object_filters] + return self.autogen_context + class AutogenFixtureTest(_ComparesFKs): @@ -214,6 +216,8 @@ class AutogenFixtureTest(_ComparesFKs): 'downgrade_token': "downgrades", 'alembic_module_prefix': 'op.', 'sqlalchemy_module_prefix': 'sa.', + 'include_object': object_filters, + 'include_schemas': include_schemas } if opts: ctx_opts.update(opts) @@ -222,21 +226,12 @@ class AutogenFixtureTest(_ComparesFKs): opts=ctx_opts ) - connection = context.bind - autogen_context = { - 'imports': set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context, - 'metadata': model_metadata, - 'object_filters': object_filters, - 'include_schemas': include_schemas - } - diffs = [] + autogen_context = api.AutogenContext(context, model_metadata) + uo = ops.UpgradeOps(ops=[]) autogenerate._produce_net_changes( - autogen_context, diffs + autogen_context, uo ) - return diffs + return uo.as_diffs() reports_unnamed_constraints = False diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py index b1717ab..6d1f55b 100644 --- a/tests/test_autogen_composition.py +++ b/tests/test_autogen_composition.py @@ -23,7 +23,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): } ) template_args = {} - autogenerate._render_migration_diffs(context, template_args, set()) + autogenerate._render_migration_diffs(context, template_args) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### @@ -50,10 +50,8 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): } ) template_args = {} - autogenerate._render_migration_diffs( - context, template_args, set(), + autogenerate._render_migration_diffs(context, template_args) - ) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### pass @@ -67,8 +65,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): """test a full render including indentation""" template_args = {} - autogenerate._render_migration_diffs( - self.context, template_args, set()) + autogenerate._render_migration_diffs(self.context, template_args) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### op.create_table('item', @@ -135,8 +132,7 @@ nullable=True)) template_args = {} self.context.opts['render_as_batch'] = True - autogenerate._render_migration_diffs( - self.context, template_args, set()) + autogenerate._render_migration_diffs(self.context, template_args) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### @@ -229,10 +225,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): } ) template_args = {} - autogenerate._render_migration_diffs( - context, template_args, set(), + autogenerate._render_migration_diffs(context, template_args) - ) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### pass @@ -250,9 +244,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): 'include_object': _default_include_object, 'include_schemas': True }) - autogenerate._render_migration_diffs( - self.context, template_args, set() - ) + autogenerate._render_migration_diffs(self.context, template_args) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### @@ -326,3 +318,4 @@ name='extra_uid_fkey'), ) op.drop_table('item', schema='%(schema)s') ### end Alembic commands ###""" % {"schema": self.schema}) + diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py index f32fd84..d176b91 100644 --- a/tests/test_autogen_diffs.py +++ b/tests/test_autogen_diffs.py @@ -6,6 +6,7 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ from sqlalchemy.types import NULLTYPE from sqlalchemy.engine.reflection import Inspector +from alembic.operations import ops from alembic import autogenerate from alembic.migration import MigrationContext from alembic.testing import TestBase @@ -14,8 +15,7 @@ from alembic.testing import assert_raises_message from alembic.testing.mock import Mock from alembic.testing import eq_ from alembic.util import CommandError -from ._autogen_fixtures import \ - AutogenTest, AutogenFixtureTest, _default_object_filters +from ._autogen_fixtures import AutogenTest, AutogenFixtureTest py3k = sys.version_info >= (3, ) @@ -63,25 +63,24 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): return m def test_default_schema_omitted_upgrade(self): - diffs = [] def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t3" else: return True - self.autogen_context.update({ - 'object_filters': [include_object], - 'include_schemas': True, - 'metadata': self.m2 - }) - autogenerate._produce_net_changes(self.autogen_context, diffs) + self._update_context( + object_filters=include_object, + include_schemas=True, + ) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(self.autogen_context, uo) + diffs = uo.as_diffs() eq_(diffs[0][0], "add_table") eq_(diffs[0][1].schema, None) def test_alt_schema_included_upgrade(self): - diffs = [] def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": @@ -89,48 +88,48 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): else: return True - self.autogen_context.update({ - 'object_filters': [include_object], - 'include_schemas': True, - 'metadata': self.m2 - }) - autogenerate._produce_net_changes(self.autogen_context, diffs) + self._update_context( + object_filters=include_object, + include_schemas=True, + ) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(self.autogen_context, uo) + diffs = uo.as_diffs() eq_(diffs[0][0], "add_table") eq_(diffs[0][1].schema, config.test_schema) def test_default_schema_omitted_downgrade(self): - diffs = [] - def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t1" else: return True - self.autogen_context.update({ - 'object_filters': [include_object], - 'include_schemas': True, - 'metadata': self.m2 - }) - autogenerate._produce_net_changes(self.autogen_context, diffs) + self._update_context( + object_filters=include_object, + include_schemas=True, + ) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(self.autogen_context, uo) + diffs = uo.as_diffs() eq_(diffs[0][0], "remove_table") eq_(diffs[0][1].schema, None) def test_alt_schema_included_downgrade(self): - diffs = [] def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t2" else: return True - self.autogen_context.update({ - 'object_filters': [include_object], - 'include_schemas': True, - 'metadata': self.m2 - }) - autogenerate._produce_net_changes(self.autogen_context, diffs) + self._update_context( + object_filters=include_object, + include_schemas=True, + ) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(self.autogen_context, uo) + diffs = uo.as_diffs() eq_(diffs[0][0], "remove_table") eq_(diffs[0][1].schema, config.test_schema) @@ -268,14 +267,14 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): """test generation of diff rules""" metadata = self.m2 - diffs = [] - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 - ctx['object_filters'] = _default_object_filters + uo = ops.UpgradeOps(ops=[]) + ctx = self.autogen_context + autogenerate._produce_net_changes( - ctx, diffs + ctx, uo ) + diffs = uo.as_diffs() eq_( diffs[0], ('add_table', metadata.tables['item']) @@ -396,21 +395,25 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): eq_(alter_cols, set(['user_id', 'order', 'user'])) def test_skip_null_type_comparison_reflected(self): - diff = [] - autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", NULLTYPE), - Column("somecol", Integer()), - diff, self.autogen_context - ) + ac = ops.AlterColumnOp("sometable", "somecol") + autogenerate.compare._compare_type( + self.autogen_context, ac, + None, "sometable", "somecol", + Column("somecol", NULLTYPE), + Column("somecol", Integer()), + ) + diff = ac.to_diff_tuple() assert not diff def test_skip_null_type_comparison_local(self): - diff = [] - autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", Integer()), - Column("somecol", NULLTYPE), - diff, self.autogen_context - ) + ac = ops.AlterColumnOp("sometable", "somecol") + autogenerate.compare._compare_type( + self.autogen_context, ac, + None, "sometable", "somecol", + Column("somecol", Integer()), + Column("somecol", NULLTYPE), + ) + diff = ac.to_diff_tuple() assert not diff def test_custom_type_compare(self): @@ -420,20 +423,24 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): def compare_against_backend(self, dialect, conn_type): return isinstance(conn_type, Integer) - diff = [] - autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", INTEGER()), - Column("somecol", MyType()), - diff, self.autogen_context - ) - assert not diff + ac = ops.AlterColumnOp("sometable", "somecol") + autogenerate.compare._compare_type( + self.autogen_context, ac, + None, "sometable", "somecol", + Column("somecol", INTEGER()), + Column("somecol", MyType()), + ) + + assert not ac.has_changes() - diff = [] - autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", String()), - Column("somecol", MyType()), - diff, self.autogen_context - ) + ac = ops.AlterColumnOp("sometable", "somecol") + autogenerate.compare._compare_type( + self.autogen_context, ac, + None, "sometable", "somecol", + Column("somecol", String()), + Column("somecol", MyType()), + ) + diff = ac.to_diff_tuple() eq_( diff[0][0:4], ('modify_type', None, 'sometable', 'somecol') @@ -449,26 +456,26 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): else: return dialect.type_descriptor(CHAR(32)) - diff = [] + uo = ops.AlterColumnOp('sometable', 'somecol') autogenerate.compare._compare_type( + self.autogen_context, uo, None, "sometable", "somecol", Column("somecol", Integer, nullable=True), - Column("somecol", MyType()), - diff, self.autogen_context + Column("somecol", MyType()) ) - assert not diff + assert not uo.has_changes() def test_dont_barf_on_already_reflected(self): - diffs = [] from sqlalchemy.util import OrderedSet inspector = Inspector.from_engine(self.bind) + uo = ops.UpgradeOps(ops=[]) autogenerate.compare._compare_tables( OrderedSet([(None, 'extra'), (None, 'user')]), - OrderedSet(), [], inspector, - MetaData(), diffs, self.autogen_context + OrderedSet(), inspector, + MetaData(), uo, self.autogen_context ) eq_( - [(rec[0], rec[1].name) for rec in diffs], + [(rec[0], rec[1].name) for rec in uo.as_diffs()], [('remove_table', 'extra'), ('remove_table', 'user')] ) @@ -481,14 +488,14 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): """test generation of diff rules""" metadata = self.m2 - diffs = [] - self.autogen_context.update({ - 'object_filters': _default_object_filters, - 'include_schemas': True, - 'metadata': self.m2 - }) - autogenerate._produce_net_changes(self.autogen_context, diffs) + self._update_context( + include_schemas=True, + ) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(self.autogen_context, uo) + + diffs = uo.as_diffs() eq_( diffs[0], @@ -567,10 +574,10 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): my_compare_type = Mock() self.context._user_compare_type = my_compare_type - diffs = [] - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 - autogenerate._produce_net_changes(ctx, diffs) + uo = ops.UpgradeOps(ops=[]) + + ctx = self.autogen_context + autogenerate._produce_net_changes(ctx, uo) first_table = self.m2.tables['sometable'] first_column = first_table.columns['id'] @@ -593,8 +600,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): self.context._user_compare_type = my_compare_type diffs = [] - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 + ctx = self.autogen_context diffs = [] autogenerate._produce_net_changes(ctx, diffs) @@ -605,10 +611,10 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): my_compare_type.return_value = True self.context._user_compare_type = my_compare_type - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 - diffs = [] - autogenerate._produce_net_changes(ctx, diffs) + ctx = self.autogen_context + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(ctx, uo) + diffs = uo.as_diffs() eq_(diffs[0][0][0], 'modify_type') eq_(diffs[1][0][0], 'modify_type') @@ -636,8 +642,7 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase): def test_no_change(self): diffs = [] - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 + ctx = self.autogen_context autogenerate._produce_net_changes(ctx, diffs) eq_(diffs, []) @@ -674,11 +679,11 @@ class AutogenKeyTest(AutogenTest, TestBase): def test_autogen(self): - diffs = [] + uo = ops.UpgradeOps(ops=[]) - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 - autogenerate._produce_net_changes(ctx, diffs) + ctx = self.autogen_context + autogenerate._produce_net_changes(ctx, uo) + diffs = uo.as_diffs() eq_(diffs[0][0], "add_table") eq_(diffs[0][1].name, "sometable") eq_(diffs[1][0], "add_column") @@ -705,8 +710,7 @@ class AutogenVersionTableTest(AutogenTest, TestBase): def test_no_version_table(self): diffs = [] - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 + ctx = self.autogen_context autogenerate._produce_net_changes(ctx, diffs) eq_(diffs, []) @@ -717,8 +721,7 @@ class AutogenVersionTableTest(AutogenTest, TestBase): self.version_table_name, self.m2, Column('x', Integer), schema=self.version_table_schema) - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 + ctx = self.autogen_context autogenerate._produce_net_changes(ctx, diffs) eq_(diffs, []) @@ -769,10 +772,10 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase): before their parent tables """ - ctx = self.autogen_context.copy() - ctx['metadata'] = self.m2 - diffs = [] - autogenerate._produce_net_changes(ctx, diffs) + ctx = self.autogen_context + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(ctx, uo) + diffs = uo.as_diffs() eq_(diffs[0][0], 'add_table') eq_(diffs[0][1].name, "parent") diff --git a/tests/test_autogen_fks.py b/tests/test_autogen_fks.py index 525bed5..174a538 100644 --- a/tests/test_autogen_fks.py +++ b/tests/test_autogen_fks.py @@ -351,7 +351,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): type_ == 'foreign_key_constraint' and reflected and name == 'fk1') - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) self._assert_fk_diff( diffs[0], "remove_fk", @@ -390,7 +390,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): type_ == 'foreign_key_constraint' and not reflected and name == 'fk1') - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) self._assert_fk_diff( diffs[0], "add_fk", @@ -456,7 +456,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): and name == 'fk1' ) - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) self._assert_fk_diff( diffs[0], "remove_fk", diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index 8ee33bc..9b6cd44 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -798,7 +798,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): isinstance(object_, Index) and type_ == 'index' and reflected and name == 'ix1') - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) eq_(diffs[0][0], 'remove_index') eq_(diffs[0][1].name, 'ix2') @@ -825,7 +825,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): isinstance(object_, UniqueConstraint) and type_ == 'unique_constraint' and reflected and name == 'uq1') - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) eq_(diffs[0][0], 'remove_constraint') eq_(diffs[0][1].name, 'uq2') @@ -846,7 +846,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): isinstance(object_, Index) and type_ == 'index' and not reflected and name == 'ix1') - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) eq_(diffs[0][0], 'add_index') eq_(diffs[0][1].name, 'ix2') @@ -871,7 +871,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): type_ == 'unique_constraint' and not reflected and name == 'uq1') - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) eq_(diffs[0][0], 'add_constraint') eq_(diffs[0][1].name, 'uq2') @@ -899,7 +899,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): type_ == 'index' and not reflected and name == 'ix1' and isinstance(compare_to, Index)) - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) eq_(diffs[0][0], 'remove_index') eq_(diffs[0][1].name, 'ix2') @@ -935,7 +935,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): not reflected and name == 'uq1' and isinstance(compare_to, UniqueConstraint)) - diffs = self._fixture(m1, m2, object_filters=[include_object]) + diffs = self._fixture(m1, m2, object_filters=include_object) eq_(diffs[0][0], 'remove_constraint') eq_(diffs[0][1].name, 'uq2') diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index 4a49d5c..a73cff5 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -14,6 +14,8 @@ from sqlalchemy.types import UserDefinedType from sqlalchemy.dialects import mysql, postgresql from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.sql import and_, column, literal_column, false +from alembic.migration import MigrationContext +from alembic.autogenerate import api from alembic.testing.mock import patch @@ -32,22 +34,30 @@ class AutogenRenderTest(TestBase): """test individual directives""" - @classmethod - def setup_class(cls): - cls.autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - }, - 'dialect': mysql.dialect() - } - cls.pg_autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - }, - 'dialect': postgresql.dialect() + def setUp(self): + ctx_opts = { + 'sqlalchemy_module_prefix': 'sa.', + 'alembic_module_prefix': 'op.', + 'target_metadata': MetaData() } + context = MigrationContext.configure( + dialect_name="mysql", + opts=ctx_opts + ) + + self.autogen_context = api.AutogenContext(context) + + context = MigrationContext.configure( + dialect_name="postgresql", + opts=ctx_opts + ) + self.pg_autogen_context = api.AutogenContext(context) + + context = MigrationContext.configure( + dialect=DefaultDialect(), + opts=ctx_opts + ) + self.default_autogen_context = api.AutogenContext(context) def test_render_add_index(self): """ @@ -812,10 +822,10 @@ unique=False, """ return "col(%s)" % obj.name return "render:%s" % type_ - autogen_context = {"opts": { - 'render_item': render, - 'alembic_module_prefix': 'sa.' - }} + self.autogen_context.opts.update( + render_item=render, + alembic_module_prefix='sa.' + ) t = Table('t', MetaData(), Column('x', Integer), @@ -824,7 +834,7 @@ unique=False, """ ForeignKeyConstraint(['x'], ['y']) ) op_obj = ops.CreateTableOp.from_table(t) - result = autogenerate.render_op_text(autogen_context, op_obj) + result = autogenerate.render_op_text(self.autogen_context, op_obj) eq_ignore_whitespace( result, "sa.create_table('t'," @@ -1087,28 +1097,13 @@ unique=False, """ def test_repr_plain_sqla_type(self): type_ = Integer() - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - }, - 'dialect': mysql.dialect() - } - eq_ignore_whitespace( - autogenerate.render._repr_type(type_, autogen_context), + autogenerate.render._repr_type(type_, self.autogen_context), "sa.Integer()" ) def test_repr_custom_type_w_sqla_prefix(self): - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'user_module_prefix': None - }, - 'dialect': mysql.dialect() - } + self.autogen_context.opts['user_module_prefix'] = None class MyType(UserDefinedType): pass @@ -1118,7 +1113,7 @@ unique=False, """ type_ = MyType() eq_ignore_whitespace( - autogenerate.render._repr_type(type_, autogen_context), + autogenerate.render._repr_type(type_, self.autogen_context), "sqlalchemy_util.types.MyType()" ) @@ -1129,17 +1124,10 @@ unique=False, """ return "MYTYPE" type_ = MyType() - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'user_module_prefix': None - }, - 'dialect': mysql.dialect() - } + self.autogen_context.opts['user_module_prefix'] = None eq_ignore_whitespace( - autogenerate.render._repr_type(type_, autogen_context), + autogenerate.render._repr_type(type_, self.autogen_context), "tests.test_autogen_render.MyType()" ) @@ -1152,17 +1140,11 @@ unique=False, """ return "MYTYPE" type_ = MyType() - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'user_module_prefix': 'user.', - }, - 'dialect': mysql.dialect() - } + + self.autogen_context.opts['user_module_prefix'] = 'user.' eq_ignore_whitespace( - autogenerate.render._repr_type(type_, autogen_context), + autogenerate.render._repr_type(type_, self.autogen_context), "user.MyType()" ) @@ -1171,20 +1153,14 @@ unique=False, """ from sqlalchemy.dialects.mysql import VARCHAR type_ = VARCHAR(20, charset='utf8', national=True) - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'user_module_prefix': None, - }, - 'imports': set(), - 'dialect': mysql.dialect() - } + + self.autogen_context.opts['user_module_prefix'] = None + eq_ignore_whitespace( - autogenerate.render._repr_type(type_, autogen_context), + autogenerate.render._repr_type(type_, self.autogen_context), "mysql.VARCHAR(charset='utf8', national=True, length=20)" ) - eq_(autogen_context['imports'], + eq_(self.autogen_context._imports, set(['from sqlalchemy.dialects import mysql']) ) @@ -1204,19 +1180,12 @@ unique=False, """ ) def test_render_server_default_native_boolean(self): - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - }, - 'dialect': postgresql.dialect() - } c = Column( 'updated_at', Boolean(), server_default=false(), nullable=False) result = autogenerate.render._render_column( - c, autogen_context, + c, self.autogen_context, ) eq_ignore_whitespace( result, @@ -1231,17 +1200,10 @@ unique=False, """ 'updated_at', Boolean(), server_default=false(), nullable=False) - dialect = DefaultDialect() - autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - }, - 'dialect': dialect - } +# MARKMARK result = autogenerate.render._render_column( - c, autogen_context + c, self.default_autogen_context ) eq_ignore_whitespace( result, @@ -1296,16 +1258,6 @@ unique=False, """ class RenderNamingConventionTest(TestBase): __requires__ = ('sqlalchemy_094',) - @classmethod - def setup_class(cls): - cls.autogen_context = { - 'opts': { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - }, - 'dialect': postgresql.dialect() - } - def setUp(self): convention = { @@ -1322,6 +1274,17 @@ class RenderNamingConventionTest(TestBase): naming_convention=convention ) + ctx_opts = { + 'sqlalchemy_module_prefix': 'sa.', + 'alembic_module_prefix': 'op.', + 'target_metadata': MetaData() + } + context = MigrationContext.configure( + dialect_name="postgresql", + opts=ctx_opts + ) + self.autogen_context = api.AutogenContext(context) + def test_schema_type_boolean(self): t = Table('t', self.metadata, Column('c', Boolean(name='xyz'))) op_obj = ops.AddColumnOp.from_column(t.c.c) @@ -1457,3 +1420,29 @@ class RenderNamingConventionTest(TestBase): "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))" ) + def test_create_table_plus_add_index_in_modify(self): + uo = ops.UpgradeOps(ops=[ + ops.CreateTableOp( + "sometable", + [Column('x', Integer), Column('y', Integer)] + ), + ops.ModifyTableOps( + "sometable", ops=[ + ops.CreateIndexOp('ix1', 'sometable', ['x', 'y']) + ] + ) + ]) + + eq_( + autogenerate.render_python_code(uo, render_as_batch=True), + "### commands auto generated by Alembic - please adjust! ###\n" + " op.create_table('sometable',\n" + " sa.Column('x', sa.Integer(), nullable=True),\n" + " sa.Column('y', sa.Integer(), nullable=True)\n" + " )\n" + " with op.batch_alter_table('sometable', schema=None) " + "as batch_op:\n" + " batch_op.create_index(" + "'ix1', ['x', 'y'], unique=False)\n\n" + " ### end Alembic commands ###" + ) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index e70d05a..576d957 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -8,9 +8,11 @@ from sqlalchemy.sql import table, column from alembic.autogenerate.compare import \ _compare_server_default, _compare_tables, _render_server_default_for_compare +from alembic.operations import ops from alembic import command, util from alembic.migration import MigrationContext from alembic.script import ScriptDirectory +from alembic.autogenerate import api from alembic.testing import eq_, provide_metadata from alembic.testing.env import staging_env, clear_staging_env, \ @@ -162,34 +164,22 @@ class PostgresqlDefaultCompareTest(TestBase): def setup_class(cls): cls.bind = config.db staging_env() - context = MigrationContext.configure( + cls.migration_context = MigrationContext.configure( connection=cls.bind.connect(), opts={ 'compare_type': True, 'compare_server_default': True } ) - connection = context.bind - cls.autogen_context = { - 'imports': set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context, - 'opts': { - 'compare_type': True, - 'compare_server_default': True, - 'alembic_module_prefix': 'op.', - 'sqlalchemy_module_prefix': 'sa.', - } - } + + def setUp(self): + self.metadata = MetaData(self.bind) + self.autogen_context = api.AutogenContext(self.migration_context) @classmethod def teardown_class(cls): clear_staging_env() - def setUp(self): - self.metadata = MetaData(self.bind) - def tearDown(self): self.metadata.drop_all() @@ -212,9 +202,12 @@ class PostgresqlDefaultCompareTest(TestBase): cols = insp.get_columns(t1.name) insp_col = Column("somecol", cols[0]['type'], server_default=text(cols[0]['default'])) - diffs = [] - _compare_server_default(None, "test", "somecol", insp_col, - t2.c.somecol, diffs, self.autogen_context) + op = ops.AlterColumnOp("test", "somecol") + _compare_server_default( + self.autogen_context, op, + None, "test", "somecol", insp_col, t2.c.somecol) + + diffs = op.to_diff_tuple() eq_(bool(diffs), diff_expected) def _compare_default( @@ -225,7 +218,7 @@ class PostgresqlDefaultCompareTest(TestBase): t1.create(self.bind, checkfirst=True) insp = Inspector.from_engine(self.bind) cols = insp.get_columns(t1.name) - ctx = self.autogen_context['context'] + ctx = self.autogen_context.migration_context return ctx.impl.compare_server_default( None, @@ -385,26 +378,16 @@ class PostgresqlDetectSerialTest(TestBase): cls.bind = config.db cls.conn = cls.bind.connect() staging_env() - context = MigrationContext.configure( + cls.migration_context = MigrationContext.configure( connection=cls.conn, opts={ 'compare_type': True, 'compare_server_default': True } ) - connection = context.bind - cls.autogen_context = { - 'imports': set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context, - 'opts': { - 'compare_type': True, - 'compare_server_default': True, - 'alembic_module_prefix': 'op.', - 'sqlalchemy_module_prefix': 'sa.', - } - } + + def setUp(self): + self.autogen_context = api.AutogenContext(self.migration_context) @classmethod def teardown_class(cls): @@ -420,24 +403,26 @@ class PostgresqlDetectSerialTest(TestBase): self.metadata.create_all(config.db) insp = Inspector.from_engine(config.db) - diffs = [] + + uo = ops.UpgradeOps(ops=[]) _compare_tables( set([(None, 't')]), set([]), - [], - insp, self.metadata, diffs, self.autogen_context) + insp, self.metadata, uo, self.autogen_context) + diffs = uo.as_diffs() tab = diffs[0][1] + eq_(_render_server_default_for_compare( tab.c.x.server_default, tab.c.x, self.autogen_context), c_expected) insp = Inspector.from_engine(config.db) - diffs = [] + uo = ops.UpgradeOps(ops=[]) m2 = MetaData() Table('t', m2, Column('x', BigInteger())) _compare_tables( set([(None, 't')]), set([(None, 't')]), - [], - insp, m2, diffs, self.autogen_context) + insp, m2, uo, self.autogen_context) + diffs = uo.as_diffs() server_default = diffs[0][0][4]['existing_server_default'] eq_(_render_server_default_for_compare( server_default, tab.c.x, self.autogen_context), |