diff options
85 files changed, 10842 insertions, 8317 deletions
diff --git a/alembic/__init__.py b/alembic/__init__.py index 3432a88..a7a2845 100644 --- a/alembic/__init__.py +++ b/alembic/__init__.py @@ -1,6 +1,6 @@ from os import path -__version__ = '1.0.6' +__version__ = "1.0.6" package_dir = path.abspath(path.dirname(__file__)) @@ -11,5 +11,6 @@ from . import context # noqa import sys from .runtime import environment from .runtime import migration -sys.modules['alembic.migration'] = migration -sys.modules['alembic.environment'] = environment + +sys.modules["alembic.migration"] = migration +sys.modules["alembic.environment"] = environment diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py index 142f55d..ad3e6e1 100644 --- a/alembic/autogenerate/__init__.py +++ b/alembic/autogenerate/__init__.py @@ -1,8 +1,10 @@ -from .api import ( # noqa - compare_metadata, _render_migration_diffs, - produce_migrations, render_python_code, - RevisionContext - ) +from .api import ( # noqa + compare_metadata, + _render_migration_diffs, + produce_migrations, + render_python_code, + RevisionContext, +) from .compare import _produce_net_changes, comparators # noqa from .render import render_op_text, renderers # noqa -from .rewriter import Rewriter # noqa
\ No newline at end of file +from .rewriter import Rewriter # noqa diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 15b5b6b..cfd6e86 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -136,8 +136,8 @@ def produce_migrations(context, metadata): def render_python_code( up_or_down_op, - sqlalchemy_module_prefix='sa.', - alembic_module_prefix='op.', + sqlalchemy_module_prefix="sa.", + alembic_module_prefix="op.", render_as_batch=False, imports=(), render_item=None, @@ -150,16 +150,17 @@ def render_python_code( """ opts = { - 'sqlalchemy_module_prefix': sqlalchemy_module_prefix, - 'alembic_module_prefix': alembic_module_prefix, - 'render_item': render_item, - 'render_as_batch': render_as_batch, + "sqlalchemy_module_prefix": sqlalchemy_module_prefix, + "alembic_module_prefix": alembic_module_prefix, + "render_item": render_item, + "render_as_batch": render_as_batch, } autogen_context = AutogenContext(None, opts=opts) autogen_context.imports = set(imports) - return render._indent(render._render_cmd_body( - up_or_down_op, autogen_context)) + return render._indent( + render._render_cmd_body(up_or_down_op, autogen_context) + ) def _render_migration_diffs(context, template_args): @@ -240,42 +241,53 @@ class AutogenContext(object): """The :class:`.MigrationContext` established by the ``env.py`` script.""" def __init__( - self, migration_context, metadata=None, - opts=None, autogenerate=True): - - if autogenerate and \ - migration_context is not None and migration_context.as_sql: + self, migration_context, metadata=None, opts=None, autogenerate=True + ): + + if ( + autogenerate + and migration_context is not None + and migration_context.as_sql + ): raise util.CommandError( "autogenerate can't use as_sql=True as it prevents querying " - "the database for schema information") + "the database for schema information" + ) if opts is None: opts = migration_context.opts - self.metadata = metadata = opts.get('target_metadata', None) \ - if metadata is None else metadata + self.metadata = metadata = ( + opts.get("target_metadata", None) if metadata is None else metadata + ) - if autogenerate and metadata is None and \ - migration_context is not None and \ - migration_context.script is not None: + if ( + autogenerate + and metadata is None + and migration_context is not None + and migration_context.script is not None + ): raise util.CommandError( "Can't proceed with --autogenerate option; environment " "script %s does not provide " - "a MetaData object or sequence of objects to the context." % ( - migration_context.script.env_py_location - )) + "a MetaData object or sequence of objects to the context." + % (migration_context.script.env_py_location) + ) - include_symbol = opts.get('include_symbol', None) - include_object = opts.get('include_object', None) + include_symbol = opts.get("include_symbol", None) + include_object = opts.get("include_object", None) object_filters = [] if include_symbol: + def include_symbol_filter( - object, name, type_, reflected, compare_to): + object, name, type_, reflected, compare_to + ): if type_ == "table": return include_symbol(name, object.schema) else: return True + object_filters.append(include_symbol_filter) if include_object: object_filters.append(include_object) @@ -357,8 +369,8 @@ class AutogenContext(object): if intersect: raise ValueError( "Duplicate table keys across multiple " - "MetaData objects: %s" % - (", ".join('"%s"' % key for key in sorted(intersect))) + "MetaData objects: %s" + % (", ".join('"%s"' % key for key in sorted(intersect))) ) result.update(m.tables) @@ -369,26 +381,29 @@ class RevisionContext(object): """Maintains configuration and state that's specific to a revision file generation operation.""" - def __init__(self, config, script_directory, command_args, - process_revision_directives=None): + def __init__( + self, + config, + script_directory, + command_args, + process_revision_directives=None, + ): self.config = config self.script_directory = script_directory self.command_args = command_args self.process_revision_directives = process_revision_directives self.template_args = { - 'config': config # Let templates use config for - # e.g. multiple databases + "config": config # Let templates use config for + # e.g. multiple databases } - self.generated_revisions = [ - self._default_revision() - ] + self.generated_revisions = [self._default_revision()] def _to_script(self, migration_script): template_args = {} for k, v in self.template_args.items(): template_args.setdefault(k, v) - if getattr(migration_script, '_needs_render', False): + if getattr(migration_script, "_needs_render", False): autogen_context = self._last_autogen_context # clear out existing imports if we are doing multiple @@ -409,7 +424,8 @@ class RevisionContext(object): branch_labels=migration_script.branch_label, version_path=migration_script.version_path, depends_on=migration_script.depends_on, - **template_args) + **template_args + ) def run_autogenerate(self, rev, migration_context): self._run_environment(rev, migration_context, True) @@ -419,21 +435,24 @@ class RevisionContext(object): def _run_environment(self, rev, migration_context, autogenerate): if autogenerate: - if self.command_args['sql']: + if self.command_args["sql"]: raise util.CommandError( - "Using --sql with --autogenerate does not make any sense") - if set(self.script_directory.get_revisions(rev)) != \ - set(self.script_directory.get_revisions("heads")): + "Using --sql with --autogenerate does not make any sense" + ) + if set(self.script_directory.get_revisions(rev)) != set( + self.script_directory.get_revisions("heads") + ): raise util.CommandError("Target database is not up to date.") - upgrade_token = migration_context.opts['upgrade_token'] - downgrade_token = migration_context.opts['downgrade_token'] + upgrade_token = migration_context.opts["upgrade_token"] + downgrade_token = migration_context.opts["downgrade_token"] migration_script = self.generated_revisions[-1] - if not getattr(migration_script, '_needs_render', False): + if not getattr(migration_script, "_needs_render", False): migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token - migration_script.downgrade_ops_list[-1].downgrade_token = \ - downgrade_token + migration_script.downgrade_ops_list[ + -1 + ].downgrade_token = downgrade_token migration_script._needs_render = True else: migration_script._upgrade_ops.append( @@ -443,18 +462,21 @@ class RevisionContext(object): ops.DowngradeOps([], downgrade_token=downgrade_token) ) - self._last_autogen_context = autogen_context = \ - AutogenContext(migration_context, autogenerate=autogenerate) + self._last_autogen_context = autogen_context = AutogenContext( + migration_context, autogenerate=autogenerate + ) if autogenerate: compare._populate_migration_script( - autogen_context, migration_script) + autogen_context, migration_script + ) if self.process_revision_directives: self.process_revision_directives( - migration_context, rev, self.generated_revisions) + migration_context, rev, self.generated_revisions + ) - hook = migration_context.opts['process_revision_directives'] + hook = migration_context.opts["process_revision_directives"] if hook: hook(migration_context, rev, self.generated_revisions) @@ -463,15 +485,15 @@ class RevisionContext(object): def _default_revision(self): op = ops.MigrationScript( - rev_id=self.command_args['rev_id'] or util.rev_id(), - message=self.command_args['message'], + rev_id=self.command_args["rev_id"] or util.rev_id(), + message=self.command_args["message"], upgrade_ops=ops.UpgradeOps([]), downgrade_ops=ops.DowngradeOps([]), - head=self.command_args['head'], - splice=self.command_args['splice'], - branch_label=self.command_args['branch_label'], - version_path=self.command_args['version_path'], - depends_on=self.command_args['depends_on'] + head=self.command_args["head"], + splice=self.command_args["splice"], + branch_label=self.command_args["branch_label"], + version_path=self.command_args["version_path"], + depends_on=self.command_args["depends_on"], ) return op diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 8b41647..7ff8be6 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -29,7 +29,7 @@ comparators = util.Dispatcher(uselist=True) def _produce_net_changes(autogen_context, upgrade_ops): connection = autogen_context.connection - include_schemas = autogen_context.opts.get('include_schemas', False) + include_schemas = autogen_context.opts.get("include_schemas", False) inspector = Inspector.from_engine(connection) @@ -55,8 +55,9 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas): conn_table_names = set() - version_table_schema = \ + version_table_schema = ( autogen_context.migration_context.version_table_schema + ) version_table = autogen_context.migration_context.version_table for s in schemas: @@ -71,12 +72,22 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas): [(table.schema, table.name) for table in autogen_context.sorted_tables] ).difference([(version_table_schema, version_table)]) - _compare_tables(conn_table_names, metadata_table_names, - inspector, upgrade_ops, autogen_context) + _compare_tables( + conn_table_names, + metadata_table_names, + inspector, + upgrade_ops, + autogen_context, + ) -def _compare_tables(conn_table_names, metadata_table_names, - inspector, upgrade_ops, autogen_context): +def _compare_tables( + conn_table_names, + metadata_table_names, + inspector, + upgrade_ops, + autogen_context, +): default_schema = inspector.bind.dialect.default_schema_name @@ -85,10 +96,12 @@ def _compare_tables(conn_table_names, metadata_table_names, # of table names from local metadata that also have "None" if schema # == default_schema_name. Most setups will be like this anyway but # some are not (see #170) - metadata_table_names_no_dflt_schema = OrderedSet([ - (schema if schema != default_schema else None, tname) - for schema, tname in metadata_table_names - ]) + metadata_table_names_no_dflt_schema = OrderedSet( + [ + (schema if schema != default_schema else None, tname) + for schema, tname in metadata_table_names + ] + ) # to adjust for the MetaData collection storing the tables either # as "schemaname.tablename" or just "tablename", create a new lookup @@ -97,27 +110,34 @@ def _compare_tables(conn_table_names, metadata_table_names, ( no_dflt_schema, autogen_context.table_key_to_table[ - sa_schema._get_table_key(tname, schema)] + sa_schema._get_table_key(tname, schema) + ], ) for no_dflt_schema, (schema, tname) in zip( - metadata_table_names_no_dflt_schema, - metadata_table_names) + metadata_table_names_no_dflt_schema, metadata_table_names + ) ) metadata_table_names = metadata_table_names_no_dflt_schema for s, tname in metadata_table_names.difference(conn_table_names): - name = '%s.%s' % (s, tname) if s else tname + name = "%s.%s" % (s, tname) if s else tname metadata_table = tname_to_table[(s, tname)] if autogen_context.run_filters( - metadata_table, tname, "table", False, None): + metadata_table, tname, "table", False, None + ): upgrade_ops.ops.append( - ops.CreateTableOp.from_table(metadata_table)) + ops.CreateTableOp.from_table(metadata_table) + ) log.info("Detected added table %r", name) modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) comparators.dispatch("table")( - autogen_context, modify_table_ops, - s, tname, None, metadata_table + autogen_context, + modify_table_ops, + s, + tname, + None, + metadata_table, ) if not modify_table_ops.is_empty(): upgrade_ops.ops.append(modify_table_ops) @@ -132,23 +152,22 @@ def _compare_tables(conn_table_names, metadata_table_names, event.listen( t, "column_reflect", - autogen_context.migration_context.impl. - _compat_autogen_column_reflect(inspector)) + autogen_context.migration_context.impl._compat_autogen_column_reflect( + inspector + ), + ) inspector.reflecttable(t, None) if autogen_context.run_filters(t, tname, "table", True, None): modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) comparators.dispatch("table")( - autogen_context, modify_table_ops, - s, tname, t, None + autogen_context, modify_table_ops, s, tname, t, None ) if not modify_table_ops.is_empty(): upgrade_ops.ops.append(modify_table_ops) - upgrade_ops.ops.append( - ops.DropTableOp.from_table(t) - ) + upgrade_ops.ops.append(ops.DropTableOp.from_table(t)) log.info("Detected removed table %r", name) existing_tables = conn_table_names.intersection(metadata_table_names) @@ -163,31 +182,41 @@ def _compare_tables(conn_table_names, metadata_table_names, event.listen( t, "column_reflect", - autogen_context.migration_context.impl. - _compat_autogen_column_reflect(inspector)) + autogen_context.migration_context.impl._compat_autogen_column_reflect( + inspector + ), + ) inspector.reflecttable(t, None) conn_column_info[(s, tname)] = t - for s, tname in sorted(existing_tables, key=lambda x: (x[0] or '', x[1])): + for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])): s = s or None - name = '%s.%s' % (s, tname) if s else tname + name = "%s.%s" % (s, tname) if s else tname metadata_table = tname_to_table[(s, tname)] conn_table = existing_metadata.tables[name] if autogen_context.run_filters( - metadata_table, tname, "table", False, - conn_table): + metadata_table, tname, "table", False, conn_table + ): modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) with _compare_columns( - s, tname, + s, + tname, conn_table, metadata_table, - modify_table_ops, autogen_context, inspector): + modify_table_ops, + autogen_context, + inspector, + ): comparators.dispatch("table")( - autogen_context, modify_table_ops, - s, tname, conn_table, metadata_table + autogen_context, + modify_table_ops, + s, + tname, + conn_table, + metadata_table, ) if not modify_table_ops.is_empty(): @@ -196,41 +225,41 @@ def _compare_tables(conn_table_names, metadata_table_names, def _make_index(params, conn_table): ix = sa_schema.Index( - params['name'], - *[conn_table.c[cname] for cname in params['column_names']], - unique=params['unique'] + params["name"], + *[conn_table.c[cname] for cname in params["column_names"]], + unique=params["unique"] ) - if 'duplicates_constraint' in params: - ix.info['duplicates_constraint'] = params['duplicates_constraint'] + if "duplicates_constraint" in params: + ix.info["duplicates_constraint"] = params["duplicates_constraint"] return ix def _make_unique_constraint(params, conn_table): uq = sa_schema.UniqueConstraint( - *[conn_table.c[cname] for cname in params['column_names']], - name=params['name'] + *[conn_table.c[cname] for cname in params["column_names"]], + name=params["name"] ) - if 'duplicates_index' in params: - uq.info['duplicates_index'] = params['duplicates_index'] + if "duplicates_index" in params: + uq.info["duplicates_index"] = params["duplicates_index"] return uq def _make_foreign_key(params, conn_table): - tname = params['referred_table'] - if params['referred_schema']: - tname = "%s.%s" % (params['referred_schema'], tname) + tname = params["referred_table"] + if params["referred_schema"]: + tname = "%s.%s" % (params["referred_schema"], tname) - options = params.get('options', {}) + options = params.get("options", {}) const = sa_schema.ForeignKeyConstraint( - [conn_table.c[cname] for cname in params['constrained_columns']], - ["%s.%s" % (tname, n) for n in params['referred_columns']], - onupdate=options.get('onupdate'), - ondelete=options.get('ondelete'), - deferrable=options.get('deferrable'), - initially=options.get('initially'), - name=params['name'] + [conn_table.c[cname] for cname in params["constrained_columns"]], + ["%s.%s" % (tname, n) for n in params["referred_columns"]], + onupdate=options.get("onupdate"), + ondelete=options.get("ondelete"), + deferrable=options.get("deferrable"), + initially=options.get("initially"), + name=params["name"], ) # needed by 0.7 conn_table.append_constraint(const) @@ -238,21 +267,30 @@ def _make_foreign_key(params, conn_table): @contextlib.contextmanager -def _compare_columns(schema, tname, conn_table, metadata_table, - modify_table_ops, autogen_context, inspector): - name = '%s.%s' % (schema, tname) if schema else tname +def _compare_columns( + schema, + tname, + conn_table, + metadata_table, + modify_table_ops, + autogen_context, + inspector, +): + name = "%s.%s" % (schema, tname) if schema else tname metadata_cols_by_name = dict( - (c.name, c) for c in metadata_table.c if not c.system) + (c.name, c) for c in metadata_table.c if not c.system + ) conn_col_names = dict((c.name, c) for c in conn_table.c) metadata_col_names = OrderedSet(sorted(metadata_cols_by_name)) for cname in metadata_col_names.difference(conn_col_names): if autogen_context.run_filters( - metadata_cols_by_name[cname], cname, - "column", False, None): + metadata_cols_by_name[cname], cname, "column", False, None + ): modify_table_ops.ops.append( ops.AddColumnOp.from_column_and_tablename( - schema, tname, metadata_cols_by_name[cname]) + schema, tname, metadata_cols_by_name[cname] + ) ) log.info("Detected added column '%s.%s'", name, cname) @@ -260,15 +298,19 @@ def _compare_columns(schema, tname, conn_table, metadata_table, metadata_col = metadata_cols_by_name[colname] conn_col = conn_table.c[colname] if not autogen_context.run_filters( - metadata_col, colname, "column", False, - conn_col): + metadata_col, colname, "column", False, conn_col + ): continue - alter_column_op = ops.AlterColumnOp( - tname, colname, schema=schema) + alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema) comparators.dispatch("column")( - autogen_context, alter_column_op, - schema, tname, colname, conn_col, metadata_col + autogen_context, + alter_column_op, + schema, + tname, + colname, + conn_col, + metadata_col, ) if alter_column_op.has_changes(): @@ -278,8 +320,8 @@ def _compare_columns(schema, tname, conn_table, metadata_table, for cname in set(conn_col_names).difference(metadata_col_names): if autogen_context.run_filters( - conn_table.c[cname], cname, - "column", True, None): + conn_table.c[cname], cname, "column", True, None + ): modify_table_ops.ops.append( ops.DropColumnOp.from_column_and_tablename( schema, tname, conn_table.c[cname] @@ -289,7 +331,6 @@ def _compare_columns(schema, tname, conn_table, metadata_table, class _constraint_sig(object): - def md_name_to_sql_name(self, context): return self.name @@ -340,36 +381,47 @@ class _fk_constraint_sig(_constraint_sig): self.name = const.name ( - self.source_schema, self.source_table, - self.source_columns, self.target_schema, self.target_table, + self.source_schema, + self.source_table, + self.source_columns, + self.target_schema, + self.target_table, self.target_columns, - onupdate, ondelete, - deferrable, initially) = _fk_spec(const) + onupdate, + ondelete, + deferrable, + initially, + ) = _fk_spec(const) self.sig = ( - self.source_schema, self.source_table, tuple(self.source_columns), - self.target_schema, self.target_table, tuple(self.target_columns) + self.source_schema, + self.source_table, + tuple(self.source_columns), + self.target_schema, + self.target_table, + tuple(self.target_columns), ) if include_options: self.sig += ( - (None if onupdate.lower() == 'no action' - else onupdate.lower()) - if onupdate else None, - (None if ondelete.lower() == 'no action' - else ondelete.lower()) - if ondelete else None, + (None if onupdate.lower() == "no action" else onupdate.lower()) + if onupdate + else None, + (None if ondelete.lower() == "no action" else ondelete.lower()) + if ondelete + else None, # convert initially + deferrable into one three-state value "initially_deferrable" if initially and initially.lower() == "deferred" - else "deferrable" if deferrable - else "not deferrable" + else "deferrable" + if deferrable + else "not deferrable", ) @comparators.dispatch_for("table") def _compare_indexes_and_uniques( - autogen_context, modify_ops, schema, tname, conn_table, - metadata_table): + autogen_context, modify_ops, schema, tname, conn_table, metadata_table +): inspector = autogen_context.inspector is_create_table = conn_table is None @@ -378,7 +430,8 @@ def _compare_indexes_and_uniques( # 1a. get raw indexes and unique constraints from metadata ... if metadata_table is not None: metadata_unique_constraints = set( - uq for uq in metadata_table.constraints + uq + for uq in metadata_table.constraints if isinstance(uq, sa_schema.UniqueConstraint) ) metadata_indexes = set(metadata_table.indexes) @@ -397,7 +450,8 @@ def _compare_indexes_and_uniques( if hasattr(inspector, "get_unique_constraints"): try: conn_uniques = inspector.get_unique_constraints( - tname, schema=schema) + tname, schema=schema + ) supports_unique_constraints = True except NotImplementedError: pass @@ -408,7 +462,7 @@ def _compare_indexes_and_uniques( pass else: for uq in conn_uniques: - if uq.get('duplicates_index'): + if uq.get("duplicates_index"): unique_constraints_duplicate_unique_indexes = True try: conn_indexes = inspector.get_indexes(tname, schema=schema) @@ -421,8 +475,10 @@ def _compare_indexes_and_uniques( # for DROP TABLE uniques are inline, don't need them conn_uniques = set() else: - conn_uniques = set(_make_unique_constraint(uq_def, conn_table) - for uq_def in conn_uniques) + conn_uniques = set( + _make_unique_constraint(uq_def, conn_table) + for uq_def in conn_uniques + ) conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes) @@ -431,64 +487,71 @@ def _compare_indexes_and_uniques( if unique_constraints_duplicate_unique_indexes: _correct_for_uq_duplicates_uix( - conn_uniques, conn_indexes, + conn_uniques, + conn_indexes, metadata_unique_constraints, - metadata_indexes + metadata_indexes, ) # 3. give the dialect a chance to omit indexes and constraints that # we know are either added implicitly by the DB or that the DB # can't accurately report on - autogen_context.migration_context.impl.\ - correct_for_autogen_constraints( - conn_uniques, conn_indexes, - metadata_unique_constraints, - metadata_indexes) + autogen_context.migration_context.impl.correct_for_autogen_constraints( + conn_uniques, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ) # 4. organize the constraints into "signature" collections, the # _constraint_sig() objects provide a consistent facade over both # Index and UniqueConstraint so we can easily work with them # interchangeably - metadata_unique_constraints = set(_uq_constraint_sig(uq) - for uq in metadata_unique_constraints - ) + metadata_unique_constraints = set( + _uq_constraint_sig(uq) for uq in metadata_unique_constraints + ) metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes) conn_unique_constraints = set( - _uq_constraint_sig(uq) for uq in conn_uniques) + _uq_constraint_sig(uq) for uq in conn_uniques + ) conn_indexes = set(_ix_constraint_sig(ix) for ix in conn_indexes) # 5. index things by name, for those objects that have names metadata_names = dict( - (c.md_name_to_sql_name(autogen_context), c) for c in - metadata_unique_constraints.union(metadata_indexes) - if c.name is not None) + (c.md_name_to_sql_name(autogen_context), c) + for c in metadata_unique_constraints.union(metadata_indexes) + if c.name is not None + ) conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints) conn_indexes_by_name = dict((c.name, c) for c in conn_indexes) - conn_names = dict((c.name, c) for c in - conn_unique_constraints.union(conn_indexes) - if c.name is not None) + conn_names = dict( + (c.name, c) + for c in conn_unique_constraints.union(conn_indexes) + if c.name is not None + ) doubled_constraints = dict( (name, (conn_uniques_by_name[name], conn_indexes_by_name[name])) - for name in set( - conn_uniques_by_name).intersection(conn_indexes_by_name) + for name in set(conn_uniques_by_name).intersection( + conn_indexes_by_name + ) ) # 6. index things by "column signature", to help with unnamed unique # constraints. conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints) metadata_uniques_by_sig = dict( - (uq.sig, uq) for uq in metadata_unique_constraints) - metadata_indexes_by_sig = dict( - (ix.sig, ix) for ix in metadata_indexes) + (uq.sig, uq) for uq in metadata_unique_constraints + ) + metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes) unnamed_metadata_uniques = dict( - (uq.sig, uq) for uq in - metadata_unique_constraints if uq.name is None) + (uq.sig, uq) for uq in metadata_unique_constraints if uq.name is None + ) # assumptions: # 1. a unique constraint or an index from the connection *always* @@ -501,14 +564,14 @@ def _compare_indexes_and_uniques( def obj_added(obj): if obj.is_index: if autogen_context.run_filters( - obj.const, obj.name, "index", False, None): - modify_ops.ops.append( - ops.CreateIndexOp.from_index(obj.const) + obj.const, obj.name, "index", False, None + ): + modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const)) + log.info( + "Detected added index '%s' on %s", + obj.name, + ", ".join(["'%s'" % obj.column_names]), ) - log.info("Detected added index '%s' on %s", - obj.name, ', '.join([ - "'%s'" % obj.column_names - ])) else: if not supports_unique_constraints: # can't report unique indexes as added if we don't @@ -518,15 +581,16 @@ def _compare_indexes_and_uniques( # unique constraints are created inline with table defs return if autogen_context.run_filters( - obj.const, obj.name, - "unique_constraint", False, None): + obj.const, obj.name, "unique_constraint", False, None + ): modify_ops.ops.append( ops.AddConstraintOp.from_constraint(obj.const) ) - log.info("Detected added unique constraint '%s' on %s", - obj.name, ', '.join([ - "'%s'" % obj.column_names - ])) + log.info( + "Detected added unique constraint '%s' on %s", + obj.name, + ", ".join(["'%s'" % obj.column_names]), + ) def obj_removed(obj): if obj.is_index: @@ -537,48 +601,52 @@ def _compare_indexes_and_uniques( return if autogen_context.run_filters( - obj.const, obj.name, "index", True, None): - modify_ops.ops.append( - ops.DropIndexOp.from_index(obj.const) - ) + obj.const, obj.name, "index", True, None + ): + modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const)) log.info( - "Detected removed index '%s' on '%s'", obj.name, tname) + "Detected removed index '%s' on '%s'", obj.name, tname + ) else: if is_create_table or is_drop_table: # if the whole table is being dropped, we don't need to # consider unique constraint separately return if autogen_context.run_filters( - obj.const, obj.name, - "unique_constraint", True, None): + obj.const, obj.name, "unique_constraint", True, None + ): modify_ops.ops.append( ops.DropConstraintOp.from_constraint(obj.const) ) - log.info("Detected removed unique constraint '%s' on '%s'", - obj.name, tname - ) + log.info( + "Detected removed unique constraint '%s' on '%s'", + obj.name, + tname, + ) def obj_changed(old, new, msg): if old.is_index: if autogen_context.run_filters( - new.const, new.name, "index", - False, old.const): - log.info("Detected changed index '%s' on '%s':%s", - old.name, tname, ', '.join(msg) - ) - modify_ops.ops.append( - ops.DropIndexOp.from_index(old.const) - ) - modify_ops.ops.append( - ops.CreateIndexOp.from_index(new.const) + new.const, new.name, "index", False, old.const + ): + log.info( + "Detected changed index '%s' on '%s':%s", + old.name, + tname, + ", ".join(msg), ) + modify_ops.ops.append(ops.DropIndexOp.from_index(old.const)) + modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const)) else: if autogen_context.run_filters( - new.const, new.name, - "unique_constraint", False, old.const): - log.info("Detected changed unique constraint '%s' on '%s':%s", - old.name, tname, ', '.join(msg) - ) + new.const, new.name, "unique_constraint", False, old.const + ): + log.info( + "Detected changed unique constraint '%s' on '%s':%s", + old.name, + tname, + ", ".join(msg), + ) modify_ops.ops.append( ops.DropConstraintOp.from_constraint(old.const) ) @@ -608,13 +676,14 @@ def _compare_indexes_and_uniques( else: msg = [] if conn_obj.is_unique != metadata_obj.is_unique: - msg.append(' unique=%r to unique=%r' % ( - conn_obj.is_unique, metadata_obj.is_unique - )) + msg.append( + " unique=%r to unique=%r" + % (conn_obj.is_unique, metadata_obj.is_unique) + ) if conn_obj.sig != metadata_obj.sig: - msg.append(' columns %r to %r' % ( - conn_obj.sig, metadata_obj.sig - )) + msg.append( + " columns %r to %r" % (conn_obj.sig, metadata_obj.sig) + ) if msg: obj_changed(conn_obj, metadata_obj, msg) @@ -624,8 +693,10 @@ def _compare_indexes_and_uniques( if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques: continue elif removed_name in doubled_constraints: - if conn_obj.sig not in metadata_indexes_by_sig and \ - conn_obj.sig not in metadata_uniques_by_sig: + if ( + conn_obj.sig not in metadata_indexes_by_sig + and conn_obj.sig not in metadata_uniques_by_sig + ): conn_uq, conn_idx = doubled_constraints[removed_name] obj_removed(conn_uq) obj_removed(conn_idx) @@ -639,40 +710,51 @@ def _compare_indexes_and_uniques( def _correct_for_uq_duplicates_uix( conn_unique_constraints, - conn_indexes, - metadata_unique_constraints, - metadata_indexes): + conn_indexes, + metadata_unique_constraints, + metadata_indexes, +): # dedupe unique indexes vs. constraints, since MySQL / Oracle # doesn't really have unique constraints as a separate construct. # but look in the metadata and try to maintain constructs # that already seem to be defined one way or the other # on that side. This logic was formerly local to MySQL dialect, # generalized to Oracle and others. See #276 - metadata_uq_names = set([ - cons.name for cons in metadata_unique_constraints - if cons.name is not None]) - - unnamed_metadata_uqs = set([ - _uq_constraint_sig(cons).sig - for cons in metadata_unique_constraints - if cons.name is None - ]) - - metadata_ix_names = set([ - cons.name for cons in metadata_indexes if cons.unique]) + metadata_uq_names = set( + [ + cons.name + for cons in metadata_unique_constraints + if cons.name is not None + ] + ) + + unnamed_metadata_uqs = set( + [ + _uq_constraint_sig(cons).sig + for cons in metadata_unique_constraints + if cons.name is None + ] + ) + + metadata_ix_names = set( + [cons.name for cons in metadata_indexes if cons.unique] + ) conn_ix_names = dict( (cons.name, cons) for cons in conn_indexes if cons.unique ) uqs_dupe_indexes = dict( - (cons.name, cons) for cons in conn_unique_constraints - if cons.info['duplicates_index'] + (cons.name, cons) + for cons in conn_unique_constraints + if cons.info["duplicates_index"] ) for overlap in uqs_dupe_indexes: if overlap not in metadata_uq_names: - if _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig \ - not in unnamed_metadata_uqs: + if ( + _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig + not in unnamed_metadata_uqs + ): conn_unique_constraints.discard(uqs_dupe_indexes[overlap]) elif overlap not in metadata_ix_names: @@ -681,8 +763,14 @@ def _correct_for_uq_duplicates_uix( @comparators.dispatch_for("column") def _compare_nullable( - autogen_context, alter_column_op, schema, tname, cname, conn_col, - metadata_col): + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, +): # work around SQLAlchemy issue #3023 if metadata_col.primary_key: @@ -694,57 +782,83 @@ def _compare_nullable( if conn_col_nullable is not metadata_col_nullable: alter_column_op.modify_nullable = metadata_col_nullable - log.info("Detected %s on column '%s.%s'", - "NULL" if metadata_col_nullable else "NOT NULL", - tname, - cname - ) + log.info( + "Detected %s on column '%s.%s'", + "NULL" if metadata_col_nullable else "NOT NULL", + tname, + cname, + ) @comparators.dispatch_for("column") def _setup_autoincrement( - autogen_context, alter_column_op, schema, tname, cname, conn_col, - metadata_col): + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, +): if metadata_col.table._autoincrement_column is metadata_col: - alter_column_op.kw['autoincrement'] = True + alter_column_op.kw["autoincrement"] = True elif util.sqla_110 and metadata_col.autoincrement is True: - alter_column_op.kw['autoincrement'] = True + alter_column_op.kw["autoincrement"] = True elif metadata_col.autoincrement is False: - alter_column_op.kw['autoincrement'] = False + alter_column_op.kw["autoincrement"] = False @comparators.dispatch_for("column") def _compare_type( - autogen_context, alter_column_op, schema, tname, cname, conn_col, - metadata_col): + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, +): conn_type = conn_col.type alter_column_op.existing_type = conn_type metadata_type = metadata_col.type if conn_type._type_affinity is sqltypes.NullType: - log.info("Couldn't determine database type " - "for column '%s.%s'", tname, cname) + log.info( + "Couldn't determine database type " "for column '%s.%s'", + tname, + cname, + ) return if metadata_type._type_affinity is sqltypes.NullType: - log.info("Column '%s.%s' has no type within " - "the model; can't compare", tname, cname) + log.info( + "Column '%s.%s' has no type within " "the model; can't compare", + tname, + cname, + ) return isdiff = autogen_context.migration_context._compare_type( - conn_col, metadata_col) + conn_col, metadata_col + ) if isdiff: alter_column_op.modify_type = metadata_type - log.info("Detected type change from %r to %r on '%s.%s'", - conn_type, metadata_type, tname, cname - ) + log.info( + "Detected type change from %r to %r on '%s.%s'", + conn_type, + metadata_type, + tname, + cname, + ) -def _render_server_default_for_compare(metadata_default, - metadata_col, autogen_context): +def _render_server_default_for_compare( + metadata_default, metadata_col, autogen_context +): rendered = _user_defined_render( - "server_default", metadata_default, autogen_context) + "server_default", metadata_default, autogen_context + ) if rendered is not False: return rendered @@ -752,8 +866,9 @@ def _render_server_default_for_compare(metadata_default, if isinstance(metadata_default.arg, compat.string_types): metadata_default = metadata_default.arg else: - metadata_default = str(metadata_default.arg.compile( - dialect=autogen_context.dialect)) + metadata_default = str( + metadata_default.arg.compile(dialect=autogen_context.dialect) + ) if isinstance(metadata_default, compat.string_types): if metadata_col.type._type_affinity is sqltypes.String: metadata_default = re.sub(r"^'|'$", "", metadata_default) @@ -766,37 +881,49 @@ def _render_server_default_for_compare(metadata_default, @comparators.dispatch_for("column") def _compare_server_default( - autogen_context, alter_column_op, schema, tname, cname, - conn_col, metadata_col): + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, +): metadata_default = metadata_col.server_default conn_col_default = conn_col.server_default if conn_col_default is None and metadata_default is None: return False rendered_metadata_default = _render_server_default_for_compare( - metadata_default, metadata_col, autogen_context) + metadata_default, metadata_col, autogen_context + ) - rendered_conn_default = conn_col.server_default.arg.text \ - if conn_col.server_default else None + rendered_conn_default = ( + conn_col.server_default.arg.text if conn_col.server_default else None + ) alter_column_op.existing_server_default = conn_col_default isdiff = autogen_context.migration_context._compare_server_default( - conn_col, metadata_col, + conn_col, + metadata_col, rendered_metadata_default, - rendered_conn_default + rendered_conn_default, ) if isdiff: alter_column_op.modify_server_default = metadata_default - log.info( - "Detected server default on column '%s.%s'", - tname, cname) + log.info("Detected server default on column '%s.%s'", tname, cname) @comparators.dispatch_for("table") def _compare_foreign_keys( - autogen_context, modify_table_ops, schema, tname, conn_table, - metadata_table): + autogen_context, + modify_table_ops, + schema, + tname, + conn_table, + metadata_table, +): # if we're doing CREATE TABLE, all FKs are created # inline within the table def @@ -805,22 +932,22 @@ def _compare_foreign_keys( inspector = autogen_context.inspector metadata_fks = set( - fk for fk in metadata_table.constraints + fk + for fk in metadata_table.constraints if isinstance(fk, sa_schema.ForeignKeyConstraint) ) conn_fks = inspector.get_foreign_keys(tname, schema=schema) - backend_reflects_fk_options = conn_fks and 'options' in conn_fks[0] + backend_reflects_fk_options = conn_fks and "options" in conn_fks[0] conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks) # give the dialect a chance to correct the FKs to match more # closely - autogen_context.migration_context.impl.\ - correct_for_autogen_foreignkeys( - conn_fks, metadata_fks, - ) + autogen_context.migration_context.impl.correct_for_autogen_foreignkeys( + conn_fks, metadata_fks + ) metadata_fks = set( _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) @@ -832,12 +959,8 @@ def _compare_foreign_keys( for fk in conn_fks ) - conn_fks_by_sig = dict( - (c.sig, c) for c in conn_fks - ) - metadata_fks_by_sig = dict( - (c.sig, c) for c in metadata_fks - ) + conn_fks_by_sig = dict((c.sig, c) for c in conn_fks) + metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks) metadata_fks_by_name = dict( (c.name, c) for c in metadata_fks if c.name is not None @@ -848,8 +971,8 @@ def _compare_foreign_keys( def _add_fk(obj, compare_to): if autogen_context.run_filters( - obj.const, obj.name, "foreign_key_constraint", False, - compare_to): + obj.const, obj.name, "foreign_key_constraint", False, compare_to + ): modify_table_ops.ops.append( ops.CreateForeignKeyOp.from_constraint(const.const) ) @@ -859,12 +982,13 @@ def _compare_foreign_keys( ", ".join(obj.source_columns), ", ".join(obj.target_columns), "%s." % obj.source_schema if obj.source_schema else "", - obj.source_table) + obj.source_table, + ) def _remove_fk(obj, compare_to): if autogen_context.run_filters( - obj.const, obj.name, "foreign_key_constraint", True, - compare_to): + obj.const, obj.name, "foreign_key_constraint", True, compare_to + ): modify_table_ops.ops.append( ops.DropConstraintOp.from_constraint(obj.const) ) @@ -873,7 +997,8 @@ def _compare_foreign_keys( ", ".join(obj.source_columns), ", ".join(obj.target_columns), "%s." % obj.source_schema if obj.source_schema else "", - obj.source_table) + obj.source_table, + ) # so far it appears we don't need to do this by name at all. # SQLite doesn't preserve constraint names anyway @@ -881,13 +1006,19 @@ def _compare_foreign_keys( for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig): const = conn_fks_by_sig[removed_sig] if removed_sig not in metadata_fks_by_sig: - compare_to = metadata_fks_by_name[const.name].const \ - if const.name in metadata_fks_by_name else None + compare_to = ( + metadata_fks_by_name[const.name].const + if const.name in metadata_fks_by_name + else None + ) _remove_fk(const, compare_to) for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig): const = metadata_fks_by_sig[added_sig] if added_sig not in conn_fks_by_sig: - compare_to = conn_fks_by_name[const.name].const \ - if const.name in conn_fks_by_name else None + compare_to = ( + conn_fks_by_name[const.name].const + if const.name in conn_fks_by_name + else None + ) _add_fk(const, compare_to) diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 4fbe91f..573ee02 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -19,29 +19,35 @@ try: return _f_name(_alembic_autogenerate_prefix(autogen_context), name) else: return name + + except ImportError: + def _render_gen_name(autogen_context, name): return name def _indent(text): - text = re.compile(r'^', re.M).sub(" ", text).strip() - text = re.compile(r' +$', re.M).sub("", text) + text = re.compile(r"^", re.M).sub(" ", text).strip() + text = re.compile(r" +$", re.M).sub("", text) return text def _render_python_into_templatevars( - autogen_context, migration_script, template_args): + autogen_context, migration_script, template_args +): imports = autogen_context.imports for upgrade_ops, downgrade_ops in zip( - migration_script.upgrade_ops_list, - migration_script.downgrade_ops_list): + migration_script.upgrade_ops_list, migration_script.downgrade_ops_list + ): template_args[upgrade_ops.upgrade_token] = _indent( - _render_cmd_body(upgrade_ops, autogen_context)) + _render_cmd_body(upgrade_ops, autogen_context) + ) template_args[downgrade_ops.downgrade_token] = _indent( - _render_cmd_body(downgrade_ops, autogen_context)) - template_args['imports'] = "\n".join(sorted(imports)) + _render_cmd_body(downgrade_ops, autogen_context) + ) + template_args["imports"] = "\n".join(sorted(imports)) default_renderers = renderers = util.Dispatcher() @@ -83,7 +89,7 @@ def render_op_text(autogen_context, op): @renderers.dispatch_for(ops.ModifyTableOps) def _render_modify_table(autogen_context, op): opts = autogen_context.opts - render_as_batch = opts.get('render_as_batch', False) + render_as_batch = opts.get("render_as_batch", False) if op.ops: lines = [] @@ -104,33 +110,39 @@ def _render_modify_table(autogen_context, op): return lines else: - return [ - "pass" - ] + return ["pass"] @renderers.dispatch_for(ops.CreateTableOp) def _add_table(autogen_context, op): table = op.to_table() - args = [col for col in - [_render_column(col, autogen_context) for col in table.columns] - if col] + \ - sorted([rcons for rcons in - [_render_constraint(cons, autogen_context) for cons in - table.constraints] - if rcons is not None - ]) + args = [ + col + for col in [ + _render_column(col, autogen_context) for col in table.columns + ] + if col + ] + sorted( + [ + rcons + for rcons in [ + _render_constraint(cons, autogen_context) + for cons in table.constraints + ] + if rcons is not None + ] + ) if len(args) > MAX_PYTHON_ARGS: - args = '*[' + ',\n'.join(args) + ']' + args = "*[" + ",\n".join(args) + "]" else: - args = ',\n'.join(args) + args = ",\n".join(args) text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % { - 'tablename': _ident(op.table_name), - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'args': args, + "tablename": _ident(op.table_name), + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": args, } if op.schema: text += ",\nschema=%r" % _ident(op.schema) @@ -144,7 +156,7 @@ def _add_table(autogen_context, op): def _drop_table(autogen_context, op): text = "%(prefix)sdrop_table(%(tname)r" % { "prefix": _alembic_autogenerate_prefix(autogen_context), - "tname": _ident(op.table_name) + "tname": _ident(op.table_name), } if op.schema: text += ", schema=%r" % _ident(op.schema) @@ -159,28 +171,39 @@ def _add_index(autogen_context, op): has_batch = autogen_context._has_batch if has_batch: - tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\ + tmpl = ( + "%(prefix)screate_index(%(name)r, [%(columns)s], " "unique=%(unique)r%(kwargs)s)" + ) else: - tmpl = "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], "\ + tmpl = ( + "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], " "unique=%(unique)r%(schema)s%(kwargs)s)" + ) text = tmpl % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'name': _render_gen_name(autogen_context, index.name), - 'table': _ident(index.table.name), - 'columns': ", ".join( - _get_index_rendered_expressions(index, autogen_context)), - 'unique': index.unique or False, - 'schema': (", schema=%r" % _ident(index.table.schema)) - if index.table.schema else '', - 'kwargs': ( - ', ' + - ', '.join( - ["%s=%s" % - (key, _render_potential_expr(val, autogen_context)) - for key, val in index.kwargs.items()])) - if len(index.kwargs) else '' + "prefix": _alembic_autogenerate_prefix(autogen_context), + "name": _render_gen_name(autogen_context, index.name), + "table": _ident(index.table.name), + "columns": ", ".join( + _get_index_rendered_expressions(index, autogen_context) + ), + "unique": index.unique or False, + "schema": (", schema=%r" % _ident(index.table.schema)) + if index.table.schema + else "", + "kwargs": ( + ", " + + ", ".join( + [ + "%s=%s" + % (key, _render_potential_expr(val, autogen_context)) + for key, val in index.kwargs.items() + ] + ) + ) + if len(index.kwargs) + else "", } return text @@ -192,15 +215,16 @@ def _drop_index(autogen_context, op): if has_batch: tmpl = "%(prefix)sdrop_index(%(name)r)" else: - tmpl = "%(prefix)sdrop_index(%(name)r, "\ + tmpl = ( + "%(prefix)sdrop_index(%(name)r, " "table_name=%(table_name)r%(schema)s)" + ) text = tmpl % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'name': _render_gen_name(autogen_context, op.index_name), - 'table_name': _ident(op.table_name), - 'schema': ((", schema=%r" % _ident(op.schema)) - if op.schema else '') + "prefix": _alembic_autogenerate_prefix(autogen_context), + "name": _render_gen_name(autogen_context, op.index_name), + "table_name": _ident(op.table_name), + "schema": ((", schema=%r" % _ident(op.schema)) if op.schema else ""), } return text @@ -213,30 +237,28 @@ def _add_unique_constraint(autogen_context, op): @renderers.dispatch_for(ops.CreateForeignKeyOp) def _add_fk_constraint(autogen_context, op): - args = [ - repr( - _render_gen_name(autogen_context, op.constraint_name)), - ] + args = [repr(_render_gen_name(autogen_context, op.constraint_name))] if not autogen_context._has_batch: - args.append( - repr(_ident(op.source_table)) - ) + args.append(repr(_ident(op.source_table))) args.extend( [ repr(_ident(op.referent_table)), repr([_ident(col) for col in op.local_cols]), - repr([_ident(col) for col in op.remote_cols]) + repr([_ident(col) for col in op.remote_cols]), ] ) kwargs = [ - 'referent_schema', - 'onupdate', 'ondelete', 'initially', - 'deferrable', 'use_alter' + "referent_schema", + "onupdate", + "ondelete", + "initially", + "deferrable", + "use_alter", ] if not autogen_context._has_batch: - kwargs.insert(0, 'source_schema') + kwargs.insert(0, "source_schema") for k in kwargs: if k in op.kw: @@ -245,8 +267,8 @@ def _add_fk_constraint(autogen_context, op): args.append("%s=%r" % (k, value)) return "%(prefix)screate_foreign_key(%(args)s)" % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'args': ", ".join(args) + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), } @@ -264,20 +286,19 @@ def _add_check_constraint(constraint, autogen_context): def _drop_constraint(autogen_context, op): if autogen_context._has_batch: - template = "%(prefix)sdrop_constraint"\ - "(%(name)r, type_=%(type)r)" + template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)" else: - template = "%(prefix)sdrop_constraint"\ + template = ( + "%(prefix)sdrop_constraint" "(%(name)r, '%(table_name)s'%(schema)s, type_=%(type)r)" + ) text = template % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'name': _render_gen_name( - autogen_context, op.constraint_name), - 'table_name': _ident(op.table_name), - 'type': op.constraint_type, - 'schema': (", schema=%r" % _ident(op.schema)) - if op.schema else '', + "prefix": _alembic_autogenerate_prefix(autogen_context), + "name": _render_gen_name(autogen_context, op.constraint_name), + "table_name": _ident(op.table_name), + "type": op.constraint_type, + "schema": (", schema=%r" % _ident(op.schema)) if op.schema else "", } return text @@ -297,7 +318,7 @@ def _add_column(autogen_context, op): "prefix": _alembic_autogenerate_prefix(autogen_context), "tname": tname, "column": _render_column(column, autogen_context), - "schema": schema + "schema": schema, } return text @@ -319,7 +340,7 @@ def _drop_column(autogen_context, op): "prefix": _alembic_autogenerate_prefix(autogen_context), "tname": _ident(tname), "cname": _ident(column_name), - "schema": _ident(schema) + "schema": _ident(schema), } return text @@ -332,7 +353,7 @@ def _alter_column(autogen_context, op): server_default = op.modify_server_default type_ = op.modify_type nullable = op.modify_nullable - autoincrement = op.kw.get('autoincrement', None) + autoincrement = op.kw.get("autoincrement", None) existing_type = op.existing_type existing_nullable = op.existing_nullable existing_server_default = op.existing_server_default @@ -346,37 +367,32 @@ def _alter_column(autogen_context, op): template = "%(prefix)salter_column(%(tname)r, %(cname)r" text = template % { - 'prefix': _alembic_autogenerate_prefix( - autogen_context), - 'tname': tname, - 'cname': cname} + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": tname, + "cname": cname, + } if existing_type is not None: text += ",\n%sexisting_type=%s" % ( indent, - _repr_type(existing_type, autogen_context)) + _repr_type(existing_type, autogen_context), + ) if server_default is not False: - rendered = _render_server_default( - server_default, autogen_context) + rendered = _render_server_default(server_default, autogen_context) text += ",\n%sserver_default=%s" % (indent, rendered) if type_ is not None: - text += ",\n%stype_=%s" % (indent, - _repr_type(type_, autogen_context)) + text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context)) if nullable is not None: - text += ",\n%snullable=%r" % ( - indent, nullable,) + text += ",\n%snullable=%r" % (indent, nullable) if nullable is None and existing_nullable is not None: - text += ",\n%sexisting_nullable=%r" % ( - indent, existing_nullable) + text += ",\n%sexisting_nullable=%r" % (indent, existing_nullable) if autoincrement is not None: - text += ",\n%sautoincrement=%r" % ( - indent, autoincrement) + text += ",\n%sautoincrement=%r" % (indent, autoincrement) if server_default is False and existing_server_default: rendered = _render_server_default( - existing_server_default, - autogen_context) - text += ",\n%sexisting_server_default=%s" % ( - indent, rendered) + existing_server_default, autogen_context + ) + text += ",\n%sexisting_server_default=%s" % (indent, rendered) if schema and not autogen_context._has_batch: text += ",\n%sschema=%r" % (indent, schema) text += ")" @@ -384,7 +400,6 @@ def _alter_column(autogen_context, op): class _f_name(object): - def __init__(self, prefix, name): self.prefix = prefix self.name = name @@ -410,7 +425,7 @@ def _ident(name): # u'' literals only when py2k + SQLA 0.9, in particular # makes unit tests testing code generation very difficult try: - return name.encode('ascii') + return name.encode("ascii") except UnicodeError: return compat.text_type(name) else: @@ -421,8 +436,9 @@ def _ident(name): def _render_potential_expr(value, autogen_context, wrap_in_text=True): if isinstance(value, sql.ClauseElement): - compile_kw = dict(compile_kwargs={ - 'literal_binds': True, "include_table": False}) + compile_kw = dict( + compile_kwargs={"literal_binds": True, "include_table": False} + ) if wrap_in_text: template = "%(prefix)stext(%(sql)r)" @@ -432,9 +448,8 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True): return template % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "sql": compat.text_type( - value.compile(dialect=autogen_context.dialect, - **compile_kw) - ) + value.compile(dialect=autogen_context.dialect, **compile_kw) + ), } else: @@ -442,10 +457,12 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True): def _get_index_rendered_expressions(idx, autogen_context): - return [repr(_ident(getattr(exp, "name", None))) - if isinstance(exp, sa_schema.Column) - else _render_potential_expr(exp, autogen_context) - for exp in idx.expressions] + return [ + repr(_ident(getattr(exp, "name", None))) + if isinstance(exp, sa_schema.Column) + else _render_potential_expr(exp, autogen_context) + for exp in idx.expressions + ] def _uq_constraint(constraint, autogen_context, alter): @@ -461,32 +478,30 @@ def _uq_constraint(constraint, autogen_context, alter): opts.append(("schema", _ident(constraint.table.schema))) if not alter and constraint.name: opts.append( - ("name", - _render_gen_name(autogen_context, constraint.name))) + ("name", _render_gen_name(autogen_context, constraint.name)) + ) if alter: - args = [ - repr(_render_gen_name( - autogen_context, constraint.name))] + args = [repr(_render_gen_name(autogen_context, constraint.name))] if not has_batch: args += [repr(_ident(constraint.table.name))] args.append(repr([_ident(col.name) for col in constraint.columns])) args.extend(["%s=%r" % (k, v) for k, v in opts]) return "%(prefix)screate_unique_constraint(%(args)s)" % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'args': ", ".join(args) + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), } else: args = [repr(_ident(col.name)) for col in constraint.columns] args.extend(["%s=%r" % (k, v) for k, v in opts]) return "%(prefix)sUniqueConstraint(%(args)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - "args": ", ".join(args) + "args": ", ".join(args), } def _user_autogenerate_prefix(autogen_context, target): - prefix = autogen_context.opts['user_module_prefix'] + prefix = autogen_context.opts["user_module_prefix"] if prefix is None: return "%s." % target.__module__ else: @@ -494,19 +509,19 @@ def _user_autogenerate_prefix(autogen_context, target): def _sqlalchemy_autogenerate_prefix(autogen_context): - return autogen_context.opts['sqlalchemy_module_prefix'] or '' + return autogen_context.opts["sqlalchemy_module_prefix"] or "" def _alembic_autogenerate_prefix(autogen_context): if autogen_context._has_batch: - return 'batch_op.' + return "batch_op." else: - return autogen_context.opts['alembic_module_prefix'] or '' + return autogen_context.opts["alembic_module_prefix"] or "" def _user_defined_render(type_, object_, autogen_context): - if 'render_item' in autogen_context.opts: - render = autogen_context.opts['render_item'] + if "render_item" in autogen_context.opts: + render = autogen_context.opts["render_item"] if render: rendered = render(type_, object_, autogen_context) if rendered is not False: @@ -527,8 +542,10 @@ def _render_column(column, autogen_context): if rendered: opts.append(("server_default", rendered)) - if column.autoincrement is not None and \ - column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT: + if ( + column.autoincrement is not None + and column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT + ): opts.append(("autoincrement", column.autoincrement)) if column.nullable is not None: @@ -539,10 +556,10 @@ def _render_column(column, autogen_context): # TODO: for non-ascii colname, assign a "key" return "%(prefix)sColumn(%(name)r, %(type)s, %(kw)s)" % { - 'prefix': _sqlalchemy_autogenerate_prefix(autogen_context), - 'name': _ident(column.name), - 'type': _repr_type(column.type, autogen_context), - 'kw': ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]) + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "name": _ident(column.name), + "type": _repr_type(column.type, autogen_context), + "kw": ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]), } @@ -568,9 +585,10 @@ def _repr_type(type_, autogen_context): if rendered is not False: return rendered - if hasattr(autogen_context.migration_context, 'impl'): + if hasattr(autogen_context.migration_context, "impl"): impl_rt = autogen_context.migration_context.impl.render_type( - type_, autogen_context) + type_, autogen_context + ) else: impl_rt = None @@ -587,8 +605,8 @@ def _repr_type(type_, autogen_context): elif impl_rt: return impl_rt elif mod.startswith("sqlalchemy."): - if '_render_%s_type' % type_.__visit_name__ in globals(): - fn = globals()['_render_%s_type' % type_.__visit_name__] + if "_render_%s_type" % type_.__visit_name__ in globals(): + fn = globals()["_render_%s_type" % type_.__visit_name__] return fn(type_, autogen_context) else: prefix = _sqlalchemy_autogenerate_prefix(autogen_context) @@ -600,12 +618,13 @@ def _repr_type(type_, autogen_context): def _render_ARRAY_type(type_, autogen_context): return _render_type_w_subtype( - type_, autogen_context, 'item_type', r'(.+?\()' + type_, autogen_context, "item_type", r"(.+?\()" ) def _render_type_w_subtype( - type_, autogen_context, attrname, regexp, prefix=None): + type_, autogen_context, attrname, regexp, prefix=None +): outer_repr = repr(type_) inner_type = getattr(type_, attrname, None) if inner_type is None: @@ -613,11 +632,9 @@ def _render_type_w_subtype( inner_repr = repr(inner_type) - inner_repr = re.sub(r'([\(\)])', r'\\\1', inner_repr) + inner_repr = re.sub(r"([\(\)])", r"\\\1", inner_repr) sub_type = _repr_type(getattr(type_, attrname), autogen_context) - outer_type = re.sub( - regexp + inner_repr, - r"\1%s" % sub_type, outer_repr) + outer_type = re.sub(regexp + inner_repr, r"\1%s" % sub_type, outer_repr) if prefix: return "%s%s" % (prefix, outer_type) @@ -632,6 +649,7 @@ def _render_type_w_subtype( else: return None + _constraint_renderers = util.Dispatcher() @@ -656,13 +674,14 @@ def _render_primary_key(constraint, autogen_context): opts = [] if constraint.name: - opts.append(("name", repr( - _render_gen_name(autogen_context, constraint.name)))) + opts.append( + ("name", repr(_render_gen_name(autogen_context, constraint.name))) + ) return "%(prefix)sPrimaryKeyConstraint(%(args)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "args": ", ".join( - [repr(c.name) for c in constraint.columns] + - ["%s=%s" % (kwname, val) for kwname, val in opts] + [repr(c.name) for c in constraint.columns] + + ["%s=%s" % (kwname, val) for kwname, val in opts] ), } @@ -681,8 +700,11 @@ def _fk_colspec(fk, metadata_schema): else: table_fullname = ".".join(tokens[0:-1]) - if not fk.link_to_name and \ - fk.parent is not None and fk.parent.table is not None: + if ( + not fk.link_to_name + and fk.parent is not None + and fk.parent.table is not None + ): # try to resolve the remote table in order to adjust for column.key. # the FK constraint needs to be rendered in terms of the column # name. @@ -719,23 +741,30 @@ def _render_foreign_key(constraint, autogen_context): opts = [] if constraint.name: - opts.append(("name", repr( - _render_gen_name(autogen_context, constraint.name)))) + opts.append( + ("name", repr(_render_gen_name(autogen_context, constraint.name))) + ) _populate_render_fk_opts(constraint, opts) apply_metadata_schema = constraint.parent.metadata.schema - return "%(prefix)sForeignKeyConstraint([%(cols)s], "\ - "[%(refcols)s], %(args)s)" % { + return ( + "%(prefix)sForeignKeyConstraint([%(cols)s], " + "[%(refcols)s], %(args)s)" + % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "cols": ", ".join( - "%r" % _ident(f.parent.name) for f in constraint.elements), - "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema)) - for f in constraint.elements), + "%r" % _ident(f.parent.name) for f in constraint.elements + ), + "refcols": ", ".join( + repr(_fk_colspec(f, apply_metadata_schema)) + for f in constraint.elements + ), "args": ", ".join( - ["%s=%s" % (kwname, val) for kwname, val in opts] + ["%s=%s" % (kwname, val) for kwname, val in opts] ), } + ) @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint) @@ -757,27 +786,25 @@ def _render_check_constraint(constraint, autogen_context): # a parent type which is probably in the Table already. # ideally SQLAlchemy would give us more of a first class # way to detect this. - if constraint._create_rule and \ - hasattr(constraint._create_rule, 'target') and \ - isinstance(constraint._create_rule.target, - sqltypes.TypeEngine): + if ( + constraint._create_rule + and hasattr(constraint._create_rule, "target") + and isinstance(constraint._create_rule.target, sqltypes.TypeEngine) + ): return None opts = [] if constraint.name: opts.append( - ( - "name", - repr( - _render_gen_name( - autogen_context, constraint.name)) - ) + ("name", repr(_render_gen_name(autogen_context, constraint.name))) ) return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - "opts": ", " + (", ".join("%s=%s" % (k, v) - for k, v in opts)) if opts else "", + "opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts)) + if opts + else "", "sqltext": _render_potential_expr( - constraint.sqltext, autogen_context, wrap_in_text=False) + constraint.sqltext, autogen_context, wrap_in_text=False + ), } @@ -788,7 +815,7 @@ def _execute_sql(autogen_context, op): "Autogenerate rendering of SQL Expression language constructs " "not supported here; please use a plain SQL string" ) - return 'op.execute(%r)' % op.sqltext + return "op.execute(%r)" % op.sqltext renderers = default_renderers.branch() diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 941bd4b..1e9522b 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -95,7 +95,8 @@ class Rewriter(object): yield directive else: for r_directive in util.to_list( - _rewriter(context, revision, directive)): + _rewriter(context, revision, directive) + ): yield r_directive def __call__(self, context, revision, directives): @@ -110,17 +111,20 @@ class Rewriter(object): ret = self._traverse_for(context, revision, directive.upgrade_ops) if len(ret) != 1: raise ValueError( - "Can only return single object for UpgradeOps traverse") + "Can only return single object for UpgradeOps traverse" + ) upgrade_ops_list.append(ret[0]) directive.upgrade_ops = upgrade_ops_list downgrade_ops_list = [] for downgrade_ops in directive.downgrade_ops_list: ret = self._traverse_for( - context, revision, directive.downgrade_ops) + context, revision, directive.downgrade_ops + ) if len(ret) != 1: raise ValueError( - "Can only return single object for DowngradeOps traverse") + "Can only return single object for DowngradeOps traverse" + ) downgrade_ops_list.append(ret[0]) directive.downgrade_ops = downgrade_ops_list diff --git a/alembic/command.py b/alembic/command.py index cd61fd1..20027b4 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -15,10 +15,9 @@ def list_templates(config): config.print_stdout("Available templates:\n") for tempname in os.listdir(config.get_template_directory()): - with open(os.path.join( - config.get_template_directory(), - tempname, - 'README')) as readme: + with open( + os.path.join(config.get_template_directory(), tempname, "README") + ) as readme: synopsis = next(readme) config.print_stdout("%s - %s", tempname, synopsis) @@ -26,7 +25,7 @@ def list_templates(config): config.print_stdout("\n alembic init --template generic ./scripts") -def init(config, directory, template='generic'): +def init(config, directory, template="generic"): """Initialize a new scripts directory. :param config: a :class:`.Config` object. @@ -41,48 +40,58 @@ def init(config, directory, template='generic'): if os.access(directory, os.F_OK): raise util.CommandError("Directory %s already exists" % directory) - template_dir = os.path.join(config.get_template_directory(), - template) + template_dir = os.path.join(config.get_template_directory(), template) if not os.access(template_dir, os.F_OK): raise util.CommandError("No such template %r" % template) - util.status("Creating directory %s" % os.path.abspath(directory), - os.makedirs, directory) + util.status( + "Creating directory %s" % os.path.abspath(directory), + os.makedirs, + directory, + ) - versions = os.path.join(directory, 'versions') - util.status("Creating directory %s" % os.path.abspath(versions), - os.makedirs, versions) + versions = os.path.join(directory, "versions") + util.status( + "Creating directory %s" % os.path.abspath(versions), + os.makedirs, + versions, + ) script = ScriptDirectory(directory) for file_ in os.listdir(template_dir): file_path = os.path.join(template_dir, file_) - if file_ == 'alembic.ini.mako': + if file_ == "alembic.ini.mako": config_file = os.path.abspath(config.config_file_name) if os.access(config_file, os.F_OK): util.msg("File %s already exists, skipping" % config_file) else: script._generate_template( - file_path, - config_file, - script_location=directory + file_path, config_file, script_location=directory ) elif os.path.isfile(file_path): output_file = os.path.join(directory, file_) - script._copy_file( - file_path, - output_file - ) + script._copy_file(file_path, output_file) - util.msg("Please edit configuration/connection/logging " - "settings in %r before proceeding." % config_file) + util.msg( + "Please edit configuration/connection/logging " + "settings in %r before proceeding." % config_file + ) def revision( - config, message=None, autogenerate=False, sql=False, - head="head", splice=False, branch_label=None, - version_path=None, rev_id=None, depends_on=None, - process_revision_directives=None): + config, + message=None, + autogenerate=False, + sql=False, + head="head", + splice=False, + branch_label=None, + version_path=None, + rev_id=None, + depends_on=None, + process_revision_directives=None, +): """Create a new revision file. :param config: a :class:`.Config` object. @@ -134,35 +143,46 @@ def revision( command_args = dict( message=message, autogenerate=autogenerate, - sql=sql, head=head, splice=splice, branch_label=branch_label, - version_path=version_path, rev_id=rev_id, depends_on=depends_on + sql=sql, + head=head, + splice=splice, + branch_label=branch_label, + version_path=version_path, + rev_id=rev_id, + depends_on=depends_on, ) revision_context = autogen.RevisionContext( - config, script_directory, command_args, - process_revision_directives=process_revision_directives) - - environment = util.asbool( - config.get_main_option("revision_environment") + config, + script_directory, + command_args, + process_revision_directives=process_revision_directives, ) + environment = util.asbool(config.get_main_option("revision_environment")) + if autogenerate: environment = True if sql: raise util.CommandError( - "Using --sql with --autogenerate does not make any sense") + "Using --sql with --autogenerate does not make any sense" + ) def retrieve_migrations(rev, context): revision_context.run_autogenerate(rev, context) return [] + elif environment: + def retrieve_migrations(rev, context): revision_context.run_no_autogenerate(rev, context) return [] + elif sql: raise util.CommandError( "Using --sql with the revision command when " - "revision_environment is not configured does not make any sense") + "revision_environment is not configured does not make any sense" + ) if environment: with EnvironmentContext( @@ -171,14 +191,11 @@ def revision( fn=retrieve_migrations, as_sql=sql, template_args=revision_context.template_args, - revision_context=revision_context + revision_context=revision_context, ): script_directory.run_env() - scripts = [ - script for script in - revision_context.generate_scripts() - ] + scripts = [script for script in revision_context.generate_scripts()] if len(scripts) == 1: return scripts[0] else: @@ -207,13 +224,17 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None): script = ScriptDirectory.from_config(config) template_args = { - 'config': config # Let templates use config for - # e.g. multiple databases + "config": config # Let templates use config for + # e.g. multiple databases } return script.generate_revision( - rev_id or util.rev_id(), message, refresh=True, - head=revisions, branch_labels=branch_label, - **template_args) + rev_id or util.rev_id(), + message, + refresh=True, + head=revisions, + branch_labels=branch_label, + **template_args + ) def upgrade(config, revision, sql=False, tag=None): @@ -237,7 +258,7 @@ def upgrade(config, revision, sql=False, tag=None): if ":" in revision: if not sql: raise util.CommandError("Range revision not allowed") - starting_rev, revision = revision.split(':', 2) + starting_rev, revision = revision.split(":", 2) def upgrade(rev, context): return script._upgrade_revs(revision, rev) @@ -249,7 +270,7 @@ def upgrade(config, revision, sql=False, tag=None): as_sql=sql, starting_rev=starting_rev, destination_rev=revision, - tag=tag + tag=tag, ): script.run_env() @@ -274,10 +295,11 @@ def downgrade(config, revision, sql=False, tag=None): if ":" in revision: if not sql: raise util.CommandError("Range revision not allowed") - starting_rev, revision = revision.split(':', 2) + starting_rev, revision = revision.split(":", 2) elif sql: raise util.CommandError( - "downgrade with --sql requires <fromrev>:<torev>") + "downgrade with --sql requires <fromrev>:<torev>" + ) def downgrade(rev, context): return script._downgrade_revs(revision, rev) @@ -289,7 +311,7 @@ def downgrade(config, revision, sql=False, tag=None): as_sql=sql, starting_rev=starting_rev, destination_rev=revision, - tag=tag + tag=tag, ): script.run_env() @@ -306,15 +328,13 @@ def show(config, rev): script = ScriptDirectory.from_config(config) if rev == "current": + def show_current(rev, context): for sc in script.get_revisions(rev): config.print_stdout(sc.log_entry) return [] - with EnvironmentContext( - config, - script, - fn=show_current - ): + + with EnvironmentContext(config, script, fn=show_current): script.run_env() else: for sc in script.get_revisions(rev): @@ -340,44 +360,45 @@ def history(config, rev_range=None, verbose=False, indicate_current=False): if rev_range is not None: if ":" not in rev_range: raise util.CommandError( - "History range requires [start]:[end], " - "[start]:, or :[end]") + "History range requires [start]:[end], " "[start]:, or :[end]" + ) base, head = rev_range.strip().split(":") else: base = head = None - environment = util.asbool( - config.get_main_option("revision_environment") - ) or indicate_current + environment = ( + util.asbool(config.get_main_option("revision_environment")) + or indicate_current + ) def _display_history(config, script, base, head, currents=()): for sc in script.walk_revisions( - base=base or "base", - head=head or "heads"): + base=base or "base", head=head or "heads" + ): if indicate_current: sc._db_current_indicator = sc.revision in currents config.print_stdout( sc.cmd_format( - verbose=verbose, include_branches=True, - include_doc=True, include_parents=True)) + verbose=verbose, + include_branches=True, + include_doc=True, + include_parents=True, + ) + ) def _display_history_w_current(config, script, base, head): def _display_current_history(rev, context): - if head == 'current': + if head == "current": _display_history(config, script, base, rev, rev) - elif base == 'current': + elif base == "current": _display_history(config, script, rev, head, rev) else: _display_history(config, script, base, head, rev) return [] - with EnvironmentContext( - config, - script, - fn=_display_current_history - ): + with EnvironmentContext(config, script, fn=_display_current_history): script.run_env() if base == "current" or head == "current" or environment: @@ -406,7 +427,9 @@ def heads(config, verbose=False, resolve_dependencies=False): for rev in heads: config.print_stdout( rev.cmd_format( - verbose, include_branches=True, tree_indicators=False)) + verbose, include_branches=True, tree_indicators=False + ) + ) def branches(config, verbose=False): @@ -424,13 +447,17 @@ def branches(config, verbose=False): "%s\n%s\n", sc.cmd_format(verbose, include_branches=True), "\n".join( - "%s -> %s" % ( + "%s -> %s" + % ( " " * len(str(sc.revision)), rev_obj.cmd_format( - False, include_branches=True, include_doc=verbose) - ) for rev_obj in - (script.get_revision(rev) for rev in sc.nextrev) - ) + False, include_branches=True, include_doc=verbose + ), + ) + for rev_obj in ( + script.get_revision(rev) for rev in sc.nextrev + ) + ), ) @@ -454,18 +481,14 @@ def current(config, verbose=False, head_only=False): if verbose: config.print_stdout( "Current revision(s) for %s:", - util.obfuscate_url_pw(context.connection.engine.url) + util.obfuscate_url_pw(context.connection.engine.url), ) for rev in script.get_all_current(rev): config.print_stdout(rev.cmd_format(verbose)) return [] - with EnvironmentContext( - config, - script, - fn=display_version - ): + with EnvironmentContext(config, script, fn=display_version): script.run_env() @@ -491,7 +514,7 @@ def stamp(config, revision, sql=False, tag=None): if ":" in revision: if not sql: raise util.CommandError("Range revision not allowed") - starting_rev, revision = revision.split(':', 2) + starting_rev, revision = revision.split(":", 2) def do_stamp(rev, context): return script._stamp_revs(revision, rev) @@ -503,7 +526,7 @@ def stamp(config, revision, sql=False, tag=None): as_sql=sql, destination_rev=revision, starting_rev=starting_rev, - tag=tag + tag=tag, ): script.run_env() @@ -520,23 +543,21 @@ def edit(config, rev): script = ScriptDirectory.from_config(config) if rev == "current": + def edit_current(rev, context): if not rev: raise util.CommandError("No current revisions") for sc in script.get_revisions(rev): util.edit(sc.path) return [] - with EnvironmentContext( - config, - script, - fn=edit_current - ): + + with EnvironmentContext(config, script, fn=edit_current): script.run_env() else: revs = script.get_revisions(rev) if not revs: raise util.CommandError( - "No revision files indicated by symbol '%s'" % rev) + "No revision files indicated by symbol '%s'" % rev + ) for sc in revs: util.edit(sc.path) - diff --git a/alembic/config.py b/alembic/config.py index 5856099..915091c 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -90,9 +90,16 @@ class Config(object): """ - def __init__(self, file_=None, ini_section='alembic', output_buffer=None, - stdout=sys.stdout, cmd_opts=None, - config_args=util.immutabledict(), attributes=None): + def __init__( + self, + file_=None, + ini_section="alembic", + output_buffer=None, + stdout=sys.stdout, + cmd_opts=None, + config_args=util.immutabledict(), + attributes=None, + ): """Construct a new :class:`.Config` """ @@ -167,15 +174,11 @@ class Config(object): """ if arg: - output = (compat.text_type(text) % arg) + output = compat.text_type(text) % arg else: output = compat.text_type(text) - util.write_outstream( - self.stdout, - output, - "\n" - ) + util.write_outstream(self.stdout, output, "\n") @util.memoized_property def file_config(self): @@ -192,7 +195,7 @@ class Config(object): here = os.path.abspath(os.path.dirname(self.config_file_name)) else: here = "" - self.config_args['here'] = here + self.config_args["here"] = here file_config = SafeConfigParser(self.config_args) if self.config_file_name: file_config.read([self.config_file_name]) @@ -207,7 +210,7 @@ class Config(object): commands. """ - return os.path.join(package_dir, 'templates') + return os.path.join(package_dir, "templates") def get_section(self, name): """Return all the configuration options from a given .ini file section @@ -265,9 +268,10 @@ class Config(object): """ if not self.file_config.has_section(section): - raise util.CommandError("No config file %r found, or file has no " - "'[%s]' section" % - (self.config_file_name, section)) + raise util.CommandError( + "No config file %r found, or file has no " + "'[%s]' section" % (self.config_file_name, section) + ) if self.file_config.has_option(section, name): return self.file_config.get(section, name) else: @@ -285,140 +289,144 @@ class Config(object): class CommandLine(object): - def __init__(self, prog=None): self._generate_args(prog) def _generate_args(self, prog): def add_options(parser, positional, kwargs): kwargs_opts = { - 'template': ( - "-t", "--template", + "template": ( + "-t", + "--template", dict( - default='generic', + default="generic", type=str, - help="Setup template for use with 'init'" - ) + help="Setup template for use with 'init'", + ), ), - 'message': ( - "-m", "--message", + "message": ( + "-m", + "--message", dict( - type=str, - help="Message string to use with 'revision'") + type=str, help="Message string to use with 'revision'" + ), ), - 'sql': ( + "sql": ( "--sql", dict( action="store_true", help="Don't emit SQL to database - dump to " "standard output/file instead. See docs on " - "offline mode." - ) + "offline mode.", + ), ), - 'tag': ( + "tag": ( "--tag", dict( type=str, help="Arbitrary 'tag' name - can be used by " - "custom env.py scripts.") + "custom env.py scripts.", + ), ), - 'head': ( + "head": ( "--head", dict( type=str, help="Specify head revision or <branchname>@head " - "to base new revision on." - ) + "to base new revision on.", + ), ), - 'splice': ( + "splice": ( "--splice", dict( action="store_true", help="Allow a non-head revision as the " - "'head' to splice onto" - ) + "'head' to splice onto", + ), ), - 'depends_on': ( + "depends_on": ( "--depends-on", dict( action="append", help="Specify one or more revision identifiers " - "which this revision should depend on." - ) + "which this revision should depend on.", + ), ), - 'rev_id': ( + "rev_id": ( "--rev-id", dict( type=str, help="Specify a hardcoded revision id instead of " - "generating one" - ) + "generating one", + ), ), - 'version_path': ( + "version_path": ( "--version-path", dict( type=str, help="Specify specific path from config for " - "version file" - ) + "version file", + ), ), - 'branch_label': ( + "branch_label": ( "--branch-label", dict( type=str, help="Specify a branch label to apply to the " - "new revision" - ) + "new revision", + ), ), - 'verbose': ( - "-v", "--verbose", - dict( - action="store_true", - help="Use more verbose output" - ) + "verbose": ( + "-v", + "--verbose", + dict(action="store_true", help="Use more verbose output"), ), - 'resolve_dependencies': ( - '--resolve-dependencies', + "resolve_dependencies": ( + "--resolve-dependencies", dict( action="store_true", - help="Treat dependency versions as down revisions" - ) + help="Treat dependency versions as down revisions", + ), ), - 'autogenerate': ( + "autogenerate": ( "--autogenerate", dict( action="store_true", help="Populate revision script with candidate " "migration operations, based on comparison " - "of database to model.") + "of database to model.", + ), ), - 'head_only': ( + "head_only": ( "--head-only", dict( action="store_true", help="Deprecated. Use --verbose for " - "additional output") + "additional output", + ), ), - 'rev_range': ( - "-r", "--rev-range", + "rev_range": ( + "-r", + "--rev-range", dict( action="store", help="Specify a revision range; " - "format is [start]:[end]") + "format is [start]:[end]", + ), ), - 'indicate_current': ( - "-i", "--indicate-current", + "indicate_current": ( + "-i", + "--indicate-current", dict( action="store_true", - help="Indicate the current revision" - ) - ) + help="Indicate the current revision", + ), + ), } positional_help = { - 'directory': "location of scripts directory", - 'revision': "revision identifier", - 'revisions': "one or more revisions, or 'heads' for all heads" - + "directory": "location of scripts directory", + "revision": "revision identifier", + "revisions": "one or more revisions, or 'heads' for all heads", } for arg in kwargs: if arg in kwargs_opts: @@ -429,44 +437,56 @@ class CommandLine(object): for arg in positional: if arg == "revisions": subparser.add_argument( - arg, nargs='+', help=positional_help.get(arg)) + arg, nargs="+", help=positional_help.get(arg) + ) else: subparser.add_argument(arg, help=positional_help.get(arg)) parser = ArgumentParser(prog=prog) - parser.add_argument("-c", "--config", - type=str, - default="alembic.ini", - help="Alternate config file") - parser.add_argument("-n", "--name", - type=str, - default="alembic", - help="Name of section in .ini file to " - "use for Alembic config") - parser.add_argument("-x", action="append", - help="Additional arguments consumed by " - "custom env.py scripts, e.g. -x " - "setting1=somesetting -x setting2=somesetting") - parser.add_argument("--raiseerr", action="store_true", - help="Raise a full stack trace on error") + parser.add_argument( + "-c", + "--config", + type=str, + default="alembic.ini", + help="Alternate config file", + ) + parser.add_argument( + "-n", + "--name", + type=str, + default="alembic", + help="Name of section in .ini file to " "use for Alembic config", + ) + parser.add_argument( + "-x", + action="append", + help="Additional arguments consumed by " + "custom env.py scripts, e.g. -x " + "setting1=somesetting -x setting2=somesetting", + ) + parser.add_argument( + "--raiseerr", + action="store_true", + help="Raise a full stack trace on error", + ) subparsers = parser.add_subparsers() for fn in [getattr(command, n) for n in dir(command)]: - if inspect.isfunction(fn) and \ - fn.__name__[0] != '_' and \ - fn.__module__ == 'alembic.command': + if ( + inspect.isfunction(fn) + and fn.__name__[0] != "_" + and fn.__module__ == "alembic.command" + ): spec = compat.inspect_getargspec(fn) if spec[3]: - positional = spec[0][1:-len(spec[3])] - kwarg = spec[0][-len(spec[3]):] + positional = spec[0][1 : -len(spec[3])] + kwarg = spec[0][-len(spec[3]) :] else: positional = spec[0][1:] kwarg = [] - subparser = subparsers.add_parser( - fn.__name__, - help=fn.__doc__) + subparser = subparsers.add_parser(fn.__name__, help=fn.__doc__) add_options(subparser, positional, kwarg) subparser.set_defaults(cmd=(fn, positional, kwarg)) self.parser = parser @@ -475,10 +495,11 @@ class CommandLine(object): fn, positional, kwarg = options.cmd try: - fn(config, - *[getattr(options, k, None) for k in positional], - **dict((k, getattr(options, k, None)) for k in kwarg) - ) + fn( + config, + *[getattr(options, k, None) for k in positional], + **dict((k, getattr(options, k, None)) for k in kwarg) + ) except util.CommandError as e: if options.raiseerr: raise @@ -492,8 +513,11 @@ class CommandLine(object): # behavior changed incompatibly in py3.3 self.parser.error("too few arguments") else: - cfg = Config(file_=options.config, - ini_section=options.name, cmd_opts=options) + cfg = Config( + file_=options.config, + ini_section=options.name, + cmd_opts=options, + ) self.run_cmd(cfg, options) @@ -502,5 +526,6 @@ def main(argv=None, prog=None, **kwargs): CommandLine(prog=prog).main(argv=argv) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index f4a525f..f177a07 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -9,7 +9,11 @@ from .. import util # backwards compat from ..util.sqla_compat import ( # noqa _table_for_constraint, - _columns_for_constraint, _fk_spec, _is_type_bound, _find_columns) + _columns_for_constraint, + _fk_spec, + _is_type_bound, + _find_columns, +) if util.sqla_09: from sqlalchemy.sql.elements import quoted_name @@ -30,65 +34,63 @@ class AlterTable(DDLElement): class RenameTable(AlterTable): - def __init__(self, old_table_name, new_table_name, schema=None): super(RenameTable, self).__init__(old_table_name, schema=schema) self.new_table_name = new_table_name class AlterColumn(AlterTable): - - def __init__(self, name, column_name, schema=None, - existing_type=None, - existing_nullable=None, - existing_server_default=None): + def __init__( + self, + name, + column_name, + schema=None, + existing_type=None, + existing_nullable=None, + existing_server_default=None, + ): super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name - self.existing_type = sqltypes.to_instance(existing_type) \ - if existing_type is not None else None + self.existing_type = ( + sqltypes.to_instance(existing_type) + if existing_type is not None + else None + ) self.existing_nullable = existing_nullable self.existing_server_default = existing_server_default class ColumnNullable(AlterColumn): - def __init__(self, name, column_name, nullable, **kw): - super(ColumnNullable, self).__init__(name, column_name, - **kw) + super(ColumnNullable, self).__init__(name, column_name, **kw) self.nullable = nullable class ColumnType(AlterColumn): - def __init__(self, name, column_name, type_, **kw): - super(ColumnType, self).__init__(name, column_name, - **kw) + super(ColumnType, self).__init__(name, column_name, **kw) self.type_ = sqltypes.to_instance(type_) class ColumnName(AlterColumn): - def __init__(self, name, column_name, newname, **kw): super(ColumnName, self).__init__(name, column_name, **kw) self.newname = newname class ColumnDefault(AlterColumn): - def __init__(self, name, column_name, default, **kw): super(ColumnDefault, self).__init__(name, column_name, **kw) self.default = default class AddColumn(AlterTable): - def __init__(self, name, column, schema=None): super(AddColumn, self).__init__(name, schema=schema) self.column = column class DropColumn(AlterTable): - def __init__(self, name, column, schema=None): super(DropColumn, self).__init__(name, schema=schema) self.column = column @@ -98,7 +100,7 @@ class DropColumn(AlterTable): def visit_rename_table(element, compiler, **kw): return "%s RENAME TO %s" % ( alter_table(compiler, element.table_name, element.schema), - format_table_name(compiler, element.new_table_name, element.schema) + format_table_name(compiler, element.new_table_name, element.schema), ) @@ -106,7 +108,7 @@ def visit_rename_table(element, compiler, **kw): def visit_add_column(element, compiler, **kw): return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), - add_column(compiler, element.column, **kw) + add_column(compiler, element.column, **kw), ) @@ -114,7 +116,7 @@ def visit_add_column(element, compiler, **kw): def visit_drop_column(element, compiler, **kw): return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), - drop_column(compiler, element.column.name, **kw) + drop_column(compiler, element.column.name, **kw), ) @@ -123,7 +125,7 @@ def visit_column_nullable(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "DROP NOT NULL" if element.nullable else "SET NOT NULL" + "DROP NOT NULL" if element.nullable else "SET NOT NULL", ) @@ -132,7 +134,7 @@ def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "TYPE %s" % format_type(compiler, element.type_) + "TYPE %s" % format_type(compiler, element.type_), ) @@ -141,7 +143,7 @@ def visit_column_name(element, compiler, **kw): return "%s RENAME %s TO %s" % ( alter_table(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), - format_column_name(compiler, element.newname) + format_column_name(compiler, element.newname), ) @@ -150,10 +152,9 @@ def visit_column_default(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "SET DEFAULT %s" % - format_server_default(compiler, element.default) + "SET DEFAULT %s" % format_server_default(compiler, element.default) if element.default is not None - else "DROP DEFAULT" + else "DROP DEFAULT", ) @@ -162,7 +163,7 @@ def quote_dotted(name, quote): if util.sqla_09 and isinstance(name, quoted_name): return quote(name) - result = '.'.join([quote(x) for x in name.split('.')]) + result = ".".join([quote(x) for x in name.split(".")]) return result @@ -193,11 +194,11 @@ def alter_table(compiler, name, schema): def drop_column(compiler, name): - return 'DROP COLUMN %s' % format_column_name(compiler, name) + return "DROP COLUMN %s" % format_column_name(compiler, name) def alter_column(compiler, name): - return 'ALTER COLUMN %s' % format_column_name(compiler, name) + return "ALTER COLUMN %s" % format_column_name(compiler, name) def add_column(compiler, column, **kw): diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 98be164..4e3ff04 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -1,22 +1,20 @@ from sqlalchemy import schema, text from sqlalchemy import types as sqltypes -from ..util.compat import ( - string_types, text_type, with_metaclass -) +from ..util.compat import string_types, text_type, with_metaclass from ..util import sqla_compat from .. import util from . import base class ImplMeta(type): - def __init__(cls, classname, bases, dict_): newtype = type.__init__(cls, classname, bases, dict_) - if '__dialect__' in dict_: - _impls[dict_['__dialect__']] = cls + if "__dialect__" in dict_: + _impls[dict_["__dialect__"]] = cls return newtype + _impls = {} @@ -33,18 +31,25 @@ class DefaultImpl(with_metaclass(ImplMeta)): bulk inserts. """ - __dialect__ = 'default' + + __dialect__ = "default" transactional_ddl = False command_terminator = ";" - def __init__(self, dialect, connection, as_sql, - transactional_ddl, output_buffer, - context_opts): + def __init__( + self, + dialect, + connection, + as_sql, + transactional_ddl, + output_buffer, + context_opts, + ): self.dialect = dialect self.connection = connection self.as_sql = as_sql - self.literal_binds = context_opts.get('literal_binds', False) + self.literal_binds = context_opts.get("literal_binds", False) self.output_buffer = output_buffer self.memo = {} @@ -55,7 +60,8 @@ class DefaultImpl(with_metaclass(ImplMeta)): if self.literal_binds: if not self.as_sql: raise util.CommandError( - "Can't use literal_binds setting without as_sql mode") + "Can't use literal_binds setting without as_sql mode" + ) @classmethod def get_by_dialect(cls, dialect): @@ -89,9 +95,13 @@ class DefaultImpl(with_metaclass(ImplMeta)): def bind(self): return self.connection - def _exec(self, construct, execution_options=None, - multiparams=(), - params=util.immutabledict()): + def _exec( + self, + construct, + execution_options=None, + multiparams=(), + params=util.immutabledict(), + ): if isinstance(construct, string_types): construct = text(construct) if self.as_sql: @@ -100,14 +110,20 @@ class DefaultImpl(with_metaclass(ImplMeta)): raise Exception("Execution arguments not allowed with as_sql") if self.literal_binds and not isinstance( - construct, schema.DDLElement): + construct, schema.DDLElement + ): compile_kw = dict(compile_kwargs={"literal_binds": True}) else: compile_kw = {} - self.static_output(text_type( - construct.compile(dialect=self.dialect, **compile_kw) - ).replace("\t", " ").strip() + self.command_terminator) + self.static_output( + text_type( + construct.compile(dialect=self.dialect, **compile_kw) + ) + .replace("\t", " ") + .strip() + + self.command_terminator + ) else: conn = self.connection if execution_options: @@ -117,53 +133,75 @@ class DefaultImpl(with_metaclass(ImplMeta)): def execute(self, sql, execution_options=None): self._exec(sql, execution_options) - def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - autoincrement=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - existing_autoincrement=None - ): + def alter_column( + self, + table_name, + column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + autoincrement=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + existing_autoincrement=None, + ): if autoincrement is not None or existing_autoincrement is not None: util.warn( "autoincrement and existing_autoincrement " - "only make sense for MySQL") + "only make sense for MySQL" + ) if nullable is not None: - self._exec(base.ColumnNullable( - table_name, column_name, - nullable, schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + self._exec( + base.ColumnNullable( + table_name, + column_name, + nullable, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) if server_default is not False: - self._exec(base.ColumnDefault( - table_name, column_name, server_default, - schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + self._exec( + base.ColumnDefault( + table_name, + column_name, + server_default, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) if type_ is not None: - self._exec(base.ColumnType( - table_name, column_name, type_, schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + self._exec( + base.ColumnType( + table_name, + column_name, + type_, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) # do the new name last ;) if name is not None: - self._exec(base.ColumnName( - table_name, column_name, name, schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + self._exec( + base.ColumnName( + table_name, + column_name, + name, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) def add_column(self, table_name, column, schema=None): self._exec(base.AddColumn(table_name, column, schema=schema)) @@ -172,25 +210,25 @@ class DefaultImpl(with_metaclass(ImplMeta)): self._exec(base.DropColumn(table_name, column, schema=schema)) def add_constraint(self, const): - if const._create_rule is None or \ - const._create_rule(self): + if const._create_rule is None or const._create_rule(self): self._exec(schema.AddConstraint(const)) def drop_constraint(self, const): self._exec(schema.DropConstraint(const)) def rename_table(self, old_table_name, new_table_name, schema=None): - self._exec(base.RenameTable(old_table_name, - new_table_name, schema=schema)) + self._exec( + base.RenameTable(old_table_name, new_table_name, schema=schema) + ) def create_table(self, table): - table.dispatch.before_create(table, self.connection, - checkfirst=False, - _ddl_runner=self) + table.dispatch.before_create( + table, self.connection, checkfirst=False, _ddl_runner=self + ) self._exec(schema.CreateTable(table)) - table.dispatch.after_create(table, self.connection, - checkfirst=False, - _ddl_runner=self) + table.dispatch.after_create( + table, self.connection, checkfirst=False, _ddl_runner=self + ) for index in table.indexes: self._exec(schema.CreateIndex(index)) @@ -210,17 +248,26 @@ class DefaultImpl(with_metaclass(ImplMeta)): raise TypeError("List of dictionaries expected") if self.as_sql: for row in rows: - self._exec(table.insert(inline=True).values(**dict( - (k, - sqla_compat._literal_bindparam( - k, v, type_=table.c[k].type) - if not isinstance( - v, sqla_compat._literal_bindparam) else v) - for k, v in row.items() - ))) + self._exec( + table.insert(inline=True).values( + **dict( + ( + k, + sqla_compat._literal_bindparam( + k, v, type_=table.c[k].type + ) + if not isinstance( + v, sqla_compat._literal_bindparam + ) + else v, + ) + for k, v in row.items() + ) + ) + ) else: # work around http://www.sqlalchemy.org/trac/ticket/2461 - if not hasattr(table, '_autoincrement_column'): + if not hasattr(table, "_autoincrement_column"): table._autoincrement_column = None if rows: if multiinsert: @@ -240,32 +287,38 @@ class DefaultImpl(with_metaclass(ImplMeta)): # work around SQLAlchemy bug "stale value for type affinity" # fixed in 0.7.4 - metadata_impl.__dict__.pop('_type_affinity', None) + metadata_impl.__dict__.pop("_type_affinity", None) if hasattr(metadata_impl, "compare_against_backend"): comparison = metadata_impl.compare_against_backend( - self.dialect, conn_type) + self.dialect, conn_type + ) if comparison is not None: return not comparison - if conn_type._compare_type_affinity( - metadata_impl - ): + if conn_type._compare_type_affinity(metadata_impl): comparator = _type_comparators.get(conn_type._type_affinity, None) return comparator and comparator(metadata_impl, conn_type) else: return True - def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): return rendered_inspector_default != rendered_metadata_default - def correct_for_autogen_constraints(self, conn_uniques, conn_indexes, - metadata_unique_constraints, - metadata_indexes): + def correct_for_autogen_constraints( + self, + conn_uniques, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): pass def _compat_autogen_column_reflect(self, inspector): @@ -316,38 +369,37 @@ class DefaultImpl(with_metaclass(ImplMeta)): def _string_compare(t1, t2): - return \ - t1.length is not None and \ - t1.length != t2.length + return t1.length is not None and t1.length != t2.length def _numeric_compare(t1, t2): - return ( - t1.precision is not None and - t1.precision != t2.precision - ) or ( - t1.precision is not None and - t1.scale is not None and - t1.scale != t2.scale + return (t1.precision is not None and t1.precision != t2.precision) or ( + t1.precision is not None + and t1.scale is not None + and t1.scale != t2.scale ) def _integer_compare(t1, t2): t1_small_or_big = ( - 'S' if isinstance(t1, sqltypes.SmallInteger) - else 'B' if isinstance(t1, sqltypes.BigInteger) else 'I' + "S" + if isinstance(t1, sqltypes.SmallInteger) + else "B" + if isinstance(t1, sqltypes.BigInteger) + else "I" ) t2_small_or_big = ( - 'S' if isinstance(t2, sqltypes.SmallInteger) - else 'B' if isinstance(t2, sqltypes.BigInteger) else 'I' + "S" + if isinstance(t2, sqltypes.SmallInteger) + else "B" + if isinstance(t2, sqltypes.BigInteger) + else "I" ) return t1_small_or_big != t2_small_or_big def _datetime_compare(t1, t2): - return ( - t1.timezone != t2.timezone - ) + return t1.timezone != t2.timezone _type_comparators = { diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index f303be4..7f43a89 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -2,24 +2,35 @@ from sqlalchemy.ext.compiler import compiles from .. import util from .impl import DefaultImpl -from .base import alter_table, AddColumn, ColumnName, RenameTable,\ - format_table_name, format_column_name, ColumnNullable, alter_column,\ - format_server_default, ColumnDefault, format_type, ColumnType +from .base import ( + alter_table, + AddColumn, + ColumnName, + RenameTable, + format_table_name, + format_column_name, + ColumnNullable, + alter_column, + format_server_default, + ColumnDefault, + format_type, + ColumnType, +) from sqlalchemy.sql.expression import ClauseElement, Executable from sqlalchemy.schema import CreateIndex, Column from sqlalchemy import types as sqltypes class MSSQLImpl(DefaultImpl): - __dialect__ = 'mssql' + __dialect__ = "mssql" transactional_ddl = True batch_separator = "GO" def __init__(self, *arg, **kw): super(MSSQLImpl, self).__init__(*arg, **kw) self.batch_separator = self.context_opts.get( - "mssql_batch_separator", - self.batch_separator) + "mssql_batch_separator", self.batch_separator + ) def _exec(self, construct, *args, **kw): result = super(MSSQLImpl, self)._exec(construct, *args, **kw) @@ -35,17 +46,20 @@ class MSSQLImpl(DefaultImpl): if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) - def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - **kw - ): + def alter_column( + self, + table_name, + column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + **kw + ): if nullable is not None and existing_type is None: if type_ is not None: @@ -57,10 +71,12 @@ class MSSQLImpl(DefaultImpl): raise util.CommandError( "MS-SQL ALTER COLUMN operations " "with NULL or NOT NULL require the " - "existing_type or a new type_ be passed.") + "existing_type or a new type_ be passed." + ) super(MSSQLImpl, self).alter_column( - table_name, column_name, + table_name, + column_name, nullable=nullable, type_=type_, schema=schema, @@ -70,30 +86,30 @@ class MSSQLImpl(DefaultImpl): ) if server_default is not False: - if existing_server_default is not False or \ - server_default is None: + if existing_server_default is not False or server_default is None: self._exec( _ExecDropConstraint( - table_name, column_name, - 'sys.default_constraints') + table_name, column_name, "sys.default_constraints" + ) ) if server_default is not None: super(MSSQLImpl, self).alter_column( - table_name, column_name, + table_name, + column_name, schema=schema, - server_default=server_default) + server_default=server_default, + ) if name is not None: super(MSSQLImpl, self).alter_column( - table_name, column_name, - schema=schema, - name=name) + table_name, column_name, schema=schema, name=name + ) def create_index(self, index): # this likely defaults to None if not present, so get() # should normally not return the default value. being # defensive in any case - mssql_include = index.kwargs.get('mssql_include', None) or () + mssql_include = index.kwargs.get("mssql_include", None) or () for col in mssql_include: if col not in index.table.c: index.table.append_column(Column(col, sqltypes.NullType)) @@ -102,42 +118,39 @@ class MSSQLImpl(DefaultImpl): def bulk_insert(self, table, rows, **kw): if self.as_sql: self._exec( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(table) + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(table) ) super(MSSQLImpl, self).bulk_insert(table, rows, **kw) self._exec( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer.format_table(table) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table(table) ) else: super(MSSQLImpl, self).bulk_insert(table, rows, **kw) def drop_column(self, table_name, column, **kw): - drop_default = kw.pop('mssql_drop_default', False) + drop_default = kw.pop("mssql_drop_default", False) if drop_default: self._exec( _ExecDropConstraint( - table_name, column, - 'sys.default_constraints') + table_name, column, "sys.default_constraints" + ) ) - drop_check = kw.pop('mssql_drop_check', False) + drop_check = kw.pop("mssql_drop_check", False) if drop_check: self._exec( _ExecDropConstraint( - table_name, column, - 'sys.check_constraints') + table_name, column, "sys.check_constraints" + ) ) - drop_fks = kw.pop('mssql_drop_foreign_key', False) + drop_fks = kw.pop("mssql_drop_foreign_key", False) if drop_fks: - self._exec( - _ExecDropFKConstraint(table_name, column) - ) + self._exec(_ExecDropFKConstraint(table_name, column)) super(MSSQLImpl, self).drop_column(table_name, column, **kw) class _ExecDropConstraint(Executable, ClauseElement): - def __init__(self, tname, colname, type_): self.tname = tname self.colname = colname @@ -145,13 +158,12 @@ class _ExecDropConstraint(Executable, ClauseElement): class _ExecDropFKConstraint(Executable, ClauseElement): - def __init__(self, tname, colname): self.tname = tname self.colname = colname -@compiles(_ExecDropConstraint, 'mssql') +@compiles(_ExecDropConstraint, "mssql") def _exec_drop_col_constraint(element, compiler, **kw): tname, colname, type_ = element.tname, element.colname, element.type_ # from http://www.mssqltips.com/sqlservertip/1425/\ @@ -162,14 +174,14 @@ select @const_name = [name] from %(type)s where parent_object_id = object_id('%(tname)s') and col_name(parent_object_id, parent_column_id) = '%(colname)s' exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { - 'type': type_, - 'tname': tname, - 'colname': colname, - 'tname_quoted': format_table_name(compiler, tname, None), + "type": type_, + "tname": tname, + "colname": colname, + "tname_quoted": format_table_name(compiler, tname, None), } -@compiles(_ExecDropFKConstraint, 'mssql') +@compiles(_ExecDropFKConstraint, "mssql") def _exec_drop_col_fk_constraint(element, compiler, **kw): tname, colname = element.tname, element.colname @@ -180,17 +192,17 @@ select @const_name = [name] from where fkc.parent_object_id = object_id('%(tname)s') and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s' exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { - 'tname': tname, - 'colname': colname, - 'tname_quoted': format_table_name(compiler, tname, None), + "tname": tname, + "colname": colname, + "tname_quoted": format_table_name(compiler, tname, None), } -@compiles(AddColumn, 'mssql') +@compiles(AddColumn, "mssql") def visit_add_column(element, compiler, **kw): return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), - mssql_add_column(compiler, element.column, **kw) + mssql_add_column(compiler, element.column, **kw), ) @@ -198,49 +210,48 @@ def mssql_add_column(compiler, column, **kw): return "ADD %s" % compiler.get_column_specification(column, **kw) -@compiles(ColumnNullable, 'mssql') +@compiles(ColumnNullable, "mssql") def visit_column_nullable(element, compiler, **kw): return "%s %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), format_type(compiler, element.existing_type), - "NULL" if element.nullable else "NOT NULL" + "NULL" if element.nullable else "NOT NULL", ) -@compiles(ColumnDefault, 'mssql') +@compiles(ColumnDefault, "mssql") def visit_column_default(element, compiler, **kw): # TODO: there can also be a named constraint # with ADD CONSTRAINT here return "%s ADD DEFAULT %s FOR %s" % ( alter_table(compiler, element.table_name, element.schema), format_server_default(compiler, element.default), - format_column_name(compiler, element.column_name) + format_column_name(compiler, element.column_name), ) -@compiles(ColumnName, 'mssql') +@compiles(ColumnName, "mssql") def visit_rename_column(element, compiler, **kw): return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % ( format_table_name(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), - format_column_name(compiler, element.newname) + format_column_name(compiler, element.newname), ) -@compiles(ColumnType, 'mssql') +@compiles(ColumnType, "mssql") def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - format_type(compiler, element.type_) + format_type(compiler, element.type_), ) -@compiles(RenameTable, 'mssql') +@compiles(RenameTable, "mssql") def visit_rename_table(element, compiler, **kw): return "EXEC sp_rename '%s', %s" % ( format_table_name(compiler, element.table_name, element.schema), - format_table_name(compiler, element.new_table_name, None) + format_table_name(compiler, element.new_table_name, None), ) - diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 1f4b345..bc98005 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -5,9 +5,15 @@ from sqlalchemy import schema from ..util.compat import string_types from .. import util from .impl import DefaultImpl -from .base import ColumnNullable, ColumnName, ColumnDefault, \ - ColumnType, AlterColumn, format_column_name, \ - format_server_default +from .base import ( + ColumnNullable, + ColumnName, + ColumnDefault, + ColumnType, + AlterColumn, + format_column_name, + format_server_default, +) from .base import alter_table from ..autogenerate import compare from ..util.sqla_compat import _is_type_bound, sqla_100 @@ -15,64 +21,76 @@ import re class MySQLImpl(DefaultImpl): - __dialect__ = 'mysql' + __dialect__ = "mysql" transactional_ddl = False - def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - autoincrement=None, - existing_autoincrement=None, - **kw - ): + def alter_column( + self, + table_name, + column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + autoincrement=None, + existing_autoincrement=None, + **kw + ): if name is not None: self._exec( MySQLChangeColumn( - table_name, column_name, + table_name, + column_name, schema=schema, newname=name, - nullable=nullable if nullable is not None else - existing_nullable + nullable=nullable + if nullable is not None + else existing_nullable if existing_nullable is not None else True, type_=type_ if type_ is not None else existing_type, - default=server_default if server_default is not False + default=server_default + if server_default is not False else existing_server_default, - autoincrement=autoincrement if autoincrement is not None - else existing_autoincrement + autoincrement=autoincrement + if autoincrement is not None + else existing_autoincrement, ) ) - elif nullable is not None or \ - type_ is not None or \ - autoincrement is not None: + elif ( + nullable is not None + or type_ is not None + or autoincrement is not None + ): self._exec( MySQLModifyColumn( - table_name, column_name, + table_name, + column_name, schema=schema, newname=name if name is not None else column_name, - nullable=nullable if nullable is not None else - existing_nullable + nullable=nullable + if nullable is not None + else existing_nullable if existing_nullable is not None else True, type_=type_ if type_ is not None else existing_type, - default=server_default if server_default is not False + default=server_default + if server_default is not False else existing_server_default, - autoincrement=autoincrement if autoincrement is not None - else existing_autoincrement + autoincrement=autoincrement + if autoincrement is not None + else existing_autoincrement, ) ) elif server_default is not False: self._exec( MySQLAlterDefault( - table_name, column_name, server_default, - schema=schema, + table_name, column_name, server_default, schema=schema ) ) @@ -82,41 +100,47 @@ class MySQLImpl(DefaultImpl): super(MySQLImpl, self).drop_constraint(const) - def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): # partially a workaround for SQLAlchemy issue #3023; if the # column were created without "NOT NULL", MySQL may have added # an implicit default of '0' which we need to skip # TODO: this is not really covered anymore ? - if metadata_column.type._type_affinity is sqltypes.Integer and \ - inspector_column.primary_key and \ - not inspector_column.autoincrement and \ - not rendered_metadata_default and \ - rendered_inspector_default == "'0'": + if ( + metadata_column.type._type_affinity is sqltypes.Integer + and inspector_column.primary_key + and not inspector_column.autoincrement + and not rendered_metadata_default + and rendered_inspector_default == "'0'" + ): return False elif inspector_column.type._type_affinity is sqltypes.Integer: rendered_inspector_default = re.sub( - r"^'|'$", '', rendered_inspector_default) + r"^'|'$", "", rendered_inspector_default + ) return rendered_inspector_default != rendered_metadata_default elif rendered_inspector_default and rendered_metadata_default: # adjust for "function()" vs. "FUNCTION" - return ( - re.sub( - r'(.*?)(?:\(\))?$', r'\1', - rendered_inspector_default.lower()) != - re.sub( - r'(.*?)(?:\(\))?$', r'\1', - rendered_metadata_default.lower()) + return re.sub( + r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower() + ) != re.sub( + r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower() ) else: return rendered_inspector_default != rendered_metadata_default - def correct_for_autogen_constraints(self, conn_unique_constraints, - conn_indexes, - metadata_unique_constraints, - metadata_indexes): + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): # TODO: if SQLA 1.0, make use of "duplicates_index" # metadata @@ -153,31 +177,41 @@ class MySQLImpl(DefaultImpl): conn_unique_constraints, conn_indexes, metadata_unique_constraints, - metadata_indexes + metadata_indexes, ) - def _legacy_correct_for_dupe_uq_uix(self, conn_unique_constraints, - conn_indexes, - metadata_unique_constraints, - metadata_indexes): + def _legacy_correct_for_dupe_uq_uix( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): # then dedupe unique indexes vs. constraints, since MySQL # doesn't really have unique constraints as a separate construct. # but look in the metadata and try to maintain constructs # that already seem to be defined one way or the other # on that side. See #276 - metadata_uq_names = set([ - cons.name for cons in metadata_unique_constraints - if cons.name is not None]) - - unnamed_metadata_uqs = set([ - compare._uq_constraint_sig(cons).sig - for cons in metadata_unique_constraints - if cons.name is None - ]) - - metadata_ix_names = set([ - cons.name for cons in metadata_indexes if cons.unique]) + metadata_uq_names = set( + [ + cons.name + for cons in metadata_unique_constraints + if cons.name is not None + ] + ) + + unnamed_metadata_uqs = set( + [ + compare._uq_constraint_sig(cons).sig + for cons in metadata_unique_constraints + if cons.name is None + ] + ) + + metadata_ix_names = set( + [cons.name for cons in metadata_indexes if cons.unique] + ) conn_uq_names = dict( (cons.name, cons) for cons in conn_unique_constraints ) @@ -187,8 +221,10 @@ class MySQLImpl(DefaultImpl): for overlap in set(conn_uq_names).intersection(conn_ix_names): if overlap not in metadata_uq_names: - if compare._uq_constraint_sig(conn_uq_names[overlap]).sig \ - not in unnamed_metadata_uqs: + if ( + compare._uq_constraint_sig(conn_uq_names[overlap]).sig + not in unnamed_metadata_uqs + ): conn_unique_constraints.discard(conn_uq_names[overlap]) elif overlap not in metadata_ix_names: @@ -208,18 +244,21 @@ class MySQLImpl(DefaultImpl): # MySQL considers RESTRICT to be the default and doesn't # report on it. if the model has explicit RESTRICT and # the conn FK has None, set it to RESTRICT - if mdfk.ondelete is not None and \ - mdfk.ondelete.lower() == 'restrict' and \ - cnfk.ondelete is None: - cnfk.ondelete = 'RESTRICT' - if mdfk.onupdate is not None and \ - mdfk.onupdate.lower() == 'restrict' and \ - cnfk.onupdate is None: - cnfk.onupdate = 'RESTRICT' + if ( + mdfk.ondelete is not None + and mdfk.ondelete.lower() == "restrict" + and cnfk.ondelete is None + ): + cnfk.ondelete = "RESTRICT" + if ( + mdfk.onupdate is not None + and mdfk.onupdate.lower() == "restrict" + and cnfk.onupdate is None + ): + cnfk.onupdate = "RESTRICT" class MySQLAlterDefault(AlterColumn): - def __init__(self, name, column_name, default, schema=None): super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name @@ -227,13 +266,17 @@ class MySQLAlterDefault(AlterColumn): class MySQLChangeColumn(AlterColumn): - - def __init__(self, name, column_name, schema=None, - newname=None, - type_=None, - nullable=None, - default=False, - autoincrement=None): + def __init__( + self, + name, + column_name, + schema=None, + newname=None, + type_=None, + nullable=None, + default=False, + autoincrement=None, + ): super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name self.nullable = nullable @@ -253,10 +296,10 @@ class MySQLModifyColumn(MySQLChangeColumn): pass -@compiles(ColumnNullable, 'mysql') -@compiles(ColumnName, 'mysql') -@compiles(ColumnDefault, 'mysql') -@compiles(ColumnType, 'mysql') +@compiles(ColumnNullable, "mysql") +@compiles(ColumnName, "mysql") +@compiles(ColumnDefault, "mysql") +@compiles(ColumnType, "mysql") def _mysql_doesnt_support_individual(element, compiler, **kw): raise NotImplementedError( "Individual alter column constructs not supported by MySQL" @@ -270,7 +313,7 @@ def _mysql_alter_default(element, compiler, **kw): format_column_name(compiler, element.column_name), "SET DEFAULT %s" % format_server_default(compiler, element.default) if element.default is not None - else "DROP DEFAULT" + else "DROP DEFAULT", ) @@ -284,7 +327,7 @@ def _mysql_modify_column(element, compiler, **kw): nullable=element.nullable, server_default=element.default, type_=element.type_, - autoincrement=element.autoincrement + autoincrement=element.autoincrement, ), ) @@ -300,7 +343,7 @@ def _mysql_change_column(element, compiler, **kw): nullable=element.nullable, server_default=element.default, type_=element.type_, - autoincrement=element.autoincrement + autoincrement=element.autoincrement, ), ) @@ -312,11 +355,10 @@ def _render_value(compiler, expr): return compiler.sql_compiler.process(expr) -def _mysql_colspec(compiler, nullable, server_default, type_, - autoincrement): +def _mysql_colspec(compiler, nullable, server_default, type_, autoincrement): spec = "%s %s" % ( compiler.dialect.type_compiler.process(type_), - "NULL" if nullable else "NOT NULL" + "NULL" if nullable else "NOT NULL", ) if autoincrement: spec += " AUTO_INCREMENT" @@ -332,21 +374,25 @@ def _mysql_drop_constraint(element, compiler, **kw): raise errors for invalid constraint type.""" constraint = element.element - if isinstance(constraint, (schema.ForeignKeyConstraint, - schema.PrimaryKeyConstraint, - schema.UniqueConstraint) - ): + if isinstance( + constraint, + ( + schema.ForeignKeyConstraint, + schema.PrimaryKeyConstraint, + schema.UniqueConstraint, + ), + ): return compiler.visit_drop_constraint(element, **kw) elif isinstance(constraint, schema.CheckConstraint): # note that SQLAlchemy as of 1.2 does not yet support # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully # here. - return "ALTER TABLE %s DROP CONSTRAINT %s" % \ - (compiler.preparer.format_table(constraint.table), - compiler.preparer.format_constraint(constraint)) + return "ALTER TABLE %s DROP CONSTRAINT %s" % ( + compiler.preparer.format_table(constraint.table), + compiler.preparer.format_constraint(constraint), + ) else: raise NotImplementedError( "No generic 'DROP CONSTRAINT' in MySQL - " - "please specify constraint type") - - + "please specify constraint type" + ) diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index e528744..3376155 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -1,13 +1,21 @@ from sqlalchemy.ext.compiler import compiles from .impl import DefaultImpl -from .base import alter_table, AddColumn, ColumnName, \ - format_column_name, ColumnNullable, \ - format_server_default, ColumnDefault, format_type, ColumnType +from .base import ( + alter_table, + AddColumn, + ColumnName, + format_column_name, + ColumnNullable, + format_server_default, + ColumnDefault, + format_type, + ColumnType, +) class OracleImpl(DefaultImpl): - __dialect__ = 'oracle' + __dialect__ = "oracle" transactional_ddl = False batch_separator = "/" command_terminator = "" @@ -15,8 +23,8 @@ class OracleImpl(DefaultImpl): def __init__(self, *arg, **kw): super(OracleImpl, self).__init__(*arg, **kw) self.batch_separator = self.context_opts.get( - "oracle_batch_separator", - self.batch_separator) + "oracle_batch_separator", self.batch_separator + ) def _exec(self, construct, *args, **kw): result = super(OracleImpl, self)._exec(construct, *args, **kw) @@ -31,7 +39,7 @@ class OracleImpl(DefaultImpl): self._exec("COMMIT") -@compiles(AddColumn, 'oracle') +@compiles(AddColumn, "oracle") def visit_add_column(element, compiler, **kw): return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -39,47 +47,46 @@ def visit_add_column(element, compiler, **kw): ) -@compiles(ColumnNullable, 'oracle') +@compiles(ColumnNullable, "oracle") def visit_column_nullable(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "NULL" if element.nullable else "NOT NULL" + "NULL" if element.nullable else "NOT NULL", ) -@compiles(ColumnType, 'oracle') +@compiles(ColumnType, "oracle") def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "%s" % format_type(compiler, element.type_) + "%s" % format_type(compiler, element.type_), ) -@compiles(ColumnName, 'oracle') +@compiles(ColumnName, "oracle") def visit_column_name(element, compiler, **kw): return "%s RENAME COLUMN %s TO %s" % ( alter_table(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), - format_column_name(compiler, element.newname) + format_column_name(compiler, element.newname), ) -@compiles(ColumnDefault, 'oracle') +@compiles(ColumnDefault, "oracle") def visit_column_default(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "DEFAULT %s" % - format_server_default(compiler, element.default) + "DEFAULT %s" % format_server_default(compiler, element.default) if element.default is not None - else "DEFAULT NULL" + else "DEFAULT NULL", ) def alter_column(compiler, name): - return 'MODIFY %s' % format_column_name(compiler, name) + return "MODIFY %s" % format_column_name(compiler, name) def add_column(compiler, column, **kw): diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index d399833..f133a05 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -2,8 +2,15 @@ import re from ..util import compat from .. import util -from .base import compiles, alter_column, alter_table, format_table_name, \ - format_type, AlterColumn, RenameTable +from .base import ( + compiles, + alter_column, + alter_table, + format_table_name, + format_type, + AlterColumn, + RenameTable, +) from .impl import DefaultImpl from sqlalchemy.dialects.postgresql import INTEGER, BIGINT from ..autogenerate import render @@ -30,7 +37,7 @@ log = logging.getLogger(__name__) class PostgresqlImpl(DefaultImpl): - __dialect__ = 'postgresql' + __dialect__ = "postgresql" transactional_ddl = True def prep_table_for_batch(self, table): @@ -38,13 +45,18 @@ class PostgresqlImpl(DefaultImpl): if constraint.name is not None: self.drop_constraint(constraint) - def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): # don't do defaults for SERIAL columns - if metadata_column.primary_key and \ - metadata_column is metadata_column.table._autoincrement_column: + if ( + metadata_column.primary_key + and metadata_column is metadata_column.table._autoincrement_column + ): return False conn_col_default = rendered_inspector_default @@ -56,53 +68,65 @@ class PostgresqlImpl(DefaultImpl): if None in (conn_col_default, rendered_metadata_default): return not defaults_equal - if metadata_column.server_default is not None and \ - isinstance(metadata_column.server_default.arg, - compat.string_types) and \ - not re.match(r"^'.+'$", rendered_metadata_default) and \ - not isinstance(inspector_column.type, Numeric): - # don't single quote if the column type is float/numeric, - # otherwise a comparison such as SELECT 5 = '5.0' will fail + if ( + metadata_column.server_default is not None + and isinstance( + metadata_column.server_default.arg, compat.string_types + ) + and not re.match(r"^'.+'$", rendered_metadata_default) + and not isinstance(inspector_column.type, Numeric) + ): + # don't single quote if the column type is float/numeric, + # otherwise a comparison such as SELECT 5 = '5.0' will fail rendered_metadata_default = re.sub( - r"^u?'?|'?$", "'", rendered_metadata_default) + r"^u?'?|'?$", "'", rendered_metadata_default + ) return not self.connection.scalar( - "SELECT %s = %s" % ( - conn_col_default, - rendered_metadata_default - ) + "SELECT %s = %s" % (conn_col_default, rendered_metadata_default) ) - def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - autoincrement=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - existing_autoincrement=None, - **kw - ): - - using = kw.pop('postgresql_using', None) + def alter_column( + self, + table_name, + column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + autoincrement=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + existing_autoincrement=None, + **kw + ): + + using = kw.pop("postgresql_using", None) if using is not None and type_ is None: raise util.CommandError( - "postgresql_using must be used with the type_ parameter") + "postgresql_using must be used with the type_ parameter" + ) if type_ is not None: - self._exec(PostgresqlColumnType( - table_name, column_name, type_, schema=schema, - using=using, existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + self._exec( + PostgresqlColumnType( + table_name, + column_name, + type_, + schema=schema, + using=using, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) super(PostgresqlImpl, self).alter_column( - table_name, column_name, + table_name, + column_name, nullable=nullable, server_default=server_default, name=name, @@ -112,57 +136,70 @@ class PostgresqlImpl(DefaultImpl): existing_server_default=existing_server_default, existing_nullable=existing_nullable, existing_autoincrement=existing_autoincrement, - **kw) + **kw + ) def autogen_column_reflect(self, inspector, table, column_info): - if column_info.get('default') and \ - isinstance(column_info['type'], (INTEGER, BIGINT)): + if column_info.get("default") and isinstance( + column_info["type"], (INTEGER, BIGINT) + ): seq_match = re.match( - r"nextval\('(.+?)'::regclass\)", - column_info['default']) + r"nextval\('(.+?)'::regclass\)", column_info["default"] + ) if seq_match: - info = inspector.bind.execute(text( - "select c.relname, a.attname " - "from pg_class as c join pg_depend d on d.objid=c.oid and " - "d.classid='pg_class'::regclass and " - "d.refclassid='pg_class'::regclass " - "join pg_class t on t.oid=d.refobjid " - "join pg_attribute a on a.attrelid=t.oid and " - "a.attnum=d.refobjsubid " - "where c.relkind='S' and c.relname=:seqname" - ), seqname=seq_match.group(1)).first() + info = inspector.bind.execute( + text( + "select c.relname, a.attname " + "from pg_class as c join pg_depend d on d.objid=c.oid and " + "d.classid='pg_class'::regclass and " + "d.refclassid='pg_class'::regclass " + "join pg_class t on t.oid=d.refobjid " + "join pg_attribute a on a.attrelid=t.oid and " + "a.attnum=d.refobjsubid " + "where c.relkind='S' and c.relname=:seqname" + ), + seqname=seq_match.group(1), + ).first() if info: seqname, colname = info - if colname == column_info['name']: + if colname == column_info["name"]: log.info( "Detected sequence named '%s' as " "owned by integer column '%s(%s)', " "assuming SERIAL and omitting", - seqname, table.name, colname) + seqname, + table.name, + colname, + ) # sequence, and the owner is this column, # its a SERIAL - whack it! - del column_info['default'] + del column_info["default"] - def correct_for_autogen_constraints(self, conn_unique_constraints, - conn_indexes, - metadata_unique_constraints, - metadata_indexes): + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): conn_uniques_by_name = dict( - (c.name, c) for c in conn_unique_constraints) - conn_indexes_by_name = dict( - (c.name, c) for c in conn_indexes) + (c.name, c) for c in conn_unique_constraints + ) + conn_indexes_by_name = dict((c.name, c) for c in conn_indexes) if not util.sqla_100: doubled_constraints = set( conn_indexes_by_name[name] for name in set(conn_uniques_by_name).intersection( - conn_indexes_by_name) + conn_indexes_by_name + ) ) else: doubled_constraints = set( - index for index in - conn_indexes if index.info.get('duplicates_constraint') + index + for index in conn_indexes + if index.info.get("duplicates_constraint") ) for ix in doubled_constraints: @@ -187,37 +224,36 @@ class PostgresqlImpl(DefaultImpl): if not mod.startswith("sqlalchemy.dialects.postgresql"): return False - if hasattr(self, '_render_%s_type' % type_.__visit_name__): - meth = getattr(self, '_render_%s_type' % type_.__visit_name__) + if hasattr(self, "_render_%s_type" % type_.__visit_name__): + meth = getattr(self, "_render_%s_type" % type_.__visit_name__) return meth(type_, autogen_context) return False def _render_HSTORE_type(self, type_, autogen_context): return render._render_type_w_subtype( - type_, autogen_context, 'text_type', r'(.+?\(.*text_type=)' + type_, autogen_context, "text_type", r"(.+?\(.*text_type=)" ) def _render_ARRAY_type(self, type_, autogen_context): return render._render_type_w_subtype( - type_, autogen_context, 'item_type', r'(.+?\()' + type_, autogen_context, "item_type", r"(.+?\()" ) def _render_JSON_type(self, type_, autogen_context): return render._render_type_w_subtype( - type_, autogen_context, 'astext_type', r'(.+?\(.*astext_type=)' + type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" ) def _render_JSONB_type(self, type_, autogen_context): return render._render_type_w_subtype( - type_, autogen_context, 'astext_type', r'(.+?\(.*astext_type=)' + type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" ) class PostgresqlColumnType(AlterColumn): - def __init__(self, name, column_name, type_, **kw): - using = kw.pop('using', None) + using = kw.pop("using", None) super(PostgresqlColumnType, self).__init__(name, column_name, **kw) self.type_ = sqltypes.to_instance(type_) self.using = using @@ -227,7 +263,7 @@ class PostgresqlColumnType(AlterColumn): def visit_rename_table(element, compiler, **kw): return "%s RENAME TO %s" % ( alter_table(compiler, element.table_name, element.schema), - format_table_name(compiler, element.new_table_name, None) + format_table_name(compiler, element.new_table_name, None), ) @@ -237,13 +273,14 @@ def visit_column_type(element, compiler, **kw): alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), "TYPE %s" % format_type(compiler, element.type_), - "USING %s" % element.using if element.using else "" + "USING %s" % element.using if element.using else "", ) @Operations.register_operation("create_exclude_constraint") @BatchOperations.register_operation( - "create_exclude_constraint", "batch_create_exclude_constraint") + "create_exclude_constraint", "batch_create_exclude_constraint" +) @ops.AddConstraintOp.register_add_constraint("exclude_constraint") class CreateExcludeConstraintOp(ops.AddConstraintOp): """Represent a create exclude constraint operation.""" @@ -251,9 +288,15 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): constraint_type = "exclude" def __init__( - self, constraint_name, table_name, - elements, where=None, schema=None, - _orig_constraint=None, **kw): + self, + constraint_name, + table_name, + elements, + where=None, + schema=None, + _orig_constraint=None, + **kw + ): self.constraint_name = constraint_name self.table_name = table_name self.elements = elements @@ -275,13 +318,14 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): _orig_constraint=constraint, deferrable=constraint.deferrable, initially=constraint.initially, - using=constraint.using + using=constraint.using, ) def to_constraint(self, migration_context=None): if not util.sqla_100: raise NotImplementedError( - "ExcludeConstraint not supported until SQLAlchemy 1.0") + "ExcludeConstraint not supported until SQLAlchemy 1.0" + ) if self._orig_constraint is not None: return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) @@ -299,8 +343,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): @classmethod def create_exclude_constraint( - cls, operations, - constraint_name, table_name, *elements, **kw): + cls, operations, constraint_name, table_name, *elements, **kw + ): """Issue an alter to create an EXCLUDE constraint using the current migration context. @@ -344,7 +388,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): @classmethod def batch_create_exclude_constraint( - cls, operations, constraint_name, *elements, **kw): + cls, operations, constraint_name, *elements, **kw + ): """Issue a "create exclude constraint" instruction using the current batch migration context. @@ -358,24 +403,23 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): :meth:`.Operations.create_exclude_constraint` """ - kw['schema'] = operations.impl.schema + kw["schema"] = operations.impl.schema op = cls(constraint_name, operations.impl.table_name, elements, **kw) return operations.invoke(op) @render.renderers.dispatch_for(CreateExcludeConstraintOp) def _add_exclude_constraint(autogen_context, op): - return _exclude_constraint( - op.to_constraint(), - autogen_context, - alter=True - ) + return _exclude_constraint(op.to_constraint(), autogen_context, alter=True) + if util.sqla_100: + @render._constraint_renderers.dispatch_for(ExcludeConstraint) def _render_inline_exclude_constraint(constraint, autogen_context): rendered = render._user_defined_render( - "exclude", constraint, autogen_context) + "exclude", constraint, autogen_context + ) if rendered is not False: return rendered @@ -405,48 +449,54 @@ def _exclude_constraint(constraint, autogen_context, alter): opts.append(("schema", render._ident(constraint.table.schema))) if not alter and constraint.name: opts.append( - ("name", - render._render_gen_name(autogen_context, constraint.name))) + ("name", render._render_gen_name(autogen_context, constraint.name)) + ) if alter: args = [ - repr(render._render_gen_name( - autogen_context, constraint.name))] + repr(render._render_gen_name(autogen_context, constraint.name)) + ] if not has_batch: args += [repr(render._ident(constraint.table.name))] - args.extend([ - "(%s, %r)" % ( - _render_potential_column(sqltext, autogen_context), - opstring - ) - for sqltext, name, opstring in constraint._render_exprs - ]) + args.extend( + [ + "(%s, %r)" + % ( + _render_potential_column(sqltext, autogen_context), + opstring, + ) + for sqltext, name, opstring in constraint._render_exprs + ] + ) if constraint.where is not None: args.append( - "where=%s" % render._render_potential_expr( - constraint.where, autogen_context) + "where=%s" + % render._render_potential_expr( + constraint.where, autogen_context + ) ) args.extend(["%s=%r" % (k, v) for k, v in opts]) return "%(prefix)screate_exclude_constraint(%(args)s)" % { - 'prefix': render._alembic_autogenerate_prefix(autogen_context), - 'args': ", ".join(args) + "prefix": render._alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), } else: args = [ - "(%s, %r)" % ( - _render_potential_column(sqltext, autogen_context), - opstring - ) for sqltext, name, opstring in constraint._render_exprs + "(%s, %r)" + % (_render_potential_column(sqltext, autogen_context), opstring) + for sqltext, name, opstring in constraint._render_exprs ] if constraint.where is not None: args.append( - "where=%s" % render._render_potential_expr( - constraint.where, autogen_context) + "where=%s" + % render._render_potential_expr( + constraint.where, autogen_context + ) ) args.extend(["%s=%r" % (k, v) for k, v in opts]) return "%(prefix)sExcludeConstraint(%(args)s)" % { "prefix": _postgresql_autogenerate_prefix(autogen_context), - "args": ", ".join(args) + "args": ", ".join(args), } @@ -456,8 +506,10 @@ def _render_potential_column(value, autogen_context): return template % { "prefix": render._sqlalchemy_autogenerate_prefix(autogen_context), - "name": value.name + "name": value.name, } else: - return render._render_potential_expr(value, autogen_context, wrap_in_text=False) + return render._render_potential_expr( + value, autogen_context, wrap_in_text=False + ) diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index 5d231b5..f7699e6 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -4,7 +4,7 @@ import re class SQLiteImpl(DefaultImpl): - __dialect__ = 'sqlite' + __dialect__ = "sqlite" transactional_ddl = False """SQLite supports transactional DDL, but pysqlite does not: @@ -21,7 +21,7 @@ class SQLiteImpl(DefaultImpl): """ for op in batch_op.batch: - if op[0] not in ('add_column', 'create_index', 'drop_index'): + if op[0] not in ("add_column", "create_index", "drop_index"): return True else: return False @@ -31,34 +31,46 @@ class SQLiteImpl(DefaultImpl): # auto-gen constraint and an explicit one if const._create_rule is None: raise NotImplementedError( - "No support for ALTER of constraints in SQLite dialect") + "No support for ALTER of constraints in SQLite dialect" + ) elif const._create_rule(self): - util.warn("Skipping unsupported ALTER for " - "creation of implicit constraint") + util.warn( + "Skipping unsupported ALTER for " + "creation of implicit constraint" + ) def drop_constraint(self, const): if const._create_rule is None: raise NotImplementedError( - "No support for ALTER of constraints in SQLite dialect") + "No support for ALTER of constraints in SQLite dialect" + ) - def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): if rendered_metadata_default is not None: rendered_metadata_default = re.sub( - r"^\"'|\"'$", "", rendered_metadata_default) + r"^\"'|\"'$", "", rendered_metadata_default + ) if rendered_inspector_default is not None: rendered_inspector_default = re.sub( - r"^\"'|\"'$", "", rendered_inspector_default) + r"^\"'|\"'$", "", rendered_inspector_default + ) return rendered_inspector_default != rendered_metadata_default def correct_for_autogen_constraints( - self, conn_unique_constraints, conn_indexes, + self, + conn_unique_constraints, + conn_indexes, metadata_unique_constraints, - metadata_indexes): + metadata_indexes, + ): if util.sqla_100: return @@ -70,10 +82,7 @@ class SQLiteImpl(DefaultImpl): def uq_sig(uq): return tuple(sorted(uq.columns.keys())) - conn_unique_sigs = set( - uq_sig(uq) - for uq in conn_unique_constraints - ) + conn_unique_sigs = set(uq_sig(uq) for uq in conn_unique_constraints) for idx in list(metadata_unique_constraints): # SQLite backend can't report on unnamed UNIQUE constraints, diff --git a/alembic/op.py b/alembic/op.py index 1f367a1..f3f5fac 100644 --- a/alembic/op.py +++ b/alembic/op.py @@ -3,4 +3,3 @@ from .operations.base import Operations # create proxy functions for # each method on the Operations class. Operations.create_module_class_proxy(globals(), locals()) - diff --git a/alembic/operations/__init__.py b/alembic/operations/__init__.py index 1f6ee5d..e1ff01c 100644 --- a/alembic/operations/__init__.py +++ b/alembic/operations/__init__.py @@ -3,4 +3,4 @@ from .ops import MigrateOperation from . import toimpl -__all__ = ['Operations', 'BatchOperations', 'MigrateOperation']
\ No newline at end of file +__all__ = ["Operations", "BatchOperations", "MigrateOperation"] diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 1ae9524..2c3408a 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -9,7 +9,7 @@ from ..util.compat import inspect_formatargspec from ..util.compat import inspect_getargspec import textwrap -__all__ = ('Operations', 'BatchOperations') +__all__ = ("Operations", "BatchOperations") try: from sqlalchemy.sql.naming import conv @@ -84,6 +84,7 @@ class Operations(util.ModuleClsProxy): """ + def register(op_cls): if sourcename is None: fn = getattr(op_cls, name) @@ -95,45 +96,53 @@ class Operations(util.ModuleClsProxy): spec = inspect_getargspec(fn) name_args = spec[0] - assert name_args[0:2] == ['cls', 'operations'] + assert name_args[0:2] == ["cls", "operations"] - name_args[0:2] = ['self'] + name_args[0:2] = ["self"] args = inspect_formatargspec(*spec) num_defaults = len(spec[3]) if spec[3] else 0 if num_defaults: - defaulted_vals = name_args[0 - num_defaults:] + defaulted_vals = name_args[0 - num_defaults :] else: defaulted_vals = () apply_kw = inspect_formatargspec( - name_args, spec[1], spec[2], + name_args, + spec[1], + spec[2], defaulted_vals, - formatvalue=lambda x: '=' + x) + formatvalue=lambda x: "=" + x, + ) - func_text = textwrap.dedent("""\ + func_text = textwrap.dedent( + """\ def %(name)s%(args)s: %(doc)r return op_cls.%(source_name)s%(apply_kw)s - """ % { - 'name': name, - 'source_name': source_name, - 'args': args, - 'apply_kw': apply_kw, - 'doc': fn.__doc__, - 'meth': fn.__name__ - }) - globals_ = {'op_cls': op_cls} + """ + % { + "name": name, + "source_name": source_name, + "args": args, + "apply_kw": apply_kw, + "doc": fn.__doc__, + "meth": fn.__name__, + } + ) + globals_ = {"op_cls": op_cls} lcl = {} exec_(func_text, globals_, lcl) setattr(cls, name, lcl[name]) - fn.__func__.__doc__ = "This method is proxied on "\ - "the :class:`.%s` class, via the :meth:`.%s.%s` method." % ( - cls.__name__, cls.__name__, name - ) - if hasattr(fn, '_legacy_translations'): + fn.__func__.__doc__ = ( + "This method is proxied on " + "the :class:`.%s` class, via the :meth:`.%s.%s` method." + % (cls.__name__, cls.__name__, name) + ) + if hasattr(fn, "_legacy_translations"): lcl[name]._legacy_translations = fn._legacy_translations return op_cls + return register @classmethod @@ -151,6 +160,7 @@ class Operations(util.ModuleClsProxy): def decorate(fn): cls._to_impl.dispatch_for(op_cls)(fn) return fn + return decorate @classmethod @@ -163,10 +173,17 @@ class Operations(util.ModuleClsProxy): @contextmanager def batch_alter_table( - self, table_name, schema=None, recreate="auto", copy_from=None, - table_args=(), table_kwargs=util.immutabledict(), - reflect_args=(), reflect_kwargs=util.immutabledict(), - naming_convention=None): + self, + table_name, + schema=None, + recreate="auto", + copy_from=None, + table_args=(), + table_kwargs=util.immutabledict(), + reflect_args=(), + reflect_kwargs=util.immutabledict(), + naming_convention=None, + ): """Invoke a series of per-table migrations in batch. Batch mode allows a series of operations specific to a table @@ -292,9 +309,17 @@ class Operations(util.ModuleClsProxy): """ impl = batch.BatchOperationsImpl( - self, table_name, schema, recreate, - copy_from, table_args, table_kwargs, reflect_args, - reflect_kwargs, naming_convention) + self, + table_name, + schema, + recreate, + copy_from, + table_args, + table_kwargs, + reflect_args, + reflect_kwargs, + naming_convention, + ) batch_op = BatchOperations(self.migration_context, impl=impl) yield batch_op impl.flush() @@ -315,7 +340,8 @@ class Operations(util.ModuleClsProxy): """ fn = self._to_impl.dispatch( - operation, self.migration_context.impl.__dialect__) + operation, self.migration_context.impl.__dialect__ + ) return fn(self, operation) def f(self, name): @@ -363,7 +389,8 @@ class Operations(util.ModuleClsProxy): return conv(name) else: raise NotImplementedError( - "op.f() feature requires SQLAlchemy 0.9.4 or greater.") + "op.f() feature requires SQLAlchemy 0.9.4 or greater." + ) def inline_literal(self, value, type_=None): """Produce an 'inline literal' expression, suitable for @@ -442,4 +469,5 @@ class BatchOperations(Operations): def _noop(self, operation): raise NotImplementedError( "The %s method does not apply to a batch table alter operation." - % operation) + % operation + ) diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index 79ad533..9362876 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -1,24 +1,48 @@ -from sqlalchemy import Table, MetaData, Index, select, Column, \ - ForeignKeyConstraint, PrimaryKeyConstraint, cast, CheckConstraint +from sqlalchemy import ( + Table, + MetaData, + Index, + select, + Column, + ForeignKeyConstraint, + PrimaryKeyConstraint, + cast, + CheckConstraint, +) from sqlalchemy import types as sqltypes from sqlalchemy import schema as sql_schema from sqlalchemy.util import OrderedDict from .. import util from sqlalchemy.events import SchemaEventTarget -from ..util.sqla_compat import _columns_for_constraint, \ - _is_type_bound, _fk_is_self_referential, _remove_column_from_collection +from ..util.sqla_compat import ( + _columns_for_constraint, + _is_type_bound, + _fk_is_self_referential, + _remove_column_from_collection, +) class BatchOperationsImpl(object): - def __init__(self, operations, table_name, schema, recreate, - copy_from, table_args, table_kwargs, - reflect_args, reflect_kwargs, naming_convention): + def __init__( + self, + operations, + table_name, + schema, + recreate, + copy_from, + table_args, + table_kwargs, + reflect_args, + reflect_kwargs, + naming_convention, + ): self.operations = operations self.table_name = table_name self.schema = schema - if recreate not in ('auto', 'always', 'never'): + if recreate not in ("auto", "always", "never"): raise ValueError( - "recreate may be one of 'auto', 'always', or 'never'.") + "recreate may be one of 'auto', 'always', or 'never'." + ) self.recreate = recreate self.copy_from = copy_from self.table_args = table_args @@ -37,9 +61,9 @@ class BatchOperationsImpl(object): return self.operations.impl def _should_recreate(self): - if self.recreate == 'auto': + if self.recreate == "auto": return self.operations.impl.requires_recreate_in_batch(self) - elif self.recreate == 'always': + elif self.recreate == "always": return True else: return False @@ -62,15 +86,19 @@ class BatchOperationsImpl(object): reflected = False else: existing_table = Table( - self.table_name, m1, + self.table_name, + m1, schema=self.schema, autoload=True, autoload_with=self.operations.get_bind(), - *self.reflect_args, **self.reflect_kwargs) + *self.reflect_args, + **self.reflect_kwargs + ) reflected = True batch_impl = ApplyBatchImpl( - existing_table, self.table_args, self.table_kwargs, reflected) + existing_table, self.table_args, self.table_kwargs, reflected + ) for opname, arg, kw in self.batch: fn = getattr(batch_impl, opname) fn(*arg, **kw) @@ -90,7 +118,7 @@ class BatchOperationsImpl(object): self.batch.append(("add_constraint", (const,), {})) def drop_constraint(self, const): - self.batch.append(("drop_constraint", (const, ), {})) + self.batch.append(("drop_constraint", (const,), {})) def rename_table(self, *arg, **kw): self.batch.append(("rename_table", arg, kw)) @@ -116,7 +144,7 @@ class ApplyBatchImpl(object): self.temp_table_name = self._calc_temp_name(table.name) self.new_table = None self.column_transfers = OrderedDict( - (c.name, {'expr': c}) for c in self.table.c + (c.name, {"expr": c}) for c in self.table.c ) self.reflected = reflected self._grab_table_elements() @@ -165,16 +193,20 @@ class ApplyBatchImpl(object): schema = self.table.schema self.new_table = new_table = Table( - self.temp_table_name, m, + self.temp_table_name, + m, *(list(self.columns.values()) + list(self.table_args)), schema=schema, - **self.table_kwargs) + **self.table_kwargs + ) - for const in list(self.named_constraints.values()) + \ - self.unnamed_constraints: + for const in ( + list(self.named_constraints.values()) + self.unnamed_constraints + ): - const_columns = set([ - c.key for c in _columns_for_constraint(const)]) + const_columns = set( + [c.key for c in _columns_for_constraint(const)] + ) if not const_columns.issubset(self.column_transfers): continue @@ -188,7 +220,8 @@ class ApplyBatchImpl(object): # no foreign keys just keeps the names unchanged, so # when we rename back, they match again. const_copy = const.copy( - schema=schema, target_table=self.table) + schema=schema, target_table=self.table + ) else: # "target_table" for ForeignKeyConstraint.copy() is # only used if the FK is detected as being @@ -209,7 +242,8 @@ class ApplyBatchImpl(object): index.name, unique=index.unique, *[self.new_table.c[col] for col in index.columns.keys()], - **index.kwargs) + **index.kwargs + ) ) return idx @@ -229,16 +263,20 @@ class ApplyBatchImpl(object): for elem in constraint.elements: colname = elem._get_colspec().split(".")[-1] if not t.c.contains_column(colname): - t.append_column( - Column(colname, sqltypes.NULLTYPE) - ) + t.append_column(Column(colname, sqltypes.NULLTYPE)) else: Table( - tname, metadata, - *[Column(n, sqltypes.NULLTYPE) for n in - [elem._get_colspec().split(".")[-1] - for elem in constraint.elements]], - schema=referent_schema) + tname, + metadata, + *[ + Column(n, sqltypes.NULLTYPE) + for n in [ + elem._get_colspec().split(".")[-1] + for elem in constraint.elements + ] + ], + schema=referent_schema + ) def _create(self, op_impl): self._transfer_elements_to_new_table() @@ -249,13 +287,18 @@ class ApplyBatchImpl(object): try: op_impl._exec( self.new_table.insert(inline=True).from_select( - list(k for k, transfer in - self.column_transfers.items() if 'expr' in transfer), - select([ - transfer['expr'] - for transfer in self.column_transfers.values() - if 'expr' in transfer - ]) + list( + k + for k, transfer in self.column_transfers.items() + if "expr" in transfer + ), + select( + [ + transfer["expr"] + for transfer in self.column_transfers.values() + if "expr" in transfer + ] + ), ) ) op_impl.drop_table(self.table) @@ -264,9 +307,7 @@ class ApplyBatchImpl(object): raise else: op_impl.rename_table( - self.temp_table_name, - self.table.name, - schema=self.table.schema + self.temp_table_name, self.table.name, schema=self.table.schema ) self.new_table.name = self.table.name try: @@ -275,14 +316,17 @@ class ApplyBatchImpl(object): finally: self.new_table.name = self.temp_table_name - def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - autoincrement=None, - **kw - ): + def alter_column( + self, + table_name, + column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + autoincrement=None, + **kw + ): existing = self.columns[column_name] existing_transfer = self.column_transfers[column_name] if name is not None and name != column_name: @@ -299,12 +343,14 @@ class ApplyBatchImpl(object): # we also ignore the drop_constraint that will come here from # Operations.implementation_for(alter_column) if isinstance(existing.type, SchemaEventTarget): - existing.type._create_events = \ - existing.type.create_constraint = False + existing.type._create_events = ( + existing.type.create_constraint + ) = False if existing.type._type_affinity is not type_._type_affinity: existing_transfer["expr"] = cast( - existing_transfer["expr"], type_) + existing_transfer["expr"], type_ + ) existing.type = type_ @@ -332,8 +378,7 @@ class ApplyBatchImpl(object): def drop_column(self, table_name, column, **kw): if column.name in self.table.primary_key.columns: _remove_column_from_collection( - self.table.primary_key.columns, - column + self.table.primary_key.columns, column ) del self.columns[column.name] del self.column_transfers[column.name] diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index ade1cb3..5824469 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -46,12 +46,14 @@ class AddConstraintOp(MigrateOperation): def go(klass): cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint) return klass + return go @classmethod def from_constraint(cls, constraint): - return cls.add_constraint_ops.dispatch( - constraint.__visit_name__)(constraint) + return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( + constraint + ) def reverse(self): return DropConstraintOp.from_constraint(self.to_constraint()) @@ -66,9 +68,13 @@ class DropConstraintOp(MigrateOperation): """Represent a drop constraint operation.""" def __init__( - self, - constraint_name, table_name, type_=None, schema=None, - _orig_constraint=None): + self, + constraint_name, + table_name, + type_=None, + schema=None, + _orig_constraint=None, + ): self.constraint_name = constraint_name self.table_name = table_name self.constraint_type = type_ @@ -79,7 +85,8 @@ class DropConstraintOp(MigrateOperation): if self._orig_constraint is None: raise ValueError( "operation is not reversible; " - "original constraint is not present") + "original constraint is not present" + ) return AddConstraintOp.from_constraint(self._orig_constraint) def to_diff_tuple(self): @@ -104,7 +111,7 @@ class DropConstraintOp(MigrateOperation): constraint_table.name, schema=constraint_table.schema, type_=types[constraint.__visit_name__], - _orig_constraint=constraint + _orig_constraint=constraint, ) def to_constraint(self): @@ -113,16 +120,14 @@ class DropConstraintOp(MigrateOperation): else: raise ValueError( "constraint cannot be produced; " - "original constraint is not present") + "original constraint is not present" + ) @classmethod - @util._with_legacy_names([ - ("type", "type_"), - ("name", "constraint_name"), - ]) + @util._with_legacy_names([("type", "type_"), ("name", "constraint_name")]) def drop_constraint( - cls, operations, constraint_name, table_name, - type_=None, schema=None): + cls, operations, constraint_name, table_name, type_=None, schema=None + ): """Drop a constraint of the given name, typically via DROP CONSTRAINT. :param constraint_name: name of the constraint. @@ -166,15 +171,18 @@ class DropConstraintOp(MigrateOperation): """ op = cls( - constraint_name, operations.impl.table_name, - type_=type_, schema=operations.impl.schema + constraint_name, + operations.impl.table_name, + type_=type_, + schema=operations.impl.schema, ) return operations.invoke(op) @Operations.register_operation("create_primary_key") @BatchOperations.register_operation( - "create_primary_key", "batch_create_primary_key") + "create_primary_key", "batch_create_primary_key" +) @AddConstraintOp.register_add_constraint("primary_key_constraint") class CreatePrimaryKeyOp(AddConstraintOp): """Represent a create primary key operation.""" @@ -182,8 +190,14 @@ class CreatePrimaryKeyOp(AddConstraintOp): constraint_type = "primarykey" def __init__( - self, constraint_name, table_name, columns, - schema=None, _orig_constraint=None, **kw): + self, + constraint_name, + table_name, + columns, + schema=None, + _orig_constraint=None, + **kw + ): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -200,7 +214,7 @@ class CreatePrimaryKeyOp(AddConstraintOp): constraint_table.name, constraint.columns, schema=constraint_table.schema, - _orig_constraint=constraint + _orig_constraint=constraint, ) def to_constraint(self, migration_context=None): @@ -209,17 +223,19 @@ class CreatePrimaryKeyOp(AddConstraintOp): schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.primary_key_constraint( - self.constraint_name, self.table_name, - self.columns, schema=self.schema) + self.constraint_name, + self.table_name, + self.columns, + schema=self.schema, + ) @classmethod - @util._with_legacy_names([ - ('name', 'constraint_name'), - ('cols', 'columns') - ]) + @util._with_legacy_names( + [("name", "constraint_name"), ("cols", "columns")] + ) def create_primary_key( - cls, operations, - constraint_name, table_name, columns, schema=None): + cls, operations, constraint_name, table_name, columns, schema=None + ): """Issue a "create primary key" instruction using the current migration context. @@ -282,15 +298,18 @@ class CreatePrimaryKeyOp(AddConstraintOp): """ op = cls( - constraint_name, operations.impl.table_name, columns, - schema=operations.impl.schema + constraint_name, + operations.impl.table_name, + columns, + schema=operations.impl.schema, ) return operations.invoke(op) @Operations.register_operation("create_unique_constraint") @BatchOperations.register_operation( - "create_unique_constraint", "batch_create_unique_constraint") + "create_unique_constraint", "batch_create_unique_constraint" +) @AddConstraintOp.register_add_constraint("unique_constraint") class CreateUniqueConstraintOp(AddConstraintOp): """Represent a create unique constraint operation.""" @@ -298,8 +317,14 @@ class CreateUniqueConstraintOp(AddConstraintOp): constraint_type = "unique" def __init__( - self, constraint_name, table_name, - columns, schema=None, _orig_constraint=None, **kw): + self, + constraint_name, + table_name, + columns, + schema=None, + _orig_constraint=None, + **kw + ): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns @@ -313,9 +338,9 @@ class CreateUniqueConstraintOp(AddConstraintOp): kw = {} if constraint.deferrable: - kw['deferrable'] = constraint.deferrable + kw["deferrable"] = constraint.deferrable if constraint.initially: - kw['initially'] = constraint.initially + kw["initially"] = constraint.initially return cls( constraint.name, @@ -332,18 +357,30 @@ class CreateUniqueConstraintOp(AddConstraintOp): schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.unique_constraint( - self.constraint_name, self.table_name, self.columns, - schema=self.schema, **self.kw) + self.constraint_name, + self.table_name, + self.columns, + schema=self.schema, + **self.kw + ) @classmethod - @util._with_legacy_names([ - ('name', 'constraint_name'), - ('source', 'table_name'), - ('local_cols', 'columns'), - ]) + @util._with_legacy_names( + [ + ("name", "constraint_name"), + ("source", "table_name"), + ("local_cols", "columns"), + ] + ) def create_unique_constraint( - cls, operations, constraint_name, table_name, columns, - schema=None, **kw): + cls, + operations, + constraint_name, + table_name, + columns, + schema=None, + **kw + ): """Issue a "create unique constraint" instruction using the current migration context. @@ -392,16 +429,14 @@ class CreateUniqueConstraintOp(AddConstraintOp): """ - op = cls( - constraint_name, table_name, columns, - schema=schema, **kw - ) + op = cls(constraint_name, table_name, columns, schema=schema, **kw) return operations.invoke(op) @classmethod - @util._with_legacy_names([('name', 'constraint_name')]) + @util._with_legacy_names([("name", "constraint_name")]) def batch_create_unique_constraint( - cls, operations, constraint_name, columns, **kw): + cls, operations, constraint_name, columns, **kw + ): """Issue a "create unique constraint" instruction using the current batch migration context. @@ -418,17 +453,15 @@ class CreateUniqueConstraintOp(AddConstraintOp): * name -> constraint_name """ - kw['schema'] = operations.impl.schema - op = cls( - constraint_name, operations.impl.table_name, columns, - **kw - ) + kw["schema"] = operations.impl.schema + op = cls(constraint_name, operations.impl.table_name, columns, **kw) return operations.invoke(op) @Operations.register_operation("create_foreign_key") @BatchOperations.register_operation( - "create_foreign_key", "batch_create_foreign_key") + "create_foreign_key", "batch_create_foreign_key" +) @AddConstraintOp.register_add_constraint("foreign_key_constraint") class CreateForeignKeyOp(AddConstraintOp): """Represent a create foreign key constraint operation.""" @@ -436,8 +469,15 @@ class CreateForeignKeyOp(AddConstraintOp): constraint_type = "foreignkey" def __init__( - self, constraint_name, source_table, referent_table, local_cols, - remote_cols, _orig_constraint=None, **kw): + self, + constraint_name, + source_table, + referent_table, + local_cols, + remote_cols, + _orig_constraint=None, + **kw + ): self.constraint_name = constraint_name self.source_table = source_table self.referent_table = referent_table @@ -453,24 +493,22 @@ class CreateForeignKeyOp(AddConstraintOp): def from_constraint(cls, constraint): kw = {} if constraint.onupdate: - kw['onupdate'] = constraint.onupdate + kw["onupdate"] = constraint.onupdate if constraint.ondelete: - kw['ondelete'] = constraint.ondelete + kw["ondelete"] = constraint.ondelete if constraint.initially: - kw['initially'] = constraint.initially + kw["initially"] = constraint.initially if constraint.deferrable: - kw['deferrable'] = constraint.deferrable + kw["deferrable"] = constraint.deferrable if constraint.use_alter: - kw['use_alter'] = constraint.use_alter + kw["use_alter"] = constraint.use_alter - source_schema, source_table, \ - source_columns, target_schema, \ - target_table, target_columns,\ - onupdate, ondelete, deferrable, initially \ - = sqla_compat._fk_spec(constraint) + source_schema, source_table, source_columns, target_schema, target_table, target_columns, onupdate, ondelete, deferrable, initially = sqla_compat._fk_spec( + constraint + ) - kw['source_schema'] = source_schema - kw['referent_schema'] = target_schema + kw["source_schema"] = source_schema + kw["referent_schema"] = target_schema return cls( constraint.name, @@ -488,22 +526,38 @@ class CreateForeignKeyOp(AddConstraintOp): schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.foreign_key_constraint( self.constraint_name, - self.source_table, self.referent_table, - self.local_cols, self.remote_cols, - **self.kw) + self.source_table, + self.referent_table, + self.local_cols, + self.remote_cols, + **self.kw + ) @classmethod - @util._with_legacy_names([ - ('name', 'constraint_name'), - ('source', 'source_table'), - ('referent', 'referent_table'), - ]) - def create_foreign_key(cls, operations, constraint_name, - source_table, referent_table, local_cols, - remote_cols, onupdate=None, ondelete=None, - deferrable=None, initially=None, match=None, - source_schema=None, referent_schema=None, - **dialect_kw): + @util._with_legacy_names( + [ + ("name", "constraint_name"), + ("source", "source_table"), + ("referent", "referent_table"), + ] + ) + def create_foreign_key( + cls, + operations, + constraint_name, + source_table, + referent_table, + local_cols, + remote_cols, + onupdate=None, + ondelete=None, + deferrable=None, + initially=None, + match=None, + source_schema=None, + referent_schema=None, + **dialect_kw + ): """Issue a "create foreign key" instruction using the current migration context. @@ -558,29 +612,40 @@ class CreateForeignKeyOp(AddConstraintOp): op = cls( constraint_name, - source_table, referent_table, - local_cols, remote_cols, - onupdate=onupdate, ondelete=ondelete, + source_table, + referent_table, + local_cols, + remote_cols, + onupdate=onupdate, + ondelete=ondelete, deferrable=deferrable, source_schema=source_schema, referent_schema=referent_schema, - initially=initially, match=match, + initially=initially, + match=match, **dialect_kw ) return operations.invoke(op) @classmethod - @util._with_legacy_names([ - ('name', 'constraint_name'), - ('referent', 'referent_table') - ]) + @util._with_legacy_names( + [("name", "constraint_name"), ("referent", "referent_table")] + ) def batch_create_foreign_key( - cls, operations, constraint_name, referent_table, - local_cols, remote_cols, - referent_schema=None, - onupdate=None, ondelete=None, - deferrable=None, initially=None, match=None, - **dialect_kw): + cls, + operations, + constraint_name, + referent_table, + local_cols, + remote_cols, + referent_schema=None, + onupdate=None, + ondelete=None, + deferrable=None, + initially=None, + match=None, + **dialect_kw + ): """Issue a "create foreign key" instruction using the current batch migration context. @@ -607,13 +672,17 @@ class CreateForeignKeyOp(AddConstraintOp): """ op = cls( constraint_name, - operations.impl.table_name, referent_table, - local_cols, remote_cols, - onupdate=onupdate, ondelete=ondelete, + operations.impl.table_name, + referent_table, + local_cols, + remote_cols, + onupdate=onupdate, + ondelete=ondelete, deferrable=deferrable, source_schema=operations.impl.schema, referent_schema=referent_schema, - initially=initially, match=match, + initially=initially, + match=match, **dialect_kw ) return operations.invoke(op) @@ -621,7 +690,8 @@ class CreateForeignKeyOp(AddConstraintOp): @Operations.register_operation("create_check_constraint") @BatchOperations.register_operation( - "create_check_constraint", "batch_create_check_constraint") + "create_check_constraint", "batch_create_check_constraint" +) @AddConstraintOp.register_add_constraint("check_constraint") @AddConstraintOp.register_add_constraint("column_check_constraint") class CreateCheckConstraintOp(AddConstraintOp): @@ -630,8 +700,14 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint_type = "check" def __init__( - self, constraint_name, table_name, - condition, schema=None, _orig_constraint=None, **kw): + self, + constraint_name, + table_name, + condition, + schema=None, + _orig_constraint=None, + **kw + ): self.constraint_name = constraint_name self.table_name = table_name self.condition = condition @@ -648,7 +724,7 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint_table.name, constraint.sqltext, schema=constraint_table.schema, - _orig_constraint=constraint + _orig_constraint=constraint, ) def to_constraint(self, migration_context=None): @@ -656,18 +732,26 @@ class CreateCheckConstraintOp(AddConstraintOp): return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.check_constraint( - self.constraint_name, self.table_name, - self.condition, schema=self.schema, **self.kw) + self.constraint_name, + self.table_name, + self.condition, + schema=self.schema, + **self.kw + ) @classmethod - @util._with_legacy_names([ - ('name', 'constraint_name'), - ('source', 'table_name') - ]) + @util._with_legacy_names( + [("name", "constraint_name"), ("source", "table_name")] + ) def create_check_constraint( - cls, operations, - constraint_name, table_name, condition, - schema=None, **kw): + cls, + operations, + constraint_name, + table_name, + condition, + schema=None, + **kw + ): """Issue a "create check constraint" instruction using the current migration context. @@ -721,9 +805,10 @@ class CreateCheckConstraintOp(AddConstraintOp): return operations.invoke(op) @classmethod - @util._with_legacy_names([('name', 'constraint_name')]) + @util._with_legacy_names([("name", "constraint_name")]) def batch_create_check_constraint( - cls, operations, constraint_name, condition, **kw): + cls, operations, constraint_name, condition, **kw + ): """Issue a "create check constraint" instruction using the current batch migration context. @@ -741,8 +826,12 @@ class CreateCheckConstraintOp(AddConstraintOp): """ op = cls( - constraint_name, operations.impl.table_name, - condition, schema=operations.impl.schema, **kw) + constraint_name, + operations.impl.table_name, + condition, + schema=operations.impl.schema, + **kw + ) return operations.invoke(op) @@ -752,8 +841,15 @@ class CreateIndexOp(MigrateOperation): """Represent a create index operation.""" def __init__( - self, index_name, table_name, columns, schema=None, - unique=False, _orig_index=None, **kw): + self, + index_name, + table_name, + columns, + schema=None, + unique=False, + _orig_index=None, + **kw + ): self.index_name = index_name self.table_name = table_name self.columns = columns @@ -785,15 +881,26 @@ class CreateIndexOp(MigrateOperation): return self._orig_index schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.index( - self.index_name, self.table_name, self.columns, schema=self.schema, - unique=self.unique, **self.kw) + self.index_name, + self.table_name, + self.columns, + schema=self.schema, + unique=self.unique, + **self.kw + ) @classmethod - @util._with_legacy_names([('name', 'index_name')]) + @util._with_legacy_names([("name", "index_name")]) def create_index( - cls, operations, - index_name, table_name, columns, schema=None, - unique=False, **kw): + cls, + operations, + index_name, + table_name, + columns, + schema=None, + unique=False, + **kw + ): r"""Issue a "create index" instruction using the current migration context. @@ -851,8 +958,7 @@ class CreateIndexOp(MigrateOperation): """ op = cls( - index_name, table_name, columns, schema=schema, - unique=unique, **kw + index_name, table_name, columns, schema=schema, unique=unique, **kw ) return operations.invoke(op) @@ -868,8 +974,11 @@ class CreateIndexOp(MigrateOperation): """ op = cls( - index_name, operations.impl.table_name, columns, - schema=operations.impl.schema, **kw + index_name, + operations.impl.table_name, + columns, + schema=operations.impl.schema, + **kw ) return operations.invoke(op) @@ -880,8 +989,8 @@ class DropIndexOp(MigrateOperation): """Represent a drop index operation.""" def __init__( - self, index_name, table_name=None, - schema=None, _orig_index=None, **kw): + self, index_name, table_name=None, schema=None, _orig_index=None, **kw + ): self.index_name = index_name self.table_name = table_name self.schema = schema @@ -894,8 +1003,8 @@ class DropIndexOp(MigrateOperation): def reverse(self): if self._orig_index is None: raise ValueError( - "operation is not reversible; " - "original index is not present") + "operation is not reversible; " "original index is not present" + ) return CreateIndexOp.from_index(self._orig_index) @classmethod @@ -917,16 +1026,20 @@ class DropIndexOp(MigrateOperation): # need a dummy column name here since SQLAlchemy # 0.7.6 and further raises on Index with no columns return schema_obj.index( - self.index_name, self.table_name, ['x'], - schema=self.schema, **self.kw) + self.index_name, + self.table_name, + ["x"], + schema=self.schema, + **self.kw + ) @classmethod - @util._with_legacy_names([ - ('name', 'index_name'), - ('tablename', 'table_name') - ]) - def drop_index(cls, operations, index_name, - table_name=None, schema=None, **kw): + @util._with_legacy_names( + [("name", "index_name"), ("tablename", "table_name")] + ) + def drop_index( + cls, operations, index_name, table_name=None, schema=None, **kw + ): r"""Issue a "drop index" instruction using the current migration context. @@ -964,7 +1077,7 @@ class DropIndexOp(MigrateOperation): return operations.invoke(op) @classmethod - @util._with_legacy_names([('name', 'index_name')]) + @util._with_legacy_names([("name", "index_name")]) def batch_drop_index(cls, operations, index_name, **kw): """Issue a "drop index" instruction using the current batch migration context. @@ -981,8 +1094,10 @@ class DropIndexOp(MigrateOperation): """ op = cls( - index_name, table_name=operations.impl.table_name, - schema=operations.impl.schema, **kw + index_name, + table_name=operations.impl.table_name, + schema=operations.impl.schema, + **kw ) return operations.invoke(op) @@ -992,7 +1107,8 @@ class CreateTableOp(MigrateOperation): """Represent a create table operation.""" def __init__( - self, table_name, columns, schema=None, _orig_table=None, **kw): + self, table_name, columns, schema=None, _orig_table=None, **kw + ): self.table_name = table_name self.columns = columns self.schema = schema @@ -1025,7 +1141,7 @@ class CreateTableOp(MigrateOperation): ) @classmethod - @util._with_legacy_names([('name', 'table_name')]) + @util._with_legacy_names([("name", "table_name")]) def create_table(cls, operations, table_name, *columns, **kw): r"""Issue a "create table" instruction using the current migration context. @@ -1125,7 +1241,8 @@ class DropTableOp(MigrateOperation): """Represent a drop table operation.""" def __init__( - self, table_name, schema=None, table_kw=None, _orig_table=None): + self, table_name, schema=None, table_kw=None, _orig_table=None + ): self.table_name = table_name self.schema = schema self.table_kw = table_kw or {} @@ -1137,8 +1254,8 @@ class DropTableOp(MigrateOperation): def reverse(self): if self._orig_table is None: raise ValueError( - "operation is not reversible; " - "original table is not present") + "operation is not reversible; " "original table is not present" + ) return CreateTableOp.from_table(self._orig_table) @classmethod @@ -1150,12 +1267,11 @@ class DropTableOp(MigrateOperation): return self._orig_table schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( - self.table_name, - schema=self.schema, - **self.table_kw) + self.table_name, schema=self.schema, **self.table_kw + ) @classmethod - @util._with_legacy_names([('name', 'table_name')]) + @util._with_legacy_names([("name", "table_name")]) def drop_table(cls, operations, table_name, schema=None, **kw): r"""Issue a "drop table" instruction using the current migration context. @@ -1205,7 +1321,8 @@ class RenameTableOp(AlterTableOp): @classmethod def rename_table( - cls, operations, old_table_name, new_table_name, schema=None): + cls, operations, old_table_name, new_table_name, schema=None + ): """Emit an ALTER TABLE to rename a table. :param old_table_name: old name. @@ -1229,16 +1346,18 @@ class AlterColumnOp(AlterTableOp): """Represent an alter column operation.""" def __init__( - self, table_name, column_name, schema=None, - existing_type=None, - existing_server_default=False, - existing_nullable=None, - modify_nullable=None, - modify_server_default=False, - modify_name=None, - modify_type=None, - **kw - + self, + table_name, + column_name, + schema=None, + existing_type=None, + existing_server_default=False, + existing_nullable=None, + modify_nullable=None, + modify_server_default=False, + modify_name=None, + modify_type=None, + **kw ): super(AlterColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name @@ -1257,47 +1376,64 @@ class AlterColumnOp(AlterTableOp): if self.modify_type is not None: col_diff.append( - ("modify_type", schema, tname, cname, - { - "existing_nullable": self.existing_nullable, - "existing_server_default": self.existing_server_default, - }, - self.existing_type, - self.modify_type) + ( + "modify_type", + schema, + tname, + cname, + { + "existing_nullable": self.existing_nullable, + "existing_server_default": self.existing_server_default, + }, + self.existing_type, + self.modify_type, + ) ) if self.modify_nullable is not None: col_diff.append( - ("modify_nullable", schema, tname, cname, + ( + "modify_nullable", + schema, + tname, + cname, { "existing_type": self.existing_type, - "existing_server_default": self.existing_server_default + "existing_server_default": self.existing_server_default, }, self.existing_nullable, - self.modify_nullable) + self.modify_nullable, + ) ) if self.modify_server_default is not False: col_diff.append( - ("modify_default", schema, tname, cname, - { - "existing_nullable": self.existing_nullable, - "existing_type": self.existing_type - }, - self.existing_server_default, - self.modify_server_default) + ( + "modify_default", + schema, + tname, + cname, + { + "existing_nullable": self.existing_nullable, + "existing_type": self.existing_type, + }, + self.existing_server_default, + self.modify_server_default, + ) ) return col_diff def has_changes(self): - hc1 = self.modify_nullable is not None or \ - self.modify_server_default is not False or \ - self.modify_type is not None + hc1 = ( + self.modify_nullable is not None + or self.modify_server_default is not False + or self.modify_type is not None + ) if hc1: return True for kw in self.kw: - if kw.startswith('modify_'): + if kw.startswith("modify_"): return True else: return False @@ -1305,37 +1441,40 @@ class AlterColumnOp(AlterTableOp): def reverse(self): kw = self.kw.copy() - kw['existing_type'] = self.existing_type - kw['existing_nullable'] = self.existing_nullable - kw['existing_server_default'] = self.existing_server_default + kw["existing_type"] = self.existing_type + kw["existing_nullable"] = self.existing_nullable + kw["existing_server_default"] = self.existing_server_default if self.modify_type is not None: - kw['modify_type'] = self.modify_type + kw["modify_type"] = self.modify_type if self.modify_nullable is not None: - kw['modify_nullable'] = self.modify_nullable + kw["modify_nullable"] = self.modify_nullable if self.modify_server_default is not False: - kw['modify_server_default'] = self.modify_server_default + kw["modify_server_default"] = self.modify_server_default # TODO: make this a little simpler - all_keys = set(m.group(1) for m in [ - re.match(r'^(?:existing_|modify_)(.+)$', k) - for k in kw - ] if m) + all_keys = set( + m.group(1) + for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw] + if m + ) for k in all_keys: - if 'modify_%s' % k in kw: - swap = kw['existing_%s' % k] - kw['existing_%s' % k] = kw['modify_%s' % k] - kw['modify_%s' % k] = swap + if "modify_%s" % k in kw: + swap = kw["existing_%s" % k] + kw["existing_%s" % k] = kw["modify_%s" % k] + kw["modify_%s" % k] = swap return self.__class__( - self.table_name, self.column_name, schema=self.schema, - **kw + self.table_name, self.column_name, schema=self.schema, **kw ) @classmethod - @util._with_legacy_names([('name', 'new_column_name')]) + @util._with_legacy_names([("name", "new_column_name")]) def alter_column( - cls, operations, table_name, column_name, + cls, + operations, + table_name, + column_name, nullable=None, server_default=False, new_column_name=None, @@ -1343,7 +1482,8 @@ class AlterColumnOp(AlterTableOp): existing_type=None, existing_server_default=False, existing_nullable=None, - schema=None, **kw + schema=None, + **kw ): """Issue an "alter column" instruction using the current migration context. @@ -1430,7 +1570,9 @@ class AlterColumnOp(AlterTableOp): """ alt = cls( - table_name, column_name, schema=schema, + table_name, + column_name, + schema=schema, existing_type=existing_type, existing_server_default=existing_server_default, existing_nullable=existing_nullable, @@ -1445,7 +1587,9 @@ class AlterColumnOp(AlterTableOp): @classmethod def batch_alter_column( - cls, operations, column_name, + cls, + operations, + column_name, nullable=None, server_default=False, new_column_name=None, @@ -1464,7 +1608,8 @@ class AlterColumnOp(AlterTableOp): """ alt = cls( - operations.impl.table_name, column_name, + operations.impl.table_name, + column_name, schema=operations.impl.schema, existing_type=existing_type, existing_server_default=existing_server_default, @@ -1490,7 +1635,8 @@ class AddColumnOp(AlterTableOp): def reverse(self): return DropColumnOp.from_column_and_tablename( - self.schema, self.table_name, self.column) + self.schema, self.table_name, self.column + ) def to_diff_tuple(self): return ("add_column", self.schema, self.table_name, self.column) @@ -1575,8 +1721,7 @@ class AddColumnOp(AlterTableOp): """ op = cls( - operations.impl.table_name, column, - schema=operations.impl.schema + operations.impl.table_name, column, schema=operations.impl.schema ) return operations.invoke(op) @@ -1587,8 +1732,8 @@ class DropColumnOp(AlterTableOp): """Represent a drop column operation.""" def __init__( - self, table_name, column_name, schema=None, - _orig_column=None, **kw): + self, table_name, column_name, schema=None, _orig_column=None, **kw + ): super(DropColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw @@ -1596,16 +1741,22 @@ class DropColumnOp(AlterTableOp): def to_diff_tuple(self): return ( - "remove_column", self.schema, self.table_name, self.to_column()) + "remove_column", + self.schema, + self.table_name, + self.to_column(), + ) def reverse(self): if self._orig_column is None: raise ValueError( "operation is not reversible; " - "original column is not present") + "original column is not present" + ) return AddColumnOp.from_column_and_tablename( - self.schema, self.table_name, self._orig_column) + self.schema, self.table_name, self._orig_column + ) @classmethod def from_column_and_tablename(cls, schema, tname, col): @@ -1619,7 +1770,8 @@ class DropColumnOp(AlterTableOp): @classmethod def drop_column( - cls, operations, table_name, column_name, schema=None, **kw): + cls, operations, table_name, column_name, schema=None, **kw + ): """Issue a "drop column" instruction using the current migration context. @@ -1677,8 +1829,11 @@ class DropColumnOp(AlterTableOp): """ op = cls( - operations.impl.table_name, column_name, - schema=operations.impl.schema, **kw) + operations.impl.table_name, + column_name, + schema=operations.impl.schema, + **kw + ) return operations.invoke(op) @@ -1877,6 +2032,7 @@ class ExecuteSQLOp(MigrateOperation): class OpContainer(MigrateOperation): """Represent a sequence of operations operation.""" + def __init__(self, ops=()): self.ops = ops @@ -1889,7 +2045,7 @@ class OpContainer(MigrateOperation): @classmethod def _ops_as_diffs(cls, migrations): for op in migrations.ops: - if hasattr(op, 'ops'): + if hasattr(op, "ops"): for sub_op in cls._ops_as_diffs(op): yield sub_op else: @@ -1907,10 +2063,8 @@ class ModifyTableOps(OpContainer): def reverse(self): return ModifyTableOps( self.table_name, - ops=list(reversed( - [op.reverse() for op in self.ops] - )), - schema=self.schema + ops=list(reversed([op.reverse() for op in self.ops])), + schema=self.schema, ) @@ -1929,9 +2083,9 @@ class UpgradeOps(OpContainer): self.upgrade_token = upgrade_token def reverse_into(self, downgrade_ops): - downgrade_ops.ops[:] = list(reversed( - [op.reverse() for op in self.ops] - )) + downgrade_ops.ops[:] = list( + reversed([op.reverse() for op in self.ops]) + ) return downgrade_ops def reverse(self): @@ -1954,9 +2108,7 @@ class DowngradeOps(OpContainer): def reverse(self): return UpgradeOps( - ops=list(reversed( - [op.reverse() for op in self.ops] - )) + ops=list(reversed([op.reverse() for op in self.ops])) ) @@ -1990,10 +2142,18 @@ class MigrationScript(MigrateOperation): """ def __init__( - self, rev_id, upgrade_ops, downgrade_ops, - message=None, - imports=set(), head=None, splice=None, - branch_label=None, version_path=None, depends_on=None): + self, + rev_id, + upgrade_ops, + downgrade_ops, + message=None, + imports=set(), + head=None, + splice=None, + branch_label=None, + version_path=None, + depends_on=None, + ): self.rev_id = rev_id self.message = message self.imports = imports @@ -2017,7 +2177,8 @@ class MigrationScript(MigrateOperation): raise ValueError( "This MigrationScript instance has a multiple-entry " "list for UpgradeOps; please use the " - "upgrade_ops_list attribute.") + "upgrade_ops_list attribute." + ) elif not self._upgrade_ops: return None else: @@ -2041,7 +2202,8 @@ class MigrationScript(MigrateOperation): raise ValueError( "This MigrationScript instance has a multiple-entry " "list for DowngradeOps; please use the " - "downgrade_ops_list attribute.") + "downgrade_ops_list attribute." + ) elif not self._downgrade_ops: return None else: @@ -2078,4 +2240,3 @@ class MigrationScript(MigrateOperation): """ return self._downgrade_ops - diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index 1014ace..548b6c5 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -5,69 +5,82 @@ from .. import util class SchemaObjects(object): - def __init__(self, migration_context=None): self.migration_context = migration_context def primary_key_constraint(self, name, table_name, cols, schema=None): m = self.metadata() columns = [sa_schema.Column(n, NULLTYPE) for n in cols] - t = sa_schema.Table( - table_name, m, - *columns, - schema=schema) - p = sa_schema.PrimaryKeyConstraint( - *[t.c[n] for n in cols], name=name) + t = sa_schema.Table(table_name, m, *columns, schema=schema) + p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name) t.append_constraint(p) return p def foreign_key_constraint( - self, name, source, referent, - local_cols, remote_cols, - onupdate=None, ondelete=None, - deferrable=None, source_schema=None, - referent_schema=None, initially=None, - match=None, **dialect_kw): + self, + name, + source, + referent, + local_cols, + remote_cols, + onupdate=None, + ondelete=None, + deferrable=None, + source_schema=None, + referent_schema=None, + initially=None, + match=None, + **dialect_kw + ): m = self.metadata() if source == referent and source_schema == referent_schema: t1_cols = local_cols + remote_cols else: t1_cols = local_cols sa_schema.Table( - referent, m, + referent, + m, *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], - schema=referent_schema) + schema=referent_schema + ) t1 = sa_schema.Table( - source, m, + source, + m, *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], - schema=source_schema) - - tname = "%s.%s" % (referent_schema, referent) if referent_schema \ - else referent - - dialect_kw['match'] = match - - f = sa_schema.ForeignKeyConstraint(local_cols, - ["%s.%s" % (tname, n) - for n in remote_cols], - name=name, - onupdate=onupdate, - ondelete=ondelete, - deferrable=deferrable, - initially=initially, - **dialect_kw - ) + schema=source_schema + ) + + tname = ( + "%s.%s" % (referent_schema, referent) + if referent_schema + else referent + ) + + dialect_kw["match"] = match + + f = sa_schema.ForeignKeyConstraint( + local_cols, + ["%s.%s" % (tname, n) for n in remote_cols], + name=name, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + initially=initially, + **dialect_kw + ) t1.append_constraint(f) return f def unique_constraint(self, name, source, local_cols, schema=None, **kw): t = sa_schema.Table( - source, self.metadata(), + source, + self.metadata(), *[sa_schema.Column(n, NULLTYPE) for n in local_cols], - schema=schema) - kw['name'] = name + schema=schema + ) + kw["name"] = name uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) # TODO: need event tests to ensure the event # is fired off here @@ -75,8 +88,12 @@ class SchemaObjects(object): return uq def check_constraint(self, name, source, condition, schema=None, **kw): - t = sa_schema.Table(source, self.metadata(), - sa_schema.Column('x', Integer), schema=schema) + t = sa_schema.Table( + source, + self.metadata(), + sa_schema.Column("x", Integer), + schema=schema, + ) ck = sa_schema.CheckConstraint(condition, name=name, **kw) t.append_constraint(ck) return ck @@ -84,18 +101,21 @@ class SchemaObjects(object): def generic_constraint(self, name, table_name, type_, schema=None, **kw): t = self.table(table_name, schema=schema) types = { - 'foreignkey': lambda name: sa_schema.ForeignKeyConstraint( - [], [], name=name), - 'primary': sa_schema.PrimaryKeyConstraint, - 'unique': sa_schema.UniqueConstraint, - 'check': lambda name: sa_schema.CheckConstraint("", name=name), - None: sa_schema.Constraint + "foreignkey": lambda name: sa_schema.ForeignKeyConstraint( + [], [], name=name + ), + "primary": sa_schema.PrimaryKeyConstraint, + "unique": sa_schema.UniqueConstraint, + "check": lambda name: sa_schema.CheckConstraint("", name=name), + None: sa_schema.Constraint, } try: const = types[type_] except KeyError: - raise TypeError("'type' can be one of %s" % - ", ".join(sorted(repr(x) for x in types))) + raise TypeError( + "'type' can be one of %s" + % ", ".join(sorted(repr(x) for x in types)) + ) else: const = const(name=name) t.append_constraint(const) @@ -103,11 +123,13 @@ class SchemaObjects(object): def metadata(self): kw = {} - if self.migration_context is not None and \ - 'target_metadata' in self.migration_context.opts: - mt = self.migration_context.opts['target_metadata'] - if hasattr(mt, 'naming_convention'): - kw['naming_convention'] = mt.naming_convention + if ( + self.migration_context is not None + and "target_metadata" in self.migration_context.opts + ): + mt = self.migration_context.opts["target_metadata"] + if hasattr(mt, "naming_convention"): + kw["naming_convention"] = mt.naming_convention return sa_schema.MetaData(**kw) def table(self, name, *columns, **kw): @@ -122,18 +144,18 @@ class SchemaObjects(object): def index(self, name, tablename, columns, schema=None, **kw): t = sa_schema.Table( - tablename or 'no_table', self.metadata(), - schema=schema + tablename or "no_table", self.metadata(), schema=schema ) idx = sa_schema.Index( name, *[util.sqla_compat._textual_index_column(t, n) for n in columns], - **kw) + **kw + ) return idx def _parse_table_key(self, table_key): - if '.' in table_key: - tokens = table_key.split('.') + if "." in table_key: + tokens = table_key.split(".") sname = ".".join(tokens[0:-1]) tname = tokens[-1] else: @@ -147,7 +169,7 @@ class SchemaObjects(object): """ if isinstance(fk._colspec, string_types): - table_key, cname = fk._colspec.rsplit('.', 1) + table_key, cname = fk._colspec.rsplit(".", 1) sname, tname = self._parse_table_key(table_key) if table_key not in metadata.tables: rel_t = sa_schema.Table(tname, metadata, schema=sname) diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py index 1327367..1635a42 100644 --- a/alembic/operations/toimpl.py +++ b/alembic/operations/toimpl.py @@ -8,8 +8,7 @@ from sqlalchemy import schema as sa_schema def alter_column(operations, operation): compiler = operations.impl.dialect.statement_compiler( - operations.impl.dialect, - None + operations.impl.dialect, None ) existing_type = operation.existing_type @@ -24,24 +23,23 @@ def alter_column(operations, operation): nullable = operation.modify_nullable def _count_constraint(constraint): - return not isinstance( - constraint, - sa_schema.PrimaryKeyConstraint) and \ - (not constraint._create_rule or - constraint._create_rule(compiler)) + return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and ( + not constraint._create_rule or constraint._create_rule(compiler) + ) if existing_type and type_: t = operations.schema_obj.table( table_name, sa_schema.Column(column_name, existing_type), - schema=schema + schema=schema, ) for constraint in t.constraints: if _count_constraint(constraint): operations.impl.drop_constraint(constraint) operations.impl.alter_column( - table_name, column_name, + table_name, + column_name, nullable=nullable, server_default=server_default, name=new_column_name, @@ -57,7 +55,7 @@ def alter_column(operations, operation): t = operations.schema_obj.table( table_name, operations.schema_obj.column(column_name, type_), - schema=schema + schema=schema, ) for constraint in t.constraints: if _count_constraint(constraint): @@ -75,10 +73,7 @@ def drop_table(operations, operation): def drop_column(operations, operation): column = operation.to_column(operations.migration_context) operations.impl.drop_column( - operation.table_name, - column, - schema=operation.schema, - **operation.kw + operation.table_name, column, schema=operation.schema, **operation.kw ) @@ -105,9 +100,8 @@ def create_table(operations, operation): @Operations.implementation_for(ops.RenameTableOp) def rename_table(operations, operation): operations.impl.rename_table( - operation.table_name, - operation.new_table_name, - schema=operation.schema) + operation.table_name, operation.new_table_name, schema=operation.schema + ) @Operations.implementation_for(ops.AddColumnOp) @@ -117,11 +111,7 @@ def add_column(operations, operation): schema = operation.schema t = operations.schema_obj.table(table_name, column, schema=schema) - operations.impl.add_column( - table_name, - column, - schema=schema - ) + operations.impl.add_column(table_name, column, schema=schema) for constraint in t.constraints: if not isinstance(constraint, sa_schema.PrimaryKeyConstraint): operations.impl.add_constraint(constraint) @@ -151,12 +141,12 @@ def drop_constraint(operations, operation): @Operations.implementation_for(ops.BulkInsertOp) def bulk_insert(operations, operation): operations.impl.bulk_insert( - operation.table, operation.rows, multiinsert=operation.multiinsert) + operation.table, operation.rows, multiinsert=operation.multiinsert + ) @Operations.implementation_for(ops.ExecuteSQLOp) def execute_sql(operations, operation): operations.migration_context.impl.execute( - operation.sqltext, - execution_options=operation.execution_options + operation.sqltext, execution_options=operation.execution_options ) diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index ce9be63..32db3ae 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -120,7 +120,7 @@ class EnvironmentContext(util.ModuleClsProxy): has been configured. """ - return self.context_opts.get('as_sql', False) + return self.context_opts.get("as_sql", False) def is_transactional_ddl(self): """Return True if the context is configured to expect a @@ -182,17 +182,20 @@ class EnvironmentContext(util.ModuleClsProxy): """ if self._migration_context is not None: return self.script.as_revision_number( - self.get_context()._start_from_rev) - elif 'starting_rev' in self.context_opts: + self.get_context()._start_from_rev + ) + elif "starting_rev" in self.context_opts: return self.script.as_revision_number( - self.context_opts['starting_rev']) + self.context_opts["starting_rev"] + ) else: # this should raise only in the case that a command # is being run where the "starting rev" is never applicable; # this is to catch scripts which rely upon this in # non-sql mode or similar raise util.CommandError( - "No starting revision argument is available.") + "No starting revision argument is available." + ) def get_revision_argument(self): """Get the 'destination' revision argument. @@ -209,7 +212,8 @@ class EnvironmentContext(util.ModuleClsProxy): """ return self.script.as_revision_number( - self.context_opts['destination_rev']) + self.context_opts["destination_rev"] + ) def get_tag_argument(self): """Return the value passed for the ``--tag`` argument, if any. @@ -229,7 +233,7 @@ class EnvironmentContext(util.ModuleClsProxy): line. """ - return self.context_opts.get('tag', None) + return self.context_opts.get("tag", None) def get_x_argument(self, as_dictionary=False): """Return the value(s) passed for the ``-x`` argument, if any. @@ -277,39 +281,38 @@ class EnvironmentContext(util.ModuleClsProxy): else: value = [] if as_dictionary: - value = dict( - arg.split('=', 1) for arg in value - ) + value = dict(arg.split("=", 1) for arg in value) return value - def configure(self, - connection=None, - url=None, - dialect_name=None, - transactional_ddl=None, - transaction_per_migration=False, - output_buffer=None, - starting_rev=None, - tag=None, - template_args=None, - render_as_batch=False, - target_metadata=None, - include_symbol=None, - include_object=None, - include_schemas=False, - process_revision_directives=None, - compare_type=False, - compare_server_default=False, - render_item=None, - literal_binds=False, - upgrade_token="upgrades", - downgrade_token="downgrades", - alembic_module_prefix="op.", - sqlalchemy_module_prefix="sa.", - user_module_prefix=None, - on_version_apply=None, - **kw - ): + def configure( + self, + connection=None, + url=None, + dialect_name=None, + transactional_ddl=None, + transaction_per_migration=False, + output_buffer=None, + starting_rev=None, + tag=None, + template_args=None, + render_as_batch=False, + target_metadata=None, + include_symbol=None, + include_object=None, + include_schemas=False, + process_revision_directives=None, + compare_type=False, + compare_server_default=False, + render_item=None, + literal_binds=False, + upgrade_token="upgrades", + downgrade_token="downgrades", + alembic_module_prefix="op.", + sqlalchemy_module_prefix="sa.", + user_module_prefix=None, + on_version_apply=None, + **kw + ): """Configure a :class:`.MigrationContext` within this :class:`.EnvironmentContext` which will provide database connectivity and other configuration to a series of @@ -774,33 +777,33 @@ class EnvironmentContext(util.ModuleClsProxy): elif self.config.output_buffer is not None: opts["output_buffer"] = self.config.output_buffer if starting_rev: - opts['starting_rev'] = starting_rev + opts["starting_rev"] = starting_rev if tag: - opts['tag'] = tag - if template_args and 'template_args' in opts: - opts['template_args'].update(template_args) + opts["tag"] = tag + if template_args and "template_args" in opts: + opts["template_args"].update(template_args) opts["transaction_per_migration"] = transaction_per_migration - opts['target_metadata'] = target_metadata - opts['include_symbol'] = include_symbol - opts['include_object'] = include_object - opts['include_schemas'] = include_schemas - opts['render_as_batch'] = render_as_batch - opts['upgrade_token'] = upgrade_token - opts['downgrade_token'] = downgrade_token - opts['sqlalchemy_module_prefix'] = sqlalchemy_module_prefix - opts['alembic_module_prefix'] = alembic_module_prefix - opts['user_module_prefix'] = user_module_prefix - opts['literal_binds'] = literal_binds - opts['process_revision_directives'] = process_revision_directives - opts['on_version_apply'] = util.to_tuple(on_version_apply, default=()) + opts["target_metadata"] = target_metadata + opts["include_symbol"] = include_symbol + opts["include_object"] = include_object + opts["include_schemas"] = include_schemas + opts["render_as_batch"] = render_as_batch + opts["upgrade_token"] = upgrade_token + opts["downgrade_token"] = downgrade_token + opts["sqlalchemy_module_prefix"] = sqlalchemy_module_prefix + opts["alembic_module_prefix"] = alembic_module_prefix + opts["user_module_prefix"] = user_module_prefix + opts["literal_binds"] = literal_binds + opts["process_revision_directives"] = process_revision_directives + opts["on_version_apply"] = util.to_tuple(on_version_apply, default=()) if render_item is not None: - opts['render_item'] = render_item + opts["render_item"] = render_item if compare_type is not None: - opts['compare_type'] = compare_type + opts["compare_type"] = compare_type if compare_server_default is not None: - opts['compare_server_default'] = compare_server_default - opts['script'] = self.script + opts["compare_server_default"] = compare_server_default + opts["script"] = self.script opts.update(kw) @@ -809,7 +812,7 @@ class EnvironmentContext(util.ModuleClsProxy): url=url, dialect_name=dialect_name, environment_context=self, - opts=opts + opts=opts, ) def run_migrations(self, **kw): @@ -847,8 +850,7 @@ class EnvironmentContext(util.ModuleClsProxy): first been made available via :meth:`.configure`. """ - self.get_context().execute(sql, - execution_options=execution_options) + self.get_context().execute(sql, execution_options=execution_options) def static_output(self, text): """Emit text directly to the "offline" SQL stream. diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 17cc226..80dc8ff 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -2,8 +2,14 @@ import logging import sys from contextlib import contextmanager -from sqlalchemy import MetaData, Table, Column, String, literal_column,\ - PrimaryKeyConstraint +from sqlalchemy import ( + MetaData, + Table, + Column, + String, + literal_column, + PrimaryKeyConstraint, +) from sqlalchemy.engine.strategies import MockEngineStrategy from sqlalchemy.engine import url as sqla_url from sqlalchemy.engine import Connection @@ -65,71 +71,82 @@ class MigrationContext(object): self.environment_context = environment_context self.opts = opts self.dialect = dialect - self.script = opts.get('script') - as_sql = opts.get('as_sql', False) + self.script = opts.get("script") + as_sql = opts.get("as_sql", False) transactional_ddl = opts.get("transactional_ddl") self._transaction_per_migration = opts.get( - "transaction_per_migration", False) - self.on_version_apply_callbacks = opts.get('on_version_apply', ()) + "transaction_per_migration", False + ) + self.on_version_apply_callbacks = opts.get("on_version_apply", ()) if as_sql: self.connection = self._stdout_connection(connection) assert self.connection is not None else: self.connection = connection - self._migrations_fn = opts.get('fn') + self._migrations_fn = opts.get("fn") self.as_sql = as_sql if "output_encoding" in opts: self.output_buffer = EncodedIO( opts.get("output_buffer") or sys.stdout, - opts['output_encoding'] + opts["output_encoding"], ) else: self.output_buffer = opts.get("output_buffer", sys.stdout) - self._user_compare_type = opts.get('compare_type', False) + self._user_compare_type = opts.get("compare_type", False) self._user_compare_server_default = opts.get( - 'compare_server_default', - False) + "compare_server_default", False + ) self.version_table = version_table = opts.get( - 'version_table', 'alembic_version') - self.version_table_schema = version_table_schema = \ - opts.get('version_table_schema', None) + "version_table", "alembic_version" + ) + self.version_table_schema = version_table_schema = opts.get( + "version_table_schema", None + ) self._version = Table( - version_table, MetaData(), - Column('version_num', String(32), nullable=False), - schema=version_table_schema) + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) if opts.get("version_table_pk", True): self._version.append_constraint( PrimaryKeyConstraint( - 'version_num', name="%s_pkc" % version_table + "version_num", name="%s_pkc" % version_table ) ) self._start_from_rev = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( - dialect, self.connection, self.as_sql, + dialect, + self.connection, + self.as_sql, transactional_ddl, self.output_buffer, - opts + opts, ) log.info("Context impl %s.", self.impl.__class__.__name__) if self.as_sql: log.info("Generating static SQL") - log.info("Will assume %s DDL.", - "transactional" if self.impl.transactional_ddl - else "non-transactional") + log.info( + "Will assume %s DDL.", + "transactional" + if self.impl.transactional_ddl + else "non-transactional", + ) @classmethod - def configure(cls, - connection=None, - url=None, - dialect_name=None, - dialect=None, - environment_context=None, - opts=None, - ): + def configure( + cls, + connection=None, + url=None, + dialect_name=None, + dialect=None, + environment_context=None, + opts=None, + ): """Create a new :class:`.MigrationContext`. This is a factory method usually called @@ -158,7 +175,8 @@ class MigrationContext(object): util.warn( "'connection' argument to configure() is expected " "to be a sqlalchemy.engine.Connection instance, " - "got %r" % connection) + "got %r" % connection + ) dialect = connection.dialect elif url: url = sqla_url.make_url(url) @@ -175,22 +193,28 @@ class MigrationContext(object): transaction_now = _per_migration == self._transaction_per_migration if not transaction_now: + @contextmanager def do_nothing(): yield + return do_nothing() elif not self.impl.transactional_ddl: + @contextmanager def do_nothing(): yield + return do_nothing() elif self.as_sql: + @contextmanager def begin_commit(): self.impl.emit_begin() yield self.impl.emit_commit() + return begin_commit() else: return self.bind.begin() @@ -217,7 +241,8 @@ class MigrationContext(object): elif len(heads) > 1: raise util.CommandError( "Version table '%s' has more than one head present; " - "please use get_current_heads()" % self.version_table) + "please use get_current_heads()" % self.version_table + ) else: return heads[0] @@ -243,18 +268,20 @@ class MigrationContext(object): """ if self.as_sql: start_from_rev = self._start_from_rev - if start_from_rev == 'base': + if start_from_rev == "base": start_from_rev = None elif start_from_rev is not None and self.script: - start_from_rev = \ - self.script.get_revision(start_from_rev).revision + start_from_rev = self.script.get_revision( + start_from_rev + ).revision return util.to_tuple(start_from_rev, default=()) else: if self._start_from_rev: raise util.CommandError( "Can't specify current_rev to context " - "when using a database connection") + "when using a database connection" + ) if not self._has_version_table(): return () return tuple( @@ -266,7 +293,8 @@ class MigrationContext(object): def _has_version_table(self): return self.connection.dialect.has_table( - self.connection, self.version_table, self.version_table_schema) + self.connection, self.version_table, self.version_table_schema + ) def stamp(self, script_directory, revision): """Stamp the version table with a specific revision. @@ -315,8 +343,9 @@ class MigrationContext(object): head_maintainer = HeadMaintainer(self, heads) - starting_in_transaction = not self.as_sql and \ - self._in_connection_transaction() + starting_in_transaction = ( + not self.as_sql and self._in_connection_transaction() + ) for step in self._migrations_fn(heads, self): with self.begin_transaction(_per_migration=True): @@ -326,7 +355,9 @@ class MigrationContext(object): self._version.create(self.connection) log.info("Running %s", step) if self.as_sql: - self.impl.static_output("-- Running %s" % (step.short_log,)) + self.impl.static_output( + "-- Running %s" % (step.short_log,) + ) step.migration_fn(**kw) # previously, we wouldn't stamp per migration @@ -336,19 +367,24 @@ class MigrationContext(object): # just to run the operations on every version head_maintainer.update_to_step(step) for callback in self.on_version_apply_callbacks: - callback(ctx=self, - step=step.info, - heads=set(head_maintainer.heads), - run_args=kw) - - if not starting_in_transaction and not self.as_sql and \ - not self.impl.transactional_ddl and \ - self._in_connection_transaction(): + callback( + ctx=self, + step=step.info, + heads=set(head_maintainer.heads), + run_args=kw, + ) + + if ( + not starting_in_transaction + and not self.as_sql + and not self.impl.transactional_ddl + and self._in_connection_transaction() + ): raise util.CommandError( - "Migration \"%s\" has left an uncommitted " + 'Migration "%s" has left an uncommitted ' "transaction opened; transactional_ddl is False so " - "Alembic is not committing transactions" - % step) + "Alembic is not committing transactions" % step + ) if self.as_sql and not head_maintainer.heads: self._version.drop(self.connection) @@ -421,19 +457,20 @@ class MigrationContext(object): inspector_column, metadata_column, inspector_column.type, - metadata_column.type + metadata_column.type, ) if user_value is not None: return user_value - return self.impl.compare_type( - inspector_column, - metadata_column) + return self.impl.compare_type(inspector_column, metadata_column) - def _compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_column_default): + def _compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_column_default, + ): if self._user_compare_server_default is False: return False @@ -445,7 +482,7 @@ class MigrationContext(object): metadata_column, rendered_column_default, metadata_column.server_default, - rendered_metadata_default + rendered_metadata_default, ) if user_value is not None: return user_value @@ -454,7 +491,8 @@ class MigrationContext(object): inspector_column, metadata_column, rendered_metadata_default, - rendered_column_default) + rendered_column_default, + ) class HeadMaintainer(object): @@ -467,8 +505,7 @@ class HeadMaintainer(object): self.heads.add(version) self.context.impl._exec( - self.context._version.insert(). - values( + self.context._version.insert().values( version_num=literal_column("'%s'" % version) ) ) @@ -478,15 +515,17 @@ class HeadMaintainer(object): ret = self.context.impl._exec( self.context._version.delete().where( - self.context._version.c.version_num == - literal_column("'%s'" % version))) + self.context._version.c.version_num + == literal_column("'%s'" % version) + ) + ) if not self.context.as_sql and ret.rowcount != 1: raise util.CommandError( "Online migration expected to match one " "row when deleting '%s' in '%s'; " "%d found" - % (version, - self.context.version_table, ret.rowcount)) + % (version, self.context.version_table, ret.rowcount) + ) def _update_version(self, from_, to_): assert to_ not in self.heads @@ -494,17 +533,20 @@ class HeadMaintainer(object): self.heads.add(to_) ret = self.context.impl._exec( - self.context._version.update(). - values(version_num=literal_column("'%s'" % to_)).where( + self.context._version.update() + .values(version_num=literal_column("'%s'" % to_)) + .where( self.context._version.c.version_num - == literal_column("'%s'" % from_)) + == literal_column("'%s'" % from_) + ) ) if not self.context.as_sql and ret.rowcount != 1: raise util.CommandError( "Online migration expected to match one " "row when updating '%s' to '%s' in '%s'; " "%d found" - % (from_, to_, self.context.version_table, ret.rowcount)) + % (from_, to_, self.context.version_table, ret.rowcount) + ) def update_to_step(self, step): if step.should_delete_branch(self.heads): @@ -517,20 +559,32 @@ class HeadMaintainer(object): self._insert_version(vers) elif step.should_merge_branches(self.heads): # delete revs, update from rev, update to rev - (delete_revs, update_from_rev, - update_to_rev) = step.merge_branch_idents(self.heads) + ( + delete_revs, + update_from_rev, + update_to_rev, + ) = step.merge_branch_idents(self.heads) log.debug( "merge, delete %s, update %s to %s", - delete_revs, update_from_rev, update_to_rev) + delete_revs, + update_from_rev, + update_to_rev, + ) for delrev in delete_revs: self._delete_version(delrev) self._update_version(update_from_rev, update_to_rev) elif step.should_unmerge_branches(self.heads): - (update_from_rev, update_to_rev, - insert_revs) = step.unmerge_branch_idents(self.heads) + ( + update_from_rev, + update_to_rev, + insert_revs, + ) = step.unmerge_branch_idents(self.heads) log.debug( "unmerge, insert %s, update %s to %s", - insert_revs, update_from_rev, update_to_rev) + insert_revs, + update_from_rev, + update_to_rev, + ) for insrev in insert_revs: self._insert_version(insrev) self._update_version(update_from_rev, update_to_rev) @@ -597,8 +651,9 @@ class MigrationInfo(object): revision_map = None """The revision map inside of which this operation occurs.""" - def __init__(self, revision_map, is_upgrade, is_stamp, up_revisions, - down_revisions): + def __init__( + self, revision_map, is_upgrade, is_stamp, up_revisions, down_revisions + ): self.revision_map = revision_map self.is_upgrade = is_upgrade self.is_stamp = is_stamp @@ -625,14 +680,16 @@ class MigrationInfo(object): @property def source_revision_ids(self): """Active revisions before this migration step is applied.""" - return self.down_revision_ids if self.is_upgrade \ - else self.up_revision_ids + return ( + self.down_revision_ids if self.is_upgrade else self.up_revision_ids + ) @property def destination_revision_ids(self): """Active revisions after this migration step is applied.""" - return self.up_revision_ids if self.is_upgrade \ - else self.down_revision_ids + return ( + self.up_revision_ids if self.is_upgrade else self.down_revision_ids + ) @property def up_revision(self): @@ -689,7 +746,7 @@ class MigrationStep(object): return "%s %s -> %s" % ( self.name, util.format_as_comma(self.from_revisions_no_deps), - util.format_as_comma(self.to_revisions_no_deps) + util.format_as_comma(self.to_revisions_no_deps), ) def __str__(self): @@ -698,7 +755,7 @@ class MigrationStep(object): self.name, util.format_as_comma(self.from_revisions_no_deps), util.format_as_comma(self.to_revisions_no_deps), - self.doc + self.doc, ) else: return self.short_log @@ -716,13 +773,16 @@ class RevisionStep(MigrationStep): def __repr__(self): return "RevisionStep(%r, is_upgrade=%r)" % ( - self.revision.revision, self.is_upgrade + self.revision.revision, + self.is_upgrade, ) def __eq__(self, other): - return isinstance(other, RevisionStep) and \ - other.revision == self.revision and \ - self.is_upgrade == other.is_upgrade + return ( + isinstance(other, RevisionStep) + and other.revision == self.revision + and self.is_upgrade == other.is_upgrade + ) @property def doc(self): @@ -733,26 +793,26 @@ class RevisionStep(MigrationStep): if self.is_upgrade: return self.revision._all_down_revisions else: - return (self.revision.revision, ) + return (self.revision.revision,) @property def from_revisions_no_deps(self): if self.is_upgrade: return self.revision._versioned_down_revisions else: - return (self.revision.revision, ) + return (self.revision.revision,) @property def to_revisions(self): if self.is_upgrade: - return (self.revision.revision, ) + return (self.revision.revision,) else: return self.revision._all_down_revisions @property def to_revisions_no_deps(self): if self.is_upgrade: - return (self.revision.revision, ) + return (self.revision.revision,) else: return self.revision._versioned_down_revisions @@ -788,31 +848,31 @@ class RevisionStep(MigrationStep): if other_heads: ancestors = set( - r.revision for r in - self.revision_map._get_ancestor_nodes( - self.revision_map.get_revisions(other_heads), - check=False + r.revision + for r in self.revision_map._get_ancestor_nodes( + self.revision_map.get_revisions(other_heads), check=False ) ) from_revisions = list( - set(self.from_revisions).difference(ancestors)) + set(self.from_revisions).difference(ancestors) + ) else: from_revisions = list(self.from_revisions) return ( # delete revs, update from rev, update to rev - list(from_revisions[0:-1]), from_revisions[-1], - self.to_revisions[0] + list(from_revisions[0:-1]), + from_revisions[-1], + self.to_revisions[0], ) def _unmerge_to_revisions(self, heads): other_heads = set(heads).difference([self.revision.revision]) if other_heads: ancestors = set( - r.revision for r in - self.revision_map._get_ancestor_nodes( - self.revision_map.get_revisions(other_heads), - check=False + r.revision + for r in self.revision_map._get_ancestor_nodes( + self.revision_map.get_revisions(other_heads), check=False ) ) return list(set(self.to_revisions).difference(ancestors)) @@ -824,8 +884,9 @@ class RevisionStep(MigrationStep): return ( # update from rev, update to rev, insert revs - self.from_revisions[0], to_revisions[-1], - to_revisions[0:-1] + self.from_revisions[0], + to_revisions[-1], + to_revisions[0:-1], ) def should_create_branch(self, heads): @@ -853,8 +914,7 @@ class RevisionStep(MigrationStep): downrevs = self.revision._all_down_revisions - if len(downrevs) > 1 and \ - len(heads.intersection(downrevs)) > 1: + if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1: return True return False @@ -873,8 +933,9 @@ class RevisionStep(MigrationStep): def update_version_num(self, heads): if not self._has_scalar_down_revision: downrev = heads.intersection(self.revision._all_down_revisions) - assert len(downrev) == 1, \ - "Can't do an UPDATE because downrevision is ambiguous" + assert ( + len(downrev) == 1 + ), "Can't do an UPDATE because downrevision is ambiguous" down_revision = list(downrev)[0] else: down_revision = self.revision._all_down_revisions[0] @@ -894,10 +955,13 @@ class RevisionStep(MigrationStep): @property def info(self): - return MigrationInfo(revision_map=self.revision_map, - up_revisions=self.revision.revision, - down_revisions=self.revision._all_down_revisions, - is_upgrade=self.is_upgrade, is_stamp=False) + return MigrationInfo( + revision_map=self.revision_map, + up_revisions=self.revision.revision, + down_revisions=self.revision._all_down_revisions, + is_upgrade=self.is_upgrade, + is_stamp=False, + ) class StampStep(MigrationStep): @@ -915,11 +979,13 @@ class StampStep(MigrationStep): return None def __eq__(self, other): - return isinstance(other, StampStep) and \ - other.from_revisions == self.revisions and \ - other.to_revisions == self.to_revisions and \ - other.branch_move == self.branch_move and \ - self.is_upgrade == other.is_upgrade + return ( + isinstance(other, StampStep) + and other.from_revisions == self.revisions + and other.to_revisions == self.to_revisions + and other.branch_move == self.branch_move + and self.is_upgrade == other.is_upgrade + ) @property def from_revisions(self): @@ -955,15 +1021,17 @@ class StampStep(MigrationStep): def merge_branch_idents(self, heads): return ( # delete revs, update from rev, update to rev - list(self.from_[0:-1]), self.from_[-1], - self.to_[0] + list(self.from_[0:-1]), + self.from_[-1], + self.to_[0], ) def unmerge_branch_idents(self, heads): return ( # update from rev, update to rev, insert revs - self.from_[0], self.to_[-1], - list(self.to_[0:-1]) + self.from_[0], + self.to_[-1], + list(self.to_[0:-1]), ) def should_delete_branch(self, heads): @@ -980,10 +1048,15 @@ class StampStep(MigrationStep): @property def info(self): - up, down = (self.to_, self.from_) if self.is_upgrade \ + up, down = ( + (self.to_, self.from_) + if self.is_upgrade else (self.from_, self.to_) - return MigrationInfo(revision_map=self.revision_map, - up_revisions=up, - down_revisions=down, - is_upgrade=self.is_upgrade, - is_stamp=True) + ) + return MigrationInfo( + revision_map=self.revision_map, + up_revisions=up, + down_revisions=down, + is_upgrade=self.is_upgrade, + is_stamp=True, + ) diff --git a/alembic/script/__init__.py b/alembic/script/__init__.py index cae294f..65562b4 100644 --- a/alembic/script/__init__.py +++ b/alembic/script/__init__.py @@ -1,3 +1,3 @@ from .base import ScriptDirectory, Script # noqa -__all__ = ['ScriptDirectory', 'Script'] +__all__ = ["ScriptDirectory", "Script"] diff --git a/alembic/script/base.py b/alembic/script/base.py index 12e1510..1c63e08 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -10,13 +10,13 @@ from ..runtime import migration from contextlib import contextmanager -_sourceless_rev_file = re.compile(r'(?!\.\#|__init__)(.*\.py)(c|o)?$') -_only_source_rev_file = re.compile(r'(?!\.\#|__init__)(.*\.py)$') -_legacy_rev = re.compile(r'([a-f0-9]+)\.py$') -_mod_def_re = re.compile(r'(upgrade|downgrade)_([a-z0-9]+)') -_slug_re = re.compile(r'\w+') +_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$") +_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$") +_legacy_rev = re.compile(r"([a-f0-9]+)\.py$") +_mod_def_re = re.compile(r"(upgrade|downgrade)_([a-z0-9]+)") +_slug_re = re.compile(r"\w+") _default_file_template = "%(rev)s_%(slug)s" -_split_on_space_comma = re.compile(r',|(?: +)') +_split_on_space_comma = re.compile(r",|(?: +)") class ScriptDirectory(object): @@ -40,11 +40,16 @@ class ScriptDirectory(object): """ - def __init__(self, dir, file_template=_default_file_template, - truncate_slug_length=40, - version_locations=None, - sourceless=False, output_encoding="utf-8", - timezone=None): + def __init__( + self, + dir, + file_template=_default_file_template, + truncate_slug_length=40, + version_locations=None, + sourceless=False, + output_encoding="utf-8", + timezone=None, + ): self.dir = dir self.file_template = file_template self.version_locations = version_locations @@ -55,9 +60,11 @@ class ScriptDirectory(object): self.timezone = timezone if not os.access(dir, os.F_OK): - raise util.CommandError("Path doesn't exist: %r. Please use " - "the 'init' command to create a new " - "scripts folder." % dir) + raise util.CommandError( + "Path doesn't exist: %r. Please use " + "the 'init' command to create a new " + "scripts folder." % dir + ) @property def versions(self): @@ -75,13 +82,15 @@ class ScriptDirectory(object): for location in self.version_locations ] else: - return (os.path.abspath(os.path.join(self.dir, 'versions')),) + return (os.path.abspath(os.path.join(self.dir, "versions")),) def _load_revisions(self): if self.version_locations: paths = [ - vers for vers in self._version_locations - if os.path.exists(vers)] + vers + for vers in self._version_locations + if os.path.exists(vers) + ] else: paths = [self.versions] @@ -110,10 +119,11 @@ class ScriptDirectory(object): present. """ - script_location = config.get_main_option('script_location') + script_location = config.get_main_option("script_location") if script_location is None: - raise util.CommandError("No 'script_location' key " - "found in configuration.") + raise util.CommandError( + "No 'script_location' key " "found in configuration." + ) truncate_slug_length = config.get_main_option("truncate_slug_length") if truncate_slug_length is not None: truncate_slug_length = int(truncate_slug_length) @@ -125,20 +135,24 @@ class ScriptDirectory(object): return ScriptDirectory( util.coerce_resource_to_filename(script_location), file_template=config.get_main_option( - 'file_template', - _default_file_template), + "file_template", _default_file_template + ), truncate_slug_length=truncate_slug_length, sourceless=config.get_main_option("sourceless") == "true", output_encoding=config.get_main_option("output_encoding", "utf-8"), version_locations=version_locations, - timezone=config.get_main_option("timezone") + timezone=config.get_main_option("timezone"), ) @contextmanager def _catch_revision_errors( - self, - ancestor=None, multiple_heads=None, start=None, end=None, - resolution=None): + self, + ancestor=None, + multiple_heads=None, + start=None, + end=None, + resolution=None, + ): try: yield except revision.RangeNotAncestorError as rna: @@ -160,10 +174,11 @@ class ScriptDirectory(object): "argument '%(head_arg)s'; please " "specify a specific target revision, " "'<branchname>@%(head_arg)s' to " - "narrow to a specific head, or 'heads' for all heads") + "narrow to a specific head, or 'heads' for all heads" + ) multiple_heads = multiple_heads % { "head_arg": end or mh.argument, - "heads": util.format_as_comma(mh.heads) + "heads": util.format_as_comma(mh.heads), } compat.raise_from_cause(util.CommandError(multiple_heads)) except revision.ResolutionError as re: @@ -192,7 +207,8 @@ class ScriptDirectory(object): """ with self._catch_revision_errors(start=base, end=head): for rev in self.revision_map.iterate_revisions( - head, base, inclusive=True, assert_relative_length=False): + head, base, inclusive=True, assert_relative_length=False + ): yield rev def get_revisions(self, id_): @@ -210,7 +226,8 @@ class ScriptDirectory(object): top_revs = set(self.revision_map.get_revisions(id_)) top_revs.update( self.revision_map._get_ancestor_nodes( - list(top_revs), include_dependencies=True) + list(top_revs), include_dependencies=True + ) ) top_revs = self.revision_map._filter_into_branch_heads(top_revs) return top_revs @@ -275,11 +292,13 @@ class ScriptDirectory(object): :meth:`.ScriptDirectory.get_heads` """ - with self._catch_revision_errors(multiple_heads=( - 'The script directory has multiple heads (due to branching).' - 'Please use get_heads(), or merge the branches using ' - 'alembic merge.' - )): + with self._catch_revision_errors( + multiple_heads=( + "The script directory has multiple heads (due to branching)." + "Please use get_heads(), or merge the branches using " + "alembic merge." + ) + ): return self.revision_map.get_current_head() def get_heads(self): @@ -310,7 +329,8 @@ class ScriptDirectory(object): if len(bases) > 1: raise util.CommandError( "The script directory has multiple bases. " - "Please use get_bases().") + "Please use get_bases()." + ) elif bases: return bases[0] else: @@ -329,40 +349,50 @@ class ScriptDirectory(object): def _upgrade_revs(self, destination, current_rev): with self._catch_revision_errors( - ancestor="Destination %(end)s is not a valid upgrade " - "target from current head(s)", end=destination): + ancestor="Destination %(end)s is not a valid upgrade " + "target from current head(s)", + end=destination, + ): revs = self.revision_map.iterate_revisions( - destination, current_rev, implicit_base=True) + destination, current_rev, implicit_base=True + ) revs = list(revs) return [ migration.MigrationStep.upgrade_from_script( - self.revision_map, script) + self.revision_map, script + ) for script in reversed(list(revs)) ] def _downgrade_revs(self, destination, current_rev): with self._catch_revision_errors( - ancestor="Destination %(end)s is not a valid downgrade " - "target from current head(s)", end=destination): + ancestor="Destination %(end)s is not a valid downgrade " + "target from current head(s)", + end=destination, + ): revs = self.revision_map.iterate_revisions( - current_rev, destination, select_for_downgrade=True) + current_rev, destination, select_for_downgrade=True + ) return [ migration.MigrationStep.downgrade_from_script( - self.revision_map, script) + self.revision_map, script + ) for script in revs ] def _stamp_revs(self, revision, heads): with self._catch_revision_errors( - multiple_heads="Multiple heads are present; please specify a " - "single target revision"): + multiple_heads="Multiple heads are present; please specify a " + "single target revision" + ): heads = self.get_revisions(heads) # filter for lineage will resolve things like # branchname@base, version@base, etc. filtered_heads = self.revision_map.filter_for_lineage( - heads, revision, include_dependencies=True) + heads, revision, include_dependencies=True + ) steps = [] @@ -371,11 +401,18 @@ class ScriptDirectory(object): if dest is None: # dest is 'base'. Return a "delete branch" migration # for all applicable heads. - steps.extend([ - migration.StampStep(head.revision, None, False, True, - self.revision_map) - for head in filtered_heads - ]) + steps.extend( + [ + migration.StampStep( + head.revision, + None, + False, + True, + self.revision_map, + ) + for head in filtered_heads + ] + ) continue elif dest in filtered_heads: # the dest is already in the version table, do nothing. @@ -384,7 +421,8 @@ class ScriptDirectory(object): # figure out if the dest is a descendant or an # ancestor of the selected nodes descendants = set( - self.revision_map._get_descendant_nodes([dest])) + self.revision_map._get_descendant_nodes([dest]) + ) ancestors = set(self.revision_map._get_ancestor_nodes([dest])) if descendants.intersection(filtered_heads): @@ -393,8 +431,12 @@ class ScriptDirectory(object): assert not ancestors.intersection(filtered_heads) todo_heads = [head.revision for head in filtered_heads] step = migration.StampStep( - todo_heads, dest.revision, False, False, - self.revision_map) + todo_heads, + dest.revision, + False, + False, + self.revision_map, + ) steps.append(step) continue elif ancestors.intersection(filtered_heads): @@ -402,15 +444,20 @@ class ScriptDirectory(object): # we can treat them as a "merge", single step. todo_heads = [head.revision for head in filtered_heads] step = migration.StampStep( - todo_heads, dest.revision, True, False, - self.revision_map) + todo_heads, + dest.revision, + True, + False, + self.revision_map, + ) steps.append(step) continue else: # destination is in a branch not represented, # treat it as new branch - step = migration.StampStep((), dest.revision, True, True, - self.revision_map) + step = migration.StampStep( + (), dest.revision, True, True, self.revision_map + ) steps.append(step) continue return steps @@ -424,32 +471,31 @@ class ScriptDirectory(object): """ - util.load_python_file(self.dir, 'env.py') + util.load_python_file(self.dir, "env.py") @property def env_py_location(self): return os.path.abspath(os.path.join(self.dir, "env.py")) def _generate_template(self, src, dest, **kw): - util.status("Generating %s" % os.path.abspath(dest), - util.template_to_file, - src, - dest, - self.output_encoding, - **kw - ) + util.status( + "Generating %s" % os.path.abspath(dest), + util.template_to_file, + src, + dest, + self.output_encoding, + **kw + ) def _copy_file(self, src, dest): - util.status("Generating %s" % os.path.abspath(dest), - shutil.copy, - src, dest) + util.status( + "Generating %s" % os.path.abspath(dest), shutil.copy, src, dest + ) def _ensure_directory(self, path): path = os.path.abspath(path) if not os.path.exists(path): - util.status( - "Creating directory %s" % path, - os.makedirs, path) + util.status("Creating directory %s" % path, os.makedirs, path) def _generate_create_date(self): if self.timezone is not None: @@ -460,17 +506,29 @@ class ScriptDirectory(object): tzinfo = tz.gettz(self.timezone.upper()) if tzinfo is None: raise util.CommandError( - "Can't locate timezone: %s" % self.timezone) - create_date = datetime.datetime.utcnow().replace( - tzinfo=tz.tzutc()).astimezone(tzinfo) + "Can't locate timezone: %s" % self.timezone + ) + create_date = ( + datetime.datetime.utcnow() + .replace(tzinfo=tz.tzutc()) + .astimezone(tzinfo) + ) else: create_date = datetime.datetime.now() return create_date def generate_revision( - self, revid, message, head=None, - refresh=False, splice=False, branch_labels=None, - version_path=None, depends_on=None, **kw): + self, + revid, + message, + head=None, + refresh=False, + splice=False, + branch_labels=None, + version_path=None, + depends_on=None, + **kw + ): """Generate a new revision file. This runs the ``script.py.mako`` template, given @@ -500,11 +558,13 @@ class ScriptDirectory(object): except revision.RevisionError as err: compat.raise_from_cause(util.CommandError(err.args[0])) - with self._catch_revision_errors(multiple_heads=( - "Multiple heads are present; please specify the head " - "revision on which the new revision should be based, " - "or perform a merge." - )): + with self._catch_revision_errors( + multiple_heads=( + "Multiple heads are present; please specify the head " + "revision on which the new revision should be based, " + "or perform a merge." + ) + ): heads = self.revision_map.get_revisions(head) if len(set(heads)) != len(heads): @@ -521,7 +581,8 @@ class ScriptDirectory(object): else: raise util.CommandError( "Multiple version locations present, " - "please specify --version-path") + "please specify --version-path" + ) else: version_path = self.versions @@ -532,7 +593,8 @@ class ScriptDirectory(object): else: raise util.CommandError( "Path %s is not represented in current " - "version locations" % version_path) + "version locations" % version_path + ) if self.version_locations: self._ensure_directory(version_path) @@ -545,7 +607,8 @@ class ScriptDirectory(object): raise util.CommandError( "Revision %s is not a head revision; please specify " "--splice to create a new branch from this revision" - % head.revision) + % head.revision + ) if depends_on: with self._catch_revision_errors(): @@ -557,7 +620,6 @@ class ScriptDirectory(object): (self.revision_map.get_revision(dep), dep) for dep in util.to_list(depends_on) ] - ] self._generate_template( @@ -565,7 +627,8 @@ class ScriptDirectory(object): path, up_revision=str(revid), down_revision=revision.tuple_rev_as_scalar( - tuple(h.revision if h is not None else None for h in heads)), + tuple(h.revision if h is not None else None for h in heads) + ), branch_labels=util.to_tuple(branch_labels), depends_on=revision.tuple_rev_as_scalar(depends_on), create_date=create_date, @@ -582,9 +645,9 @@ class ScriptDirectory(object): "Version %s specified branch_labels %s, however the " "migration file %s does not have them; have you upgraded " "your script.py.mako to include the " - "'branch_labels' section?" % ( - script.revision, branch_labels, script.path - )) + "'branch_labels' section?" + % (script.revision, branch_labels, script.path) + ) self.revision_map.add_revision(script) return script @@ -592,17 +655,18 @@ class ScriptDirectory(object): def _rev_path(self, path, rev_id, message, create_date): slug = "_".join(_slug_re.findall(message or "")).lower() if len(slug) > self.truncate_slug_length: - slug = slug[:self.truncate_slug_length].rsplit('_', 1)[0] + '_' + slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_" filename = "%s.py" % ( - self.file_template % { - 'rev': rev_id, - 'slug': slug, - 'year': create_date.year, - 'month': create_date.month, - 'day': create_date.day, - 'hour': create_date.hour, - 'minute': create_date.minute, - 'second': create_date.second + self.file_template + % { + "rev": rev_id, + "slug": slug, + "year": create_date.year, + "month": create_date.month, + "day": create_date.day, + "hour": create_date.hour, + "minute": create_date.minute, + "second": create_date.second, } ) return os.path.join(path, filename) @@ -624,9 +688,11 @@ class Script(revision.Revision): rev_id, module.down_revision, branch_labels=util.to_tuple( - getattr(module, 'branch_labels', None), default=()), + getattr(module, "branch_labels", None), default=() + ), dependencies=util.to_tuple( - getattr(module, 'depends_on', None), default=()) + getattr(module, "depends_on", None), default=() + ), ) module = None @@ -664,32 +730,32 @@ class Script(revision.Revision): " (head)" if self.is_head else "", " (branchpoint)" if self.is_branch_point else "", " (mergepoint)" if self.is_merge_point else "", - " (current)" if self._db_current_indicator else "" + " (current)" if self._db_current_indicator else "", ) if self.is_merge_point: - entry += "Merges: %s\n" % (self._format_down_revision(), ) + entry += "Merges: %s\n" % (self._format_down_revision(),) else: - entry += "Parent: %s\n" % (self._format_down_revision(), ) + entry += "Parent: %s\n" % (self._format_down_revision(),) if self.dependencies: entry += "Also depends on: %s\n" % ( - util.format_as_comma(self.dependencies)) + util.format_as_comma(self.dependencies) + ) if self.is_branch_point: entry += "Branches into: %s\n" % ( - util.format_as_comma(self.nextrev)) + util.format_as_comma(self.nextrev) + ) if self.branch_labels: entry += "Branch names: %s\n" % ( - util.format_as_comma(self.branch_labels), ) + util.format_as_comma(self.branch_labels), + ) entry += "Path: %s\n" % (self.path,) entry += "\n%s\n" % ( - "\n".join( - " %s" % para - for para in self.longdoc.splitlines() - ) + "\n".join(" %s" % para for para in self.longdoc.splitlines()) ) return entry @@ -700,36 +766,41 @@ class Script(revision.Revision): " (head)" if self.is_head else "", " (branchpoint)" if self.is_branch_point else "", " (mergepoint)" if self.is_merge_point else "", - self.doc) + self.doc, + ) def _head_only( - self, include_branches=False, include_doc=False, - include_parents=False, tree_indicators=True, - head_indicators=True): + self, + include_branches=False, + include_doc=False, + include_parents=False, + tree_indicators=True, + head_indicators=True, + ): text = self.revision if include_parents: if self.dependencies: text = "%s (%s) -> %s" % ( self._format_down_revision(), util.format_as_comma(self.dependencies), - text + text, ) else: - text = "%s -> %s" % ( - self._format_down_revision(), text) + text = "%s -> %s" % (self._format_down_revision(), text) if include_branches and self.branch_labels: text += " (%s)" % util.format_as_comma(self.branch_labels) if head_indicators or tree_indicators: text += "%s%s%s" % ( " (head)" if self._is_real_head else "", - " (effective head)" if self.is_head and - not self._is_real_head else "", - " (current)" if self._db_current_indicator else "" + " (effective head)" + if self.is_head and not self._is_real_head + else "", + " (current)" if self._db_current_indicator else "", ) if tree_indicators: text += "%s%s" % ( " (branchpoint)" if self.is_branch_point else "", - " (mergepoint)" if self.is_merge_point else "" + " (mergepoint)" if self.is_merge_point else "", ) if include_doc: text += ", %s" % self.doc @@ -737,15 +808,18 @@ class Script(revision.Revision): def cmd_format( self, - verbose, - include_branches=False, include_doc=False, - include_parents=False, tree_indicators=True): + verbose, + include_branches=False, + include_doc=False, + include_parents=False, + tree_indicators=True, + ): if verbose: return self.log_entry else: return self._head_only( - include_branches, include_doc, - include_parents, tree_indicators) + include_branches, include_doc, include_parents, tree_indicators + ) def _format_down_revision(self): if not self.down_revision: @@ -768,13 +842,13 @@ class Script(revision.Revision): names = set(fname.split(".")[0] for fname in paths) # look for __pycache__ - if os.path.exists(os.path.join(path, '__pycache__')): + if os.path.exists(os.path.join(path, "__pycache__")): # add all files from __pycache__ whose filename is not # already in the names we got from the version directory. # add as relative paths including __pycache__ token paths.extend( - os.path.join('__pycache__', pyc) - for pyc in os.listdir(os.path.join(path, '__pycache__')) + os.path.join("__pycache__", pyc) + for pyc in os.listdir(os.path.join(path, "__pycache__")) if pyc.split(".")[0] not in names ) return paths @@ -794,8 +868,8 @@ class Script(revision.Revision): py_filename = py_match.group(1) if scriptdir.sourceless: - is_c = py_match.group(2) == 'c' - is_o = py_match.group(2) == 'o' + is_c = py_match.group(2) == "c" + is_o = py_match.group(2) == "o" else: is_c = is_o = False @@ -821,7 +895,8 @@ class Script(revision.Revision): "Be sure the 'revision' variable is " "declared inside the script (please see 'Upgrading " "from Alembic 0.1 to 0.2' in the documentation)." - % filename) + % filename + ) else: revision = m.group(1) else: diff --git a/alembic/script/revision.py b/alembic/script/revision.py index 3d9a332..832cce1 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -5,8 +5,8 @@ from .. import util from sqlalchemy import util as sqlautil from ..util import compat -_relative_destination = re.compile(r'(?:(.+?)@)?(\w+)?((?:\+|-)\d+)') -_revision_illegal_chars = ['@', '-', '+'] +_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)") +_revision_illegal_chars = ["@", "-", "+"] class RevisionError(Exception): @@ -18,8 +18,8 @@ class RangeNotAncestorError(RevisionError): self.lower = lower self.upper = upper super(RangeNotAncestorError, self).__init__( - "Revision %s is not an ancestor of revision %s" % - (lower or "base", upper or "base") + "Revision %s is not an ancestor of revision %s" + % (lower or "base", upper or "base") ) @@ -122,8 +122,9 @@ class RevisionMap(object): for revision in self._generator(): if revision.revision in map_: - util.warn("Revision %s is present more than once" % - revision.revision) + util.warn( + "Revision %s is present more than once" % revision.revision + ) map_[revision.revision] = revision if revision.branch_labels: has_branch_labels.add(revision) @@ -132,9 +133,9 @@ class RevisionMap(object): heads.add(revision.revision) _real_heads.add(revision.revision) if revision.is_base: - self.bases += (revision.revision, ) + self.bases += (revision.revision,) if revision._is_real_base: - self._real_bases += (revision.revision, ) + self._real_bases += (revision.revision,) # add the branch_labels to the map_. We'll need these # to resolve the dependencies. @@ -147,8 +148,10 @@ class RevisionMap(object): for rev in map_.values(): for downrev in rev._all_down_revisions: if downrev not in map_: - util.warn("Revision %s referenced from %s is not present" - % (downrev, rev)) + util.warn( + "Revision %s referenced from %s is not present" + % (downrev, rev) + ) down_revision = map_[downrev] down_revision.add_nextrev(rev) if downrev in rev._versioned_down_revisions: @@ -169,9 +172,12 @@ class RevisionMap(object): if branch_label in map_: raise RevisionError( "Branch name '%s' in revision %s already " - "used by revision %s" % - (branch_label, revision.revision, - map_[branch_label].revision) + "used by revision %s" + % ( + branch_label, + revision.revision, + map_[branch_label].revision, + ) ) map_[branch_label] = revision @@ -182,13 +188,16 @@ class RevisionMap(object): if revision.branch_labels: revision.branch_labels.update(revision.branch_labels) for node in self._get_descendant_nodes( - [revision], map_, include_dependencies=False): + [revision], map_, include_dependencies=False + ): node.branch_labels.update(revision.branch_labels) parent = node - while parent and \ - not parent._is_real_branch_point and \ - not parent.is_merge_point: + while ( + parent + and not parent._is_real_branch_point + and not parent.is_merge_point + ): parent.branch_labels.update(revision.branch_labels) if parent.down_revision: @@ -201,7 +210,6 @@ class RevisionMap(object): deps = [map_[dep] for dep in util.to_tuple(revision.dependencies)] revision._resolved_dependencies = tuple([d.revision for d in deps]) - def add_revision(self, revision, _replace=False): """add a single revision to an existing map. @@ -211,8 +219,9 @@ class RevisionMap(object): """ map_ = self._revision_map if not _replace and revision.revision in map_: - util.warn("Revision %s is present more than once" % - revision.revision) + util.warn( + "Revision %s is present more than once" % revision.revision + ) elif _replace and revision.revision not in map_: raise Exception("revision %s not in map" % revision.revision) @@ -221,9 +230,9 @@ class RevisionMap(object): self._add_depends_on(revision, map_) if revision.is_base: - self.bases += (revision.revision, ) + self.bases += (revision.revision,) if revision._is_real_base: - self._real_bases += (revision.revision, ) + self._real_bases += (revision.revision,) for downrev in revision._all_down_revisions: if downrev not in map_: util.warn( @@ -233,15 +242,21 @@ class RevisionMap(object): map_[downrev].add_nextrev(revision) if revision._is_real_head: self._real_heads = tuple( - head for head in self._real_heads - if head not in - set(revision._all_down_revisions).union([revision.revision]) + head + for head in self._real_heads + if head + not in set(revision._all_down_revisions).union( + [revision.revision] + ) ) + (revision.revision,) if revision.is_head: self.heads = tuple( - head for head in self.heads - if head not in - set(revision._versioned_down_revisions).union([revision.revision]) + head + for head in self.heads + if head + not in set(revision._versioned_down_revisions).union( + [revision.revision] + ) ) + (revision.revision,) def get_current_head(self, branch_label=None): @@ -264,11 +279,14 @@ class RevisionMap(object): """ current_heads = self.heads if branch_label: - current_heads = self.filter_for_lineage(current_heads, branch_label) + current_heads = self.filter_for_lineage( + current_heads, branch_label + ) if len(current_heads) > 1: raise MultipleHeads( current_heads, - "%s@head" % branch_label if branch_label else "head") + "%s@head" % branch_label if branch_label else "head", + ) if current_heads: return current_heads[0] @@ -301,7 +319,8 @@ class RevisionMap(object): resolved_id, branch_label = self._resolve_revision_number(id_) return tuple( self._revision_for_ident(rev_id, branch_label) - for rev_id in resolved_id) + for rev_id in resolved_id + ) def get_revision(self, id_): """Return the :class:`.Revision` instance with the given rev id. @@ -333,7 +352,8 @@ class RevisionMap(object): nonbranch_rev = self._revision_for_ident(branch_label) except ResolutionError: raise ResolutionError( - "No such branch: '%s'" % branch_label, branch_label) + "No such branch: '%s'" % branch_label, branch_label + ) else: return nonbranch_rev else: @@ -352,30 +372,37 @@ class RevisionMap(object): revision = False if revision is False: # do a partial lookup - revs = [x for x in self._revision_map - if x and x.startswith(resolved_id)] + revs = [ + x + for x in self._revision_map + if x and x.startswith(resolved_id) + ] if branch_rev: revs = self.filter_for_lineage(revs, check_branch) if not revs: raise ResolutionError( "No such revision or branch '%s'" % resolved_id, - resolved_id) + resolved_id, + ) elif len(revs) > 1: raise ResolutionError( "Multiple revisions start " - "with '%s': %s..." % ( - resolved_id, - ", ".join("'%s'" % r for r in revs[0:3]) - ), resolved_id) + "with '%s': %s..." + % (resolved_id, ", ".join("'%s'" % r for r in revs[0:3])), + resolved_id, + ) else: revision = self._revision_map[revs[0]] if check_branch and revision is not None: if not self._shares_lineage( - revision.revision, branch_rev.revision): + revision.revision, branch_rev.revision + ): raise ResolutionError( - "Revision %s is not a member of branch '%s'" % - (revision.revision, check_branch), resolved_id) + "Revision %s is not a member of branch '%s'" + % (revision.revision, check_branch), + resolved_id, + ) return revision def _filter_into_branch_heads(self, targets): @@ -383,14 +410,14 @@ class RevisionMap(object): for rev in list(targets): if targets.intersection( - self._get_descendant_nodes( - [rev], include_dependencies=False)).\ - difference([rev]): + self._get_descendant_nodes([rev], include_dependencies=False) + ).difference([rev]): targets.discard(rev) return targets def filter_for_lineage( - self, targets, check_against, include_dependencies=False): + self, targets, check_against, include_dependencies=False + ): id_, branch_label = self._resolve_revision_number(check_against) shares = [] @@ -400,12 +427,16 @@ class RevisionMap(object): shares.extend(id_) return [ - tg for tg in targets + tg + for tg in targets if self._shares_lineage( - tg, shares, include_dependencies=include_dependencies)] + tg, shares, include_dependencies=include_dependencies + ) + ] def _shares_lineage( - self, target, test_against_revs, include_dependencies=False): + self, target, test_against_revs, include_dependencies=False + ): if not test_against_revs: return True if not isinstance(target, Revision): @@ -415,46 +446,61 @@ class RevisionMap(object): self._revision_for_ident(test_against_rev) if not isinstance(test_against_rev, Revision) else test_against_rev - for test_against_rev - in util.to_tuple(test_against_revs, default=()) + for test_against_rev in util.to_tuple( + test_against_revs, default=() + ) ] return bool( - set(self._get_descendant_nodes([target], - include_dependencies=include_dependencies)) - .union(self._get_ancestor_nodes([target], - include_dependencies=include_dependencies)) + set( + self._get_descendant_nodes( + [target], include_dependencies=include_dependencies + ) + ) + .union( + self._get_ancestor_nodes( + [target], include_dependencies=include_dependencies + ) + ) .intersection(test_against_revs) ) def _resolve_revision_number(self, id_): if isinstance(id_, compat.string_types) and "@" in id_: - branch_label, id_ = id_.split('@', 1) + branch_label, id_ = id_.split("@", 1) else: branch_label = None # ensure map is loaded self._revision_map - if id_ == 'heads': + if id_ == "heads": if branch_label: - return self.filter_for_lineage( - self.heads, branch_label), branch_label + return ( + self.filter_for_lineage(self.heads, branch_label), + branch_label, + ) else: return self._real_heads, branch_label - elif id_ == 'head': + elif id_ == "head": current_head = self.get_current_head(branch_label) if current_head: - return (current_head, ), branch_label + return (current_head,), branch_label else: return (), branch_label - elif id_ == 'base' or id_ is None: + elif id_ == "base" or id_ is None: return (), branch_label else: return util.to_tuple(id_, default=None), branch_label def _relative_iterate( - self, destination, source, is_upwards, - implicit_base, inclusive, assert_relative_length): + self, + destination, + source, + is_upwards, + implicit_base, + inclusive, + assert_relative_length, + ): if isinstance(destination, compat.string_types): match = _relative_destination.match(destination) if not match: @@ -490,13 +536,15 @@ class RevisionMap(object): revs = list( self._iterate_revisions( - from_, to_, - inclusive=inclusive, implicit_base=implicit_base)) + from_, to_, inclusive=inclusive, implicit_base=implicit_base + ) + ) if symbol: if branch_label: symbol_rev = self.get_revision( - "%s@%s" % (branch_label, symbol)) + "%s@%s" % (branch_label, symbol) + ) else: symbol_rev = self.get_revision(symbol) if symbol.startswith("head"): @@ -513,25 +561,39 @@ class RevisionMap(object): else: index = 0 if is_upwards: - revs = revs[index - relative - reldelta:] - if not index and assert_relative_length and \ - len(revs) < abs(relative - reldelta): + revs = revs[index - relative - reldelta :] + if ( + not index + and assert_relative_length + and len(revs) < abs(relative - reldelta) + ): raise RevisionError( "Relative revision %s didn't " - "produce %d migrations" % (destination, abs(relative))) + "produce %d migrations" % (destination, abs(relative)) + ) else: - revs = revs[0:index - relative + reldelta] - if not index and assert_relative_length and \ - len(revs) != abs(relative) + reldelta: + revs = revs[0 : index - relative + reldelta] + if ( + not index + and assert_relative_length + and len(revs) != abs(relative) + reldelta + ): raise RevisionError( "Relative revision %s didn't " - "produce %d migrations" % (destination, abs(relative))) + "produce %d migrations" % (destination, abs(relative)) + ) return iter(revs) def iterate_revisions( - self, upper, lower, implicit_base=False, inclusive=False, - assert_relative_length=True, select_for_downgrade=False): + self, + upper, + lower, + implicit_base=False, + inclusive=False, + assert_relative_length=True, + select_for_downgrade=False, + ): """Iterate through script revisions, starting at the given upper revision identifier and ending at the lower. @@ -545,37 +607,59 @@ class RevisionMap(object): """ relative_upper = self._relative_iterate( - upper, lower, True, implicit_base, - inclusive, assert_relative_length + upper, + lower, + True, + implicit_base, + inclusive, + assert_relative_length, ) if relative_upper: return relative_upper relative_lower = self._relative_iterate( - lower, upper, False, implicit_base, - inclusive, assert_relative_length + lower, + upper, + False, + implicit_base, + inclusive, + assert_relative_length, ) if relative_lower: return relative_lower return self._iterate_revisions( - upper, lower, inclusive=inclusive, implicit_base=implicit_base, - select_for_downgrade=select_for_downgrade) + upper, + lower, + inclusive=inclusive, + implicit_base=implicit_base, + select_for_downgrade=select_for_downgrade, + ) def _get_descendant_nodes( - self, targets, map_=None, check=False, - omit_immediate_dependencies=False, include_dependencies=True): + self, + targets, + map_=None, + check=False, + omit_immediate_dependencies=False, + include_dependencies=True, + ): if omit_immediate_dependencies: + def fn(rev): if rev not in targets: return rev._all_nextrev else: return rev.nextrev + elif include_dependencies: + def fn(rev): return rev._all_nextrev + else: + def fn(rev): return rev.nextrev @@ -584,12 +668,16 @@ class RevisionMap(object): ) def _get_ancestor_nodes( - self, targets, map_=None, check=False, include_dependencies=True): + self, targets, map_=None, check=False, include_dependencies=True + ): if include_dependencies: + def fn(rev): return rev._all_down_revisions + else: + def fn(rev): return rev._versioned_down_revisions @@ -617,24 +705,30 @@ class RevisionMap(object): if rev in seen: continue seen.add(rev) - todo.extend( - map_[rev_id] for rev_id in fn(rev)) + todo.extend(map_[rev_id] for rev_id in fn(rev)) yield rev if check: - overlaps = per_target.intersection(targets).\ - difference([target]) + overlaps = per_target.intersection(targets).difference( + [target] + ) if overlaps: raise RevisionError( "Requested revision %s overlaps with " - "other requested revisions %s" % ( + "other requested revisions %s" + % ( target.revision, - ", ".join(r.revision for r in overlaps) + ", ".join(r.revision for r in overlaps), ) ) def _iterate_revisions( - self, upper, lower, inclusive=True, implicit_base=False, - select_for_downgrade=False): + self, + upper, + lower, + inclusive=True, + implicit_base=False, + select_for_downgrade=False, + ): """iterate revisions from upper to lower. The traversal is depth-first within branches, and breadth-first @@ -650,8 +744,9 @@ class RevisionMap(object): # is specified using a branch identifier, then we limit operations # to just that branch. - limit_to_lower_branch = \ - isinstance(lower, compat.string_types) and lower.endswith('@base') + limit_to_lower_branch = isinstance( + lower, compat.string_types + ) and lower.endswith("@base") uppers = util.dedupe_tuple(self.get_revisions(upper)) @@ -663,16 +758,14 @@ class RevisionMap(object): if limit_to_lower_branch: lowers = self.get_revisions(self._get_base_revisions(lower)) elif implicit_base and requested_lowers: - lower_ancestors = set( - self._get_ancestor_nodes(requested_lowers) - ) + lower_ancestors = set(self._get_ancestor_nodes(requested_lowers)) lower_descendants = set( self._get_descendant_nodes(requested_lowers) ) base_lowers = set() - candidate_lowers = upper_ancestors.\ - difference(lower_ancestors).\ - difference(lower_descendants) + candidate_lowers = upper_ancestors.difference( + lower_ancestors + ).difference(lower_descendants) for rev in candidate_lowers: for downrev in rev._all_down_revisions: if self._revision_map[downrev] in candidate_lowers: @@ -690,13 +783,15 @@ class RevisionMap(object): # represents all nodes we will produce total_space = set( - rev.revision for rev in upper_ancestors).intersection( - rev.revision for rev - in self._get_descendant_nodes( - lowers, check=True, + rev.revision for rev in upper_ancestors + ).intersection( + rev.revision + for rev in self._get_descendant_nodes( + lowers, + check=True, omit_immediate_dependencies=( select_for_downgrade and requested_lowers - ) + ), ) ) @@ -706,7 +801,8 @@ class RevisionMap(object): start_from = set(requested_lowers) start_from.update( self._get_ancestor_nodes( - list(start_from), include_dependencies=True) + list(start_from), include_dependencies=True + ) ) # determine all the current branch points represented @@ -725,19 +821,18 @@ class RevisionMap(object): # organize branch points to be consumed separately from # member nodes branch_todo = set( - rev for rev in - (self._revision_map[rev] for rev in total_space) - if rev._is_real_branch_point and - len(total_space.intersection(rev._all_nextrev)) > 1 + rev + for rev in (self._revision_map[rev] for rev in total_space) + if rev._is_real_branch_point + and len(total_space.intersection(rev._all_nextrev)) > 1 ) # it's not possible for any "uppers" to be in branch_todo, # because the ._all_nextrev of those nodes is not in total_space - #assert not branch_todo.intersection(uppers) + # assert not branch_todo.intersection(uppers) todo = collections.deque( - r for r in uppers - if r.revision in total_space + r for r in uppers if r.revision in total_space ) # iterate for total_space being emptied out @@ -746,7 +841,8 @@ class RevisionMap(object): if not total_space_modified: raise RevisionError( - "Dependency resolution failed; iteration can't proceed") + "Dependency resolution failed; iteration can't proceed" + ) total_space_modified = False # when everything non-branch pending is consumed, # add to the todo any branch nodes that have no @@ -755,12 +851,13 @@ class RevisionMap(object): todo.extendleft( sorted( ( - rev for rev in branch_todo + rev + for rev in branch_todo if not rev._all_nextrev.intersection(total_space) ), # favor "revisioned" branch points before # dependent ones - key=lambda rev: 0 if rev.is_branch_point else 1 + key=lambda rev: 0 if rev.is_branch_point else 1, ) ) branch_todo.difference_update(todo) @@ -772,11 +869,14 @@ class RevisionMap(object): # do depth first for elements within branches, # don't consume any actual branch nodes - todo.extendleft([ - self._revision_map[downrev] - for downrev in reversed(rev._all_down_revisions) - if self._revision_map[downrev] not in branch_todo - and downrev in total_space]) + todo.extendleft( + [ + self._revision_map[downrev] + for downrev in reversed(rev._all_down_revisions) + if self._revision_map[downrev] not in branch_todo + and downrev in total_space + ] + ) if not inclusive and rev in requested_lowers: continue @@ -795,6 +895,7 @@ class Revision(object): to Python files in a version directory. """ + nextrev = frozenset() """following revisions, based on down_revision only.""" @@ -830,15 +931,13 @@ class Revision(object): illegal_chars = set(revision).intersection(_revision_illegal_chars) if illegal_chars: raise RevisionError( - "Character(s) '%s' not allowed in revision identifier '%s'" % ( - ", ".join(sorted(illegal_chars)), - revision - ) + "Character(s) '%s' not allowed in revision identifier '%s'" + % (", ".join(sorted(illegal_chars)), revision) ) def __init__( - self, revision, down_revision, - dependencies=None, branch_labels=None): + self, revision, down_revision, dependencies=None, branch_labels=None + ): self.verify_rev_id(revision) self.revision = revision self.down_revision = tuple_rev_as_scalar(down_revision) @@ -848,18 +947,12 @@ class Revision(object): self.branch_labels = set(self._orig_branch_labels) def __repr__(self): - args = [ - repr(self.revision), - repr(self.down_revision) - ] + args = [repr(self.revision), repr(self.down_revision)] if self.dependencies: args.append("dependencies=%r" % (self.dependencies,)) if self.branch_labels: args.append("branch_labels=%r" % (self.branch_labels,)) - return "%s(%s)" % ( - self.__class__.__name__, - ", ".join(args) - ) + return "%s(%s)" % (self.__class__.__name__, ", ".join(args)) def add_nextrev(self, revision): self._all_nextrev = self._all_nextrev.union([revision.revision]) @@ -868,8 +961,10 @@ class Revision(object): @property def _all_down_revisions(self): - return util.to_tuple(self.down_revision, default=()) + \ - self._resolved_dependencies + return ( + util.to_tuple(self.down_revision, default=()) + + self._resolved_dependencies + ) @property def _versioned_down_revisions(self): diff --git a/alembic/templates/generic/env.py b/alembic/templates/generic/env.py index 058378b..f3df952 100644 --- a/alembic/templates/generic/env.py +++ b/alembic/templates/generic/env.py @@ -37,7 +37,8 @@ def run_migrations_offline(): """ url = config.get_main_option("sqlalchemy.url") context.configure( - url=url, target_metadata=target_metadata, literal_binds=True) + url=url, target_metadata=target_metadata, literal_binds=True + ) with context.begin_transaction(): context.run_migrations() @@ -52,18 +53,19 @@ def run_migrations_online(): """ connectable = engine_from_config( config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) with connectable.connect() as connection: context.configure( - connection=connection, - target_metadata=target_metadata + connection=connection, target_metadata=target_metadata ) with context.begin_transaction(): context.run_migrations() + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/alembic/templates/multidb/env.py b/alembic/templates/multidb/env.py index db24173..f5ad3d4 100644 --- a/alembic/templates/multidb/env.py +++ b/alembic/templates/multidb/env.py @@ -14,12 +14,12 @@ config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. fileConfig(config.config_file_name) -logger = logging.getLogger('alembic.env') +logger = logging.getLogger("alembic.env") # gather section names referring to different # databases. These are named "engine1", "engine2" # in the sample .ini file. -db_names = config.get_main_option('databases') +db_names = config.get_main_option("databases") # add your model's MetaData objects here # for 'autogenerate' support. These must be set @@ -56,19 +56,21 @@ def run_migrations_offline(): # individual files. engines = {} - for name in re.split(r',\s*', db_names): + for name in re.split(r",\s*", db_names): engines[name] = rec = {} - rec['url'] = context.config.get_section_option(name, - "sqlalchemy.url") + rec["url"] = context.config.get_section_option(name, "sqlalchemy.url") for name, rec in engines.items(): logger.info("Migrating database %s" % name) file_ = "%s.sql" % name logger.info("Writing output to %s" % file_) - with open(file_, 'w') as buffer: - context.configure(url=rec['url'], output_buffer=buffer, - target_metadata=target_metadata.get(name), - literal_binds=True) + with open(file_, "w") as buffer: + context.configure( + url=rec["url"], + output_buffer=buffer, + target_metadata=target_metadata.get(name), + literal_binds=True, + ) with context.begin_transaction(): context.run_migrations(engine_name=name) @@ -85,46 +87,47 @@ def run_migrations_online(): # engines, then run all migrations, then commit all transactions. engines = {} - for name in re.split(r',\s*', db_names): + for name in re.split(r",\s*", db_names): engines[name] = rec = {} - rec['engine'] = engine_from_config( + rec["engine"] = engine_from_config( context.config.get_section(name), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) for name, rec in engines.items(): - engine = rec['engine'] - rec['connection'] = conn = engine.connect() + engine = rec["engine"] + rec["connection"] = conn = engine.connect() if USE_TWOPHASE: - rec['transaction'] = conn.begin_twophase() + rec["transaction"] = conn.begin_twophase() else: - rec['transaction'] = conn.begin() + rec["transaction"] = conn.begin() try: for name, rec in engines.items(): logger.info("Migrating database %s" % name) context.configure( - connection=rec['connection'], + connection=rec["connection"], upgrade_token="%s_upgrades" % name, downgrade_token="%s_downgrades" % name, - target_metadata=target_metadata.get(name) + target_metadata=target_metadata.get(name), ) context.run_migrations(engine_name=name) if USE_TWOPHASE: for rec in engines.values(): - rec['transaction'].prepare() + rec["transaction"].prepare() for rec in engines.values(): - rec['transaction'].commit() + rec["transaction"].commit() except: for rec in engines.values(): - rec['transaction'].rollback() + rec["transaction"].rollback() raise finally: for rec in engines.values(): - rec['connection'].close() + rec["connection"].close() if context.is_offline_mode(): diff --git a/alembic/templates/pylons/env.py b/alembic/templates/pylons/env.py index 5ad9fd5..8c06cdc 100644 --- a/alembic/templates/pylons/env.py +++ b/alembic/templates/pylons/env.py @@ -13,18 +13,21 @@ from sqlalchemy.engine.base import Engine try: # if pylons app already in, don't create a new app from pylons import config as pylons_config - pylons_config['__file__'] + + pylons_config["__file__"] except: config = context.config # can use config['__file__'] here, i.e. the Pylons # ini file, instead of alembic.ini - config_file = config.get_main_option('pylons_config_file') + config_file = config.get_main_option("pylons_config_file") fileConfig(config_file) - wsgi_app = loadapp('config:%s' % config_file, relative_to='.') + wsgi_app = loadapp("config:%s" % config_file, relative_to=".") # customize this section for non-standard engine configurations. -meta = __import__("%s.model.meta" % wsgi_app.config['pylons.package']).model.meta +meta = __import__( + "%s.model.meta" % wsgi_app.config["pylons.package"] +).model.meta # add your model's MetaData object here # for 'autogenerate' support @@ -46,8 +49,10 @@ def run_migrations_offline(): """ context.configure( - url=meta.engine.url, target_metadata=target_metadata, - literal_binds=True) + url=meta.engine.url, + target_metadata=target_metadata, + literal_binds=True, + ) with context.begin_transaction(): context.run_migrations() @@ -65,13 +70,13 @@ def run_migrations_online(): with engine.connect() as connection: context.configure( - connection=connection, - target_metadata=target_metadata + connection=connection, target_metadata=target_metadata ) with context.begin_transaction(): context.run_migrations() + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index 553f501..70c28e0 100644 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -1,6 +1,13 @@ from .fixtures import TestBase -from .assertions import eq_, ne_, is_, is_not_, assert_raises_message, \ - eq_ignore_whitespace, assert_raises +from .assertions import ( + eq_, + ne_, + is_, + is_not_, + assert_raises_message, + eq_ignore_whitespace, + assert_raises, +) from .util import provide_metadata diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py index 2c7382c..c25b444 100644 --- a/alembic/testing/assertions.py +++ b/alembic/testing/assertions.py @@ -15,6 +15,7 @@ from . import config if not util.sqla_094: + def eq_(a, b, msg=None): """Assert a == b, with repr messaging on failure.""" assert a == b, msg or "%r != %r" % (a, b) @@ -46,27 +47,36 @@ if not util.sqla_094: callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search( - msg, text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) - print(text_type(e).encode('utf-8')) + assert re.search(msg, text_type(e), re.UNICODE), "%r !~ %s" % ( + msg, + e, + ) + print(text_type(e).encode("utf-8")) + else: - from sqlalchemy.testing.assertions import eq_, ne_, is_, is_not_, \ - assert_raises_message, assert_raises + from sqlalchemy.testing.assertions import ( + eq_, + ne_, + is_, + is_not_, + assert_raises_message, + assert_raises, + ) def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r'^\s+?|\n', "", a) - a = re.sub(r' {2,}', " ", a) - b = re.sub(r'^\s+?|\n', "", b) - b = re.sub(r' {2,}', " ", b) + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) # convert for unicode string rendering, # using special escape character "!U" if py3k: - b = re.sub(r'!U', '', b) + b = re.sub(r"!U", "", b) else: - b = re.sub(r'!U', 'u', b) + b = re.sub(r"!U", "u", b) assert a == b, msg or "%r != %r" % (a, b) @@ -74,9 +84,10 @@ def eq_ignore_whitespace(a, b, msg=None): def assert_compiled(element, assert_string, dialect=None): dialect = _get_dialect(dialect) eq_( - text_type(element.compile(dialect=dialect)). - replace("\n", "").replace("\t", ""), - assert_string.replace("\n", "").replace("\t", "") + text_type(element.compile(dialect=dialect)) + .replace("\n", "") + .replace("\t", ""), + assert_string.replace("\n", "").replace("\t", ""), ) @@ -84,19 +95,20 @@ _dialect_mods = {} def _get_dialect(name): - if name is None or name == 'default': + if name is None or name == "default": return default.DefaultDialect() else: try: dialect_mod = _dialect_mods[name] except KeyError: dialect_mod = getattr( - __import__('sqlalchemy.dialects.%s' % name).dialects, name) + __import__("sqlalchemy.dialects.%s" % name).dialects, name + ) _dialect_mods[name] = dialect_mod d = dialect_mod.dialect() - if name == 'postgresql': + if name == "postgresql": d.implicit_returning = True - elif name == 'mssql': + elif name == "mssql": d.legacy_schema_aliasing = False return d @@ -161,6 +173,7 @@ def emits_warning_on(db, *messages): were in fact seen. """ + @decorator def decorate(fn, *args, **kw): with expect_warnings_on(db, *messages): @@ -189,8 +202,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True): return for filter_ in filters: - if (regex and filter_.match(msg)) or \ - (not regex and filter_ == msg): + if (regex and filter_.match(msg)) or ( + not regex and filter_ == msg + ): seen.discard(filter_) break else: @@ -203,6 +217,6 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True): yield if assert_: - assert not seen, "Warnings were not seen: %s" % \ - ", ".join("%r" % (s.pattern if regex else s) for s in seen) - + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) diff --git a/alembic/testing/compat.py b/alembic/testing/compat.py index e0af6a2..9fbd50f 100644 --- a/alembic/testing/compat.py +++ b/alembic/testing/compat.py @@ -1,13 +1,12 @@ def get_url_driver_name(url): - if '+' not in url.drivername: + if "+" not in url.drivername: return url.get_dialect().driver else: - return url.drivername.split('+')[1] + return url.drivername.split("+")[1] def get_url_backend_name(url): - if '+' not in url.drivername: + if "+" not in url.drivername: return url.drivername else: - return url.drivername.split('+')[0] - + return url.drivername.split("+")[0] diff --git a/alembic/testing/config.py b/alembic/testing/config.py index ca28c6b..7d7009e 100644 --- a/alembic/testing/config.py +++ b/alembic/testing/config.py @@ -66,7 +66,8 @@ class Config(object): assert _current, "Can't push without a default Config set up" cls.push( Config( - db, _current.db_opts, _current.options, _current.file_config) + db, _current.db_opts, _current.options, _current.file_config + ) ) @classmethod @@ -88,4 +89,3 @@ class Config(object): def all_dbs(cls): for cfg in cls.all_configs(): yield cfg.db - diff --git a/alembic/testing/engines.py b/alembic/testing/engines.py index dadabc8..68d0068 100644 --- a/alembic/testing/engines.py +++ b/alembic/testing/engines.py @@ -25,4 +25,3 @@ def testing_engine(url=None, options=None): engine = create_engine(url, **options) return engine - diff --git a/alembic/testing/env.py b/alembic/testing/env.py index 0318703..51483d1 100644 --- a/alembic/testing/env.py +++ b/alembic/testing/env.py @@ -15,21 +15,22 @@ def _get_staging_directory(): if provision.FOLLOWER_IDENT: return "scratch_%s" % provision.FOLLOWER_IDENT else: - return 'scratch' + return "scratch" def staging_env(create=True, template="generic", sourceless=False): from alembic import command, script + cfg = _testing_config() if create: - path = os.path.join(_get_staging_directory(), 'scripts') + path = os.path.join(_get_staging_directory(), "scripts") if os.path.exists(path): shutil.rmtree(path) command.init(cfg, path, template=template) if sourceless: try: # do an import so that a .pyc/.pyo is generated. - util.load_python_file(path, 'env.py') + util.load_python_file(path, "env.py") except AttributeError: # we don't have the migration context set up yet # so running the .env py throws this exception. @@ -38,10 +39,13 @@ def staging_env(create=True, template="generic", sourceless=False): # worth it. pass assert sourceless in ( - "pep3147_envonly", "simple", "pep3147_everything"), sourceless + "pep3147_envonly", + "simple", + "pep3147_everything", + ), sourceless make_sourceless( os.path.join(path, "env.py"), - "pep3147" if "pep3147" in sourceless else "simple" + "pep3147" if "pep3147" in sourceless else "simple", ) sc = script.ScriptDirectory.from_config(cfg) @@ -53,40 +57,44 @@ def clear_staging_env(): def script_file_fixture(txt): - dir_ = os.path.join(_get_staging_directory(), 'scripts') + dir_ = os.path.join(_get_staging_directory(), "scripts") path = os.path.join(dir_, "script.py.mako") - with open(path, 'w') as f: + with open(path, "w") as f: f.write(txt) def env_file_fixture(txt): - dir_ = os.path.join(_get_staging_directory(), 'scripts') - txt = """ + dir_ = os.path.join(_get_staging_directory(), "scripts") + txt = ( + """ from alembic import context config = context.config -""" + txt +""" + + txt + ) path = os.path.join(dir_, "env.py") pyc_path = util.pyc_file_from_path(path) if pyc_path: os.unlink(pyc_path) - with open(path, 'w') as f: + with open(path, "w") as f: f.write(txt) def _sqlite_file_db(tempname="foo.db"): - dir_ = os.path.join(_get_staging_directory(), 'scripts') + dir_ = os.path.join(_get_staging_directory(), "scripts") url = "sqlite:///%s/%s" % (dir_, tempname) return engines.testing_engine(url=url) def _sqlite_testing_config(sourceless=False): - dir_ = os.path.join(_get_staging_directory(), 'scripts') + dir_ = os.path.join(_get_staging_directory(), "scripts") url = "sqlite:///%s/foo.db" % dir_ - return _write_config_file(""" + return _write_config_file( + """ [alembic] script_location = %s sqlalchemy.url = %s @@ -115,14 +123,17 @@ keys = generic [formatter_generic] format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S - """ % (dir_, url, "true" if sourceless else "false")) + """ + % (dir_, url, "true" if sourceless else "false") + ) -def _multi_dir_testing_config(sourceless=False, extra_version_location=''): - dir_ = os.path.join(_get_staging_directory(), 'scripts') +def _multi_dir_testing_config(sourceless=False, extra_version_location=""): + dir_ = os.path.join(_get_staging_directory(), "scripts") url = "sqlite:///%s/foo.db" % dir_ - return _write_config_file(""" + return _write_config_file( + """ [alembic] script_location = %s sqlalchemy.url = %s @@ -152,15 +163,22 @@ keys = generic [formatter_generic] format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S - """ % (dir_, url, "true" if sourceless else "false", - extra_version_location)) + """ + % ( + dir_, + url, + "true" if sourceless else "false", + extra_version_location, + ) + ) def _no_sql_testing_config(dialect="postgresql", directives=""): """use a postgresql url with no host so that connections guaranteed to fail""" - dir_ = os.path.join(_get_staging_directory(), 'scripts') - return _write_config_file(""" + dir_ = os.path.join(_get_staging_directory(), "scripts") + return _write_config_file( + """ [alembic] script_location = %s sqlalchemy.url = %s:// @@ -190,32 +208,36 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S -""" % (dir_, dialect, directives)) +""" + % (dir_, dialect, directives) + ) def _write_config_file(text): cfg = _testing_config() - with open(cfg.config_file_name, 'w') as f: + with open(cfg.config_file_name, "w") as f: f.write(text) return cfg def _testing_config(): from alembic.config import Config + if not os.access(_get_staging_directory(), os.F_OK): os.mkdir(_get_staging_directory()) - return Config(os.path.join(_get_staging_directory(), 'test_alembic.ini')) + return Config(os.path.join(_get_staging_directory(), "test_alembic.ini")) def write_script( - scriptdir, rev_id, content, encoding='ascii', sourceless=False): + scriptdir, rev_id, content, encoding="ascii", sourceless=False +): old = scriptdir.revision_map.get_revision(rev_id) path = old.path content = textwrap.dedent(content) if encoding: content = content.encode(encoding) - with open(path, 'wb') as fp: + with open(path, "wb") as fp: fp.write(content) pyc_path = util.pyc_file_from_path(path) if pyc_path: @@ -223,20 +245,21 @@ def write_script( script = Script._from_path(scriptdir, path) old = scriptdir.revision_map.get_revision(script.revision) if old.down_revision != script.down_revision: - raise Exception("Can't change down_revision " - "on a refresh operation.") + raise Exception( + "Can't change down_revision " "on a refresh operation." + ) scriptdir.revision_map.add_revision(script, _replace=True) if sourceless: make_sourceless( - path, - "pep3147" if sourceless == "pep3147_everything" else "simple" + path, "pep3147" if sourceless == "pep3147_everything" else "simple" ) def make_sourceless(path, style): import py_compile + py_compile.compile(path) if style == "simple" and has_pep3147(): @@ -264,7 +287,10 @@ def three_rev_fixture(cfg): script = ScriptDirectory.from_config(cfg) script.generate_revision(a, "revision a", refresh=True) - write_script(script, a, """\ + write_script( + script, + a, + """\ "Rev A" revision = '%s' down_revision = None @@ -279,10 +305,16 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 1") -""" % a) +""" + % a, + ) script.generate_revision(b, "revision b", refresh=True) - write_script(script, b, u("""# coding: utf-8 + write_script( + script, + b, + u( + """# coding: utf-8 "Rev B, méil, %3" revision = '{}' down_revision = '{}' @@ -297,10 +329,16 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 2") -""").format(b, a), encoding="utf-8") +""" + ).format(b, a), + encoding="utf-8", + ) script.generate_revision(c, "revision c", refresh=True) - write_script(script, c, """\ + write_script( + script, + c, + """\ "Rev C" revision = '%s' down_revision = '%s' @@ -315,7 +353,9 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 3") -""" % (c, b)) +""" + % (c, b), + ) return a, b, c @@ -328,8 +368,12 @@ def multi_heads_fixture(cfg, a, b, c): script = ScriptDirectory.from_config(cfg) script.generate_revision( - d, "revision d from b", head=b, splice=True, refresh=True) - write_script(script, d, """\ + d, "revision d from b", head=b, splice=True, refresh=True + ) + write_script( + script, + d, + """\ "Rev D" revision = '%s' down_revision = '%s' @@ -344,11 +388,17 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 4") -""" % (d, b)) +""" + % (d, b), + ) script.generate_revision( - e, "revision e from d", head=d, splice=True, refresh=True) - write_script(script, e, """\ + e, "revision e from d", head=d, splice=True, refresh=True + ) + write_script( + script, + e, + """\ "Rev E" revision = '%s' down_revision = '%s' @@ -363,11 +413,17 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 5") -""" % (e, d)) +""" + % (e, d), + ) script.generate_revision( - f, "revision f from b", head=b, splice=True, refresh=True) - write_script(script, f, """\ + f, "revision f from b", head=b, splice=True, refresh=True + ) + write_script( + script, + f, + """\ "Rev F" revision = '%s' down_revision = '%s' @@ -382,7 +438,9 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 6") -""" % (f, b)) +""" + % (f, b), + ) return d, e, f @@ -390,18 +448,16 @@ def downgrade(): def _multidb_testing_config(engines): """alembic.ini fixture to work exactly with the 'multidb' template""" - dir_ = os.path.join(_get_staging_directory(), 'scripts') + dir_ = os.path.join(_get_staging_directory(), "scripts") - databases = ", ".join( - engines.keys() - ) + databases = ", ".join(engines.keys()) engines = "\n\n".join( - "[%s]\n" - "sqlalchemy.url = %s" % (key, value.url) + "[%s]\n" "sqlalchemy.url = %s" % (key, value.url) for key, value in engines.items() ) - return _write_config_file(""" + return _write_config_file( + """ [alembic] script_location = %s sourceless = false @@ -432,5 +488,6 @@ keys = generic [formatter_generic] format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S - """ % (dir_, databases, engines) + """ + % (dir_, databases, engines) ) diff --git a/alembic/testing/exclusions.py b/alembic/testing/exclusions.py index 7d33a5b..41ed547 100644 --- a/alembic/testing/exclusions.py +++ b/alembic/testing/exclusions.py @@ -74,15 +74,15 @@ class compound(object): def matching_config_reasons(self, config): return [ - predicate._as_string(config) for predicate - in self.skips.union(self.fails) + predicate._as_string(config) + for predicate in self.skips.union(self.fails) if predicate(config) ] def include_test(self, include_tags, exclude_tags): return bool( - not self.tags.intersection(exclude_tags) and - (not include_tags or self.tags.intersection(include_tags)) + not self.tags.intersection(exclude_tags) + and (not include_tags or self.tags.intersection(include_tags)) ) def _extend(self, other): @@ -91,13 +91,14 @@ class compound(object): self.tags.update(other.tags) def __call__(self, fn): - if hasattr(fn, '_sa_exclusion_extend'): + if hasattr(fn, "_sa_exclusion_extend"): fn._sa_exclusion_extend._extend(self) return fn @decorator def decorate(fn, *args, **kw): return self._do(config._current, fn, *args, **kw) + decorated = decorate(fn) decorated._sa_exclusion_extend = self return decorated @@ -117,10 +118,7 @@ class compound(object): def _do(self, config, fn, *args, **kw): for skip in self.skips: if skip(config): - msg = "'%s' : %s" % ( - fn.__name__, - skip._as_string(config) - ) + msg = "'%s' : %s" % (fn.__name__, skip._as_string(config)) raise SkipTest(msg) try: @@ -131,16 +129,20 @@ class compound(object): self._expect_success(config, name=fn.__name__) return return_value - def _expect_failure(self, config, ex, name='block'): + def _expect_failure(self, config, ex, name="block"): for fail in self.fails: if fail(config): - print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(ex)))) + print( + ( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), str(ex)) + ) + ) break else: compat.raise_from_cause(ex) - def _expect_success(self, config, name='block'): + def _expect_success(self, config, name="block"): if not self.fails: return for fail in self.fails: @@ -148,13 +150,12 @@ class compound(object): break else: raise AssertionError( - "Unexpected success for '%s' (%s)" % - ( + "Unexpected success for '%s' (%s)" + % ( name, " and ".join( - fail._as_string(config) - for fail in self.fails - ) + fail._as_string(config) for fail in self.fails + ), ) ) @@ -191,8 +192,8 @@ class Predicate(object): return predicate elif isinstance(predicate, (list, set)): return OrPredicate( - [cls.as_predicate(pred) for pred in predicate], - description) + [cls.as_predicate(pred) for pred in predicate], description + ) elif isinstance(predicate, tuple): return SpecPredicate(*predicate) elif isinstance(predicate, compat.string_types): @@ -217,7 +218,7 @@ class Predicate(object): "driver": get_url_driver_name(config.db.url), "database": get_url_backend_name(config.db.url), "doesnt_support": "doesn't support" if bool_ else "does support", - "does_support": "does support" if bool_ else "doesn't support" + "does_support": "does support" if bool_ else "doesn't support", } def _as_string(self, config=None, negate=False): @@ -244,21 +245,21 @@ class SpecPredicate(Predicate): self.description = description _ops = { - '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], } def __call__(self, config): engine = config.db if "+" in self.db: - dialect, driver = self.db.split('+') + dialect, driver = self.db.split("+") else: dialect, driver = self.db, None @@ -271,8 +272,9 @@ class SpecPredicate(Predicate): assert driver is None, "DBAPI version specs not supported yet" version = _server_version(engine) - oper = hasattr(self.op, '__call__') and self.op \ - or self._ops[self.op] + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) return oper(version, self.spec) else: return True @@ -287,17 +289,9 @@ class SpecPredicate(Predicate): return "%s" % self.db else: if negate: - return "not %s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "not %s %s %s" % (self.db, self.op, self.spec) else: - return "%s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "%s %s %s" % (self.db, self.op, self.spec) class LambdaPredicate(Predicate): @@ -354,8 +348,9 @@ class OrPredicate(Predicate): conjunction = " and " else: conjunction = " or " - return conjunction.join(p._as_string(config, negate=negate) - for p in self.predicates) + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) def _negation_str(self, config): if self.description is not None: @@ -385,15 +380,13 @@ def _server_version(engine): # force metadata to be retrieved conn = engine.connect() - version = getattr(engine.dialect, 'server_version_info', ()) + version = getattr(engine.dialect, "server_version_info", ()) conn.close() return version def db_spec(*dbs): - return OrPredicate( - [Predicate.as_predicate(db) for db in dbs] - ) + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) def open(): @@ -418,11 +411,7 @@ def fails_on(db, reason=None): def fails_on_everything_except(*dbs): - return succeeds_if( - OrPredicate([ - Predicate.as_predicate(db) for db in dbs - ]) - ) + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) def skip(db, reason=None): @@ -441,7 +430,6 @@ def exclude(db, op, spec, reason=None): def against(config, *queries): assert queries, "no queries sent!" - return OrPredicate([ - Predicate.as_predicate(query) - for query in queries - ])(config) + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 86d40a2..b812476 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -17,10 +17,11 @@ from .assertions import _get_dialect, eq_ from . import mock testing_config = configparser.ConfigParser() -testing_config.read(['test.cfg']) +testing_config.read(["test.cfg"]) if not util.sqla_094: + class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -51,6 +52,8 @@ if not util.sqla_094: def teardown(self): if hasattr(self, "tearDown"): self.tearDown() + + else: from sqlalchemy.testing.fixtures import TestBase @@ -60,23 +63,22 @@ def capture_db(): def dump(sql, *multiparams, **params): buf.append(str(sql.compile(dialect=engine.dialect))) + engine = create_engine("postgresql://", strategy="mock", executor=dump) return engine, buf + _engs = {} @contextmanager def capture_context_buffer(**kw): - if kw.pop('bytes_io', False): + if kw.pop("bytes_io", False): buf = io.BytesIO() else: buf = io.StringIO() - kw.update({ - 'dialect_name': "sqlite", - 'output_buffer': buf - }) + kw.update({"dialect_name": "sqlite", "output_buffer": buf}) conf = EnvironmentContext.configure def configure(*arg, **opt): @@ -88,17 +90,20 @@ def capture_context_buffer(**kw): def op_fixture( - dialect='default', as_sql=False, - naming_convention=None, literal_binds=False, - native_boolean=None): + dialect="default", + as_sql=False, + naming_convention=None, + literal_binds=False, + native_boolean=None, +): opts = {} if naming_convention: if not util.sqla_092: raise SkipTest( - "naming_convention feature requires " - "sqla 0.9.2 or greater") - opts['target_metadata'] = MetaData(naming_convention=naming_convention) + "naming_convention feature requires " "sqla 0.9.2 or greater" + ) + opts["target_metadata"] = MetaData(naming_convention=naming_convention) class buffer_(object): def __init__(self): @@ -106,12 +111,12 @@ def op_fixture( def write(self, msg): msg = msg.strip() - msg = re.sub(r'[\n\t]', '', msg) + msg = re.sub(r"[\n\t]", "", msg) if as_sql: # the impl produces soft tabs, # so search for blocks of 4 spaces - msg = re.sub(r' ', '', msg) - msg = re.sub(r'\;\n*$', '', msg) + msg = re.sub(r" ", "", msg) + msg = re.sub(r"\;\n*$", "", msg) self.lines.append(msg) @@ -136,13 +141,13 @@ def op_fixture( else: assert False, "Could not locate fragment %r in %r" % ( sql, - buf.lines + buf.lines, ) if as_sql: - opts['as_sql'] = as_sql + opts["as_sql"] = as_sql if literal_binds: - opts['literal_binds'] = literal_binds + opts["literal_binds"] = literal_binds ctx_dialect = _get_dialect(dialect) if native_boolean is not None: ctx_dialect.supports_native_boolean = native_boolean @@ -150,6 +155,7 @@ def op_fixture( # which breaks assumptions in the alembic test suite ctx_dialect.non_native_boolean_check_constraint = True if not as_sql: + def execute(stmt, *multiparam, **param): if isinstance(stmt, string_types): stmt = text(stmt) @@ -160,12 +166,9 @@ def op_fixture( connection = mock.Mock(dialect=ctx_dialect, execute=execute) else: - opts['output_buffer'] = buf + opts["output_buffer"] = buf connection = None - context = ctx( - ctx_dialect, - connection, - opts) + context = ctx(ctx_dialect, connection, opts) alembic.op._proxy = Operations(context) return context diff --git a/alembic/testing/mock.py b/alembic/testing/mock.py index 08a756c..1d5256d 100644 --- a/alembic/testing/mock.py +++ b/alembic/testing/mock.py @@ -22,4 +22,5 @@ else: except ImportError: raise ImportError( "SQLAlchemy's test suite requires the " - "'mock' library as of 0.8.2.") + "'mock' library as of 0.8.2." + ) diff --git a/alembic/testing/plugin/bootstrap.py b/alembic/testing/plugin/bootstrap.py index 9f42fd2..4bd415d 100644 --- a/alembic/testing/plugin/bootstrap.py +++ b/alembic/testing/plugin/bootstrap.py @@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0. import os import sys -bootstrap_file = locals()['bootstrap_file'] -to_bootstrap = locals()['to_bootstrap'] +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] def load_file_as_module(name): path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) if sys.version_info.major >= 3: from importlib import machinery + mod = machinery.SourceFileLoader(name, path).load_module() else: import imp + mod = imp.load_source(name, path) return mod + if to_bootstrap == "pytest": sys.modules["alembic_plugin_base"] = load_file_as_module("plugin_base") sys.modules["alembic_pytestplugin"] = load_file_as_module("pytestplugin") diff --git a/alembic/testing/plugin/noseplugin.py b/alembic/testing/plugin/noseplugin.py index f8894d6..fafb9e1 100644 --- a/alembic/testing/plugin/noseplugin.py +++ b/alembic/testing/plugin/noseplugin.py @@ -25,6 +25,7 @@ import os import sys from nose.plugins import Plugin + fixtures = None py3k = sys.version_info.major >= 3 @@ -33,7 +34,7 @@ py3k = sys.version_info.major >= 3 class NoseSQLAlchemy(Plugin): enabled = True - name = 'sqla_testing' + name = "sqla_testing" score = 100 def options(self, parser, env=os.environ): @@ -43,8 +44,10 @@ class NoseSQLAlchemy(Plugin): def make_option(name, **kw): callback_ = kw.pop("callback", None) if callback_: + def wrap_(option, opt_str, value, parser): callback_(opt_str, value, parser) + kw["callback"] = wrap_ opt(name, **kw) @@ -71,7 +74,7 @@ class NoseSQLAlchemy(Plugin): def wantMethod(self, fn): if py3k: - if not hasattr(fn.__self__, 'cls'): + if not hasattr(fn.__self__, "cls"): return False cls = fn.__self__.cls else: @@ -85,19 +88,19 @@ class NoseSQLAlchemy(Plugin): plugin_base.before_test( test, test.test.cls.__module__, - test.test.cls, test.test.method.__name__) + test.test.cls, + test.test.method.__name__, + ) def afterTest(self, test): plugin_base.after_test(test) def startContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.start_test_class(ctx) def stopContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.stop_test_class(ctx) diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py index 141e82f..9acffb5 100644 --- a/alembic/testing/plugin/plugin_base.py +++ b/alembic/testing/plugin/plugin_base.py @@ -17,12 +17,14 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0 """ from __future__ import absolute_import + try: # unitttest has a SkipTest also but pytest doesn't # honor it unless nose is imported too... from nose import SkipTest except ImportError: from pytest import skip + SkipTest = skip.Exception import sys @@ -55,54 +57,118 @@ options = None def setup_options(make_option): - make_option("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for <LOG> (multiple OK)") - make_option("--log-debug", action="callback", - type="string", callback=_log, - help="turn on debug logging for <LOG> (multiple OK)") - make_option("--db", action="append", type="string", dest="db", - help="Use prefab database uri. Multiple OK, " - "first one is run by default.") - make_option('--dbs', action='callback', zeroarg_callback=_list_dbs, - help="List available prefab dbs") - make_option("--dburi", action="append", type="string", dest="dburi", - help="Database uri. Multiple OK, " - "first one is run by default.") - make_option("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first") - make_option("--backend-only", action="store_true", dest="backend_only", - help="Run only tests marked with __backend__") - make_option("--postgresql-templatedb", type="string", - help="name of template database to use for Postgresql " - "CREATE DATABASE (defaults to current database)") - make_option("--low-connections", action="store_true", - dest="low_connections", - help="Use a low number of distinct connections - " - "i.e. for Oracle TNS") - make_option("--write-idents", type="string", dest="write_idents", - help="write out generated follower idents to <file>, " - "when -n<num> is used") - make_option("--reversetop", action="store_true", - dest="reversetop", default=False, - help="Use a random-ordering set implementation in the ORM " - "(helps reveal dependency issues)") - make_option("--requirements", action="callback", type="string", - callback=_requirements_opt, - help="requirements class for testing, overrides setup.cfg") - make_option("--with-cdecimal", action="store_true", - dest="cdecimal", default=False, - help="Monkeypatch the cdecimal library into Python 'decimal' " - "for all tests") - make_option("--include-tag", action="callback", callback=_include_tag, - type="string", - help="Include tests with tag <tag>") - make_option("--exclude-tag", action="callback", callback=_exclude_tag, - type="string", - help="Exclude tests with tag <tag>") - make_option("--mysql-engine", action="store", - dest="mysql_engine", default=None, - help="Use the specified MySQL storage engine for all tables, " - "default is a db-default/InnoDB combo.") + make_option( + "--log-info", + action="callback", + type="string", + callback=_log, + help="turn on info logging for <LOG> (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type="string", + callback=_log, + help="turn on debug logging for <LOG> (multiple OK)", + ) + make_option( + "--db", + action="append", + type="string", + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type="string", + dest="dburi", + help="Database uri. Multiple OK, " "first one is run by default.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--backend-only", + action="store_true", + dest="backend_only", + help="Run only tests marked with __backend__", + ) + make_option( + "--postgresql-templatedb", + type="string", + help="name of template database to use for Postgresql " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type="string", + dest="write_idents", + help="write out generated follower idents to <file>, " + "when -n<num> is used", + ) + make_option( + "--reversetop", + action="store_true", + dest="reversetop", + default=False, + help="Use a random-ordering set implementation in the ORM " + "(helps reveal dependency issues)", + ) + make_option( + "--requirements", + action="callback", + type="string", + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--with-cdecimal", + action="store_true", + dest="cdecimal", + default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' " + "for all tests", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type="string", + help="Include tests with tag <tag>", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type="string", + help="Exclude tests with tag <tag>", + ) + make_option( + "--mysql-engine", + action="store", + dest="mysql_engine", + default=None, + help="Use the specified MySQL storage engine for all tables, " + "default is a db-default/InnoDB combo.", + ) def configure_follower(follower_ident): @@ -113,6 +179,7 @@ def configure_follower(follower_ident): """ from alembic.testing import provision + provision.FOLLOWER_IDENT = follower_ident @@ -126,9 +193,9 @@ def memoize_important_follower_config(dict_): callables, so we have to just copy all of that over. """ - dict_['memoized_config'] = { - 'include_tags': include_tags, - 'exclude_tags': exclude_tags + dict_["memoized_config"] = { + "include_tags": include_tags, + "exclude_tags": exclude_tags, } @@ -138,14 +205,14 @@ def restore_important_follower_config(dict_): This invokes in the follower process. """ - include_tags.update(dict_['memoized_config']['include_tags']) - exclude_tags.update(dict_['memoized_config']['exclude_tags']) + include_tags.update(dict_["memoized_config"]["include_tags"]) + exclude_tags.update(dict_["memoized_config"]["exclude_tags"]) def read_config(): global file_config file_config = configparser.ConfigParser() - file_config.read(['setup.cfg', 'test.cfg']) + file_config.read(["setup.cfg", "test.cfg"]) def pre_begin(opt): @@ -169,12 +236,11 @@ def post_begin(): # late imports, has to happen after config as well # as nose plugins like coverage - global util, fixtures, engines, exclusions, \ - assertions, warnings, profiling,\ - config, testing + global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing from alembic.testing import config, warnings, exclusions # noqa from alembic.testing import engines, fixtures # noqa from sqlalchemy import util # noqa + warnings.setup_filters() @@ -182,18 +248,19 @@ def _log(opt_str, value, parser): global logging if not logging: import logging + logging.basicConfig() - if opt_str.endswith('-info'): + if opt_str.endswith("-info"): logging.getLogger(value).setLevel(logging.INFO) - elif opt_str.endswith('-debug'): + elif opt_str.endswith("-debug"): logging.getLogger(value).setLevel(logging.DEBUG) def _list_dbs(*args): print("Available --db options (use --dburi to override)") - for macro in sorted(file_config.options('db')): - print("%20s\t%s" % (macro, file_config.get('db', macro))) + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) sys.exit(0) @@ -202,11 +269,12 @@ def _requirements_opt(opt_str, value, parser): def _exclude_tag(opt_str, value, parser): - exclude_tags.add(value.replace('-', '_')) + exclude_tags.add(value.replace("-", "_")) def _include_tag(opt_str, value, parser): - include_tags.add(value.replace('-', '_')) + include_tags.add(value.replace("-", "_")) + pre_configure = [] post_configure = [] @@ -228,12 +296,12 @@ def _setup_options(opt, file_config): options = opt - @pre def _monkeypatch_cdecimal(options, file_config): if options.cdecimal: import cdecimal - sys.modules['decimal'] = cdecimal + + sys.modules["decimal"] = cdecimal @post @@ -248,26 +316,27 @@ def _engine_uri(options, file_config): if options.db: for db_token in options.db: - for db in re.split(r'[,\s]+', db_token): - if db not in file_config.options('db'): + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): raise RuntimeError( "Unknown URI specifier '%s'. " - "Specify --dbs for known uris." - % db) + "Specify --dbs for known uris." % db + ) else: - db_urls.append(file_config.get('db', db)) + db_urls.append(file_config.get("db", db)) if not db_urls: - db_urls.append(file_config.get('db', 'default')) + db_urls.append(file_config.get("db", "default")) for db_url in db_urls: - if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': + if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': with open(options.write_idents, "a") as file_: file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n") cfg = provision.setup_config( - db_url, options, file_config, provision.FOLLOWER_IDENT) + db_url, options, file_config, provision.FOLLOWER_IDENT + ) if not config._current: cfg.set_as_current(cfg) @@ -276,7 +345,7 @@ def _engine_uri(options, file_config): @post def _requirements(options, file_config): - requirement_cls = file_config.get('sqla_testing', "requirement_cls") + requirement_cls = file_config.get("sqla_testing", "requirement_cls") _setup_requirements(requirement_cls) @@ -317,56 +386,75 @@ def _prep_testing_database(options, file_config): pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData()) - )) + e.execute( + schema._DropView( + schema.Table(vname, schema.MetaData()) + ) + ) if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names( - schema="test_schema") + view_names = inspector.get_view_names(schema="test_schema") except NotImplementedError: pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData(), - schema="test_schema") - )) - - for tname in reversed(inspector.get_table_names( - order_by="foreign_key")): - e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData()) - )) + e.execute( + schema._DropView( + schema.Table( + vname, + schema.MetaData(), + schema="test_schema", + ) + ) + ) + + for tname in reversed( + inspector.get_table_names(order_by="foreign_key") + ): + e.execute( + schema.DropTable(schema.Table(tname, schema.MetaData())) + ) if config.requirements.schemas.enabled_for_config(cfg): - for tname in reversed(inspector.get_table_names( - order_by="foreign_key", schema="test_schema")): - e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData(), - schema="test_schema") - )) + for tname in reversed( + inspector.get_table_names( + order_by="foreign_key", schema="test_schema" + ) + ): + e.execute( + schema.DropTable( + schema.Table( + tname, schema.MetaData(), schema="test_schema" + ) + ) + ) if against(cfg, "postgresql") and util.sqla_100: from sqlalchemy.dialects import postgresql + for enum in inspector.get_enums("*"): - e.execute(postgresql.DropEnumType( - postgresql.ENUM( - name=enum['name'], - schema=enum['schema']))) + e.execute( + postgresql.DropEnumType( + postgresql.ENUM( + name=enum["name"], schema=enum["schema"] + ) + ) + ) @post def _reverse_topological(options, file_config): if options.reversetop: from sqlalchemy.orm.util import randomize_unitofwork + randomize_unitofwork() @post def _post_setup_options(opt, file_config): from alembic.testing import config + config.options = options config.file_config = file_config @@ -374,10 +462,11 @@ def _post_setup_options(opt, file_config): def want_class(cls): if not issubclass(cls, fixtures.TestBase): return False - elif cls.__name__.startswith('_'): + elif cls.__name__.startswith("_"): return False - elif config.options.backend_only and not getattr(cls, '__backend__', - False): + elif config.options.backend_only and not getattr( + cls, "__backend__", False + ): return False else: return True @@ -390,25 +479,28 @@ def want_method(cls, fn): return False elif include_tags: return ( - hasattr(cls, '__tags__') and - exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) + hasattr(cls, "__tags__") + and exclusions.tags(cls.__tags__).include_test( + include_tags, exclude_tags + ) ) or ( - hasattr(fn, '_sa_exclusion_extend') and - fn._sa_exclusion_extend.include_test( - include_tags, exclude_tags) + hasattr(fn, "_sa_exclusion_extend") + and fn._sa_exclusion_extend.include_test( + include_tags, exclude_tags + ) ) - elif exclude_tags and hasattr(cls, '__tags__'): + elif exclude_tags and hasattr(cls, "__tags__"): return exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) - elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'): + include_tags, exclude_tags + ) + elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"): return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags) else: return True def generate_sub_tests(cls, module): - if getattr(cls, '__backend__', False): + if getattr(cls, "__backend__", False): for cfg in _possible_configs_for_cls(cls): orig_name = cls.__name__ @@ -416,17 +508,14 @@ def generate_sub_tests(cls, module): # pytest junit plugin, which is tripped up by the brackets # and periods, so sanitize - alpha_name = re.sub(r'[_\[\]\.]+', '_', cfg.name) - alpha_name = re.sub('_+$', '', alpha_name) + alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub("_+$", "", alpha_name) name = "%s_%s" % (cls.__name__, alpha_name) subcls = type( name, - (cls, ), - { - "_sa_orig_cls_name": orig_name, - "__only_on_config__": cfg - } + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, ) setattr(module, name, subcls) yield subcls @@ -440,8 +529,8 @@ def start_test_class(cls): def stop_test_class(cls): - #from sqlalchemy import inspect - #assert not inspect(testing.db).get_table_names() + # from sqlalchemy import inspect + # assert not inspect(testing.db).get_table_names() _restore_engine() @@ -450,7 +539,7 @@ def _restore_engine(): def _setup_engine(cls): - if getattr(cls, '__engine_options__', None): + if getattr(cls, "__engine_options__", None): eng = engines.testing_engine(options=cls.__engine_options__) config._current.push_engine(eng) @@ -472,16 +561,16 @@ def _possible_configs_for_cls(cls, reasons=None): if spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on__', None): + if getattr(cls, "__only_on__", None): spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) for config_obj in list(all_configs): if not spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on_config__', None): + if getattr(cls, "__only_on_config__", None): all_configs.intersection_update([cls.__only_on_config__]) - if hasattr(cls, '__requires__'): + if hasattr(cls, "__requires__"): requirements = config.requirements for config_obj in list(all_configs): for requirement in cls.__requires__: @@ -494,7 +583,7 @@ def _possible_configs_for_cls(cls, reasons=None): reasons.extend(skip_reasons) break - if hasattr(cls, '__prefer_requires__'): + if hasattr(cls, "__prefer_requires__"): non_preferred = set() requirements = config.requirements for config_obj in list(all_configs): @@ -513,30 +602,32 @@ def _do_skips(cls): reasons = [] all_configs = _possible_configs_for_cls(cls, reasons) - if getattr(cls, '__skip_if__', False): - for c in getattr(cls, '__skip_if__'): + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): if c(): - raise SkipTest("'%s' skipped by %s" % ( - cls.__name__, c.__name__) + raise SkipTest( + "'%s' skipped by %s" % (cls.__name__, c.__name__) ) if not all_configs: msg = "'%s' unsupported on any DB implementation %s%s" % ( cls.__name__, ", ".join( - "'%s(%s)+%s'" % ( + "'%s(%s)+%s'" + % ( config_obj.db.name, ".".join( - str(dig) for dig in - config_obj.db.dialect.server_version_info), - config_obj.db.driver + str(dig) + for dig in config_obj.db.dialect.server_version_info + ), + config_obj.db.driver, ) - for config_obj in config.Config.all_configs() + for config_obj in config.Config.all_configs() ), - ", ".join(reasons) + ", ".join(reasons), ) raise SkipTest(msg) - elif hasattr(cls, '__prefer_backends__'): + elif hasattr(cls, "__prefer_backends__"): non_preferred = set() spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) for config_obj in all_configs: diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py index 4d0f340..cc5b69f 100644 --- a/alembic/testing/plugin/pytestplugin.py +++ b/alembic/testing/plugin/pytestplugin.py @@ -21,6 +21,7 @@ import os try: import xdist # noqa + has_xdist = True except ImportError: has_xdist = False @@ -32,30 +33,42 @@ def pytest_addoption(parser): def make_option(name, **kw): callback_ = kw.pop("callback", None) if callback_: + class CallableAction(argparse.Action): - def __call__(self, parser, namespace, - values, option_string=None): + def __call__( + self, parser, namespace, values, option_string=None + ): callback_(option_string, values, parser) + kw["action"] = CallableAction zeroarg_callback = kw.pop("zeroarg_callback", None) if zeroarg_callback: + class CallableAction(argparse.Action): - def __init__(self, option_strings, - dest, default=False, - required=False, help=None): - super(CallableAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=True, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, - values, option_string=None): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, + ): + super(CallableAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): zeroarg_callback(option_string, values, parser) + kw["action"] = CallableAction group.addoption(name, **kw) @@ -67,23 +80,24 @@ def pytest_addoption(parser): def pytest_configure(config): if hasattr(config, "slaveinput"): plugin_base.restore_important_follower_config(config.slaveinput) - plugin_base.configure_follower( - config.slaveinput["follower_ident"] - ) + plugin_base.configure_follower(config.slaveinput["follower_ident"]) else: - if config.option.write_idents and \ - os.path.exists(config.option.write_idents): + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): os.remove(config.option.write_idents) plugin_base.pre_begin(config.option) - plugin_base.set_coverage_flag(bool(getattr(config.option, - "cov_source", False))) + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) def pytest_sessionstart(session): plugin_base.post_begin() + if has_xdist: import uuid @@ -95,10 +109,12 @@ if has_xdist: node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] from alembic.testing import provision + provision.create_follower_db(node.slaveinput["follower_ident"]) def pytest_testnodedown(node, error): from alembic.testing import provision + provision.drop_follower_db(node.slaveinput["follower_ident"]) @@ -115,18 +131,19 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) items[:] = [ - item for item in - items if isinstance(item.parent, pytest.Instance)] + item for item in items if isinstance(item.parent, pytest.Instance) + ] test_classes = set(item.parent for item in items) for test_class in test_classes: for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module): + test_class.cls, test_class.parent.module + ): if sub_cls is not test_class.cls: list_ = rebuilt_items[test_class.cls] for inst in pytest.Class( - sub_cls.__name__, - parent=test_class.parent.parent).collect(): + sub_cls.__name__, parent=test_class.parent.parent + ).collect(): list_.extend(inst.collect()) newitems = [] @@ -139,23 +156,29 @@ def pytest_collection_modifyitems(session, config, items): # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) - items[:] = sorted(newitems, key=lambda item: ( - item.parent.parent.parent.name, - item.parent.parent.name, - item.name - )) + items[:] = sorted( + newitems, + key=lambda item: ( + item.parent.parent.parent.name, + item.parent.parent.name, + item.name, + ), + ) def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(obj): return pytest.Class(name, parent=collector) - elif inspect.isfunction(obj) and \ - isinstance(collector, pytest.Instance) and \ - plugin_base.want_method(collector.cls, obj): + elif ( + inspect.isfunction(obj) + and isinstance(collector, pytest.Instance) + and plugin_base.want_method(collector.cls, obj) + ): return pytest.Function(name, parent=collector) else: return [] + _current_class = None @@ -180,6 +203,7 @@ def pytest_runtest_setup(item): global _current_class class_teardown(item.parent.parent) _current_class = None + item.parent.parent.addfinalizer(finalize) test_setup(item) @@ -194,8 +218,9 @@ def pytest_runtest_teardown(item): def test_setup(item): - plugin_base.before_test(item, item.parent.module.__name__, - item.parent.cls, item.name) + plugin_base.before_test( + item, item.parent.module.__name__, item.parent.cls, item.name + ) def test_teardown(item): diff --git a/alembic/testing/provision.py b/alembic/testing/provision.py index 05a21d3..a5ce53c 100644 --- a/alembic/testing/provision.py +++ b/alembic/testing/provision.py @@ -30,6 +30,7 @@ class register(object): def decorate(fn): self.fns[dbname] = fn return self + return decorate def __call__(self, cfg, *arg): @@ -43,7 +44,7 @@ class register(object): if backend in self.fns: return self.fns[backend](cfg, *arg) else: - return self.fns['*'](cfg, *arg) + return self.fns["*"](cfg, *arg) def create_follower_db(follower_ident): @@ -86,9 +87,7 @@ def _configs_for_db_operation(): for cfg in config.Config.all_configs(): url = cfg.db.url backend = get_url_backend_name(url) - host_conf = ( - backend, - url.username, url.host, url.database) + host_conf = (backend, url.username, url.host, url.database) if host_conf not in hosts: yield cfg @@ -132,13 +131,13 @@ def _follower_url_from_main(url, ident): @_update_db_opts.for_db("mssql") def _mssql_update_db_opts(db_url, db_opts): - db_opts['legacy_schema_aliasing'] = False + db_opts["legacy_schema_aliasing"] = False @_follower_url_from_main.for_db("sqlite") def _sqlite_follower_url_from_main(url, ident): url = sa_url.make_url(url) - if not url.database or url.database == ':memory:': + if not url.database or url.database == ":memory:": return url else: return sa_url.make_url("sqlite:///%s.db" % ident) @@ -154,19 +153,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident): # as an attached if not follower_ident: dbapi_connection.execute( - 'ATTACH DATABASE "test_schema.db" AS test_schema') + 'ATTACH DATABASE "test_schema.db" AS test_schema' + ) else: dbapi_connection.execute( 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' - % follower_ident) + % follower_ident + ) @_create_db.for_db("postgresql") def _pg_create_db(cfg, eng, ident): template_db = cfg.options.postgresql_templatedb - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: try: _pg_drop_db(cfg, conn, ident) except Exception: @@ -222,14 +222,15 @@ def _sqlite_create_db(cfg, eng, ident): @_drop_db.for_db("postgresql") def _pg_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute( text( "select pg_terminate_backend(pid) from pg_stat_activity " "where usename=current_user and pid != pg_backend_pid() " "and datname=:dname" - ), dname=ident) + ), + dname=ident, + ) conn.execute("DROP DATABASE %s" % ident) @@ -258,7 +259,7 @@ def _oracle_create_db(cfg, eng, ident): conn.execute("create user %s identified by xe" % ident) conn.execute("create user %s_ts1 identified by xe" % ident) conn.execute("create user %s_ts2 identified by xe" % ident) - conn.execute("grant dba to %s" % (ident, )) + conn.execute("grant dba to %s" % (ident,)) conn.execute("grant unlimited tablespace to %s" % ident) conn.execute("grant unlimited tablespace to %s_ts1" % ident) conn.execute("grant unlimited tablespace to %s_ts2" % ident) @@ -316,8 +317,9 @@ def reap_oracle_dbs(idents_file): to_reap = conn.execute( "select u.username from all_users u where username " "like 'TEST_%' and not exists (select username " - "from v$session where username=u.username)") - all_names = set(username.lower() for (username, ) in to_reap) + "from v$session where username=u.username)" + ) + all_names = set(username.lower() for (username,) in to_reap) to_drop = set() for name in all_names: if name.endswith("_ts1") or name.endswith("_ts2"): @@ -334,15 +336,13 @@ def reap_oracle_dbs(idents_file): if _ora_drop_ignore(conn, username): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) + "Dropped %d out of %d stale databases detected", dropped, total + ) @_follower_url_from_main.for_db("oracle") def _oracle_follower_url_from_main(url, ident): url = sa_url.make_url(url) url.username = ident - url.password = 'xe' + url.password = "xe" return url - - diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 400642f..f25f5d7 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -5,6 +5,7 @@ from . import exclusions if util.sqla_094: from sqlalchemy.testing.requirements import Requirements else: + class Requirements(object): pass @@ -28,7 +29,7 @@ class SuiteRequirements(Requirements): insp = inspect(config.db) try: - insp.get_unique_constraints('x') + insp.get_unique_constraints("x") except NotImplementedError: return True except TypeError: @@ -62,83 +63,80 @@ class SuiteRequirements(Requirements): def fail_before_sqla_100(self): return exclusions.fails_if( lambda config: not util.sqla_100, - "SQLAlchemy 1.0.0 or greater required" + "SQLAlchemy 1.0.0 or greater required", ) @property def fail_before_sqla_1010(self): return exclusions.fails_if( lambda config: not util.sqla_1010, - "SQLAlchemy 1.0.10 or greater required" + "SQLAlchemy 1.0.10 or greater required", ) @property def fail_before_sqla_099(self): return exclusions.fails_if( lambda config: not util.sqla_099, - "SQLAlchemy 0.9.9 or greater required" + "SQLAlchemy 0.9.9 or greater required", ) @property def fail_before_sqla_110(self): return exclusions.fails_if( lambda config: not util.sqla_110, - "SQLAlchemy 1.1.0 or greater required" + "SQLAlchemy 1.1.0 or greater required", ) @property def sqlalchemy_092(self): return exclusions.skip_if( lambda config: not util.sqla_092, - "SQLAlchemy 0.9.2 or greater required" + "SQLAlchemy 0.9.2 or greater required", ) @property def sqlalchemy_094(self): return exclusions.skip_if( lambda config: not util.sqla_094, - "SQLAlchemy 0.9.4 or greater required" + "SQLAlchemy 0.9.4 or greater required", ) @property def sqlalchemy_099(self): return exclusions.skip_if( lambda config: not util.sqla_099, - "SQLAlchemy 0.9.9 or greater required" + "SQLAlchemy 0.9.9 or greater required", ) @property def sqlalchemy_100(self): return exclusions.skip_if( lambda config: not util.sqla_100, - "SQLAlchemy 1.0.0 or greater required" + "SQLAlchemy 1.0.0 or greater required", ) @property def sqlalchemy_1014(self): return exclusions.skip_if( lambda config: not util.sqla_1014, - "SQLAlchemy 1.0.14 or greater required" + "SQLAlchemy 1.0.14 or greater required", ) @property def sqlalchemy_1115(self): return exclusions.skip_if( lambda config: not util.sqla_1115, - "SQLAlchemy 1.1.15 or greater required" + "SQLAlchemy 1.1.15 or greater required", ) @property def sqlalchemy_110(self): return exclusions.skip_if( lambda config: not util.sqla_110, - "SQLAlchemy 1.1.0 or greater required" + "SQLAlchemy 1.1.0 or greater required", ) @property def pep3147(self): - return exclusions.only_if( - lambda config: util.compat.has_pep3147() - ) - + return exclusions.only_if(lambda config: util.compat.has_pep3147()) diff --git a/alembic/testing/runner.py b/alembic/testing/runner.py index d4adbcf..46236a0 100644 --- a/alembic/testing/runner.py +++ b/alembic/testing/runner.py @@ -45,4 +45,4 @@ def setup_py_test(): to nose. """ - nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner']) + nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"]) diff --git a/alembic/testing/util.py b/alembic/testing/util.py index 466dea3..b2b3476 100644 --- a/alembic/testing/util.py +++ b/alembic/testing/util.py @@ -10,7 +10,7 @@ def provide_metadata(fn, *args, **kw): metadata = schema.MetaData(config.db) self = args[0] - prev_meta = getattr(self, 'metadata', None) + prev_meta = getattr(self, "metadata", None) self.metadata = metadata try: return fn(*args, **kw) diff --git a/alembic/testing/warnings.py b/alembic/testing/warnings.py index de91778..cb59a64 100644 --- a/alembic/testing/warnings.py +++ b/alembic/testing/warnings.py @@ -17,11 +17,12 @@ import re def setup_filters(): """Set global warning behavior for the test suite.""" - warnings.filterwarnings('ignore', - category=sa_exc.SAPendingDeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SAWarning) - warnings.filterwarnings('error', category=DeprecationWarning) + warnings.filterwarnings( + "ignore", category=sa_exc.SAPendingDeprecationWarning + ) + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) + warnings.filterwarnings("error", category=DeprecationWarning) def assert_warnings(fn, warning_msgs, regex=False): diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 1e6c645..e28f715 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -1,17 +1,45 @@ from .langhelpers import ( # noqa - asbool, rev_id, to_tuple, to_list, memoized_property, dedupe_tuple, - immutabledict, _with_legacy_names, Dispatcher, ModuleClsProxy) + asbool, + rev_id, + to_tuple, + to_list, + memoized_property, + dedupe_tuple, + immutabledict, + _with_legacy_names, + Dispatcher, + ModuleClsProxy, +) from .messaging import ( # noqa - write_outstream, status, err, obfuscate_url_pw, warn, msg, format_as_comma) + write_outstream, + status, + err, + obfuscate_url_pw, + warn, + msg, + format_as_comma, +) from .pyfiles import ( # noqa - template_to_file, coerce_resource_to_filename, - pyc_file_from_path, load_python_file, edit) + template_to_file, + coerce_resource_to_filename, + pyc_file_from_path, + load_python_file, + edit, +) from .sqla_compat import ( # noqa - sqla_09, sqla_092, sqla_094, sqla_099, sqla_100, sqla_105, sqla_110, sqla_1010, - sqla_1014, sqla_1115) + sqla_09, + sqla_092, + sqla_094, + sqla_099, + sqla_100, + sqla_105, + sqla_110, + sqla_1010, + sqla_1014, + sqla_1115, +) from .exc import CommandError if not sqla_09: - raise CommandError( - "SQLAlchemy 0.9.0 or greater is required. ") + raise CommandError("SQLAlchemy 0.9.0 or greater is required. ") diff --git a/alembic/util/compat.py b/alembic/util/compat.py index dec2ca8..7e07ed4 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -19,12 +19,13 @@ else: if py3k: import builtins as compat_builtins - string_types = str, + + string_types = (str,) binary_type = bytes text_type = str def callable(fn): - return hasattr(fn, '__call__') + return hasattr(fn, "__call__") def u(s): return s @@ -35,7 +36,8 @@ if py3k: range = range else: import __builtin__ as compat_builtins - string_types = basestring, + + string_types = (basestring,) binary_type = str text_type = unicode callable = callable @@ -55,16 +57,17 @@ else: if py3k: import collections + ArgSpec = collections.namedtuple( - "ArgSpec", - ["args", "varargs", "keywords", "defaults"]) + "ArgSpec", ["args", "varargs", "keywords", "defaults"] + ) from inspect import getfullargspec as inspect_getfullargspec def inspect_getargspec(func): - return ArgSpec( - *inspect_getfullargspec(func)[0:4] - ) + return ArgSpec(*inspect_getfullargspec(func)[0:4]) + + else: from inspect import getargspec as inspect_getargspec # noqa @@ -72,14 +75,20 @@ if py35: from inspect import formatannotation def inspect_formatargspec( - args, varargs=None, varkw=None, defaults=None, - kwonlyargs=(), kwonlydefaults={}, annotations={}, - formatarg=str, - formatvarargs=lambda name: '*' + name, - formatvarkw=lambda name: '**' + name, - formatvalue=lambda value: '=' + repr(value), - formatreturns=lambda text: ' -> ' + text, - formatannotation=formatannotation): + args, + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=(), + kwonlydefaults={}, + annotations={}, + formatarg=str, + formatvarargs=lambda name: "*" + name, + formatvarkw=lambda name: "**" + name, + formatvalue=lambda value: "=" + repr(value), + formatreturns=lambda text: " -> " + text, + formatannotation=formatannotation, + ): """Copy formatargspec from python 3.7 standard library. Python 3 has deprecated formatargspec and requested that Signature @@ -93,8 +102,9 @@ if py35: def formatargandannotation(arg): result = formatarg(arg) if arg in annotations: - result += ': ' + formatannotation(annotations[arg]) + result += ": " + formatannotation(annotations[arg]) return result + specs = [] if defaults: firstdefault = len(args) - len(defaults) @@ -107,7 +117,7 @@ if py35: specs.append(formatvarargs(formatargandannotation(varargs))) else: if kwonlyargs: - specs.append('*') + specs.append("*") if kwonlyargs: for kwonlyarg in kwonlyargs: spec = formatargandannotation(kwonlyarg) @@ -116,11 +126,12 @@ if py35: specs.append(spec) if varkw is not None: specs.append(formatvarkw(formatargandannotation(varkw))) - result = '(' + ', '.join(specs) + ')' - if 'return' in annotations: - result += formatreturns(formatannotation(annotations['return'])) + result = "(" + ", ".join(specs) + ")" + if "return" in annotations: + result += formatreturns(formatannotation(annotations["return"])) return result + else: from inspect import formatargspec as inspect_formatargspec @@ -151,22 +162,27 @@ if py35: spec.loader.exec_module(module) return module + elif py3k: import importlib.machinery def load_module_py(module_id, path): module = importlib.machinery.SourceFileLoader( - module_id, path).load_module(module_id) + module_id, path + ).load_module(module_id) del sys.modules[module_id] return module def load_module_pyc(module_id, path): module = importlib.machinery.SourcelessFileLoader( - module_id, path).load_module(module_id) + module_id, path + ).load_module(module_id) del sys.modules[module_id] return module + if py3k: + def get_bytecode_suffixes(): try: return importlib.machinery.BYTECODE_SUFFIXES @@ -188,13 +204,15 @@ if py3k: # http://www.python.org/dev/peps/pep-3147/#detecting-pep-3147-availability import imp - return hasattr(imp, 'get_tag') + + return hasattr(imp, "get_tag") + else: import imp def load_module_py(module_id, path): # noqa - with open(path, 'rb') as fp: + with open(path, "rb") as fp: mod = imp.load_source(module_id, path, fp) if py2k: source_encoding = parse_encoding(fp) @@ -204,7 +222,7 @@ else: return mod def load_module_pyc(module_id, path): # noqa - with open(path, 'rb') as fp: + with open(path, "rb") as fp: mod = imp.load_compiled(module_id, path, fp) # no source encoding here del sys.modules[module_id] @@ -219,12 +237,14 @@ else: def has_pep3147(): return False + try: - exec_ = getattr(compat_builtins, 'exec') + exec_ = getattr(compat_builtins, "exec") except AttributeError: # Python 2 def exec_(func_text, globals_, lcl): - exec('exec func_text in globals_, lcl') + exec("exec func_text in globals_, lcl") + ################################################ # cross-compatible metaclass implementation @@ -234,9 +254,12 @@ except AttributeError: def with_metaclass(meta, base=object): """Create a base class with a metaclass.""" return meta("%sBase" % meta.__name__, (base,), {}) + + ################################################ if py3k: + def reraise(tp, value, tb=None, cause=None): if cause is not None: value.__cause__ = cause @@ -249,9 +272,13 @@ if py3k: exc_info = sys.exc_info() exc_type, exc_value, exc_tb = exc_info reraise(type(exception), exception, tb=exc_tb, cause=exc_value) + + else: - exec("def reraise(tp, value, tb=None, cause=None):\n" - " raise tp, value, tb\n") + exec( + "def reraise(tp, value, tb=None, cause=None):\n" + " raise tp, value, tb\n" + ) def raise_from_cause(exception, exc_info=None): # not as nice as that of Py3K, but at least preserves @@ -261,14 +288,15 @@ else: exc_type, exc_value, exc_tb = exc_info reraise(type(exception), exception, tb=exc_tb) + # produce a wrapper that allows encoded text to stream # into a given buffer, but doesn't close it. # not sure of a more idiomatic approach to this. class EncodedIO(io.TextIOWrapper): - def close(self): pass + if py2k: # in Py2K, the io.* package is awkward because it does not # easily wrap the file type (e.g. sys.stdout) and I can't @@ -303,7 +331,7 @@ if py2k: return self.file_.flush() class EncodedIO(EncodedIO): - def __init__(self, file_, encoding): super(EncodedIO, self).__init__( - ActLikePy3kIO(file_), encoding=encoding) + ActLikePy3kIO(file_), encoding=encoding + ) diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index 832332c..a298cc0 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -37,23 +37,21 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): def _install_proxy(self): attr_names, modules = self._setups[self.__class__] for globals_, locals_ in modules: - globals_['_proxy'] = self + globals_["_proxy"] = self for attr_name in attr_names: globals_[attr_name] = getattr(self, attr_name) def _remove_proxy(self): attr_names, modules = self._setups[self.__class__] for globals_, locals_ in modules: - globals_['_proxy'] = None + globals_["_proxy"] = None for attr_name in attr_names: del globals_[attr_name] @classmethod def create_module_class_proxy(cls, globals_, locals_): attr_names, modules = cls._setups[cls] - modules.append( - (globals_, locals_) - ) + modules.append((globals_, locals_)) cls._setup_proxy(globals_, locals_, attr_names) @classmethod @@ -63,11 +61,12 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): @classmethod def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names): - if not methname.startswith('_'): + if not methname.startswith("_"): meth = getattr(cls, methname) if callable(meth): locals_[methname] = cls._create_method_proxy( - methname, globals_, locals_) + methname, globals_, locals_ + ) else: attr_names.add(methname) @@ -75,7 +74,7 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): def _create_method_proxy(cls, name, globals_, locals_): fn = getattr(cls, name) spec = inspect_getargspec(fn) - if spec[0] and spec[0][0] == 'self': + if spec[0] and spec[0][0] == "self": spec[0].pop(0) args = inspect_formatargspec(*spec) num_defaults = 0 @@ -83,24 +82,28 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): num_defaults += len(spec[3]) name_args = spec[0] if num_defaults: - defaulted_vals = name_args[0 - num_defaults:] + defaulted_vals = name_args[0 - num_defaults :] else: defaulted_vals = () apply_kw = inspect_formatargspec( - name_args, spec[1], spec[2], + name_args, + spec[1], + spec[2], defaulted_vals, - formatvalue=lambda x: '=' + x) + formatvalue=lambda x: "=" + x, + ) def _name_error(name): raise NameError( "Can't invoke function '%s', as the proxy object has " "not yet been " "established for the Alembic '%s' class. " - "Try placing this code inside a callable." % ( - name, cls.__name__ - )) - globals_['_name_error'] = _name_error + "Try placing this code inside a callable." + % (name, cls.__name__) + ) + + globals_["_name_error"] = _name_error translations = getattr(fn, "_legacy_translations", []) if translations: @@ -108,7 +111,7 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % ( fn.__name__, tuple(spec), - translations + translations, ) def translate(fn_name, spec, translations, args, kw): @@ -119,15 +122,14 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): if oldname in kw: warnings.warn( "Argument %r is now named %r " - "for method %s()." % ( - oldname, newname, fn_name - )) + "for method %s()." % (oldname, newname, fn_name) + ) return_kw[newname] = kw.pop(oldname) return_kw.update(kw) args = list(args) if spec[3]: - pos_only = spec[0][:-len(spec[3])] + pos_only = spec[0][: -len(spec[3])] else: pos_only = spec[0] for arg in pos_only: @@ -137,17 +139,20 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): except IndexError: raise TypeError( "missing required positional argument: %s" - % arg) + % arg + ) return_args.extend(args) return return_args, return_kw - globals_['_translate'] = translate + + globals_["_translate"] = translate else: outer_args = args[1:-1] inner_args = apply_kw[1:-1] translate_str = "" - func_text = textwrap.dedent("""\ + func_text = textwrap.dedent( + """\ def %(name)s(%(args)s): %(doc)r %(translate)s @@ -157,13 +162,15 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): _name_error('%(name)s') return _proxy.%(name)s(%(apply_kw)s) e - """ % { - 'name': name, - 'translate': translate_str, - 'args': outer_args, - 'apply_kw': inner_args, - 'doc': fn.__doc__, - }) + """ + % { + "name": name, + "translate": translate_str, + "args": outer_args, + "apply_kw": inner_args, + "doc": fn.__doc__, + } + ) lcl = {} exec_(func_text, globals_, lcl) return lcl[name] @@ -178,8 +185,7 @@ def _with_legacy_names(translations): def asbool(value): - return value is not None and \ - value.lower() == 'true' + return value is not None and value.lower() == "true" def rev_id(): @@ -201,31 +207,30 @@ def to_tuple(x, default=None): if x is None: return default elif isinstance(x, string_types): - return (x, ) + return (x,) elif isinstance(x, collections_abc.Iterable): return tuple(x) else: - return (x, ) + return (x,) def unique_list(seq, hashfunc=None): seen = set() seen_add = seen.add if not hashfunc: - return [x for x in seq - if x not in seen - and not seen_add(x)] + return [x for x in seq if x not in seen and not seen_add(x)] else: - return [x for x in seq - if hashfunc(x) not in seen - and not seen_add(hashfunc(x))] + return [ + x + for x in seq + if hashfunc(x) not in seen and not seen_add(hashfunc(x)) + ] def dedupe_tuple(tup): return tuple(unique_list(tup)) - class memoized_property(object): """A read-only @property that is only evaluated once.""" @@ -243,13 +248,12 @@ class memoized_property(object): class immutabledict(dict): - def _immutable(self, *arg, **kw): raise TypeError("%s object is immutable" % self.__class__.__name__) - __delitem__ = __setitem__ = __setattr__ = \ - clear = pop = popitem = setdefault = \ - update = _immutable + __delitem__ = ( + __setitem__ + ) = __setattr__ = clear = pop = popitem = setdefault = update = _immutable def __new__(cls, *args): new = dict.__new__(cls) @@ -260,7 +264,7 @@ class immutabledict(dict): pass def __reduce__(self): - return immutabledict, (dict(self), ) + return immutabledict, (dict(self),) def union(self, d): if not self: @@ -279,7 +283,7 @@ class Dispatcher(object): self._registry = {} self.uselist = uselist - def dispatch_for(self, target, qualifier='default'): + def dispatch_for(self, target, qualifier="default"): def decorate(fn): if self.uselist: self._registry.setdefault((target, qualifier), []).append(fn) @@ -287,9 +291,10 @@ class Dispatcher(object): assert (target, qualifier) not in self._registry self._registry[(target, qualifier)] = fn return fn + return decorate - def dispatch(self, obj, qualifier='default'): + def dispatch(self, obj, qualifier="default"): if isinstance(obj, string_types): targets = [obj] @@ -299,20 +304,20 @@ class Dispatcher(object): targets = type(obj).__mro__ for spcls in targets: - if qualifier != 'default' and (spcls, qualifier) in self._registry: - return self._fn_or_list( - self._registry[(spcls, qualifier)]) - elif (spcls, 'default') in self._registry: - return self._fn_or_list( - self._registry[(spcls, 'default')]) + if qualifier != "default" and (spcls, qualifier) in self._registry: + return self._fn_or_list(self._registry[(spcls, qualifier)]) + elif (spcls, "default") in self._registry: + return self._fn_or_list(self._registry[(spcls, "default")]) else: raise ValueError("no dispatch function for object: %s" % obj) def _fn_or_list(self, fn_or_list): if self.uselist: + def go(*arg, **kw): for fn in fn_or_list: fn(*arg, **kw) + return go else: return fn_or_list @@ -324,8 +329,7 @@ class Dispatcher(object): d = Dispatcher() if self.uselist: d._registry.update( - (k, [fn for fn in self._registry[k]]) - for k in self._registry + (k, [fn for fn in self._registry[k]]) for k in self._registry ) else: d._registry.update(self._registry) diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py index 872345b..44eacbf 100644 --- a/alembic/util/messaging.py +++ b/alembic/util/messaging.py @@ -11,16 +11,16 @@ log = logging.getLogger(__name__) if py27: # disable "no handler found" errors - logging.getLogger('alembic').addHandler(logging.NullHandler()) + logging.getLogger("alembic").addHandler(logging.NullHandler()) try: import fcntl import termios import struct - ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, - struct.pack('HHHH', 0, 0, 0, 0)) - _h, TERMWIDTH, _hp, _wp = struct.unpack('HHHH', ioctl) + + ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack("HHHH", 0, 0, 0, 0)) + _h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl) if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty TERMWIDTH = None except (ImportError, IOError): @@ -28,10 +28,10 @@ except (ImportError, IOError): def write_outstream(stream, *text): - encoding = getattr(stream, 'encoding', 'ascii') or 'ascii' + encoding = getattr(stream, "encoding", "ascii") or "ascii" for t in text: if not isinstance(t, binary_type): - t = t.encode(encoding, 'replace') + t = t.encode(encoding, "replace") t = t.decode(encoding) try: stream.write(t) @@ -62,7 +62,7 @@ def err(message): def obfuscate_url_pw(u): u = url.make_url(u) if u.password: - u.password = 'XXXXX' + u.password = "XXXXX" return str(u) diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py index 0e52133..4093b89 100644 --- a/alembic/util/pyfiles.py +++ b/alembic/util/pyfiles.py @@ -1,8 +1,12 @@ import sys import os import re -from .compat import load_module_py, load_module_pyc, \ - get_current_bytecode_suffixes, has_pep3147 +from .compat import ( + load_module_py, + load_module_pyc, + get_current_bytecode_suffixes, + has_pep3147, +) from mako.template import Template from mako import exceptions import tempfile @@ -14,16 +18,19 @@ def template_to_file(template_file, dest, output_encoding, **kw): try: output = template.render_unicode(**kw).encode(output_encoding) except: - with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as ntf: + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as ntf: ntf.write( - exceptions.text_error_template(). - render_unicode().encode(output_encoding)) + exceptions.text_error_template() + .render_unicode() + .encode(output_encoding) + ) fname = ntf.name raise CommandError( "Template rendering failed; see %s for a " - "template-oriented traceback." % fname) + "template-oriented traceback." % fname + ) else: - with open(dest, 'wb') as f: + with open(dest, "wb") as f: f.write(output) @@ -37,7 +44,8 @@ def coerce_resource_to_filename(fname): """ if not os.path.isabs(fname) and ":" in fname: import pkg_resources - fname = pkg_resources.resource_filename(*fname.split(':')) + + fname = pkg_resources.resource_filename(*fname.split(":")) return fname @@ -48,6 +56,7 @@ def pyc_file_from_path(path): if has_pep3147(): import imp + candidate = imp.cache_from_source(path) if os.path.exists(candidate): return candidate @@ -64,16 +73,17 @@ def edit(path): """Given a source path, run the EDITOR for it""" import editor + try: editor.edit(path) except Exception as exc: - raise CommandError('Error executing editor (%s)' % (exc,)) + raise CommandError("Error executing editor (%s)" % (exc,)) def load_python_file(dir_, filename): """Load a file from the given path as a Python module.""" - module_id = re.sub(r'\W', "_", filename) + module_id = re.sub(r"\W", "_", filename) path = os.path.join(dir_, filename) _, ext = os.path.splitext(filename) if ext == ".py": diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 0556124..63c9798 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -15,8 +15,11 @@ def _safe_int(value): return int(value) except: return value + + _vers = tuple( - [_safe_int(x) for x in re.findall(r'(\d+|[abc]\d)', __version__)]) + [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] +) sqla_09 = _vers >= (0, 9, 0) sqla_092 = _vers >= (0, 9, 2) sqla_094 = _vers >= (0, 9, 4) @@ -31,7 +34,7 @@ sqla_1115 = _vers >= (1, 1, 15) if sqla_110: - AUTOINCREMENT_DEFAULT = 'auto' + AUTOINCREMENT_DEFAULT = "auto" else: AUTOINCREMENT_DEFAULT = True @@ -55,10 +58,12 @@ def _columns_for_constraint(constraint): def _fk_spec(constraint): if sqla_100: source_columns = [ - constraint.columns[key].name for key in constraint.column_keys] + constraint.columns[key].name for key in constraint.column_keys + ] else: source_columns = [ - element.parent.name for element in constraint.elements] + element.parent.name for element in constraint.elements + ] source_table = constraint.parent.name source_schema = constraint.parent.schema @@ -70,9 +75,17 @@ def _fk_spec(constraint): deferrable = constraint.deferrable initially = constraint.initially return ( - source_schema, source_table, - source_columns, target_schema, target_table, target_columns, - onupdate, ondelete, deferrable, initially) + source_schema, + source_table, + source_columns, + target_schema, + target_table, + target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) def _fk_is_self_referential(constraint): @@ -91,11 +104,9 @@ def _is_type_bound(constraint): return constraint._type_bound else: # old way, look at what we know Boolean/Enum to use - return ( - constraint._create_rule is not None and - isinstance( - getattr(constraint._create_rule, "target", None), - sqltypes.SchemaType) + return constraint._create_rule is not None and isinstance( + getattr(constraint._create_rule, "target", None), + sqltypes.SchemaType, ) @@ -103,7 +114,7 @@ def _find_columns(clause): """locate Column objects within the given expression.""" cols = set() - traverse(clause, {}, {'column': cols.add}) + traverse(clause, {}, {"column": cols.add}) return cols @@ -143,7 +154,8 @@ class _textual_index_element(sql.ColumnElement): See SQLAlchemy issue 3174. """ - __visit_name__ = '_textual_idx_element' + + __visit_name__ = "_textual_idx_element" def __init__(self, table, text): self.table = table @@ -198,7 +210,7 @@ def _get_index_final_name(dialect, idx): def _is_mariadb(mysql_dialect): - return 'MariaDB' in mysql_dialect.server_version_info + return "MariaDB" in mysql_dialect.server_version_info def _mariadb_normalized_version_info(mysql_dialect): @@ -15,6 +15,21 @@ identity = C4DAFEE1 with-sqla_testing = true where = tests +[flake8] +show-source = true +enable-extensions = G +# E203 is due to https://github.com/PyCQA/pycodestyle/issues/373 +ignore = + A003, + D, + E203,E305,E711,E712,E721,E722,E741, + N801,N802,N806, + RST304,RST303,RST299,RST399, + W503,W504 +exclude = .venv,.git,.tox,dist,doc,*egg,build +import-order-style = google +application-import-names = alembic,tests + [sqla_testing] requirement_cls=tests.requirements:DefaultRequirements @@ -5,23 +5,23 @@ import re import sys -v = open(os.path.join(os.path.dirname(__file__), 'alembic', '__init__.py')) -VERSION = re.compile(r".*__version__ = '(.*?)'", re.S).match(v.read()).group(1) +v = open(os.path.join(os.path.dirname(__file__), "alembic", "__init__.py")) +VERSION = re.compile(r""".*__version__ = ["'](.*?)["']""", re.S).match(v.read()).group(1) v.close() -readme = os.path.join(os.path.dirname(__file__), 'README.rst') +readme = os.path.join(os.path.dirname(__file__), "README.rst") requires = [ - 'SQLAlchemy>=0.9.0', - 'Mako', - 'python-editor>=0.3', - 'python-dateutil' + "SQLAlchemy>=0.9.0", + "Mako", + "python-editor>=0.3", + "python-dateutil", ] class PyTest(TestCommand): - user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")] + user_options = [("pytest-args=", "a", "Arguments to pass to py.test")] def initialize_options(self): TestCommand.initialize_options(self) @@ -35,42 +35,42 @@ class PyTest(TestCommand): def run_tests(self): # import here, cause outside the eggs aren't loaded import pytest + errno = pytest.main(self.pytest_args) sys.exit(errno) -setup(name='alembic', - version=VERSION, - description="A database migration tool for SQLAlchemy.", - long_description=open(readme).read(), - python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Database :: Front-Ends', - ], - keywords='SQLAlchemy migrations', - author='Mike Bayer', - author_email='mike@zzzcomputing.com', - url='https://alembic.sqlalchemy.org', - license='MIT', - packages=find_packages('.', exclude=['examples*', 'test*']), - include_package_data=True, - tests_require=['pytest!=3.9.1,!=3.9.2', 'mock', 'Mako'], - cmdclass={'test': PyTest}, - zip_safe=False, - install_requires=requires, - entry_points={ - 'console_scripts': ['alembic = alembic.config:main'], - } - ) +setup( + name="alembic", + version=VERSION, + description="A database migration tool for SQLAlchemy.", + long_description=open(readme).read(), + python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Database :: Front-Ends", + ], + keywords="SQLAlchemy migrations", + author="Mike Bayer", + author_email="mike@zzzcomputing.com", + url="https://alembic.sqlalchemy.org", + license="MIT", + packages=find_packages(".", exclude=["examples*", "test*"]), + include_package_data=True, + tests_require=["pytest!=3.9.1,!=3.9.2", "mock", "Mako"], + cmdclass={"test": PyTest}, + zip_safe=False, + install_requires=requires, + entry_points={"console_scripts": ["alembic = alembic.config:main"]}, +) diff --git a/tests/_autogen_fixtures.py b/tests/_autogen_fixtures.py index 94c6866..4bda756 100644 --- a/tests/_autogen_fixtures.py +++ b/tests/_autogen_fixtures.py @@ -1,5 +1,18 @@ -from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ - Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text +from sqlalchemy import ( + MetaData, + Column, + Table, + Integer, + String, + Text, + Numeric, + CHAR, + ForeignKey, + Index, + UniqueConstraint, + CheckConstraint, + text, +) from sqlalchemy.engine.reflection import Inspector from alembic.operations import ops @@ -27,11 +40,12 @@ def _default_include_object(obj, name, type_, reflected, compare_to): else: return True + _default_object_filters = _default_include_object class ModelOne(object): - __requires__ = ('unique_constraint_reflection', ) + __requires__ = ("unique_constraint_reflection",) schema = None @@ -41,30 +55,42 @@ class ModelOne(object): m = MetaData(schema=schema) - Table('user', m, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('a1', Text), - Column("pw", String(50)), - Index('pw_idx', 'pw') - ) - - Table('address', m, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - ) - - Table('order', m, - Column('order_id', Integer, primary_key=True), - Column("amount", Numeric(8, 2), nullable=False, - server_default=text("0")), - CheckConstraint('amount >= 0', name='ck_order_amount') - ) - - Table('extra', m, - Column("x", CHAR), - Column('uid', Integer, ForeignKey('user.id')) - ) + Table( + "user", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("a1", Text), + Column("pw", String(50)), + Index("pw_idx", "pw"), + ) + + Table( + "address", + m, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + ) + + Table( + "order", + m, + Column("order_id", Integer, primary_key=True), + Column( + "amount", + Numeric(8, 2), + nullable=False, + server_default=text("0"), + ), + CheckConstraint("amount >= 0", name="ck_order_amount"), + ) + + Table( + "extra", + m, + Column("x", CHAR), + Column("uid", Integer, ForeignKey("user.id")), + ) return m @@ -74,50 +100,80 @@ class ModelOne(object): m = MetaData(schema=schema) - Table('user', m, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', Text, server_default="x") - ) - - Table('address', m, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('street', String(50)), - UniqueConstraint('email_address', name="uq_email") - ) - - Table('order', m, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True, - server_default=text("0")), - Column('user_id', Integer, ForeignKey('user.id')), - CheckConstraint('amount > -1', name='ck_order_amount'), - ) - - Table('item', m, - Column('id', Integer, primary_key=True), - Column('description', String(100)), - Column('order_id', Integer, ForeignKey('order.order_id')), - CheckConstraint('len(description) > 5') - ) + Table( + "user", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", Text, server_default="x"), + ) + + Table( + "address", + m, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + Column("street", String(50)), + UniqueConstraint("email_address", name="uq_email"), + ) + + Table( + "order", + m, + Column("order_id", Integer, primary_key=True), + Column( + "amount", + Numeric(10, 2), + nullable=True, + server_default=text("0"), + ), + Column("user_id", Integer, ForeignKey("user.id")), + CheckConstraint("amount > -1", name="ck_order_amount"), + ) + + Table( + "item", + m, + Column("id", Integer, primary_key=True), + Column("description", String(100)), + Column("order_id", Integer, ForeignKey("order.order_id")), + CheckConstraint("len(description) > 5"), + ) return m class _ComparesFKs(object): def _assert_fk_diff( - self, diff, type_, source_table, source_columns, - target_table, target_columns, name=None, conditional_name=None, - source_schema=None, onupdate=None, ondelete=None, - initially=None, deferrable=None): + self, + diff, + type_, + source_table, + source_columns, + target_table, + target_columns, + name=None, + conditional_name=None, + source_schema=None, + onupdate=None, + ondelete=None, + initially=None, + deferrable=None, + ): # the public API for ForeignKeyConstraint was not very rich # in 0.7, 0.8, so here we use the well-known but slightly # private API to get at its elements - (fk_source_schema, fk_source_table, - fk_source_columns, fk_target_schema, fk_target_table, - fk_target_columns, - fk_onupdate, fk_ondelete, fk_deferrable, fk_initially - ) = _fk_spec(diff[1]) + ( + fk_source_schema, + fk_source_table, + fk_source_columns, + fk_target_schema, + fk_target_table, + fk_target_columns, + fk_onupdate, + fk_ondelete, + fk_deferrable, + fk_initially, + ) = _fk_spec(diff[1]) eq_(diff[0], type_) eq_(fk_source_table, source_table) @@ -129,15 +185,15 @@ class _ComparesFKs(object): eq_(fk_initially, initially) eq_(fk_deferrable, deferrable) - eq_([elem.column.name for elem in diff[1].elements], - target_columns) + eq_([elem.column.name for elem in diff[1].elements], target_columns) if conditional_name is not None: if config.requirements.no_fk_names.enabled: eq_(diff[1].name, None) - elif conditional_name == 'servergenerated': - fks = Inspector.from_engine(self.bind).\ - get_foreign_keys(source_table) - server_fk_name = fks[0]['name'] + elif conditional_name == "servergenerated": + fks = Inspector.from_engine(self.bind).get_foreign_keys( + source_table + ) + server_fk_name = fks[0]["name"] eq_(diff[1].name, server_fk_name) else: eq_(diff[1].name, conditional_name) @@ -146,7 +202,6 @@ class _ComparesFKs(object): class AutogenTest(_ComparesFKs): - def _flatten_diffs(self, diffs): for d in diffs: if isinstance(d, list): @@ -177,20 +232,19 @@ class AutogenTest(_ComparesFKs): def setUp(self): self.conn = conn = self.bind.connect() ctx_opts = { - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': self.m2, - 'upgrade_token': "upgrades", - 'downgrade_token': "downgrades", - 'alembic_module_prefix': 'op.', - 'sqlalchemy_module_prefix': 'sa.', - 'include_object': _default_object_filters + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m2, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + "alembic_module_prefix": "op.", + "sqlalchemy_module_prefix": "sa.", + "include_object": _default_object_filters, } if self.configure_opts: ctx_opts.update(self.configure_opts) self.context = context = MigrationContext.configure( - connection=conn, - opts=ctx_opts + connection=conn, opts=ctx_opts ) self.autogen_context = api.AutogenContext(context, self.m2) @@ -200,46 +254,47 @@ class AutogenTest(_ComparesFKs): def _update_context(self, object_filters=None, include_schemas=None): if include_schemas is not None: - self.autogen_context.opts['include_schemas'] = include_schemas + self.autogen_context.opts["include_schemas"] = include_schemas if object_filters is not None: self.autogen_context._object_filters = [object_filters] return self.autogen_context class AutogenFixtureTest(_ComparesFKs): - def _fixture( - self, m1, m2, include_schemas=False, - opts=None, object_filters=_default_object_filters, - return_ops=False): + self, + m1, + m2, + include_schemas=False, + opts=None, + object_filters=_default_object_filters, + return_ops=False, + ): self.metadata, model_metadata = m1, m2 for m in util.to_list(self.metadata): m.create_all(self.bind) with self.bind.connect() as conn: ctx_opts = { - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': model_metadata, - 'upgrade_token': "upgrades", - 'downgrade_token': "downgrades", - 'alembic_module_prefix': 'op.', - 'sqlalchemy_module_prefix': 'sa.', - 'include_object': object_filters, - 'include_schemas': include_schemas + "compare_type": True, + "compare_server_default": True, + "target_metadata": model_metadata, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + "alembic_module_prefix": "op.", + "sqlalchemy_module_prefix": "sa.", + "include_object": object_filters, + "include_schemas": include_schemas, } if opts: ctx_opts.update(opts) self.context = context = MigrationContext.configure( - connection=conn, - opts=ctx_opts + connection=conn, opts=ctx_opts ) autogen_context = api.AutogenContext(context, model_metadata) uo = ops.UpgradeOps(ops=[]) - autogenerate._produce_net_changes( - autogen_context, uo - ) + autogenerate._produce_net_changes(autogen_context, uo) if return_ops: return uo @@ -253,8 +308,7 @@ class AutogenFixtureTest(_ComparesFKs): self.bind = config.db def tearDown(self): - if hasattr(self, 'metadata'): + if hasattr(self, "metadata"): for m in util.to_list(self.metadata): m.drop_all(self.bind) clear_staging_env() - diff --git a/tests/_large_map.py b/tests/_large_map.py index bc7133f..13ac41f 100644 --- a/tests/_large_map.py +++ b/tests/_large_map.py @@ -1,152 +1,148 @@ from alembic.script.revision import RevisionMap, Revision data = [ - Revision('3fc8a578bc0a', ('4878cb1cb7f6', '454a0529f84e'), ), - Revision('69285b0faaa', ('36c31e4e1c37', '3a3b24a31b57'), ), - Revision('3b0452c64639', '2f1a0f3667f3', ), - Revision('2d9d787a496', '135b5fd31062', ), - Revision('184f65ed83af', '3b0452c64639', ), - Revision('430074f99c29', '54f871bfe0b0', ), - Revision('3ffb59981d9a', '519c9f3ce294', ), - Revision('454a0529f84e', ('40f6508e4373', '38a936c6ab11'), ), - Revision('24c2620b2e3f', ('430074f99c29', '1f5ceb1ec255'), ), - Revision('169a948471a9', '247ad6880f93', ), - Revision('2f1a0f3667f3', '17dd0f165262', ), - Revision('27227dc4fda8', '2a66d7c4d8a1', ), - Revision('4b2ad1ffe2e7', ('3b409f268da4', '4f8a9b79a063'), ), - Revision('124ef6a17781', '2529684536da', ), - Revision('4789d9c82ca7', '593b8076fb2c', ), - Revision('64ed798bcc3', ('44ed1bf512a0', '169a948471a9'), ), - Revision('2588a3c36a0f', '50c7b21c9089', ), - Revision('359329c2ebb', ('5810e9eff996', '339faa12616'), ), - Revision('540bc5634bd', '3a5db5f31209', ), - Revision('20fe477817d2', '53d5ff905573', ), - Revision('4f8a9b79a063', ('3cf34fcd6473', '300209d8594'), ), - Revision('6918589deaf', '3314c17f6e35', ), - Revision('1755e3b1481c', ('17b66754be21', '31b1d4b7fc95'), ), - Revision('58c988e1aa4e', ('219240032b88', 'f067f0b825c'), ), - Revision('593b8076fb2c', '1d94175d221b', ), - Revision('38d069994064', ('46b70a57edc0', '3ed56beabfb7'), ), - Revision('3e2f6c6d1182', '7f96a01461b', ), - Revision('1f6969597fe7', '1811bdae9e63', ), - Revision('17dd0f165262', '3cf02a593a68', ), - Revision('3cf02a593a68', '25a7ef58d293', ), - Revision('34dfac7edb2d', '28f4dd53ad3a', ), - Revision('4009c533e05d', '42ded7355da2', ), - Revision('5a0003c3b09c', ('3ed56beabfb7', '2028d94d3863'), ), - Revision('38a936c6ab11', '2588a3c36a0f', ), - Revision('59223c5b7b36', '2f93dd880bae', ), - Revision('4121bd6e99e9', '540bc5634bd', ), - Revision('260714a3f2de', '6918589deaf', ), - Revision('ae77a2ed69b', '274fd2642933', ), - Revision('18ff1ab3b4c4', '430133b6d46c', ), - Revision('2b9a327527a9', ('359329c2ebb', '593b8076fb2c'), ), - Revision('4e6167c75ed0', '325b273d61bd', ), - Revision('21ab11a7c5c4', ('3da31f3323ec', '22f26011d635'), ), - Revision('3b93e98481b1', '4e28e2f4fe2f', ), - Revision('145d8f1e334d', 'b4143d129e', ), - Revision('135b5fd31062', '1d94175d221b', ), - Revision('300209d8594', ('52804033910e', '593b8076fb2c'), ), - Revision('8dca95cce28', 'f034666cd80', ), - Revision('46b70a57edc0', ('145d8f1e334d', '4cc2960cbe19'), ), - Revision('4d45e479fbb9', '2d9d787a496', ), - Revision('22f085bf8bbd', '540bc5634bd', ), - Revision('263e91fd17d8', '2b9a327527a9', ), - Revision('219240032b88', ('300209d8594', '2b9a327527a9'), ), - Revision('325b273d61bd', '4b2ad1ffe2e7', ), - Revision('199943ccc774', '1aa674ccfa4e', ), - Revision('247ad6880f93', '1f6969597fe7', ), - Revision('4878cb1cb7f6', '28f4dd53ad3a', ), - Revision('2a66d7c4d8a1', '23f1ccb18d6d', ), - Revision('42b079245b55', '593b8076fb2c', ), - Revision('1cccf82219cb', ('20fe477817d2', '915c67915c2'), ), - Revision('b4143d129e', ('159331d6f484', '504d5168afe1'), ), - Revision('53d5ff905573', '3013877bf5bd', ), - Revision('1f5ceb1ec255', '3ffb59981d9a', ), - Revision('ef1c1c1531f', '4738812e6ece', ), - Revision('1f6963d1ae02', '247ad6880f93', ), - Revision('44d58f1d31f0', '18ff1ab3b4c4', ), - Revision('c3ebe64dfb5', ('3409c57b0da', '31f352e77045'), ), - Revision('f067f0b825c', '359329c2ebb', ), - Revision('52ab2d3b57ce', '96d590bd82e', ), - Revision('3b409f268da4', ('20e90eb3eeb6', '263e91fd17d8'), ), - Revision('5a4ca8889674', '4e6167c75ed0', ), - Revision('5810e9eff996', ('2d30d79c4093', '52804033910e'), ), - Revision('40f6508e4373', '4ed16fad67a7', ), - Revision('1811bdae9e63', '260714a3f2de', ), - Revision('3013877bf5bd', ('8dca95cce28', '3fc8a578bc0a'), ), - Revision('16426dbea880', '28f4dd53ad3a', ), - Revision('22f26011d635', ('4c93d063d2ba', '3b93e98481b1'), ), - Revision('3409c57b0da', '17b66754be21', ), - Revision('44373001000f', ('42b079245b55', '219240032b88'), ), - Revision('28f4dd53ad3a', '2e71fd90eb9d', ), - Revision('4cc2960cbe19', '504d5168afe1', ), - Revision('31f352e77045', ('17b66754be21', '22f085bf8bbd'), ), - Revision('4ed16fad67a7', 'f034666cd80', ), - Revision('3da31f3323ec', '4c93d063d2ba', ), - Revision('31b1d4b7fc95', '1cc4459fd115', ), - Revision('11bc0ff42f87', '28f4dd53ad3a', ), - Revision('3a5db5f31209', '59742a546b84', ), - Revision('20e90eb3eeb6', ('58c988e1aa4e', '44373001000f'), ), - Revision('23f1ccb18d6d', '52ab2d3b57ce', ), - Revision('1d94175d221b', '21ab11a7c5c4', ), - Revision('36f1a410ed', '54f871bfe0b0', ), - Revision('181a149173e', '2ee35cac4c62', ), - Revision('171ad2f0c672', '4a4e0838e206', ), - Revision('2f93dd880bae', '540bc5634bd', ), - Revision('25a7ef58d293', None, ), - Revision('7f96a01461b', '184f65ed83af', ), - Revision('b21f22233f', '3e2f6c6d1182', ), - Revision('52804033910e', '1d94175d221b', ), - Revision('1e6240aba5b3', ('4121bd6e99e9', '2c50d8bab6ee'), ), - Revision('1cc4459fd115', '1e6240aba5b3', ), - Revision('274fd2642933', '4009c533e05d', ), - Revision('1aa674ccfa4e', ('59223c5b7b36', '42050bf030fd'), ), - Revision('4e28e2f4fe2f', '596d7b9e11', ), - Revision('49ddec8c7a5e', ('124ef6a17781', '47578179e766'), ), - Revision('3e9bb349cc46', 'ef1c1c1531f', ), - Revision('2028d94d3863', '504d5168afe1', ), - Revision('159331d6f484', '34dfac7edb2d', ), - Revision('596d7b9e11', '171ad2f0c672', ), - Revision('3b96bcc8da76', 'f034666cd80', ), - Revision('4738812e6ece', '78982bf5499', ), - Revision('3314c17f6e35', '27227dc4fda8', ), - Revision('30931c545bf', '2e71fd90eb9d', ), - Revision('2e71fd90eb9d', ('c3ebe64dfb5', '1755e3b1481c'), ), - Revision('3ed56beabfb7', ('11bc0ff42f87', '69285b0faaa'), ), - Revision('96d590bd82e', '3e9bb349cc46', ), - Revision('339faa12616', '4d45e479fbb9', ), - Revision('47578179e766', '2529684536da', ), - Revision('2ee35cac4c62', 'b21f22233f', ), - Revision('50c7b21c9089', ('4ed16fad67a7', '3b96bcc8da76'), ), - Revision('78982bf5499', 'ae77a2ed69b', ), - Revision('519c9f3ce294', '2c50d8bab6ee', ), - Revision('2720fc75e5fd', '1cccf82219cb', ), - Revision('21638ec787ba', '44d58f1d31f0', ), - Revision('59742a546b84', '49ddec8c7a5e', ), - Revision('2d30d79c4093', '135b5fd31062', ), - Revision('f034666cd80', ('5a0003c3b09c', '38d069994064'), ), - Revision('430133b6d46c', '181a149173e', ), - Revision('3a3b24a31b57', ('16426dbea880', '4cc2960cbe19'), ), - Revision('2529684536da', ('64ed798bcc3', '1f6963d1ae02'), ), - Revision('17b66754be21', ('19e0db9d806a', '24c2620b2e3f'), ), - Revision('3cf34fcd6473', ('52804033910e', '4789d9c82ca7'), ), - Revision('36c31e4e1c37', '504d5168afe1', ), - Revision('54f871bfe0b0', '519c9f3ce294', ), - Revision('4a4e0838e206', '2a7f37cf7770', ), - Revision('19e0db9d806a', ('430074f99c29', '36f1a410ed'), ), - Revision('44ed1bf512a0', '247ad6880f93', ), - Revision('42050bf030fd', '2f93dd880bae', ), - Revision('2c50d8bab6ee', '199943ccc774', ), - Revision('504d5168afe1', ('28f4dd53ad3a', '30931c545bf'), ), - Revision('915c67915c2', '3fc8a578bc0a', ), - Revision('2a7f37cf7770', '2720fc75e5fd', ), - Revision('4c93d063d2ba', '4e28e2f4fe2f', ), - Revision('42ded7355da2', '21638ec787ba', ), + Revision("3fc8a578bc0a", ("4878cb1cb7f6", "454a0529f84e")), + Revision("69285b0faaa", ("36c31e4e1c37", "3a3b24a31b57")), + Revision("3b0452c64639", "2f1a0f3667f3"), + Revision("2d9d787a496", "135b5fd31062"), + Revision("184f65ed83af", "3b0452c64639"), + Revision("430074f99c29", "54f871bfe0b0"), + Revision("3ffb59981d9a", "519c9f3ce294"), + Revision("454a0529f84e", ("40f6508e4373", "38a936c6ab11")), + Revision("24c2620b2e3f", ("430074f99c29", "1f5ceb1ec255")), + Revision("169a948471a9", "247ad6880f93"), + Revision("2f1a0f3667f3", "17dd0f165262"), + Revision("27227dc4fda8", "2a66d7c4d8a1"), + Revision("4b2ad1ffe2e7", ("3b409f268da4", "4f8a9b79a063")), + Revision("124ef6a17781", "2529684536da"), + Revision("4789d9c82ca7", "593b8076fb2c"), + Revision("64ed798bcc3", ("44ed1bf512a0", "169a948471a9")), + Revision("2588a3c36a0f", "50c7b21c9089"), + Revision("359329c2ebb", ("5810e9eff996", "339faa12616")), + Revision("540bc5634bd", "3a5db5f31209"), + Revision("20fe477817d2", "53d5ff905573"), + Revision("4f8a9b79a063", ("3cf34fcd6473", "300209d8594")), + Revision("6918589deaf", "3314c17f6e35"), + Revision("1755e3b1481c", ("17b66754be21", "31b1d4b7fc95")), + Revision("58c988e1aa4e", ("219240032b88", "f067f0b825c")), + Revision("593b8076fb2c", "1d94175d221b"), + Revision("38d069994064", ("46b70a57edc0", "3ed56beabfb7")), + Revision("3e2f6c6d1182", "7f96a01461b"), + Revision("1f6969597fe7", "1811bdae9e63"), + Revision("17dd0f165262", "3cf02a593a68"), + Revision("3cf02a593a68", "25a7ef58d293"), + Revision("34dfac7edb2d", "28f4dd53ad3a"), + Revision("4009c533e05d", "42ded7355da2"), + Revision("5a0003c3b09c", ("3ed56beabfb7", "2028d94d3863")), + Revision("38a936c6ab11", "2588a3c36a0f"), + Revision("59223c5b7b36", "2f93dd880bae"), + Revision("4121bd6e99e9", "540bc5634bd"), + Revision("260714a3f2de", "6918589deaf"), + Revision("ae77a2ed69b", "274fd2642933"), + Revision("18ff1ab3b4c4", "430133b6d46c"), + Revision("2b9a327527a9", ("359329c2ebb", "593b8076fb2c")), + Revision("4e6167c75ed0", "325b273d61bd"), + Revision("21ab11a7c5c4", ("3da31f3323ec", "22f26011d635")), + Revision("3b93e98481b1", "4e28e2f4fe2f"), + Revision("145d8f1e334d", "b4143d129e"), + Revision("135b5fd31062", "1d94175d221b"), + Revision("300209d8594", ("52804033910e", "593b8076fb2c")), + Revision("8dca95cce28", "f034666cd80"), + Revision("46b70a57edc0", ("145d8f1e334d", "4cc2960cbe19")), + Revision("4d45e479fbb9", "2d9d787a496"), + Revision("22f085bf8bbd", "540bc5634bd"), + Revision("263e91fd17d8", "2b9a327527a9"), + Revision("219240032b88", ("300209d8594", "2b9a327527a9")), + Revision("325b273d61bd", "4b2ad1ffe2e7"), + Revision("199943ccc774", "1aa674ccfa4e"), + Revision("247ad6880f93", "1f6969597fe7"), + Revision("4878cb1cb7f6", "28f4dd53ad3a"), + Revision("2a66d7c4d8a1", "23f1ccb18d6d"), + Revision("42b079245b55", "593b8076fb2c"), + Revision("1cccf82219cb", ("20fe477817d2", "915c67915c2")), + Revision("b4143d129e", ("159331d6f484", "504d5168afe1")), + Revision("53d5ff905573", "3013877bf5bd"), + Revision("1f5ceb1ec255", "3ffb59981d9a"), + Revision("ef1c1c1531f", "4738812e6ece"), + Revision("1f6963d1ae02", "247ad6880f93"), + Revision("44d58f1d31f0", "18ff1ab3b4c4"), + Revision("c3ebe64dfb5", ("3409c57b0da", "31f352e77045")), + Revision("f067f0b825c", "359329c2ebb"), + Revision("52ab2d3b57ce", "96d590bd82e"), + Revision("3b409f268da4", ("20e90eb3eeb6", "263e91fd17d8")), + Revision("5a4ca8889674", "4e6167c75ed0"), + Revision("5810e9eff996", ("2d30d79c4093", "52804033910e")), + Revision("40f6508e4373", "4ed16fad67a7"), + Revision("1811bdae9e63", "260714a3f2de"), + Revision("3013877bf5bd", ("8dca95cce28", "3fc8a578bc0a")), + Revision("16426dbea880", "28f4dd53ad3a"), + Revision("22f26011d635", ("4c93d063d2ba", "3b93e98481b1")), + Revision("3409c57b0da", "17b66754be21"), + Revision("44373001000f", ("42b079245b55", "219240032b88")), + Revision("28f4dd53ad3a", "2e71fd90eb9d"), + Revision("4cc2960cbe19", "504d5168afe1"), + Revision("31f352e77045", ("17b66754be21", "22f085bf8bbd")), + Revision("4ed16fad67a7", "f034666cd80"), + Revision("3da31f3323ec", "4c93d063d2ba"), + Revision("31b1d4b7fc95", "1cc4459fd115"), + Revision("11bc0ff42f87", "28f4dd53ad3a"), + Revision("3a5db5f31209", "59742a546b84"), + Revision("20e90eb3eeb6", ("58c988e1aa4e", "44373001000f")), + Revision("23f1ccb18d6d", "52ab2d3b57ce"), + Revision("1d94175d221b", "21ab11a7c5c4"), + Revision("36f1a410ed", "54f871bfe0b0"), + Revision("181a149173e", "2ee35cac4c62"), + Revision("171ad2f0c672", "4a4e0838e206"), + Revision("2f93dd880bae", "540bc5634bd"), + Revision("25a7ef58d293", None), + Revision("7f96a01461b", "184f65ed83af"), + Revision("b21f22233f", "3e2f6c6d1182"), + Revision("52804033910e", "1d94175d221b"), + Revision("1e6240aba5b3", ("4121bd6e99e9", "2c50d8bab6ee")), + Revision("1cc4459fd115", "1e6240aba5b3"), + Revision("274fd2642933", "4009c533e05d"), + Revision("1aa674ccfa4e", ("59223c5b7b36", "42050bf030fd")), + Revision("4e28e2f4fe2f", "596d7b9e11"), + Revision("49ddec8c7a5e", ("124ef6a17781", "47578179e766")), + Revision("3e9bb349cc46", "ef1c1c1531f"), + Revision("2028d94d3863", "504d5168afe1"), + Revision("159331d6f484", "34dfac7edb2d"), + Revision("596d7b9e11", "171ad2f0c672"), + Revision("3b96bcc8da76", "f034666cd80"), + Revision("4738812e6ece", "78982bf5499"), + Revision("3314c17f6e35", "27227dc4fda8"), + Revision("30931c545bf", "2e71fd90eb9d"), + Revision("2e71fd90eb9d", ("c3ebe64dfb5", "1755e3b1481c")), + Revision("3ed56beabfb7", ("11bc0ff42f87", "69285b0faaa")), + Revision("96d590bd82e", "3e9bb349cc46"), + Revision("339faa12616", "4d45e479fbb9"), + Revision("47578179e766", "2529684536da"), + Revision("2ee35cac4c62", "b21f22233f"), + Revision("50c7b21c9089", ("4ed16fad67a7", "3b96bcc8da76")), + Revision("78982bf5499", "ae77a2ed69b"), + Revision("519c9f3ce294", "2c50d8bab6ee"), + Revision("2720fc75e5fd", "1cccf82219cb"), + Revision("21638ec787ba", "44d58f1d31f0"), + Revision("59742a546b84", "49ddec8c7a5e"), + Revision("2d30d79c4093", "135b5fd31062"), + Revision("f034666cd80", ("5a0003c3b09c", "38d069994064")), + Revision("430133b6d46c", "181a149173e"), + Revision("3a3b24a31b57", ("16426dbea880", "4cc2960cbe19")), + Revision("2529684536da", ("64ed798bcc3", "1f6963d1ae02")), + Revision("17b66754be21", ("19e0db9d806a", "24c2620b2e3f")), + Revision("3cf34fcd6473", ("52804033910e", "4789d9c82ca7")), + Revision("36c31e4e1c37", "504d5168afe1"), + Revision("54f871bfe0b0", "519c9f3ce294"), + Revision("4a4e0838e206", "2a7f37cf7770"), + Revision("19e0db9d806a", ("430074f99c29", "36f1a410ed")), + Revision("44ed1bf512a0", "247ad6880f93"), + Revision("42050bf030fd", "2f93dd880bae"), + Revision("2c50d8bab6ee", "199943ccc774"), + Revision("504d5168afe1", ("28f4dd53ad3a", "30931c545bf")), + Revision("915c67915c2", "3fc8a578bc0a"), + Revision("2a7f37cf7770", "2720fc75e5fd"), + Revision("4c93d063d2ba", "4e28e2f4fe2f"), + Revision("42ded7355da2", "21638ec787ba"), ] -map_ = RevisionMap( - lambda: data -) - - +map_ = RevisionMap(lambda: data) diff --git a/tests/conftest.py b/tests/conftest.py index 608d903..6cf770d 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,12 +11,16 @@ import os # use bootstrapping so that test plugins are loaded # without touching the main library before coverage starts bootstrap_file = os.path.join( - os.path.dirname(__file__), "..", "alembic", - "testing", "plugin", "bootstrap.py" + os.path.dirname(__file__), + "..", + "alembic", + "testing", + "plugin", + "bootstrap.py", ) with open(bootstrap_file) as f: - code = compile(f.read(), "bootstrap.py", 'exec') + code = compile(f.read(), "bootstrap.py", "exec") to_bootstrap = "pytest" exec(code, globals(), locals()) from pytestplugin import * # noqa diff --git a/tests/requirements.py b/tests/requirements.py index 0be95eb..74a7d83 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -5,16 +5,12 @@ from alembic.util import sqla_compat class DefaultRequirements(SuiteRequirements): - @property def schemas(self): """Target database must support external schemas, and have one named 'test_schema'.""" - return exclusions.skip_if([ - "sqlite", - "firebird" - ], "no schema support") + return exclusions.skip_if(["sqlite", "firebird"], "no schema support") @property def no_referential_integrity(self): @@ -48,12 +44,12 @@ class DefaultRequirements(SuiteRequirements): @property def unnamed_constraints(self): """constraints without names are supported.""" - return exclusions.only_on(['sqlite']) + return exclusions.only_on(["sqlite"]) @property def fk_names(self): """foreign key constraints always have names in the DB""" - return exclusions.fails_on('sqlite') + return exclusions.fails_on("sqlite") @property def no_name_normalize(self): @@ -63,20 +59,24 @@ class DefaultRequirements(SuiteRequirements): @property def reflects_fk_options(self): - return exclusions.only_on([ - 'postgresql', 'mysql', - lambda config: util.sqla_110 and - exclusions.against(config, 'sqlite')]) + return exclusions.only_on( + [ + "postgresql", + "mysql", + lambda config: util.sqla_110 + and exclusions.against(config, "sqlite"), + ] + ) @property def fk_initially(self): """backend supports INITIALLY option in foreign keys""" - return exclusions.only_on(['postgresql']) + return exclusions.only_on(["postgresql"]) @property def fk_deferrable(self): """backend supports DEFERRABLE option in foreign keys""" - return exclusions.only_on(['postgresql']) + return exclusions.only_on(["postgresql"]) @property def flexible_fk_cascades(self): @@ -84,8 +84,7 @@ class DefaultRequirements(SuiteRequirements): full range of keywords (e.g. NO ACTION, etc.)""" return exclusions.skip_if( - ['oracle'], - 'target backend has poor FK cascade syntax' + ["oracle"], "target backend has poor FK cascade syntax" ) @property @@ -97,10 +96,13 @@ class DefaultRequirements(SuiteRequirements): """Target driver reflects the name of primary key constraints.""" return exclusions.fails_on_everything_except( - 'postgresql', 'oracle', 'mssql', 'sybase', + "postgresql", + "oracle", + "mssql", + "sybase", lambda config: ( util.sqla_110 and exclusions.against(config, "sqlite") - ) + ), ) @property @@ -122,8 +124,10 @@ class DefaultRequirements(SuiteRequirements): return False count = config.db.scalar( "SELECT count(*) FROM pg_extension " - "WHERE extname='%s'" % name) + "WHERE extname='%s'" % name + ) return bool(count) + return exclusions.only_if(check, "needs %s extension" % name) @property @@ -134,7 +138,6 @@ class DefaultRequirements(SuiteRequirements): def btree_gist(self): return self._has_pg_extension("btree_gist") - @property def autoincrement_on_composite_pk(self): return exclusions.skip_if(["sqlite"], "not supported by database") @@ -153,34 +156,42 @@ class DefaultRequirements(SuiteRequirements): @property def mysql_check_reflection_or_none(self): def go(config): - return not self._mariadb_102(config) \ - or self.sqlalchemy_1115.enabled + return ( + not self._mariadb_102(config) or self.sqlalchemy_1115.enabled + ) + return exclusions.succeeds_if(go) @property def mysql_timestamp_reflection(self): def go(config): - return not self._mariadb_102(config) \ - or self.sqlalchemy_1115.enabled + return ( + not self._mariadb_102(config) or self.sqlalchemy_1115.enabled + ) + return exclusions.only_if(go) def _mariadb_102(self, config): - return exclusions.against(config, "mysql") and \ - sqla_compat._is_mariadb(config.db.dialect) and \ - sqla_compat._mariadb_normalized_version_info( - config.db.dialect) > (10, 2) + return ( + exclusions.against(config, "mysql") + and sqla_compat._is_mariadb(config.db.dialect) + and sqla_compat._mariadb_normalized_version_info(config.db.dialect) + > (10, 2) + ) def _mariadb_only_102(self, config): - return exclusions.against(config, "mysql") and \ - sqla_compat._is_mariadb(config.db.dialect) and \ - sqla_compat._mariadb_normalized_version_info( - config.db.dialect) >= (10, 2) and \ - sqla_compat._mariadb_normalized_version_info( - config.db.dialect) < (10, 3) + return ( + exclusions.against(config, "mysql") + and sqla_compat._is_mariadb(config.db.dialect) + and sqla_compat._mariadb_normalized_version_info(config.db.dialect) + >= (10, 2) + and sqla_compat._mariadb_normalized_version_info(config.db.dialect) + < (10, 3) + ) def _mysql_not_mariadb_102(self, config): return exclusions.against(config, "mysql") and ( - not sqla_compat._is_mariadb(config.db.dialect) or - sqla_compat._mariadb_normalized_version_info( - config.db.dialect) < (10, 2) + not sqla_compat._is_mariadb(config.db.dialect) + or sqla_compat._mariadb_normalized_version_info(config.db.dialect) + < (10, 2) ) diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py index c536317..7694cbd 100644 --- a/tests/test_autogen_composition.py +++ b/tests/test_autogen_composition.py @@ -9,64 +9,73 @@ from ._autogen_fixtures import AutogenTest, ModelOne, _default_include_object class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def test_render_nothing(self): context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': self.m1, - 'upgrade_token': "upgrades", - 'downgrade_token': "downgrades", - } + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m1, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + }, ) template_args = {} autogenerate._render_migration_diffs(context, template_args) - eq_(re.sub(r"u'", "'", template_args['upgrades']), + eq_( + re.sub(r"u'", "'", template_args["upgrades"]), """# ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ###""") - eq_(re.sub(r"u'", "'", template_args['downgrades']), + # ### end Alembic commands ###""", + ) + eq_( + re.sub(r"u'", "'", template_args["downgrades"]), """# ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) def test_render_nothing_batch(self): context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': self.m1, - 'upgrade_token': "upgrades", - 'downgrade_token': "downgrades", - 'alembic_module_prefix': 'op.', - 'sqlalchemy_module_prefix': 'sa.', - 'render_as_batch': True, - 'include_symbol': lambda name, schema: False - } + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m1, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + "alembic_module_prefix": "op.", + "sqlalchemy_module_prefix": "sa.", + "render_as_batch": True, + "include_symbol": lambda name, schema: False, + }, ) template_args = {} autogenerate._render_migration_diffs(context, template_args) - eq_(re.sub(r"u'", "'", template_args['upgrades']), + eq_( + re.sub(r"u'", "'", template_args["upgrades"]), """# ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ###""") - eq_(re.sub(r"u'", "'", template_args['downgrades']), + # ### end Alembic commands ###""", + ) + eq_( + re.sub(r"u'", "'", template_args["downgrades"]), """# ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) def test_render_diffs_standard(self): """test a full render including indentation""" template_args = {} autogenerate._render_migration_diffs(self.context, template_args) - eq_(re.sub(r"u'", "'", template_args['upgrades']), + eq_( + re.sub(r"u'", "'", template_args["upgrades"]), """# ### commands auto generated by Alembic - please adjust! ### op.create_table('item', sa.Column('id', sa.Integer(), nullable=False), @@ -96,9 +105,11 @@ nullable=True)) nullable=False) op.drop_index('pw_idx', table_name='user') op.drop_column('user', 'pw') - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) - eq_(re.sub(r"u'", "'", template_args['downgrades']), + eq_( + re.sub(r"u'", "'", template_args["downgrades"]), """# ### commands auto generated by Alembic - please adjust! ### op.add_column('user', sa.Column('pw', sa.VARCHAR(length=50), \ nullable=True)) @@ -125,16 +136,18 @@ nullable=True)) sa.ForeignKeyConstraint(['uid'], ['user.id'], ) ) op.drop_table('item') - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) def test_render_diffs_batch(self): """test a full render in batch mode including indentation""" template_args = {} - self.context.opts['render_as_batch'] = True + self.context.opts["render_as_batch"] = True autogenerate._render_migration_diffs(self.context, template_args) - eq_(re.sub(r"u'", "'", template_args['upgrades']), + eq_( + re.sub(r"u'", "'", template_args["upgrades"]), """# ### commands auto generated by Alembic - please adjust! ### op.create_table('item', sa.Column('id', sa.Integer(), nullable=False), @@ -169,9 +182,11 @@ nullable=True)) batch_op.drop_index('pw_idx') batch_op.drop_column('pw') - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) - eq_(re.sub(r"u'", "'", template_args['downgrades']), + eq_( + re.sub(r"u'", "'", template_args["downgrades"]), """# ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table('user', schema=None) as batch_op: batch_op.add_column(sa.Column('pw', sa.VARCHAR(length=50), nullable=True)) @@ -203,74 +218,80 @@ nullable=True)) sa.ForeignKeyConstraint(['uid'], ['user.id'], ) ) op.drop_table('item') - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) def test_imports_maintined(self): template_args = {} - self.context.opts['render_as_batch'] = True + self.context.opts["render_as_batch"] = True def render_item(type_, col, autogen_context): autogen_context.imports.add( "from mypackage import my_special_import" ) - autogen_context.imports.add( - "from foobar import bat" - ) + autogen_context.imports.add("from foobar import bat") self.context.opts["render_item"] = render_item autogenerate._render_migration_diffs(self.context, template_args) eq_( + set(template_args["imports"].split("\n")), set( - template_args['imports'].split("\n") + [ + "from foobar import bat", + "from mypackage import my_special_import", + ] ), - set([ - "from foobar import bat", - "from mypackage import my_special_import" - ]) ) class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" schema = "test_schema" def test_render_nothing(self): context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': self.m1, - 'upgrade_token': "upgrades", - 'downgrade_token': "downgrades", - 'alembic_module_prefix': 'op.', - 'sqlalchemy_module_prefix': 'sa.', - 'include_symbol': lambda name, schema: False - } + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m1, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + "alembic_module_prefix": "op.", + "sqlalchemy_module_prefix": "sa.", + "include_symbol": lambda name, schema: False, + }, ) template_args = {} autogenerate._render_migration_diffs(context, template_args) - eq_(re.sub(r"u'", "'", template_args['upgrades']), + eq_( + re.sub(r"u'", "'", template_args["upgrades"]), """# ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ###""") - eq_(re.sub(r"u'", "'", template_args['downgrades']), + # ### end Alembic commands ###""", + ) + eq_( + re.sub(r"u'", "'", template_args["downgrades"]), """# ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ###""") + # ### end Alembic commands ###""", + ) def test_render_diffs_extras(self): """test a full render including indentation (include and schema)""" template_args = {} - self.context.opts.update({ - 'include_object': _default_include_object, - 'include_schemas': True - }) + self.context.opts.update( + { + "include_object": _default_include_object, + "include_schemas": True, + } + ) autogenerate._render_migration_diffs(self.context, template_args) - eq_(re.sub(r"u'", "'", template_args['upgrades']), + eq_( + re.sub(r"u'", "'", template_args["upgrades"]), """# ### commands auto generated by Alembic - please adjust! ### op.create_table('item', sa.Column('id', sa.Integer(), nullable=False), @@ -307,9 +328,12 @@ source_schema='%(schema)s', referent_schema='%(schema)s') schema='%(schema)s') op.drop_index('pw_idx', table_name='user', schema='test_schema') op.drop_column('user', 'pw', schema='%(schema)s') - # ### end Alembic commands ###""" % {"schema": self.schema}) + # ### end Alembic commands ###""" + % {"schema": self.schema}, + ) - eq_(re.sub(r"u'", "'", template_args['downgrades']), + eq_( + re.sub(r"u'", "'", template_args["downgrades"]), """# ### commands auto generated by Alembic - please adjust! ### op.add_column('user', sa.Column('pw', sa.VARCHAR(length=50), \ autoincrement=False, nullable=True), schema='%(schema)s') @@ -341,5 +365,6 @@ name='extra_uid_fkey'), schema='%(schema)s' ) op.drop_table('item', schema='%(schema)s') - # ### end Alembic commands ###""" % {"schema": self.schema}) - + # ### end Alembic commands ###""" + % {"schema": self.schema}, + ) diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py index 38e06b6..af2e2c2 100644 --- a/tests/test_autogen_diffs.py +++ b/tests/test_autogen_diffs.py @@ -1,10 +1,30 @@ import sys -from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ - Numeric, CHAR, ForeignKey, INTEGER, Index, UniqueConstraint, \ - TypeDecorator, CheckConstraint, text, PrimaryKeyConstraint, \ - ForeignKeyConstraint, VARCHAR, DECIMAL, DateTime, BigInteger, BIGINT, \ - SmallInteger +from sqlalchemy import ( + MetaData, + Column, + Table, + Integer, + String, + Text, + Numeric, + CHAR, + ForeignKey, + INTEGER, + Index, + UniqueConstraint, + TypeDecorator, + CheckConstraint, + text, + PrimaryKeyConstraint, + ForeignKeyConstraint, + VARCHAR, + DECIMAL, + DateTime, + BigInteger, + BIGINT, + SmallInteger, +) from sqlalchemy.dialects import sqlite from sqlalchemy.types import NULLTYPE, VARBINARY from sqlalchemy.engine.reflection import Inspector @@ -20,62 +40,41 @@ from alembic.testing import eq_, is_, is_not_ from alembic.util import CommandError from ._autogen_fixtures import AutogenTest, AutogenFixtureTest -py3k = sys.version_info >= (3, ) +py3k = sys.version_info >= (3,) class AutogenCrossSchemaTest(AutogenTest, TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @classmethod def _get_db_schema(cls): m = MetaData() - Table('t1', m, - Column('x', Integer) - ) - Table('t2', m, - Column('y', Integer), - schema=config.test_schema - ) - Table('t6', m, - Column('u', Integer) - ) - Table('t7', m, - Column('v', Integer), - schema=config.test_schema - ) + Table("t1", m, Column("x", Integer)) + Table("t2", m, Column("y", Integer), schema=config.test_schema) + Table("t6", m, Column("u", Integer)) + Table("t7", m, Column("v", Integer), schema=config.test_schema) return m @classmethod def _get_model_schema(cls): m = MetaData() - Table('t3', m, - Column('q', Integer) - ) - Table('t4', m, - Column('z', Integer), - schema=config.test_schema - ) - Table('t6', m, - Column('u', Integer) - ) - Table('t7', m, - Column('v', Integer), - schema=config.test_schema - ) + Table("t3", m, Column("q", Integer)) + Table("t4", m, Column("z", Integer), schema=config.test_schema) + Table("t6", m, Column("u", Integer)) + Table("t7", m, Column("v", Integer), schema=config.test_schema) return m def test_default_schema_omitted_upgrade(self): - def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t3" else: return True + self._update_context( - object_filters=include_object, - include_schemas=True, + object_filters=include_object, include_schemas=True ) uo = ops.UpgradeOps(ops=[]) autogenerate._produce_net_changes(self.autogen_context, uo) @@ -85,7 +84,6 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): eq_(diffs[0][1].schema, None) def test_alt_schema_included_upgrade(self): - def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t4" @@ -93,8 +91,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): return True self._update_context( - object_filters=include_object, - include_schemas=True, + object_filters=include_object, include_schemas=True ) uo = ops.UpgradeOps(ops=[]) autogenerate._produce_net_changes(self.autogen_context, uo) @@ -109,9 +106,9 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): return name == "t1" else: return True + self._update_context( - object_filters=include_object, - include_schemas=True, + object_filters=include_object, include_schemas=True ) uo = ops.UpgradeOps(ops=[]) autogenerate._produce_net_changes(self.autogen_context, uo) @@ -121,15 +118,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): eq_(diffs[0][1].schema, None) def test_alt_schema_included_downgrade(self): - def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t2" else: return True + self._update_context( - object_filters=include_object, - include_schemas=True, + object_filters=include_object, include_schemas=True ) uo = ops.UpgradeOps(ops=[]) autogenerate._produce_net_changes(self.autogen_context, uo) @@ -139,7 +135,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase): class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True def test_uses_explcit_schema_in_default_one(self): @@ -149,8 +145,8 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('a', m1, Column('x', String(50))) - Table('a', m2, Column('x', String(50)), schema=default_schema) + Table("a", m1, Column("x", String(50))) + Table("a", m2, Column("x", String(50)), schema=default_schema) diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs, []) @@ -162,15 +158,15 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('a', m1, Column('x', String(50))) - Table('a', m2, Column('x', String(50)), schema=default_schema) - Table('a', m2, Column('y', String(50)), schema="test_schema") + Table("a", m1, Column("x", String(50))) + Table("a", m2, Column("x", String(50)), schema=default_schema) + Table("a", m2, Column("y", String(50)), schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(len(diffs), 1) eq_(diffs[0][0], "add_table") eq_(diffs[0][1].schema, "test_schema") - eq_(diffs[0][1].c.keys(), ['y']) + eq_(diffs[0][1].c.keys(), ["y"]) def test_uses_explcit_schema_in_default_three(self): @@ -179,20 +175,20 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('a', m1, Column('y', String(50)), schema="test_schema") + Table("a", m1, Column("y", String(50)), schema="test_schema") - Table('a', m2, Column('x', String(50)), schema=default_schema) - Table('a', m2, Column('y', String(50)), schema="test_schema") + Table("a", m2, Column("x", String(50)), schema=default_schema) + Table("a", m2, Column("y", String(50)), schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(len(diffs), 1) eq_(diffs[0][0], "add_table") eq_(diffs[0][1].schema, default_schema) - eq_(diffs[0][1].c.keys(), ['x']) + eq_(diffs[0][1].c.keys(), ["x"]) class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def setUp(self): super(AutogenDefaultSchemaIsNoneTest, self).setUp() @@ -205,23 +201,23 @@ class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('a', m1, Column('x', String(50))) - Table('a', m2, Column('x', String(50))) + Table("a", m1, Column("x", String(50))) + Table("a", m2, Column("x", String(50))) def _include_object(obj, name, type_, reflected, compare_to): if type_ == "table": - return name in 'a' and obj.schema != 'main' + return name in "a" and obj.schema != "main" else: return True diffs = self._fixture( - m1, m2, include_schemas=True, - object_filters=_include_object) + m1, m2, include_schemas=True, object_filters=_include_object + ) eq_(len(diffs), 0) class ModelOne(object): - __requires__ = ('unique_constraint_reflection', ) + __requires__ = ("unique_constraint_reflection",) schema = None @@ -231,30 +227,42 @@ class ModelOne(object): m = MetaData(schema=schema) - Table('user', m, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('a1', Text), - Column("pw", String(50)), - Index('pw_idx', 'pw') - ) - - Table('address', m, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - ) - - Table('order', m, - Column('order_id', Integer, primary_key=True), - Column("amount", Numeric(8, 2), nullable=False, - server_default=text("0")), - CheckConstraint('amount >= 0', name='ck_order_amount') - ) - - Table('extra', m, - Column("x", CHAR), - Column('uid', Integer, ForeignKey('user.id')) - ) + Table( + "user", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("a1", Text), + Column("pw", String(50)), + Index("pw_idx", "pw"), + ) + + Table( + "address", + m, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + ) + + Table( + "order", + m, + Column("order_id", Integer, primary_key=True), + Column( + "amount", + Numeric(8, 2), + nullable=False, + server_default=text("0"), + ), + CheckConstraint("amount >= 0", name="ck_order_amount"), + ) + + Table( + "extra", + m, + Column("x", CHAR), + Column("uid", Integer, ForeignKey("user.id")), + ) return m @@ -264,38 +272,50 @@ class ModelOne(object): m = MetaData(schema=schema) - Table('user', m, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', Text, server_default="x") - ) - - Table('address', m, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('street', String(50)), - UniqueConstraint('email_address', name="uq_email") - ) - - Table('order', m, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True, - server_default=text("0")), - Column('user_id', Integer, ForeignKey('user.id')), - CheckConstraint('amount > -1', name='ck_order_amount'), - ) - - Table('item', m, - Column('id', Integer, primary_key=True), - Column('description', String(100)), - Column('order_id', Integer, ForeignKey('order.order_id')), - CheckConstraint('len(description) > 5') - ) + Table( + "user", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", Text, server_default="x"), + ) + + Table( + "address", + m, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + Column("street", String(50)), + UniqueConstraint("email_address", name="uq_email"), + ) + + Table( + "order", + m, + Column("order_id", Integer, primary_key=True), + Column( + "amount", + Numeric(10, 2), + nullable=True, + server_default=text("0"), + ), + Column("user_id", Integer, ForeignKey("user.id")), + CheckConstraint("amount > -1", name="ck_order_amount"), + ) + + Table( + "item", + m, + Column("id", Integer, primary_key=True), + Column("description", String(100)), + Column("order_id", Integer, ForeignKey("order.order_id")), + CheckConstraint("len(description) > 5"), + ) return m class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def test_diffs(self): """test generation of diff rules""" @@ -304,23 +324,18 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): uo = ops.UpgradeOps(ops=[]) ctx = self.autogen_context - autogenerate._produce_net_changes( - ctx, uo - ) + autogenerate._produce_net_changes(ctx, uo) diffs = uo.as_diffs() - eq_( - diffs[0], - ('add_table', metadata.tables['item']) - ) + eq_(diffs[0], ("add_table", metadata.tables["item"])) - eq_(diffs[1][0], 'remove_table') + eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") eq_(diffs[2][0], "add_column") eq_(diffs[2][1], None) eq_(diffs[2][2], "address") - eq_(diffs[2][3], metadata.tables['address'].c.street) + eq_(diffs[2][3], metadata.tables["address"].c.street) eq_(diffs[3][0], "add_constraint") eq_(diffs[3][1].name, "uq_email") @@ -328,7 +343,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): eq_(diffs[4][0], "add_column") eq_(diffs[4][1], None) eq_(diffs[4][2], "order") - eq_(diffs[4][3], metadata.tables['order'].c.user_id) + eq_(diffs[4][3], metadata.tables["order"].c.user_id) eq_(diffs[5][0][0], "modify_type") eq_(diffs[5][0][1], None) @@ -338,9 +353,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)") self._assert_fk_diff( - diffs[6], "add_fk", - "order", ["user_id"], - "user", ["id"] + diffs[6], "add_fk", "order", ["user_id"], "user", ["id"] ) eq_(diffs[7][0][0], "modify_default") @@ -349,45 +362,47 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): eq_(diffs[7][0][3], "a1") eq_(diffs[7][0][6].arg, "x") - eq_(diffs[8][0][0], 'modify_nullable') + eq_(diffs[8][0][0], "modify_nullable") eq_(diffs[8][0][5], True) eq_(diffs[8][0][6], False) - eq_(diffs[9][0], 'remove_index') - eq_(diffs[9][1].name, 'pw_idx') + eq_(diffs[9][0], "remove_index") + eq_(diffs[9][1].name, "pw_idx") - eq_(diffs[10][0], 'remove_column') - eq_(diffs[10][3].name, 'pw') - eq_(diffs[10][3].table.name, 'user') - assert isinstance( - diffs[10][3].type, String - ) + eq_(diffs[10][0], "remove_column") + eq_(diffs[10][3].name, "pw") + eq_(diffs[10][3].table.name, "user") + assert isinstance(diffs[10][3].type, String) def test_include_symbol(self): diffs = [] def include_symbol(name, schema=None): - return name in ('address', 'order') + return name in ("address", "order") context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': self.m2, - 'include_symbol': include_symbol, - } + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m2, + "include_symbol": include_symbol, + }, ) diffs = autogenerate.compare_metadata( - context, context.opts['target_metadata']) + context, context.opts["target_metadata"] + ) - alter_cols = set([ - d[2] for d in self._flatten_diffs(diffs) - if d[0].startswith('modify') - ]) - eq_(alter_cols, set(['order'])) + alter_cols = set( + [ + d[2] + for d in self._flatten_diffs(diffs) + if d[0].startswith("modify") + ] + ) + eq_(alter_cols, set(["order"])) def test_include_object(self): def include_object(obj, name, type_, reflected, compare_to): @@ -410,33 +425,46 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'target_metadata': self.m2, - 'include_object': include_object, - } + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m2, + "include_object": include_object, + }, ) diffs = autogenerate.compare_metadata( - context, context.opts['target_metadata']) + context, context.opts["target_metadata"] + ) - alter_cols = set([ - d[2] for d in self._flatten_diffs(diffs) - if d[0].startswith('modify') - ]).union( - d[3].name for d in self._flatten_diffs(diffs) - if d[0] == 'add_column' - ).union( - d[1].name for d in self._flatten_diffs(diffs) - if d[0] == 'add_table' + alter_cols = ( + set( + [ + d[2] + for d in self._flatten_diffs(diffs) + if d[0].startswith("modify") + ] + ) + .union( + d[3].name + for d in self._flatten_diffs(diffs) + if d[0] == "add_column" + ) + .union( + d[1].name + for d in self._flatten_diffs(diffs) + if d[0] == "add_table" + ) ) - eq_(alter_cols, set(['user_id', 'order', 'user'])) + eq_(alter_cols, set(["user_id", "order", "user"])) def test_skip_null_type_comparison_reflected(self): ac = ops.AlterColumnOp("sometable", "somecol") autogenerate.compare._compare_type( - self.autogen_context, ac, - None, "sometable", "somecol", + self.autogen_context, + ac, + None, + "sometable", + "somecol", Column("somecol", NULLTYPE), Column("somecol", Integer()), ) @@ -446,8 +474,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): def test_skip_null_type_comparison_local(self): ac = ops.AlterColumnOp("sometable", "somecol") autogenerate.compare._compare_type( - self.autogen_context, ac, - None, "sometable", "somecol", + self.autogen_context, + ac, + None, + "sometable", + "somecol", Column("somecol", Integer()), Column("somecol", NULLTYPE), ) @@ -463,8 +494,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): ac = ops.AlterColumnOp("sometable", "somecol") autogenerate.compare._compare_type( - self.autogen_context, ac, - None, "sometable", "somecol", + self.autogen_context, + ac, + None, + "sometable", + "somecol", Column("somecol", INTEGER()), Column("somecol", MyType()), ) @@ -473,56 +507,63 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): ac = ops.AlterColumnOp("sometable", "somecol") autogenerate.compare._compare_type( - self.autogen_context, ac, - None, "sometable", "somecol", + self.autogen_context, + ac, + None, + "sometable", + "somecol", Column("somecol", String()), Column("somecol", MyType()), ) diff = ac.to_diff_tuple() - eq_( - diff[0][0:4], - ('modify_type', None, 'sometable', 'somecol') - ) + eq_(diff[0][0:4], ("modify_type", None, "sometable", "somecol")) def test_affinity_typedec(self): class MyType(TypeDecorator): impl = CHAR def load_dialect_impl(self, dialect): - if dialect.name == 'sqlite': + if dialect.name == "sqlite": return dialect.type_descriptor(Integer()) else: return dialect.type_descriptor(CHAR(32)) - uo = ops.AlterColumnOp('sometable', 'somecol') + uo = ops.AlterColumnOp("sometable", "somecol") autogenerate.compare._compare_type( - self.autogen_context, uo, - None, "sometable", "somecol", + self.autogen_context, + uo, + None, + "sometable", + "somecol", Column("somecol", Integer, nullable=True), - Column("somecol", MyType()) + Column("somecol", MyType()), ) assert not uo.has_changes() def test_dont_barf_on_already_reflected(self): from sqlalchemy.util import OrderedSet + inspector = Inspector.from_engine(self.bind) uo = ops.UpgradeOps(ops=[]) autogenerate.compare._compare_tables( - OrderedSet([(None, 'extra'), (None, 'user')]), - OrderedSet(), inspector, - uo, self.autogen_context + OrderedSet([(None, "extra"), (None, "user")]), + OrderedSet(), + inspector, + uo, + self.autogen_context, ) eq_( [(rec[0], rec[1].name) for rec in uo.as_diffs()], [ - ('remove_table', 'extra'), - ('remove_index', 'pw_idx'), - ('remove_table', 'user'), ] + ("remove_table", "extra"), + ("remove_index", "pw_idx"), + ("remove_table", "user"), + ], ) class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True schema = "test_schema" @@ -531,26 +572,21 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): metadata = self.m2 - self._update_context( - include_schemas=True, - ) + self._update_context(include_schemas=True) uo = ops.UpgradeOps(ops=[]) autogenerate._produce_net_changes(self.autogen_context, uo) diffs = uo.as_diffs() - eq_( - diffs[0], - ('add_table', metadata.tables['%s.item' % self.schema]) - ) + eq_(diffs[0], ("add_table", metadata.tables["%s.item" % self.schema])) - eq_(diffs[1][0], 'remove_table') + eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") eq_(diffs[2][0], "add_column") eq_(diffs[2][1], self.schema) eq_(diffs[2][2], "address") - eq_(diffs[2][3], metadata.tables['%s.address' % self.schema].c.street) + eq_(diffs[2][3], metadata.tables["%s.address" % self.schema].c.street) eq_(diffs[3][0], "add_constraint") eq_(diffs[3][1].name, "uq_email") @@ -558,7 +594,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): eq_(diffs[4][0], "add_column") eq_(diffs[4][1], self.schema) eq_(diffs[4][2], "order") - eq_(diffs[4][3], metadata.tables['%s.order' % self.schema].c.user_id) + eq_(diffs[4][3], metadata.tables["%s.order" % self.schema].c.user_id) eq_(diffs[5][0][0], "modify_type") eq_(diffs[5][0][1], self.schema) @@ -568,10 +604,13 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)") self._assert_fk_diff( - diffs[6], "add_fk", - "order", ["user_id"], - "user", ["id"], - source_schema=config.test_schema + diffs[6], + "add_fk", + "order", + ["user_id"], + "user", + ["id"], + source_schema=config.test_schema, ) eq_(diffs[7][0][0], "modify_default") @@ -580,15 +619,15 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): eq_(diffs[7][0][3], "a1") eq_(diffs[7][0][6].arg, "x") - eq_(diffs[8][0][0], 'modify_nullable') + eq_(diffs[8][0][0], "modify_nullable") eq_(diffs[8][0][5], True) eq_(diffs[8][0][6], False) - eq_(diffs[9][0], 'remove_index') - eq_(diffs[9][1].name, 'pw_idx') + eq_(diffs[9][0], "remove_index") + eq_(diffs[9][1].name, "pw_idx") - eq_(diffs[10][0], 'remove_column') - eq_(diffs[10][3].name, 'pw') + eq_(diffs[10][0], "remove_column") + eq_(diffs[10][3].name, "pw") class CompareTypeSpecificityTest(TestBase): @@ -597,10 +636,10 @@ class CompareTypeSpecificityTest(TestBase): from sqlalchemy.engine import default return impl.DefaultImpl( - default.DefaultDialect(), None, False, True, None, {}) + default.DefaultDialect(), None, False, True, None, {} + ) def test_typedec_to_nonstandard(self): - class PasswordType(TypeDecorator): impl = VARBINARY @@ -608,7 +647,7 @@ class CompareTypeSpecificityTest(TestBase): return PasswordType(self.impl.length) def load_dialect_impl(self, dialect): - if dialect.name == 'default': + if dialect.name == "default": impl = sqlite.NUMERIC(self.length) else: impl = VARBINARY(self.length) @@ -616,8 +655,8 @@ class CompareTypeSpecificityTest(TestBase): impl = self._fixture() impl.compare_type( - Column('x', sqlite.NUMERIC(50)), - Column('x', PasswordType(50))) + Column("x", sqlite.NUMERIC(50)), Column("x", PasswordType(50)) + ) def test_string(self): t1 = String(30) @@ -626,9 +665,9 @@ class CompareTypeSpecificityTest(TestBase): t4 = Integer impl = self._fixture() - is_(impl.compare_type(Column('x', t3), Column('x', t1)), False) - is_(impl.compare_type(Column('x', t3), Column('x', t2)), True) - is_(impl.compare_type(Column('x', t3), Column('x', t4)), True) + is_(impl.compare_type(Column("x", t3), Column("x", t1)), False) + is_(impl.compare_type(Column("x", t3), Column("x", t2)), True) + is_(impl.compare_type(Column("x", t3), Column("x", t4)), True) def test_numeric(self): t1 = Numeric(10, 5) @@ -637,16 +676,16 @@ class CompareTypeSpecificityTest(TestBase): t4 = DateTime impl = self._fixture() - is_(impl.compare_type(Column('x', t3), Column('x', t1)), False) - is_(impl.compare_type(Column('x', t3), Column('x', t2)), True) - is_(impl.compare_type(Column('x', t3), Column('x', t4)), True) + is_(impl.compare_type(Column("x", t3), Column("x", t1)), False) + is_(impl.compare_type(Column("x", t3), Column("x", t2)), True) + is_(impl.compare_type(Column("x", t3), Column("x", t4)), True) def test_numeric_noprecision(self): t1 = Numeric() t2 = Numeric(scale=5) impl = self._fixture() - is_(impl.compare_type(Column('x', t1), Column('x', t2)), False) + is_(impl.compare_type(Column("x", t1), Column("x", t2)), False) def test_integer(self): t1 = Integer() @@ -657,12 +696,12 @@ class CompareTypeSpecificityTest(TestBase): t6 = BigInteger() impl = self._fixture() - is_(impl.compare_type(Column('x', t5), Column('x', t1)), False) - is_(impl.compare_type(Column('x', t3), Column('x', t1)), True) - is_(impl.compare_type(Column('x', t3), Column('x', t6)), False) - is_(impl.compare_type(Column('x', t3), Column('x', t2)), True) - is_(impl.compare_type(Column('x', t5), Column('x', t2)), True) - is_(impl.compare_type(Column('x', t1), Column('x', t4)), True) + is_(impl.compare_type(Column("x", t5), Column("x", t1)), False) + is_(impl.compare_type(Column("x", t3), Column("x", t1)), True) + is_(impl.compare_type(Column("x", t3), Column("x", t6)), False) + is_(impl.compare_type(Column("x", t3), Column("x", t2)), True) + is_(impl.compare_type(Column("x", t5), Column("x", t2)), True) + is_(impl.compare_type(Column("x", t1), Column("x", t4)), True) def test_datetime(self): t1 = DateTime() @@ -670,22 +709,19 @@ class CompareTypeSpecificityTest(TestBase): t3 = DateTime(timezone=True) impl = self._fixture() - is_(impl.compare_type(Column('x', t1), Column('x', t2)), False) - is_(impl.compare_type(Column('x', t1), Column('x', t3)), True) - is_(impl.compare_type(Column('x', t2), Column('x', t3)), True) + is_(impl.compare_type(Column("x", t1), Column("x", t2)), False) + is_(impl.compare_type(Column("x", t1), Column("x", t3)), True) + is_(impl.compare_type(Column("x", t2), Column("x", t3)), True) class AutogenSystemColTest(AutogenTest, TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" @classmethod def _get_db_schema(cls): m = MetaData() - Table( - 'sometable', m, - Column('id', Integer, primary_key=True), - ) + Table("sometable", m, Column("id", Integer, primary_key=True)) return m @classmethod @@ -695,9 +731,10 @@ class AutogenSystemColTest(AutogenTest, TestBase): # 'xmin' is implicitly present, when added to a model should produce # no change Table( - 'sometable', m, - Column('id', Integer, primary_key=True), - Column('xmin', Integer, system=True) + "sometable", + m, + Column("id", Integer, primary_key=True), + Column("xmin", Integer, system=True), ) return m @@ -715,30 +752,38 @@ class AutogenerateVariantCompareTest(AutogenTest, TestBase): # 1.0.13 and lower fail on Postgresql due to variant / bigserial issue # #3739 - __requires__ = ('sqlalchemy_1014', ) + __requires__ = ("sqlalchemy_1014",) @classmethod def _get_db_schema(cls): m = MetaData() - Table('sometable', m, - Column( - 'id', - BigInteger().with_variant(Integer, "sqlite"), - primary_key=True), - Column('value', String(50))) + Table( + "sometable", + m, + Column( + "id", + BigInteger().with_variant(Integer, "sqlite"), + primary_key=True, + ), + Column("value", String(50)), + ) return m @classmethod def _get_model_schema(cls): m = MetaData() - Table('sometable', m, - Column( - 'id', - BigInteger().with_variant(Integer, "sqlite"), - primary_key=True), - Column('value', String(50))) + Table( + "sometable", + m, + Column( + "id", + BigInteger().with_variant(Integer, "sqlite"), + primary_key=True, + ), + Column("value", String(50)), + ) return m def test_variant_no_issue(self): @@ -750,24 +795,30 @@ class AutogenerateVariantCompareTest(AutogenTest, TestBase): class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" @classmethod def _get_db_schema(cls): m = MetaData() - Table('sometable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer)) + Table( + "sometable", + m, + Column("id", Integer, primary_key=True), + Column("value", Integer), + ) return m @classmethod def _get_model_schema(cls): m = MetaData() - Table('sometable', m, - Column('id', Integer, primary_key=True), - Column('value', String)) + Table( + "sometable", + m, + Column("id", Integer, primary_key=True), + Column("value", String), + ) return m def test_uses_custom_compare_type_function(self): @@ -779,15 +830,20 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): ctx = self.autogen_context autogenerate._produce_net_changes(ctx, uo) - first_table = self.m2.tables['sometable'] - first_column = first_table.columns['id'] + first_table = self.m2.tables["sometable"] + first_column = first_table.columns["id"] eq_(len(my_compare_type.mock_calls), 2) # We'll just test the first call _, args, _ = my_compare_type.mock_calls[0] - (context, inspected_column, metadata_column, - inspected_type, metadata_type) = args + ( + context, + inspected_column, + metadata_column, + inspected_type, + metadata_type, + ) = args eq_(context, self.context) eq_(metadata_column, first_column) eq_(metadata_type, first_column.type) @@ -816,8 +872,8 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): autogenerate._produce_net_changes(ctx, uo) diffs = uo.as_diffs() - eq_(diffs[0][0][0], 'modify_type') - eq_(diffs[1][0][0], 'modify_type') + eq_(diffs[0][0][0], "modify_type") + eq_(diffs[1][0][0], "modify_type") class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase): @@ -829,10 +885,11 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase): m = MetaData() Table( - 'person_to_role', m, - Column('person_id', Integer, autoincrement=False), - Column('role_id', Integer, autoincrement=False), - PrimaryKeyConstraint('person_id', 'role_id') + "person_to_role", + m, + Column("person_id", Integer, autoincrement=False), + Column("role_id", Integer, autoincrement=False), + PrimaryKeyConstraint("person_id", "role_id"), ) return m @@ -849,34 +906,40 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase): class AutogenKeyTest(AutogenTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" @classmethod def _get_db_schema(cls): m = MetaData() - Table('someothertable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer, key="somekey"), - ) + Table( + "someothertable", + m, + Column("id", Integer, primary_key=True), + Column("value", Integer, key="somekey"), + ) return m @classmethod def _get_model_schema(cls): m = MetaData() - Table('sometable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer, key="someotherkey"), - ) - Table('someothertable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer, key="somekey"), - Column("othervalue", Integer, key="otherkey") - ) + Table( + "sometable", + m, + Column("id", Integer, primary_key=True), + Column("value", Integer, key="someotherkey"), + ) + Table( + "someothertable", + m, + Column("id", Integer, primary_key=True), + Column("value", Integer, key="somekey"), + Column("othervalue", Integer, key="otherkey"), + ) return m - symbols = ['someothertable', 'sometable'] + symbols = ["someothertable", "sometable"] def test_autogen(self): @@ -892,16 +955,19 @@ class AutogenKeyTest(AutogenTest, TestBase): class AutogenVersionTableTest(AutogenTest, TestBase): - __only_on__ = 'sqlite' - version_table_name = 'alembic_version' + __only_on__ = "sqlite" + version_table_name = "alembic_version" version_table_schema = None @classmethod def _get_db_schema(cls): m = MetaData() Table( - cls.version_table_name, m, - Column('x', Integer), schema=cls.version_table_schema) + cls.version_table_name, + m, + Column("x", Integer), + schema=cls.version_table_schema, + ) return m @classmethod @@ -919,7 +985,10 @@ class AutogenVersionTableTest(AutogenTest, TestBase): def test_version_table_in_target(self): Table( self.version_table_name, - self.m2, Column('x', Integer), schema=self.version_table_schema) + self.m2, + Column("x", Integer), + schema=self.version_table_schema, + ) ctx = self.autogen_context uo = ops.UpgradeOps(ops=[]) @@ -928,29 +997,30 @@ class AutogenVersionTableTest(AutogenTest, TestBase): class AutogenCustomVersionTableSchemaTest(AutogenVersionTableTest): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True - version_table_schema = 'test_schema' - configure_opts = {'version_table_schema': 'test_schema'} + version_table_schema = "test_schema" + configure_opts = {"version_table_schema": "test_schema"} class AutogenCustomVersionTableTest(AutogenVersionTableTest): - version_table_name = 'my_version_table' - configure_opts = {'version_table': 'my_version_table'} + version_table_name = "my_version_table" + configure_opts = {"version_table": "my_version_table"} class AutogenCustomVersionTableAndSchemaTest(AutogenVersionTableTest): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True - version_table_name = 'my_version_table' - version_table_schema = 'test_schema' + version_table_name = "my_version_table" + version_table_schema = "test_schema" configure_opts = { - 'version_table': 'my_version_table', - 'version_table_schema': 'test_schema'} + "version_table": "my_version_table", + "version_table_schema": "test_schema", + } class AutogenerateDiffOrderTest(AutogenTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" @classmethod def _get_db_schema(cls): @@ -959,13 +1029,11 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase): @classmethod def _get_model_schema(cls): m = MetaData() - Table('parent', m, - Column('id', Integer, primary_key=True) - ) + Table("parent", m, Column("id", Integer, primary_key=True)) - Table('child', m, - Column('parent_id', Integer, ForeignKey('parent.id')), - ) + Table( + "child", m, Column("parent_id", Integer, ForeignKey("parent.id")) + ) return m @@ -980,32 +1048,29 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase): autogenerate._produce_net_changes(ctx, uo) diffs = uo.as_diffs() - eq_(diffs[0][0], 'add_table') + eq_(diffs[0][0], "add_table") eq_(diffs[0][1].name, "parent") - eq_(diffs[1][0], 'add_table') + eq_(diffs[1][0], "add_table") eq_(diffs[1][1].name, "child") class CompareMetadataTest(ModelOne, AutogenTest, TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def test_compare_metadata(self): metadata = self.m2 diffs = autogenerate.compare_metadata(self.context, metadata) - eq_( - diffs[0], - ('add_table', metadata.tables['item']) - ) + eq_(diffs[0], ("add_table", metadata.tables["item"])) - eq_(diffs[1][0], 'remove_table') + eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") eq_(diffs[2][0], "add_column") eq_(diffs[2][1], None) eq_(diffs[2][2], "address") - eq_(diffs[2][3], metadata.tables['address'].c.street) + eq_(diffs[2][3], metadata.tables["address"].c.street) eq_(diffs[3][0], "add_constraint") eq_(diffs[3][1].name, "uq_email") @@ -1013,7 +1078,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): eq_(diffs[4][0], "add_column") eq_(diffs[4][1], None) eq_(diffs[4][2], "order") - eq_(diffs[4][3], metadata.tables['order'].c.user_id) + eq_(diffs[4][3], metadata.tables["order"].c.user_id) eq_(diffs[5][0][0], "modify_type") eq_(diffs[5][0][1], None) @@ -1023,9 +1088,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)") self._assert_fk_diff( - diffs[6], "add_fk", - "order", ["user_id"], - "user", ["id"] + diffs[6], "add_fk", "order", ["user_id"], "user", ["id"] ) eq_(diffs[7][0][0], "modify_default") @@ -1034,15 +1097,15 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): eq_(diffs[7][0][3], "a1") eq_(diffs[7][0][6].arg, "x") - eq_(diffs[8][0][0], 'modify_nullable') + eq_(diffs[8][0][0], "modify_nullable") eq_(diffs[8][0][5], True) eq_(diffs[8][0][6], False) - eq_(diffs[9][0], 'remove_index') - eq_(diffs[9][1].name, 'pw_idx') + eq_(diffs[9][0], "remove_index") + eq_(diffs[9][1].name, "pw_idx") - eq_(diffs[10][0], 'remove_column') - eq_(diffs[10][3].name, 'pw') + eq_(diffs[10][0], "remove_column") + eq_(diffs[10][3].name, "pw") def test_compare_metadata_include_object(self): metadata = self.m2 @@ -1058,46 +1121,46 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'include_object': include_object, - } + "compare_type": True, + "compare_server_default": True, + "include_object": include_object, + }, ) diffs = autogenerate.compare_metadata(context, metadata) - eq_(diffs[0][0], 'remove_table') + eq_(diffs[0][0], "remove_table") eq_(diffs[0][1].name, "extra") eq_(diffs[1][0], "add_column") eq_(diffs[1][1], None) eq_(diffs[1][2], "order") - eq_(diffs[1][3], metadata.tables['order'].c.user_id) + eq_(diffs[1][3], metadata.tables["order"].c.user_id) def test_compare_metadata_include_symbol(self): metadata = self.m2 def include_symbol(table_name, schema_name): - return table_name in ('extra', 'order') + return table_name in ("extra", "order") context = MigrationContext.configure( connection=self.bind.connect(), opts={ - 'compare_type': True, - 'compare_server_default': True, - 'include_symbol': include_symbol, - } + "compare_type": True, + "compare_server_default": True, + "include_symbol": include_symbol, + }, ) diffs = autogenerate.compare_metadata(context, metadata) - eq_(diffs[0][0], 'remove_table') + eq_(diffs[0][0], "remove_table") eq_(diffs[0][1].name, "extra") eq_(diffs[1][0], "add_column") eq_(diffs[1][1], None) eq_(diffs[1][2], "order") - eq_(diffs[1][3], metadata.tables['order'].c.user_id) + eq_(diffs[1][3], metadata.tables["order"].c.user_id) eq_(diffs[2][0][0], "modify_type") eq_(diffs[2][0][1], None) @@ -1106,15 +1169,14 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): eq_(repr(diffs[2][0][5]), "NUMERIC(precision=8, scale=2)") eq_(repr(diffs[2][0][6]), "Numeric(precision=10, scale=2)") - eq_(diffs[2][1][0], 'modify_nullable') - eq_(diffs[2][1][2], 'order') + eq_(diffs[2][1][0], "modify_nullable") + eq_(diffs[2][1][2], "order") eq_(diffs[2][1][5], False) eq_(diffs[2][1][6], True) def test_compare_metadata_as_sql(self): context = MigrationContext.configure( - connection=self.bind.connect(), - opts={'as_sql': True} + connection=self.bind.connect(), opts={"as_sql": True} ) metadata = self.m2 @@ -1122,12 +1184,14 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): CommandError, "autogenerate can't use as_sql=True as it prevents " "querying the database for schema information", - autogenerate.compare_metadata, context, metadata + autogenerate.compare_metadata, + context, + metadata, ) class PGCompareMetaData(ModelOne, AutogenTest, TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True schema = "test_schema" @@ -1135,26 +1199,20 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase): metadata = self.m2 context = MigrationContext.configure( - connection=self.bind.connect(), - opts={ - "include_schemas": True - } + connection=self.bind.connect(), opts={"include_schemas": True} ) diffs = autogenerate.compare_metadata(context, metadata) - eq_( - diffs[0], - ('add_table', metadata.tables['test_schema.item']) - ) + eq_(diffs[0], ("add_table", metadata.tables["test_schema.item"])) - eq_(diffs[1][0], 'remove_table') + eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") eq_(diffs[2][0], "add_column") eq_(diffs[2][1], "test_schema") eq_(diffs[2][2], "address") - eq_(diffs[2][3], metadata.tables['test_schema.address'].c.street) + eq_(diffs[2][3], metadata.tables["test_schema.address"].c.street) eq_(diffs[3][0], "add_constraint") eq_(diffs[3][1].name, "uq_email") @@ -1162,27 +1220,25 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase): eq_(diffs[4][0], "add_column") eq_(diffs[4][1], "test_schema") eq_(diffs[4][2], "order") - eq_(diffs[4][3], metadata.tables['test_schema.order'].c.user_id) + eq_(diffs[4][3], metadata.tables["test_schema.order"].c.user_id) - eq_(diffs[5][0][0], 'modify_nullable') + eq_(diffs[5][0][0], "modify_nullable") eq_(diffs[5][0][5], False) eq_(diffs[5][0][6], True) + class OrigObjectTest(TestBase): def setUp(self): self.metadata = m = MetaData() t = Table( - 't', m, - Column('id', Integer(), primary_key=True), - Column('x', Integer()) - ) - self.ix = Index('ix1', t.c.id) - fk = ForeignKeyConstraint(['t_id'], ['t.id']) - q = Table( - 'q', m, - Column('t_id', Integer()), - fk + "t", + m, + Column("id", Integer(), primary_key=True), + Column("x", Integer()), ) + self.ix = Index("ix1", t.c.id) + fk = ForeignKeyConstraint(["t_id"], ["t.id"]) + q = Table("q", m, Column("t_id", Integer()), fk) self.table = t self.fk = fk self.ck = CheckConstraint(t.c.x > 5) @@ -1232,10 +1288,10 @@ class OrigObjectTest(TestBase): is_not_(None, op.to_constraint().table) def test_add_pk_no_orig(self): - op = ops.CreatePrimaryKeyOp('pk1', 't', ['x', 'y']) + op = ops.CreatePrimaryKeyOp("pk1", "t", ["x", "y"]) pk = op.to_constraint() - eq_(pk.name, 'pk1') - eq_(pk.table.name, 't') + eq_(pk.name, "pk1") + eq_(pk.table.name, "t") def test_add_pk(self): pk = self.pk @@ -1254,7 +1310,7 @@ class OrigObjectTest(TestBase): def test_drop_column(self): t = self.table - op = ops.DropColumnOp.from_column_and_tablename(None, 't', t.c.x) + op = ops.DropColumnOp.from_column_and_tablename(None, "t", t.c.x) is_(op.to_column(), t.c.x) is_(op.reverse().to_column(), t.c.x) is_not_(None, op.to_column().table) @@ -1262,7 +1318,7 @@ class OrigObjectTest(TestBase): def test_add_column(self): t = self.table - op = ops.AddColumnOp.from_column_and_tablename(None, 't', t.c.x) + op = ops.AddColumnOp.from_column_and_tablename(None, "t", t.c.x) is_(op.to_column(), t.c.x) is_(op.reverse().to_column(), t.c.x) is_not_(None, op.to_column().table) @@ -1304,25 +1360,33 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase): m2b = MetaData() m2c = MetaData() - Table('a', m1a, Column('id', Integer, primary_key=True)) - Table('b1', m1b, Column('id', Integer, primary_key=True)) - Table('b2', m1b, Column('id', Integer, primary_key=True)) - Table('c1', m1c, Column('id', Integer, primary_key=True), - Column('x', Integer)) + Table("a", m1a, Column("id", Integer, primary_key=True)) + Table("b1", m1b, Column("id", Integer, primary_key=True)) + Table("b2", m1b, Column("id", Integer, primary_key=True)) + Table( + "c1", + m1c, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) - a = Table('a', m2a, Column('id', Integer, primary_key=True), - Column('q', Integer)) - Table('b1', m2b, Column('id', Integer, primary_key=True)) - Table('c1', m2c, Column('id', Integer, primary_key=True)) - c2 = Table('c2', m2c, Column('id', Integer, primary_key=True)) + a = Table( + "a", + m2a, + Column("id", Integer, primary_key=True), + Column("q", Integer), + ) + Table("b1", m2b, Column("id", Integer, primary_key=True)) + Table("c1", m2c, Column("id", Integer, primary_key=True)) + c2 = Table("c2", m2c, Column("id", Integer, primary_key=True)) diffs = self._fixture([m1a, m1b, m1c], [m2a, m2b, m2c]) - eq_(diffs[0], ('add_table', c2)) - eq_(diffs[1][0], 'remove_table') - eq_(diffs[1][1].name, 'b2') - eq_(diffs[2], ('add_column', None, 'a', a.c.q)) - eq_(diffs[3][0:3], ('remove_column', None, 'c1')) - eq_(diffs[3][3].name, 'x') + eq_(diffs[0], ("add_table", c2)) + eq_(diffs[1][0], "remove_table") + eq_(diffs[1][1].name, "b2") + eq_(diffs[2], ("add_column", None, "a", a.c.q)) + eq_(diffs[3][0:3], ("remove_column", None, "c1")) + eq_(diffs[3][3].name, "x") def test_empty_list(self): # because they're going to do it.... @@ -1339,18 +1403,19 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase): m2a = MetaData() m2b = MetaData() - Table('a', m1a, Column('id', Integer, primary_key=True)) - Table('b', m1b, Column('id', Integer, primary_key=True)) + Table("a", m1a, Column("id", Integer, primary_key=True)) + Table("b", m1b, Column("id", Integer, primary_key=True)) - Table('a', m2a, Column('id', Integer, primary_key=True)) - b = Table('b', m2b, Column('id', Integer, primary_key=True), - Column('q', Integer)) + Table("a", m2a, Column("id", Integer, primary_key=True)) + b = Table( + "b", + m2b, + Column("id", Integer, primary_key=True), + Column("q", Integer), + ) diffs = self._fixture((m1a, m1b), (m2a, m2b)) - eq_( - diffs, - [('add_column', None, 'b', b.c.q)] - ) + eq_(diffs, [("add_column", None, "b", b.c.q)]) def test_raise_on_dupe(self): m1a = MetaData() @@ -1359,116 +1424,123 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase): m2a = MetaData() m2b = MetaData() - Table('a', m1a, Column('id', Integer, primary_key=True)) - Table('b1', m1b, Column('id', Integer, primary_key=True)) - Table('b2', m1b, Column('id', Integer, primary_key=True)) - Table('b3', m1b, Column('id', Integer, primary_key=True)) + Table("a", m1a, Column("id", Integer, primary_key=True)) + Table("b1", m1b, Column("id", Integer, primary_key=True)) + Table("b2", m1b, Column("id", Integer, primary_key=True)) + Table("b3", m1b, Column("id", Integer, primary_key=True)) - Table('a', m2a, Column('id', Integer, primary_key=True)) - Table('a', m2b, Column('id', Integer, primary_key=True)) - Table('b1', m2b, Column('id', Integer, primary_key=True)) - Table('b2', m2a, Column('id', Integer, primary_key=True)) - Table('b2', m2b, Column('id', Integer, primary_key=True)) + Table("a", m2a, Column("id", Integer, primary_key=True)) + Table("a", m2b, Column("id", Integer, primary_key=True)) + Table("b1", m2b, Column("id", Integer, primary_key=True)) + Table("b2", m2a, Column("id", Integer, primary_key=True)) + Table("b2", m2b, Column("id", Integer, primary_key=True)) assert_raises_message( ValueError, 'Duplicate table keys across multiple MetaData objects: "a", "b2"', self._fixture, - [m1a, m1b], [m2a, m2b] + [m1a, m1b], + [m2a, m2b], ) class AutoincrementTest(AutogenFixtureTest, TestBase): __backend__ = True - __requires__ = 'integer_subtype_comparisons', + __requires__ = ("integer_subtype_comparisons",) def test_alter_column_autoincrement_none(self): m1 = MetaData() m2 = MetaData() - Table('a', m1, Column('x', Integer, nullable=False)) - Table('a', m2, Column('x', Integer, nullable=True)) + Table("a", m1, Column("x", Integer, nullable=False)) + Table("a", m2, Column("x", Integer, nullable=True)) ops = self._fixture(m1, m2, return_ops=True) - assert 'autoincrement' not in ops.ops[0].ops[0].kw + assert "autoincrement" not in ops.ops[0].ops[0].kw def test_alter_column_autoincrement_pk_false(self): m1 = MetaData() m2 = MetaData() Table( - 'a', m1, - Column('x', Integer, primary_key=True, autoincrement=False)) + "a", + m1, + Column("x", Integer, primary_key=True, autoincrement=False), + ) Table( - 'a', m2, - Column('x', BigInteger, primary_key=True, autoincrement=False)) + "a", + m2, + Column("x", BigInteger, primary_key=True, autoincrement=False), + ) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], False) + is_(ops.ops[0].ops[0].kw["autoincrement"], False) def test_alter_column_autoincrement_pk_implicit_true(self): m1 = MetaData() m2 = MetaData() - Table( - 'a', m1, - Column('x', Integer, primary_key=True)) - Table( - 'a', m2, - Column('x', BigInteger, primary_key=True)) + Table("a", m1, Column("x", Integer, primary_key=True)) + Table("a", m2, Column("x", BigInteger, primary_key=True)) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) def test_alter_column_autoincrement_pk_explicit_true(self): m1 = MetaData() m2 = MetaData() Table( - 'a', m1, - Column('x', Integer, primary_key=True, autoincrement=True)) + "a", m1, Column("x", Integer, primary_key=True, autoincrement=True) + ) Table( - 'a', m2, - Column('x', BigInteger, primary_key=True, autoincrement=True)) + "a", + m2, + Column("x", BigInteger, primary_key=True, autoincrement=True), + ) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) def test_alter_column_autoincrement_nonpk_false(self): m1 = MetaData() m2 = MetaData() Table( - 'a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, autoincrement=False) + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, autoincrement=False), ) Table( - 'a', m2, - Column('id', Integer, primary_key=True), - Column('x', BigInteger, autoincrement=False) + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, autoincrement=False), ) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], False) + is_(ops.ops[0].ops[0].kw["autoincrement"], False) def test_alter_column_autoincrement_nonpk_implicit_false(self): m1 = MetaData() m2 = MetaData() Table( - 'a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer) + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer), ) Table( - 'a', m2, - Column('id', Integer, primary_key=True), - Column('x', BigInteger) + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger), ) ops = self._fixture(m1, m2, return_ops=True) - assert 'autoincrement' not in ops.ops[0].ops[0].kw + assert "autoincrement" not in ops.ops[0].ops[0].kw @config.requirements.fail_before_sqla_110 def test_alter_column_autoincrement_nonpk_explicit_true(self): @@ -1476,54 +1548,60 @@ class AutoincrementTest(AutogenFixtureTest, TestBase): m2 = MetaData() Table( - 'a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, autoincrement=True) + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, autoincrement=True), ) Table( - 'a', m2, - Column('id', Integer, primary_key=True), - Column('x', BigInteger, autoincrement=True) + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, autoincrement=True), ) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) def test_alter_column_autoincrement_compositepk_false(self): m1 = MetaData() m2 = MetaData() Table( - 'a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, primary_key=True, autoincrement=False) + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, primary_key=True, autoincrement=False), ) Table( - 'a', m2, - Column('id', Integer, primary_key=True), - Column('x', BigInteger, primary_key=True, autoincrement=False) + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, primary_key=True, autoincrement=False), ) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], False) + is_(ops.ops[0].ops[0].kw["autoincrement"], False) def test_alter_column_autoincrement_compositepk_implicit_false(self): m1 = MetaData() m2 = MetaData() Table( - 'a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, primary_key=True) + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, primary_key=True), ) Table( - 'a', m2, - Column('id', Integer, primary_key=True), - Column('x', BigInteger, primary_key=True) + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, primary_key=True), ) ops = self._fixture(m1, m2, return_ops=True) - assert 'autoincrement' not in ops.ops[0].ops[0].kw + assert "autoincrement" not in ops.ops[0].ops[0].kw @config.requirements.autoincrement_on_composite_pk def test_alter_column_autoincrement_compositepk_explicit_true(self): @@ -1531,20 +1609,22 @@ class AutoincrementTest(AutogenFixtureTest, TestBase): m2 = MetaData() Table( - 'a', m1, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('x', Integer, primary_key=True, autoincrement=True), + "a", + m1, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", Integer, primary_key=True, autoincrement=True), # on SQLA 1.0 and earlier, this being present # trips the "add KEY for the primary key" so that the # AUTO_INCREMENT keyword is accepted by MySQL. SQLA 1.1 and # greater the columns are just reorganized. - mysql_engine='InnoDB' + mysql_engine="InnoDB", ) Table( - 'a', m2, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('x', BigInteger, primary_key=True, autoincrement=True) + "a", + m2, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", BigInteger, primary_key=True, autoincrement=True), ) ops = self._fixture(m1, m2, return_ops=True) - is_(ops.ops[0].ops[0].kw['autoincrement'], True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) diff --git a/tests/test_autogen_fks.py b/tests/test_autogen_fks.py index 3dd66ae..66c5ac4 100644 --- a/tests/test_autogen_fks.py +++ b/tests/test_autogen_fks.py @@ -1,8 +1,14 @@ import sys from alembic.testing import TestBase, config, mock -from sqlalchemy import MetaData, Column, Table, Integer, String, \ - ForeignKeyConstraint +from sqlalchemy import ( + MetaData, + Column, + Table, + Integer, + String, + ForeignKeyConstraint, +) from alembic.testing import eq_ py3k = sys.version_info.major >= 3 @@ -17,105 +23,141 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('test', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('test2', String(10)), - ForeignKeyConstraint(['test2'], ['some_table.test']), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('test', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('test2', String(10)), - mysql_engine='InnoDB' - ) + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ForeignKeyConstraint(["test2"], ["some_table.test"]), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ['test2'], - 'some_table', ['test'], - conditional_name="servergenerated" + diffs[0], + "remove_fk", + "user", + ["test2"], + "some_table", + ["test"], + conditional_name="servergenerated", ) def test_add_fk(self): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id', Integer, primary_key=True), - Column('test', String(10)), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('test2', String(10)), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id', Integer, primary_key=True), - Column('test', String(10)), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('test2', String(10)), - ForeignKeyConstraint(['test2'], ['some_table.test']), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ForeignKeyConstraint(["test2"], ["some_table.test"]), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) self._assert_fk_diff( - diffs[0], "add_fk", - "user", ["test2"], - "some_table", ["test"] + diffs[0], "add_fk", "user", ["test2"], "some_table", ["test"] ) def test_no_change(self): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id', Integer, primary_key=True), - Column('test', String(10)), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('test2', Integer), - ForeignKeyConstraint(['test2'], ['some_table.id']), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id', Integer, primary_key=True), - Column('test', String(10)), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('test2', Integer), - ForeignKeyConstraint(['test2'], ['some_table.id']), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", Integer), + ForeignKeyConstraint(["test2"], ["some_table.id"]), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", Integer), + ForeignKeyConstraint(["test2"], ["some_table.id"]), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) @@ -125,36 +167,51 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - ForeignKeyConstraint(['other_id_1', 'other_id_2'], - ['some_table.id_1', 'some_table.id_2']), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB' - ) - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - ForeignKeyConstraint(['other_id_1', 'other_id_2'], - ['some_table.id_1', 'some_table.id_2']), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + ), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + ), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) @@ -164,42 +221,59 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - ForeignKeyConstraint(['other_id_1', 'other_id_2'], - ['some_table.id_1', 'some_table.id_2'], - name='fk_test_name'), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + name="fk_test_name", + ), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) self._assert_fk_diff( - diffs[0], "add_fk", - "user", ['other_id_1', 'other_id_2'], - 'some_table', ['id_1', 'id_2'], - name="fk_test_name" + diffs[0], + "add_fk", + "user", + ["other_id_1", "other_id_2"], + "some_table", + ["id_1", "id_2"], + name="fk_test_name", ) @config.requirements.no_name_normalize @@ -207,111 +281,160 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - ForeignKeyConstraint(['other_id_1', 'other_id_2'], - ['some_table.id_1', 'some_table.id_2'], - name='fk_test_name'), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + name="fk_test_name", + ), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ['other_id_1', 'other_id_2'], - "some_table", ['id_1', 'id_2'], - conditional_name="fk_test_name" + diffs[0], + "remove_fk", + "user", + ["other_id_1", "other_id_2"], + "some_table", + ["id_1", "id_2"], + conditional_name="fk_test_name", ) def test_add_fk_colkeys(self): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id_1', String(10), key='tid1', primary_key=True), - Column('id_2', String(10), key='tid2', primary_key=True), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('other_id_1', String(10), key='oid1'), - Column('other_id_2', String(10), key='oid2'), - ForeignKeyConstraint(['oid1', 'oid2'], - ['some_table.tid1', 'some_table.tid2'], - name='fk_test_name'), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), key="tid1", primary_key=True), + Column("id_2", String(10), key="tid2", primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10), key="oid1"), + Column("other_id_2", String(10), key="oid2"), + ForeignKeyConstraint( + ["oid1", "oid2"], + ["some_table.tid1", "some_table.tid2"], + name="fk_test_name", + ), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) self._assert_fk_diff( - diffs[0], "add_fk", - "user", ['other_id_1', 'other_id_2'], - 'some_table', ['id_1', 'id_2'], - name="fk_test_name" + diffs[0], + "add_fk", + "user", + ["other_id_1", "other_id_2"], + "some_table", + ["id_1", "id_2"], + name="fk_test_name", ) def test_no_change_colkeys(self): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id_1', String(10), primary_key=True), - Column('id_2', String(10), primary_key=True), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('other_id_1', String(10)), - Column('other_id_2', String(10)), - ForeignKeyConstraint(['other_id_1', 'other_id_2'], - ['some_table.id_1', 'some_table.id_2']), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id_1', String(10), key='tid1', primary_key=True), - Column('id_2', String(10), key='tid2', primary_key=True), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('other_id_1', String(10), key='oid1'), - Column('other_id_2', String(10), key='oid2'), - ForeignKeyConstraint(['oid1', 'oid2'], - ['some_table.tid1', 'some_table.tid2']), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + ), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), key="tid1", primary_key=True), + Column("id_2", String(10), key="tid2", primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10), key="oid1"), + Column("other_id_2", String(10), key="oid2"), + ForeignKeyConstraint( + ["oid1", "oid2"], ["some_table.tid1", "some_table.tid2"] + ), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) @@ -320,7 +443,7 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase): class IncludeHooksTest(AutogenFixtureTest, TestBase): __backend__ = True - __requires__ = 'fk_names', + __requires__ = ("fk_names",) @config.requirements.no_name_normalize def test_remove_connection_fk(self): @@ -328,11 +451,18 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m2 = MetaData() ref = Table( - 'ref', m1, Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') + "ref", + m1, + Column("id", Integer, primary_key=True), + mysql_engine="InnoDB", + ) t1 = Table( - 't', m1, Column('x', Integer), Column('y', Integer), - mysql_engine='InnoDB') + "t", + m1, + Column("x", Integer), + Column("y", Integer), + mysql_engine="InnoDB", + ) t1.append_constraint( ForeignKeyConstraint([t1.c.x], [ref.c.id], name="fk1") ) @@ -341,24 +471,37 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): ) ref = Table( - 'ref', m2, Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') + "ref", + m2, + Column("id", Integer, primary_key=True), + mysql_engine="InnoDB", + ) Table( - 't', m2, Column('x', Integer), Column('y', Integer), - mysql_engine='InnoDB') + "t", + m2, + Column("x", Integer), + Column("y", Integer), + mysql_engine="InnoDB", + ) def include_object(object_, name, type_, reflected, compare_to): return not ( - isinstance(object_, ForeignKeyConstraint) and - type_ == 'foreign_key_constraint' - and reflected and name == 'fk1') + isinstance(object_, ForeignKeyConstraint) + and type_ == "foreign_key_constraint" + and reflected + and name == "fk1" + ) diffs = self._fixture(m1, m2, object_filters=include_object) self._assert_fk_diff( - diffs[0], "remove_fk", - 't', ['y'], 'ref', ['id'], - conditional_name='fk2' + diffs[0], + "remove_fk", + "t", + ["y"], + "ref", + ["id"], + conditional_name="fk2", ) eq_(len(diffs), 1) @@ -367,18 +510,32 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m2 = MetaData() Table( - 'ref', m1, - Column('id', Integer, primary_key=True), mysql_engine='InnoDB') + "ref", + m1, + Column("id", Integer, primary_key=True), + mysql_engine="InnoDB", + ) Table( - 't', m1, - Column('x', Integer), Column('y', Integer), mysql_engine='InnoDB') + "t", + m1, + Column("x", Integer), + Column("y", Integer), + mysql_engine="InnoDB", + ) ref = Table( - 'ref', m2, Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') + "ref", + m2, + Column("id", Integer, primary_key=True), + mysql_engine="InnoDB", + ) t2 = Table( - 't', m2, Column('x', Integer), Column('y', Integer), - mysql_engine='InnoDB') + "t", + m2, + Column("x", Integer), + Column("y", Integer), + mysql_engine="InnoDB", + ) t2.append_constraint( ForeignKeyConstraint([t2.c.x], [ref.c.id], name="fk1") ) @@ -388,16 +545,16 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): def include_object(object_, name, type_, reflected, compare_to): return not ( - isinstance(object_, ForeignKeyConstraint) and - type_ == 'foreign_key_constraint' - and not reflected and name == 'fk1') + isinstance(object_, ForeignKeyConstraint) + and type_ == "foreign_key_constraint" + and not reflected + and name == "fk1" + ) diffs = self._fixture(m1, m2, object_filters=include_object) self._assert_fk_diff( - diffs[0], "add_fk", - 't', ['y'], 'ref', ['id'], - name='fk2' + diffs[0], "add_fk", "t", ["y"], "ref", ["id"], name="fk2" ) eq_(len(diffs), 1) @@ -407,20 +564,26 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m2 = MetaData() r1a = Table( - 'ref_a', m1, - Column('a', Integer, primary_key=True), - mysql_engine='InnoDB' + "ref_a", + m1, + Column("a", Integer, primary_key=True), + mysql_engine="InnoDB", ) Table( - 'ref_b', m1, - Column('a', Integer, primary_key=True), - Column('b', Integer, primary_key=True), - mysql_engine='InnoDB' + "ref_b", + m1, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True), + mysql_engine="InnoDB", ) t1 = Table( - 't', m1, Column('x', Integer), - Column('y', Integer), Column('z', Integer), - mysql_engine='InnoDB') + "t", + m1, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + mysql_engine="InnoDB", + ) t1.append_constraint( ForeignKeyConstraint([t1.c.x], [r1a.c.a], name="fk1") ) @@ -429,82 +592,104 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): ) Table( - 'ref_a', m2, - Column('a', Integer, primary_key=True), - mysql_engine='InnoDB' + "ref_a", + m2, + Column("a", Integer, primary_key=True), + mysql_engine="InnoDB", ) r2b = Table( - 'ref_b', m2, - Column('a', Integer, primary_key=True), - Column('b', Integer, primary_key=True), - mysql_engine='InnoDB' + "ref_b", + m2, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True), + mysql_engine="InnoDB", ) t2 = Table( - 't', m2, Column('x', Integer), - Column('y', Integer), Column('z', Integer), - mysql_engine='InnoDB') + "t", + m2, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + mysql_engine="InnoDB", + ) t2.append_constraint( ForeignKeyConstraint( - [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1") + [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1" + ) ) t2.append_constraint( ForeignKeyConstraint( - [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2") + [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2" + ) ) def include_object(object_, name, type_, reflected, compare_to): return not ( - isinstance(object_, ForeignKeyConstraint) and - type_ == 'foreign_key_constraint' - and name == 'fk1' + isinstance(object_, ForeignKeyConstraint) + and type_ == "foreign_key_constraint" + and name == "fk1" ) diffs = self._fixture(m1, m2, object_filters=include_object) self._assert_fk_diff( - diffs[0], "remove_fk", - 't', ['y'], 'ref_a', ['a'], - name='fk2' + diffs[0], "remove_fk", "t", ["y"], "ref_a", ["a"], name="fk2" ) self._assert_fk_diff( - diffs[1], "add_fk", - 't', ['y', 'z'], 'ref_b', ['a', 'b'], - name='fk2' + diffs[1], + "add_fk", + "t", + ["y", "z"], + "ref_b", + ["a", "b"], + name="fk2", ) eq_(len(diffs), 2) class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): __backend__ = True - __requires__ = ('flexible_fk_cascades', ) + __requires__ = ("flexible_fk_cascades",) def _fk_opts_fixture(self, old_opts, new_opts): m1 = MetaData() m2 = MetaData() - Table('some_table', m1, - Column('id', Integer, primary_key=True), - Column('test', String(10)), - mysql_engine='InnoDB') - - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('tid', Integer), - ForeignKeyConstraint(['tid'], ['some_table.id'], **old_opts), - mysql_engine='InnoDB') - - Table('some_table', m2, - Column('id', Integer, primary_key=True), - Column('test', String(10)), - mysql_engine='InnoDB') - - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('tid', Integer), - ForeignKeyConstraint(['tid'], ['some_table.id'], **new_opts), - mysql_engine='InnoDB') + Table( + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + ForeignKeyConstraint(["tid"], ["some_table.id"], **old_opts), + mysql_engine="InnoDB", + ) + + Table( + "some_table", + m2, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + mysql_engine="InnoDB", + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + ForeignKeyConstraint(["tid"], ["some_table.id"], **new_opts), + mysql_engine="InnoDB", + ) return self._fixture(m1, m2) @@ -521,47 +706,55 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): return True def test_add_ondelete(self): - diffs = self._fk_opts_fixture( - {}, {"ondelete": "cascade"} - ) + diffs = self._fk_opts_fixture({}, {"ondelete": "cascade"}) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], ondelete=None, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - ondelete="cascade" + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + ondelete="cascade", ) else: eq_(diffs, []) def test_remove_ondelete(self): - diffs = self._fk_opts_fixture( - {"ondelete": "CASCADE"}, {} - ) + diffs = self._fk_opts_fixture({"ondelete": "CASCADE"}, {}) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], ondelete="CASCADE", - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - ondelete=None + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + ondelete=None, ) else: eq_(diffs, []) @@ -574,47 +767,55 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): eq_(diffs, []) def test_add_onupdate(self): - diffs = self._fk_opts_fixture( - {}, {"onupdate": "cascade"} - ) + diffs = self._fk_opts_fixture({}, {"onupdate": "cascade"}) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate=None, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - onupdate="cascade" + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate="cascade", ) else: eq_(diffs, []) def test_remove_onupdate(self): - diffs = self._fk_opts_fixture( - {"onupdate": "CASCADE"}, {} - ) + diffs = self._fk_opts_fixture({"onupdate": "CASCADE"}, {}) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate="CASCADE", - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - onupdate=None + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate=None, ) else: eq_(diffs, []) @@ -668,20 +869,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): ) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate=None, ondelete=mock.ANY, # MySQL reports None, PG reports RESTRICT - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate=None, - ondelete="cascade" + ondelete="cascade", ) else: eq_(diffs, []) @@ -696,20 +903,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): ) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate=mock.ANY, # MySQL reports None, PG reports RESTRICT ondelete=None, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate="cascade", - ondelete=None + ondelete=None, ) else: eq_(diffs, []) @@ -717,70 +930,84 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): def test_ondelete_onupdate_combo(self): diffs = self._fk_opts_fixture( {"onupdate": "CASCADE", "ondelete": "SET NULL"}, - {"onupdate": "RESTRICT", "ondelete": "RESTRICT"} + {"onupdate": "RESTRICT", "ondelete": "RESTRICT"}, ) if self._expect_opts_supported(): self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate="CASCADE", ondelete="SET NULL", - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], onupdate="RESTRICT", - ondelete="RESTRICT" + ondelete="RESTRICT", ) else: eq_(diffs, []) @config.requirements.fk_initially def test_add_initially_deferred(self): - diffs = self._fk_opts_fixture( - {}, {"initially": "deferred"} - ) + diffs = self._fk_opts_fixture({}, {"initially": "deferred"}) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], initially=None, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - initially="deferred" + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially="deferred", ) @config.requirements.fk_initially def test_remove_initially_deferred(self): - diffs = self._fk_opts_fixture( - {"initially": "deferred"}, {} - ) + diffs = self._fk_opts_fixture({"initially": "deferred"}, {}) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], initially="DEFERRED", deferrable=True, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - initially=None + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially=None, ) @config.requirements.fk_deferrable @@ -791,19 +1018,25 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): ) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], initially=None, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], initially="immediate", - deferrable=True + deferrable=True, ) @config.requirements.fk_deferrable @@ -814,20 +1047,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): ) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], initially=None, # immediate is the default deferrable=True, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], initially=None, - deferrable=None + deferrable=None, ) @config.requirements.fk_initially @@ -835,7 +1074,7 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): def test_add_initially_deferrable_nochange_one(self): diffs = self._fk_opts_fixture( {"deferrable": True, "initially": "immediate"}, - {"deferrable": True, "initially": "immediate"} + {"deferrable": True, "initially": "immediate"}, ) eq_(diffs, []) @@ -845,7 +1084,7 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): def test_add_initially_deferrable_nochange_two(self): diffs = self._fk_opts_fixture( {"deferrable": True, "initially": "deferred"}, - {"deferrable": True, "initially": "deferred"} + {"deferrable": True, "initially": "deferred"}, ) eq_(diffs, []) @@ -855,49 +1094,57 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): def test_add_initially_deferrable_nochange_three(self): diffs = self._fk_opts_fixture( {"deferrable": None, "initially": "deferred"}, - {"deferrable": None, "initially": "deferred"} + {"deferrable": None, "initially": "deferred"}, ) eq_(diffs, []) @config.requirements.fk_deferrable def test_add_deferrable(self): - diffs = self._fk_opts_fixture( - {}, {"deferrable": True} - ) + diffs = self._fk_opts_fixture({}, {"deferrable": True}) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], deferrable=None, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - deferrable=True + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + deferrable=True, ) @config.requirements.fk_deferrable def test_remove_deferrable(self): - diffs = self._fk_opts_fixture( - {"deferrable": True}, {} - ) + diffs = self._fk_opts_fixture({"deferrable": True}, {}) self._assert_fk_diff( - diffs[0], "remove_fk", - "user", ["tid"], - "some_table", ["id"], + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], deferrable=True, - conditional_name="servergenerated" + conditional_name="servergenerated", ) self._assert_fk_diff( - diffs[1], "add_fk", - "user", ["tid"], - "some_table", ["id"], - deferrable=None + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + deferrable=None, ) diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index b588cbe..f03155f 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -3,14 +3,24 @@ from alembic.testing import TestBase from alembic.testing import config from alembic.testing import assertions -from sqlalchemy import MetaData, Column, Table, Integer, String, \ - Numeric, UniqueConstraint, Index, ForeignKeyConstraint,\ - ForeignKey, func +from sqlalchemy import ( + MetaData, + Column, + Table, + Integer, + String, + Numeric, + UniqueConstraint, + Index, + ForeignKeyConstraint, + ForeignKey, + func, +) from alembic.testing import engines from alembic.testing import eq_ from alembic.testing.env import staging_env -py3k = sys.version_info >= (3, ) +py3k = sys.version_info >= (3,) from ._autogen_fixtures import AutogenFixtureTest @@ -24,6 +34,7 @@ class NoUqReflection(object): def unimpl(*arg, **kw): raise NotImplementedError() + eng.dialect.get_unique_constraints = unimpl def test_add_ix_on_table_create(self): @@ -37,25 +48,29 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): reports_unique_constraints = True reports_unique_constraints_as_indexes = False - __requires__ = ('unique_constraint_reflection', ) - __only_on__ = 'sqlite' + __requires__ = ("unique_constraint_reflection",) + __only_on__ = "sqlite" def test_index_flag_becomes_named_unique_constraint(self): m1 = MetaData() m2 = MetaData() - Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False, index=True), - Column('a1', String(10), server_default="x") - ) + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False, index=True), + Column("a1", String(10), server_default="x"), + ) - Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - UniqueConstraint("name", name="uq_user_name") - ) + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + UniqueConstraint("name", name="uq_user_name"), + ) diffs = self._fixture(m1, m2) @@ -72,17 +87,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_add_unique_constraint(self): m1 = MetaData() m2 = MetaData() - Table('address', m1, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('qpr', String(10), index=True), - ) - Table('address', m2, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('qpr', String(10), index=True), - UniqueConstraint("email_address", name="uq_email_address") - ) + Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + Column("qpr", String(10), index=True), + ) + Table( + "address", + m2, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + Column("qpr", String(10), index=True), + UniqueConstraint("email_address", name="uq_email_address"), + ) diffs = self._fixture(m1, m2) @@ -96,17 +115,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('unq_idx', m1, - Column('id', Integer, primary_key=True), - Column('x', String(20)), - Index('x', 'x', unique=True) - ) + Table( + "unq_idx", + m1, + Column("id", Integer, primary_key=True), + Column("x", String(20)), + Index("x", "x", unique=True), + ) - Table('unq_idx', m2, - Column('id', Integer, primary_key=True), - Column('x', String(20)), - Index('x', 'x', unique=True) - ) + Table( + "unq_idx", + m2, + Column("id", Integer, primary_key=True), + Column("x", String(20)), + Index("x", "x", unique=True), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -114,27 +137,31 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_index_becomes_unique(self): m1 = MetaData() m2 = MetaData() - Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - UniqueConstraint('order_id', 'user_id', - name='order_order_id_user_id_unique' - ), - Index('order_user_id_amount_idx', 'user_id', 'amount') - ) - - Table('order', m2, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - UniqueConstraint('order_id', 'user_id', - name='order_order_id_user_id_unique' - ), - Index( - 'order_user_id_amount_idx', 'user_id', - 'amount', unique=True), - ) + Table( + "order", + m1, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + UniqueConstraint( + "order_id", "user_id", name="order_order_id_user_id_unique" + ), + Index("order_user_id_amount_idx", "user_id", "amount"), + ) + + Table( + "order", + m2, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + UniqueConstraint( + "order_id", "user_id", name="order_order_id_user_id_unique" + ), + Index( + "order_user_id_amount_idx", "user_id", "amount", unique=True + ), + ) diffs = self._fixture(m1, m2) eq_(diffs[0][0], "remove_index") @@ -148,16 +175,16 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_mismatch_db_named_col_flag(self): m1 = MetaData() m2 = MetaData() - Table('item', m1, - Column('x', Integer), - UniqueConstraint('x', name="db_generated_name") - ) + Table( + "item", + m1, + Column("x", Integer), + UniqueConstraint("x", name="db_generated_name"), + ) # test mismatch between unique=True and # named uq constraint - Table('item', m2, - Column('x', Integer, unique=True) - ) + Table("item", m2, Column("x", Integer, unique=True)) diffs = self._fixture(m1, m2) @@ -166,11 +193,13 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_new_table_added(self): m1 = MetaData() m2 = MetaData() - Table('extra', m2, - Column('foo', Integer, index=True), - Column('bar', Integer), - Index('newtable_idx', 'bar') - ) + Table( + "extra", + m2, + Column("foo", Integer, index=True), + Column("bar", Integer), + Index("newtable_idx", "bar"), + ) diffs = self._fixture(m1, m2) @@ -185,16 +214,20 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_named_cols_changed(self): m1 = MetaData() m2 = MetaData() - Table('col_change', m1, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x', name="nochange") - ) - Table('col_change', m2, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x', 'y', name="nochange") - ) + Table( + "col_change", + m1, + Column("x", Integer), + Column("y", Integer), + UniqueConstraint("x", name="nochange"), + ) + Table( + "col_change", + m2, + Column("x", Integer), + Column("y", Integer), + UniqueConstraint("x", "y", name="nochange"), + ) diffs = self._fixture(m1, m2) @@ -211,13 +244,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('nothing_changed', m1, - Column('x', String(20), unique=True, index=True) - ) + Table( + "nothing_changed", + m1, + Column("x", String(20), unique=True, index=True), + ) - Table('nothing_changed', m2, - Column('x', String(20), unique=True, index=True) - ) + Table( + "nothing_changed", + m2, + Column("x", String(20), unique=True, index=True), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -226,35 +263,43 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('nothing_changed', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20), unique=True), - mysql_engine='InnoDB' - ) - Table('nothing_changed_related', m1, - Column('id1', Integer), - Column('id2', Integer), - ForeignKeyConstraint( - ['id1', 'id2'], - ['nothing_changed.id1', 'nothing_changed.id2']), - mysql_engine='InnoDB' - ) - - Table('nothing_changed', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20), unique=True), - mysql_engine='InnoDB' - ) - Table('nothing_changed_related', m2, - Column('id1', Integer), - Column('id2', Integer), - ForeignKeyConstraint( - ['id1', 'id2'], - ['nothing_changed.id1', 'nothing_changed.id2']), - mysql_engine='InnoDB' - ) + Table( + "nothing_changed", + m1, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20), unique=True), + mysql_engine="InnoDB", + ) + Table( + "nothing_changed_related", + m1, + Column("id1", Integer), + Column("id2", Integer), + ForeignKeyConstraint( + ["id1", "id2"], ["nothing_changed.id1", "nothing_changed.id2"] + ), + mysql_engine="InnoDB", + ) + + Table( + "nothing_changed", + m2, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20), unique=True), + mysql_engine="InnoDB", + ) + Table( + "nothing_changed_related", + m2, + Column("id1", Integer), + Column("id2", Integer), + ForeignKeyConstraint( + ["id1", "id2"], ["nothing_changed.id1", "nothing_changed.id2"] + ), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -263,15 +308,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('nothing_changed', m1, - Column('x', String(20), key='nx'), - UniqueConstraint('nx') - ) + Table( + "nothing_changed", + m1, + Column("x", String(20), key="nx"), + UniqueConstraint("nx"), + ) - Table('nothing_changed', m2, - Column('x', String(20), key='nx'), - UniqueConstraint('nx') - ) + Table( + "nothing_changed", + m2, + Column("x", String(20), key="nx"), + UniqueConstraint("nx"), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -280,15 +329,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('nothing_changed', m1, - Column('x', String(20), key='nx'), - Index('foobar', 'nx') - ) + Table( + "nothing_changed", + m1, + Column("x", String(20), key="nx"), + Index("foobar", "nx"), + ) - Table('nothing_changed', m2, - Column('x', String(20), key='nx'), - Index('foobar', 'nx') - ) + Table( + "nothing_changed", + m2, + Column("x", String(20), key="nx"), + Index("foobar", "nx"), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -297,19 +350,23 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('nothing_changed', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - Index('x', 'x') - ) + Table( + "nothing_changed", + m1, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20)), + Index("x", "x"), + ) - Table('nothing_changed', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - Index('x', 'x') - ) + Table( + "nothing_changed", + m2, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20)), + Index("x", "x"), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -318,29 +375,43 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table("nothing_changed", m1, - Column('id', Integer, primary_key=True), - Column('other_id', - ForeignKey('nc2.id', - name='fk_my_table_other_table' - ), - nullable=False), - Column('foo', Integer), - mysql_engine='InnoDB') - Table('nc2', m1, - Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') - - Table("nothing_changed", m2, - Column('id', Integer, primary_key=True), - Column('other_id', ForeignKey('nc2.id', - name='fk_my_table_other_table'), - nullable=False), - Column('foo', Integer), - mysql_engine='InnoDB') - Table('nc2', m2, - Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') + Table( + "nothing_changed", + m1, + Column("id", Integer, primary_key=True), + Column( + "other_id", + ForeignKey("nc2.id", name="fk_my_table_other_table"), + nullable=False, + ), + Column("foo", Integer), + mysql_engine="InnoDB", + ) + Table( + "nc2", + m1, + Column("id", Integer, primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "nothing_changed", + m2, + Column("id", Integer, primary_key=True), + Column( + "other_id", + ForeignKey("nc2.id", name="fk_my_table_other_table"), + nullable=False, + ), + Column("foo", Integer), + mysql_engine="InnoDB", + ) + Table( + "nc2", + m2, + Column("id", Integer, primary_key=True), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -348,35 +419,49 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table("nothing_changed", m1, - Column('id', Integer, primary_key=True), - Column('other_id_1', Integer), - Column('other_id_2', Integer), - Column('foo', Integer), - ForeignKeyConstraint( - ['other_id_1', 'other_id_2'], ['nc2.id1', 'nc2.id2'], - name='fk_my_table_other_table' - ), - mysql_engine='InnoDB') - Table('nc2', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - mysql_engine='InnoDB') - - Table("nothing_changed", m2, - Column('id', Integer, primary_key=True), - Column('other_id_1', Integer), - Column('other_id_2', Integer), - Column('foo', Integer), - ForeignKeyConstraint( - ['other_id_1', 'other_id_2'], ['nc2.id1', 'nc2.id2'], - name='fk_my_table_other_table' - ), - mysql_engine='InnoDB') - Table('nc2', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - mysql_engine='InnoDB') + Table( + "nothing_changed", + m1, + Column("id", Integer, primary_key=True), + Column("other_id_1", Integer), + Column("other_id_2", Integer), + Column("foo", Integer), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["nc2.id1", "nc2.id2"], + name="fk_my_table_other_table", + ), + mysql_engine="InnoDB", + ) + Table( + "nc2", + m1, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + mysql_engine="InnoDB", + ) + + Table( + "nothing_changed", + m2, + Column("id", Integer, primary_key=True), + Column("other_id_1", Integer), + Column("other_id_2", Integer), + Column("foo", Integer), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["nc2.id1", "nc2.id2"], + name="fk_my_table_other_table", + ), + mysql_engine="InnoDB", + ) + Table( + "nc2", + m2, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + mysql_engine="InnoDB", + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -384,65 +469,73 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('new_idx', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - ) + Table( + "new_idx", + m1, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20)), + ) - idx = Index('x', 'x') - Table('new_idx', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - idx - ) + idx = Index("x", "x") + Table( + "new_idx", + m2, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20)), + idx, + ) diffs = self._fixture(m1, m2) - eq_(diffs, [('add_index', idx)]) + eq_(diffs, [("add_index", idx)]) def test_removed_idx_index_named_as_column(self): m1 = MetaData() m2 = MetaData() - idx = Index('x', 'x') - Table('new_idx', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - idx - ) + idx = Index("x", "x") + Table( + "new_idx", + m1, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20)), + idx, + ) - Table('new_idx', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)) - ) + Table( + "new_idx", + m2, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + Column("x", String(20)), + ) diffs = self._fixture(m1, m2) - eq_(diffs[0][0], 'remove_index') + eq_(diffs[0][0], "remove_index") def test_drop_table_w_indexes(self): m1 = MetaData() m2 = MetaData() t = Table( - 'some_table', m1, - Column('id', Integer, primary_key=True), - Column('x', String(20)), - Column('y', String(20)), + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("x", String(20)), + Column("y", String(20)), ) - Index('xy_idx', t.c.x, t.c.y) - Index('y_idx', t.c.y) + Index("xy_idx", t.c.x, t.c.y) + Index("y_idx", t.c.y) diffs = self._fixture(m1, m2) - eq_(diffs[0][0], 'remove_index') - eq_(diffs[1][0], 'remove_index') - eq_(diffs[2][0], 'remove_table') + eq_(diffs[0][0], "remove_index") + eq_(diffs[1][0], "remove_index") + eq_(diffs[2][0], "remove_table") eq_( - set([diffs[0][1].name, diffs[1][1].name]), - set(['xy_idx', 'y_idx']) + set([diffs[0][1].name, diffs[1][1].name]), set(["xy_idx", "y_idx"]) ) # this simply doesn't fully work before we had @@ -453,11 +546,12 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m2 = MetaData() Table( - 'some_table', m1, - Column('id', Integer, primary_key=True), - Column('x', String(20)), - Column('y', String(20)), - UniqueConstraint('y', name='uq_y') + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("x", String(20)), + Column("y", String(20)), + UniqueConstraint("y", name="uq_y"), ) diffs = self._fixture(m1, m2) @@ -465,65 +559,80 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): if self.reports_unique_constraints_as_indexes: # for MySQL this UQ will look like an index, so # make sure it at least sets it up correctly - eq_(diffs[0][0], 'remove_index') - eq_(diffs[1][0], 'remove_table') + eq_(diffs[0][0], "remove_index") + eq_(diffs[1][0], "remove_table") eq_(len(diffs), 2) - constraints = [c for c in diffs[1][1].constraints - if isinstance(c, UniqueConstraint)] + constraints = [ + c + for c in diffs[1][1].constraints + if isinstance(c, UniqueConstraint) + ] eq_(len(constraints), 0) else: - eq_(diffs[0][0], 'remove_table') + eq_(diffs[0][0], "remove_table") eq_(len(diffs), 1) - constraints = [c for c in diffs[0][1].constraints - if isinstance(c, UniqueConstraint)] + constraints = [ + c + for c in diffs[0][1].constraints + if isinstance(c, UniqueConstraint) + ] if self.reports_unique_constraints: eq_(len(constraints), 1) def test_unnamed_cols_changed(self): m1 = MetaData() m2 = MetaData() - Table('col_change', m1, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x') - ) - Table('col_change', m2, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x', 'y') - ) + Table( + "col_change", + m1, + Column("x", Integer), + Column("y", Integer), + UniqueConstraint("x"), + ) + Table( + "col_change", + m2, + Column("x", Integer), + Column("y", Integer), + UniqueConstraint("x", "y"), + ) diffs = self._fixture(m1, m2) - diffs = set((cmd, - ('x' in obj.name) if obj.name is not None else False) - for cmd, obj in diffs) + diffs = set( + (cmd, ("x" in obj.name) if obj.name is not None else False) + for cmd, obj in diffs + ) if self.reports_unnamed_constraints: if self.reports_unique_constraints_as_indexes: eq_( diffs, - set([("remove_index", True), ("add_constraint", False)]) + set([("remove_index", True), ("add_constraint", False)]), ) else: eq_( diffs, - set([("remove_constraint", True), - ("add_constraint", False)]) + set( + [ + ("remove_constraint", True), + ("add_constraint", False), + ] + ), ) def test_remove_named_unique_index(self): m1 = MetaData() m2 = MetaData() - Table('remove_idx', m1, - Column('x', Integer), - Index('xidx', 'x', unique=True) - ) - Table('remove_idx', m2, - Column('x', Integer) - ) + Table( + "remove_idx", + m1, + Column("x", Integer), + Index("xidx", "x", unique=True), + ) + Table("remove_idx", m2, Column("x", Integer)) diffs = self._fixture(m1, m2) @@ -537,13 +646,13 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('remove_idx', m1, - Column('x', Integer), - UniqueConstraint('x', name='xidx') - ) - Table('remove_idx', m2, - Column('x', Integer), - ) + Table( + "remove_idx", + m1, + Column("x", Integer), + UniqueConstraint("x", name="xidx"), + ) + Table("remove_idx", m2, Column("x", Integer)) diffs = self._fixture(m1, m2) @@ -559,46 +668,49 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_dont_add_uq_on_table_create(self): m1 = MetaData() m2 = MetaData() - Table('no_uq', m2, Column('x', String(50), unique=True)) + Table("no_uq", m2, Column("x", String(50), unique=True)) diffs = self._fixture(m1, m2) eq_(diffs[0][0], "add_table") eq_(len(diffs), 1) assert UniqueConstraint in set( - type(c) for c in diffs[0][1].constraints) + type(c) for c in diffs[0][1].constraints + ) def test_add_uq_ix_on_table_create(self): m1 = MetaData() m2 = MetaData() - Table('add_ix', m2, Column('x', String(50), unique=True, index=True)) + Table("add_ix", m2, Column("x", String(50), unique=True, index=True)) diffs = self._fixture(m1, m2) eq_(diffs[0][0], "add_table") eq_(len(diffs), 2) assert UniqueConstraint not in set( - type(c) for c in diffs[0][1].constraints) + type(c) for c in diffs[0][1].constraints + ) eq_(diffs[1][0], "add_index") eq_(diffs[1][1].unique, True) def test_add_ix_on_table_create(self): m1 = MetaData() m2 = MetaData() - Table('add_ix', m2, Column('x', String(50), index=True)) + Table("add_ix", m2, Column("x", String(50), index=True)) diffs = self._fixture(m1, m2) eq_(diffs[0][0], "add_table") eq_(len(diffs), 2) assert UniqueConstraint not in set( - type(c) for c in diffs[0][1].constraints) + type(c) for c in diffs[0][1].constraints + ) eq_(diffs[1][0], "add_index") eq_(diffs[1][1].unique, False) def test_add_idx_non_col(self): m1 = MetaData() m2 = MetaData() - Table('add_ix', m1, Column('x', String(50))) - t2 = Table('add_ix', m2, Column('x', String(50))) - Index('foo_idx', t2.c.x.desc()) + Table("add_ix", m1, Column("x", String(50))) + t2 = Table("add_ix", m2, Column("x", String(50))) + Index("foo_idx", t2.c.x.desc()) diffs = self._fixture(m1, m2) eq_(diffs[0][0], "add_index") @@ -606,10 +718,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_unchanged_idx_non_col(self): m1 = MetaData() m2 = MetaData() - t1 = Table('add_ix', m1, Column('x', String(50))) - Index('foo_idx', t1.c.x.desc()) - t2 = Table('add_ix', m2, Column('x', String(50))) - Index('foo_idx', t2.c.x.desc()) + t1 = Table("add_ix", m1, Column("x", String(50))) + Index("foo_idx", t1.c.x.desc()) + t2 = Table("add_ix", m2, Column("x", String(50))) + Index("foo_idx", t2.c.x.desc()) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -622,8 +734,8 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_unchanged_case_sensitive_implicit_idx(self): m1 = MetaData() m2 = MetaData() - Table('add_ix', m1, Column('regNumber', String(50), index=True)) - Table('add_ix', m2, Column('regNumber', String(50), index=True)) + Table("add_ix", m1, Column("regNumber", String(50), index=True)) + Table("add_ix", m2, Column("regNumber", String(50), index=True)) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -631,10 +743,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): def test_unchanged_case_sensitive_explicit_idx(self): m1 = MetaData() m2 = MetaData() - t1 = Table('add_ix', m1, Column('reg_number', String(50))) - Index('regNumber_idx', t1.c.reg_number) - t2 = Table('add_ix', m2, Column('reg_number', String(50))) - Index('regNumber_idx', t2.c.reg_number) + t1 = Table("add_ix", m1, Column("reg_number", String(50))) + Index("regNumber_idx", t1.c.reg_number) + t2 = Table("add_ix", m2, Column("reg_number", String(50))) + Index("regNumber_idx", t2.c.reg_number) diffs = self._fixture(m1, m2) @@ -649,21 +761,36 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): def test_idx_added_schema(self): m1 = MetaData() m2 = MetaData() - Table('add_ix', m1, Column('x', String(50)), schema="test_schema") - Table('add_ix', m2, Column('x', String(50)), - Index('ix_1', 'x'), schema="test_schema") + Table("add_ix", m1, Column("x", String(50)), schema="test_schema") + Table( + "add_ix", + m2, + Column("x", String(50)), + Index("ix_1", "x"), + schema="test_schema", + ) diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs[0][0], "add_index") - eq_(diffs[0][1].name, 'ix_1') + eq_(diffs[0][1].name, "ix_1") def test_idx_unchanged_schema(self): m1 = MetaData() m2 = MetaData() - Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'), - schema="test_schema") - Table('add_ix', m2, Column('x', String(50)), - Index('ix_1', 'x'), schema="test_schema") + Table( + "add_ix", + m1, + Column("x", String(50)), + Index("ix_1", "x"), + schema="test_schema", + ) + Table( + "add_ix", + m2, + Column("x", String(50)), + Index("ix_1", "x"), + schema="test_schema", + ) diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs, []) @@ -671,23 +798,36 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): def test_uq_added_schema(self): m1 = MetaData() m2 = MetaData() - Table('add_uq', m1, Column('x', String(50)), schema="test_schema") - Table('add_uq', m2, Column('x', String(50)), - UniqueConstraint('x', name='ix_1'), schema="test_schema") + Table("add_uq", m1, Column("x", String(50)), schema="test_schema") + Table( + "add_uq", + m2, + Column("x", String(50)), + UniqueConstraint("x", name="ix_1"), + schema="test_schema", + ) diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs[0][0], "add_constraint") - eq_(diffs[0][1].name, 'ix_1') + eq_(diffs[0][1].name, "ix_1") def test_uq_unchanged_schema(self): m1 = MetaData() m2 = MetaData() - Table('add_uq', m1, Column('x', String(50)), - UniqueConstraint('x', name='ix_1'), - schema="test_schema") - Table('add_uq', m2, Column('x', String(50)), - UniqueConstraint('x', name='ix_1'), - schema="test_schema") + Table( + "add_uq", + m1, + Column("x", String(50)), + UniqueConstraint("x", name="ix_1"), + schema="test_schema", + ) + Table( + "add_uq", + m2, + Column("x", String(50)), + UniqueConstraint("x", name="ix_1"), + schema="test_schema", + ) diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs, []) @@ -701,17 +841,19 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m2 = MetaData() Table( - 'add_excl', m1, - Column('id', Integer, primary_key=True), - Column('period', TSRANGE), - ExcludeConstraint(('period', '&&'), name='quarters_period_excl') + "add_excl", + m1, + Column("id", Integer, primary_key=True), + Column("period", TSRANGE), + ExcludeConstraint(("period", "&&"), name="quarters_period_excl"), ) Table( - 'add_excl', m2, - Column('id', Integer, primary_key=True), - Column('period', TSRANGE), - ExcludeConstraint(('period', '&&'), name='quarters_period_excl') + "add_excl", + m2, + Column("id", Integer, primary_key=True), + Column("period", TSRANGE), + ExcludeConstraint(("period", "&&"), name="quarters_period_excl"), ) diffs = self._fixture(m1, m2) @@ -721,10 +863,10 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m1 = MetaData() m2 = MetaData() - Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x')) + Table("add_ix", m1, Column("x", String(50)), Index("ix_1", "x")) - Table('add_ix', m2, Column('x', String(50)), Index('ix_1', 'x')) - Table('add_ix', m2, Column('x', String(50)), schema="test_schema") + Table("add_ix", m2, Column("x", String(50)), Index("ix_1", "x")) + Table("add_ix", m2, Column("x", String(50)), schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs[0][0], "add_table") @@ -734,15 +876,17 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m1 = MetaData() m2 = MetaData() Table( - 'add_uq', m1, - Column('id', Integer, primary_key=True), - Column('name', String), - UniqueConstraint('name', name='uq_name') + "add_uq", + m1, + Column("id", Integer, primary_key=True), + Column("name", String), + UniqueConstraint("name", name="uq_name"), ) Table( - 'add_uq', m2, - Column('id', Integer, primary_key=True), - Column('name', String), + "add_uq", + m2, + Column("id", Integer, primary_key=True), + Column("name", String), ) diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs[0][0], "remove_constraint") @@ -754,22 +898,24 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m2 = MetaData() t1 = Table( - 'foo', m1, - Column('id', Integer, primary_key=True), - Column('email', String(50)) + "foo", + m1, + Column("id", Integer, primary_key=True), + Column("email", String(50)), ) Index("email_idx", func.lower(t1.c.email), unique=True) t2 = Table( - 'foo', m2, - Column('id', Integer, primary_key=True), - Column('email', String(50)) + "foo", + m2, + Column("id", Integer, primary_key=True), + Column("email", String(50)), ) Index("email_idx", func.lower(t2.c.email), unique=True) with assertions.expect_warnings( - "Skipped unsupported reflection", - "autogenerate skipping functional index" + "Skipped unsupported reflection", + "autogenerate skipping functional index", ): diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -779,28 +925,34 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m2 = MetaData() t1 = Table( - 'foo', m1, - Column('id', Integer, primary_key=True), - Column('email', String(50)), - Column('name', String(50)) + "foo", + m1, + Column("id", Integer, primary_key=True), + Column("email", String(50)), + Column("name", String(50)), ) Index( "email_idx", - func.coalesce(t1.c.email, t1.c.name).desc(), unique=True) + func.coalesce(t1.c.email, t1.c.name).desc(), + unique=True, + ) t2 = Table( - 'foo', m2, - Column('id', Integer, primary_key=True), - Column('email', String(50)), - Column('name', String(50)) + "foo", + m2, + Column("id", Integer, primary_key=True), + Column("email", String(50)), + Column("name", String(50)), ) Index( "email_idx", - func.coalesce(t2.c.email, t2.c.name).desc(), unique=True) + func.coalesce(t2.c.email, t2.c.name).desc(), + unique=True, + ) with assertions.expect_warnings( - "Skipped unsupported reflection", - "autogenerate skipping functional index" + "Skipped unsupported reflection", + "autogenerate skipping functional index", ): diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -809,13 +961,14 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest): reports_unnamed_constraints = True reports_unique_constraints_as_indexes = True - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True def test_removed_idx_index_named_as_column(self): try: - super(MySQLUniqueIndexTest, - self).test_removed_idx_index_named_as_column() + super( + MySQLUniqueIndexTest, self + ).test_removed_idx_index_named_as_column() except IndexError: assert True else: @@ -828,61 +981,70 @@ class OracleUniqueIndexTest(AutogenerateUniqueIndexTest): __only_on__ = "oracle" __backend__ = True + class NoUqReflectionIndexTest(NoUqReflection, AutogenerateUniqueIndexTest): reports_unique_constraints = False - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def test_unique_not_reported(self): m1 = MetaData() - Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - UniqueConstraint('order_id', 'user_id', - name='order_order_id_user_id_unique' - ) - ) + Table( + "order", + m1, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + UniqueConstraint( + "order_id", "user_id", name="order_order_id_user_id_unique" + ), + ) diffs = self._fixture(m1, m1) eq_(diffs, []) def test_remove_unique_index_not_reported(self): m1 = MetaData() - Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - Index('oid_ix', 'order_id', 'user_id', - unique=True - ) - ) + Table( + "order", + m1, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + Index("oid_ix", "order_id", "user_id", unique=True), + ) m2 = MetaData() - Table('order', m2, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - ) + Table( + "order", + m2, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) def test_remove_plain_index_is_reported(self): m1 = MetaData() - Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - Index('oid_ix', 'order_id', 'user_id') - ) + Table( + "order", + m1, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + Index("oid_ix", "order_id", "user_id"), + ) m2 = MetaData() - Table('order', m2, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - ) + Table( + "order", + m2, + Column("order_id", Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True), + Column("user_id", Integer), + ) diffs = self._fixture(m1, m2) - eq_(diffs[0][0], 'remove_index') + eq_(diffs[0][0], "remove_index") class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest): @@ -899,7 +1061,7 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest): """ - __only_on__ = 'sqlite' + __only_on__ = "sqlite" @classmethod def _get_bind(cls): @@ -916,7 +1078,7 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest): for uq in _get_unique_constraints( self, connection, tablename, **kw ): - uq['unique'] = True + uq["unique"] = True indexes.append(uq) return indexes @@ -932,23 +1094,26 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - t1 = Table('t', m1, Column('x', Integer), Column('y', Integer)) - Index('ix1', t1.c.x) - Index('ix2', t1.c.y) + t1 = Table("t", m1, Column("x", Integer), Column("y", Integer)) + Index("ix1", t1.c.x) + Index("ix2", t1.c.y) - Table('t', m2, Column('x', Integer), Column('y', Integer)) + Table("t", m2, Column("x", Integer), Column("y", Integer)) def include_object(object_, name, type_, reflected, compare_to): - if type_ == 'unique_constraint': + if type_ == "unique_constraint": return False return not ( - isinstance(object_, Index) and - type_ == 'index' and reflected and name == 'ix1') + isinstance(object_, Index) + and type_ == "index" + and reflected + and name == "ix1" + ) diffs = self._fixture(m1, m2, object_filters=include_object) - eq_(diffs[0][0], 'remove_index') - eq_(diffs[0][1].name, 'ix2') + eq_(diffs[0][0], "remove_index") + eq_(diffs[0][1].name, "ix2") eq_(len(diffs), 1) @config.requirements.unique_constraint_reflection @@ -958,45 +1123,54 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m2 = MetaData() Table( - 't', m1, Column('x', Integer), Column('y', Integer), - UniqueConstraint('x', name='uq1'), - UniqueConstraint('y', name='uq2'), + "t", + m1, + Column("x", Integer), + Column("y", Integer), + UniqueConstraint("x", name="uq1"), + UniqueConstraint("y", name="uq2"), ) - Table('t', m2, Column('x', Integer), Column('y', Integer)) + Table("t", m2, Column("x", Integer), Column("y", Integer)) def include_object(object_, name, type_, reflected, compare_to): - if type_ == 'index': + if type_ == "index": return False return not ( - isinstance(object_, UniqueConstraint) and - type_ == 'unique_constraint' and reflected and name == 'uq1') + isinstance(object_, UniqueConstraint) + and type_ == "unique_constraint" + and reflected + and name == "uq1" + ) diffs = self._fixture(m1, m2, object_filters=include_object) - eq_(diffs[0][0], 'remove_constraint') - eq_(diffs[0][1].name, 'uq2') + eq_(diffs[0][0], "remove_constraint") + eq_(diffs[0][1].name, "uq2") eq_(len(diffs), 1) def test_add_metadata_index(self): m1 = MetaData() m2 = MetaData() - Table('t', m1, Column('x', Integer)) + Table("t", m1, Column("x", Integer)) - t2 = Table('t', m2, Column('x', Integer)) - Index('ix1', t2.c.x) - Index('ix2', t2.c.x) + t2 = Table("t", m2, Column("x", Integer)) + Index("ix1", t2.c.x) + Index("ix2", t2.c.x) def include_object(object_, name, type_, reflected, compare_to): return not ( - isinstance(object_, Index) and - type_ == 'index' and not reflected and name == 'ix1') + isinstance(object_, Index) + and type_ == "index" + and not reflected + and name == "ix1" + ) diffs = self._fixture(m1, m2, object_filters=include_object) - eq_(diffs[0][0], 'add_index') - eq_(diffs[0][1].name, 'ix2') + eq_(diffs[0][0], "add_index") + eq_(diffs[0][1].name, "ix2") eq_(len(diffs), 1) @config.requirements.unique_constraint_reflection @@ -1004,24 +1178,28 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m1 = MetaData() m2 = MetaData() - Table('t', m1, Column('x', Integer)) + Table("t", m1, Column("x", Integer)) Table( - 't', m2, Column('x', Integer), - UniqueConstraint('x', name='uq1'), - UniqueConstraint('x', name='uq2') + "t", + m2, + Column("x", Integer), + UniqueConstraint("x", name="uq1"), + UniqueConstraint("x", name="uq2"), ) def include_object(object_, name, type_, reflected, compare_to): return not ( - isinstance(object_, UniqueConstraint) and - type_ == 'unique_constraint' and - not reflected and name == 'uq1') + isinstance(object_, UniqueConstraint) + and type_ == "unique_constraint" + and not reflected + and name == "uq1" + ) diffs = self._fixture(m1, m2, object_filters=include_object) - eq_(diffs[0][0], 'add_constraint') - eq_(diffs[0][1].name, 'uq2') + eq_(diffs[0][0], "add_constraint") + eq_(diffs[0][1].name, "uq2") eq_(len(diffs), 1) def test_change_index(self): @@ -1029,29 +1207,40 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m2 = MetaData() t1 = Table( - 't', m1, Column('x', Integer), - Column('y', Integer), Column('z', Integer)) - Index('ix1', t1.c.x) - Index('ix2', t1.c.y) + "t", + m1, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + Index("ix1", t1.c.x) + Index("ix2", t1.c.y) t2 = Table( - 't', m2, Column('x', Integer), - Column('y', Integer), Column('z', Integer)) - Index('ix1', t2.c.x, t2.c.y) - Index('ix2', t2.c.x, t2.c.z) + "t", + m2, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + Index("ix1", t2.c.x, t2.c.y) + Index("ix2", t2.c.x, t2.c.z) def include_object(object_, name, type_, reflected, compare_to): return not ( - isinstance(object_, Index) and - type_ == 'index' and not reflected and name == 'ix1' - and isinstance(compare_to, Index)) + isinstance(object_, Index) + and type_ == "index" + and not reflected + and name == "ix1" + and isinstance(compare_to, Index) + ) diffs = self._fixture(m1, m2, object_filters=include_object) - eq_(diffs[0][0], 'remove_index') - eq_(diffs[0][1].name, 'ix2') - eq_(diffs[1][0], 'add_index') - eq_(diffs[1][1].name, 'ix2') + eq_(diffs[0][0], "remove_index") + eq_(diffs[0][1].name, "ix2") + eq_(diffs[1][0], "add_index") + eq_(diffs[1][1].name, "ix2") eq_(len(diffs), 2) @config.requirements.unique_constraint_reflection @@ -1060,39 +1249,46 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): m2 = MetaData() Table( - 't', m1, Column('x', Integer), - Column('y', Integer), Column('z', Integer), - UniqueConstraint('x', name='uq1'), - UniqueConstraint('y', name='uq2') + "t", + m1, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + UniqueConstraint("x", name="uq1"), + UniqueConstraint("y", name="uq2"), ) Table( - 't', m2, Column('x', Integer), Column('y', Integer), - Column('z', Integer), - UniqueConstraint('x', 'z', name='uq1'), - UniqueConstraint('y', 'z', name='uq2') + "t", + m2, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + UniqueConstraint("x", "z", name="uq1"), + UniqueConstraint("y", "z", name="uq2"), ) def include_object(object_, name, type_, reflected, compare_to): - if type_ == 'index': + if type_ == "index": return False return not ( - isinstance(object_, UniqueConstraint) and - type_ == 'unique_constraint' and - not reflected and name == 'uq1' - and isinstance(compare_to, UniqueConstraint)) + isinstance(object_, UniqueConstraint) + and type_ == "unique_constraint" + and not reflected + and name == "uq1" + and isinstance(compare_to, UniqueConstraint) + ) diffs = self._fixture(m1, m2, object_filters=include_object) - eq_(diffs[0][0], 'remove_constraint') - eq_(diffs[0][1].name, 'uq2') - eq_(diffs[1][0], 'add_constraint') - eq_(diffs[1][1].name, 'uq2') + eq_(diffs[0][0], "remove_constraint") + eq_(diffs[0][1].name, "uq2") + eq_(diffs[1][0], "add_constraint") + eq_(diffs[1][1].name, "uq2") eq_(len(diffs), 2) class TruncatedIdxTest(AutogenFixtureTest, TestBase): - def setUp(self): self.bind = engines.testing_engine() self.bind.dialect.max_identifier_length = 30 @@ -1102,12 +1298,13 @@ class TruncatedIdxTest(AutogenFixtureTest, TestBase): m1 = MetaData() Table( - 'q', m1, - Column('id', Integer, primary_key=True), - Column('data', Integer), + "q", + m1, + Column("id", Integer, primary_key=True), + Column("data", Integer), Index( - conv("idx_q_table_this_is_more_than_thirty_characters"), - "data") + conv("idx_q_table_this_is_more_than_thirty_characters"), "data" + ), ) diffs = self._fixture(m1, m1) diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index b32358f..37e1c61 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -4,11 +4,31 @@ from alembic.testing import TestBase, exclusions, assert_raises from alembic.testing import assertions from alembic.operations import ops -from sqlalchemy import MetaData, Column, Table, String, \ - Numeric, CHAR, ForeignKey, DATETIME, Integer, BigInteger, \ - CheckConstraint, Unicode, Enum, cast,\ - DateTime, UniqueConstraint, Boolean, ForeignKeyConstraint,\ - PrimaryKeyConstraint, Index, func, text, DefaultClause +from sqlalchemy import ( + MetaData, + Column, + Table, + String, + Numeric, + CHAR, + ForeignKey, + DATETIME, + Integer, + BigInteger, + CheckConstraint, + Unicode, + Enum, + cast, + DateTime, + UniqueConstraint, + Boolean, + ForeignKeyConstraint, + PrimaryKeyConstraint, + Index, + func, + text, + DefaultClause, +) from sqlalchemy.types import TIMESTAMP from sqlalchemy import types @@ -28,7 +48,7 @@ from alembic.testing.fixtures import op_fixture from alembic import op # noqa import sqlalchemy as sa # noqa -py3k = sys.version_info >= (3, ) +py3k = sys.version_info >= (3,) class AutogenRenderTest(TestBase): @@ -37,13 +57,12 @@ class AutogenRenderTest(TestBase): def setUp(self): ctx_opts = { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'target_metadata': MetaData() + "sqlalchemy_module_prefix": "sa.", + "alembic_module_prefix": "op.", + "target_metadata": MetaData(), } context = MigrationContext.configure( - dialect=DefaultDialect(), - opts=ctx_opts + dialect=DefaultDialect(), opts=ctx_opts ) self.autogen_context = api.AutogenContext(context) @@ -53,17 +72,19 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_index """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.CreateIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index('test_active_code_idx', 'test', " - "['active', 'code'], unique=False)" + "['active', 'code'], unique=False)", ) def test_render_add_index_batch(self): @@ -71,18 +92,20 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_index """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.CreateIndexOp.from_index(idx) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "batch_op.create_index('test_active_code_idx', " - "['active', 'code'], unique=False)" + "['active', 'code'], unique=False)", ) def test_render_add_index_schema(self): @@ -90,18 +113,20 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_index using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.CreateIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index('test_active_code_idx', 'test', " - "['active', 'code'], unique=False, schema='CamelSchema')" + "['active', 'code'], unique=False, schema='CamelSchema')", ) def test_render_add_index_schema_batch(self): @@ -109,73 +134,78 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_index using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.CreateIndexOp.from_index(idx) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "batch_op.create_index('test_active_code_idx', " - "['active', 'code'], unique=False)" + "['active', 'code'], unique=False)", ) def test_render_add_index_func(self): m = MetaData() t = Table( - 'test', m, - Column('id', Integer, primary_key=True), - Column('code', String(255)) + "test", + m, + Column("id", Integer, primary_key=True), + Column("code", String(255)), ) - idx = Index('test_lower_code_idx', func.lower(t.c.code)) + idx = Index("test_lower_code_idx", func.lower(t.c.code)) op_obj = ops.CreateIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index('test_lower_code_idx', 'test', " - "[sa.text(!U'lower(code)')], unique=False)" + "[sa.text(!U'lower(code)')], unique=False)", ) def test_render_add_index_cast(self): m = MetaData() t = Table( - 'test', m, - Column('id', Integer, primary_key=True), - Column('code', String(255)) + "test", + m, + Column("id", Integer, primary_key=True), + Column("code", String(255)), ) - idx = Index('test_lower_code_idx', cast(t.c.code, String)) + idx = Index("test_lower_code_idx", cast(t.c.code, String)) op_obj = ops.CreateIndexOp.from_index(idx) if config.requirements.sqlalchemy_110.enabled: eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index('test_lower_code_idx', 'test', " - "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)" + "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)", ) else: eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index('test_lower_code_idx', 'test', " - "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)" + "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)", ) def test_render_add_index_desc(self): m = MetaData() t = Table( - 'test', m, - Column('id', Integer, primary_key=True), - Column('code', String(255)) + "test", + m, + Column("id", Integer, primary_key=True), + Column("code", String(255)), ) - idx = Index('test_desc_code_idx', t.c.code.desc()) + idx = Index("test_desc_code_idx", t.c.code.desc()) op_obj = ops.CreateIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index('test_desc_code_idx', 'test', " - "[sa.text(!U'code DESC')], unique=False)" + "[sa.text(!U'code DESC')], unique=False)", ) def test_drop_index(self): @@ -183,16 +213,18 @@ class AutogenRenderTest(TestBase): autogenerate.render._drop_index """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.DropIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_index('test_active_code_idx', table_name='test')" + "op.drop_index('test_active_code_idx', table_name='test')", ) def test_drop_index_batch(self): @@ -200,17 +232,19 @@ class AutogenRenderTest(TestBase): autogenerate.render._drop_index """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.DropIndexOp.from_index(idx) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "batch_op.drop_index('test_active_code_idx')" + "batch_op.drop_index('test_active_code_idx')", ) def test_drop_index_schema(self): @@ -218,18 +252,20 @@ class AutogenRenderTest(TestBase): autogenerate.render._drop_index using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.DropIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_index('test_active_code_idx', " + - "table_name='test', schema='CamelSchema')" + "op.drop_index('test_active_code_idx', " + + "table_name='test', schema='CamelSchema')", ) def test_drop_index_schema_batch(self): @@ -237,18 +273,20 @@ class AutogenRenderTest(TestBase): autogenerate.render._drop_index using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - idx = Index('test_active_code_idx', t.c.active, t.c.code) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + idx = Index("test_active_code_idx", t.c.active, t.c.code) op_obj = ops.DropIndexOp.from_index(idx) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "batch_op.drop_index('test_active_code_idx')" + "batch_op.drop_index('test_active_code_idx')", ) def test_add_unique_constraint(self): @@ -256,16 +294,18 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_unique_constraint """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - uq = UniqueConstraint(t.c.code, name='uq_test_code') + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + uq = UniqueConstraint(t.c.code, name="uq_test_code") op_obj = ops.AddConstraintOp.from_constraint(uq) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.create_unique_constraint('uq_test_code', 'test', ['code'])" + "op.create_unique_constraint('uq_test_code', 'test', ['code'])", ) def test_add_unique_constraint_batch(self): @@ -273,17 +313,19 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_unique_constraint """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - uq = UniqueConstraint(t.c.code, name='uq_test_code') + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + uq = UniqueConstraint(t.c.code, name="uq_test_code") op_obj = ops.AddConstraintOp.from_constraint(uq) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "batch_op.create_unique_constraint('uq_test_code', ['code'])" + "batch_op.create_unique_constraint('uq_test_code', ['code'])", ) def test_add_unique_constraint_schema(self): @@ -291,18 +333,20 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_unique_constraint using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - uq = UniqueConstraint(t.c.code, name='uq_test_code') + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + uq = UniqueConstraint(t.c.code, name="uq_test_code") op_obj = ops.AddConstraintOp.from_constraint(uq) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_unique_constraint('uq_test_code', 'test', " - "['code'], schema='CamelSchema')" + "['code'], schema='CamelSchema')", ) def test_add_unique_constraint_schema_batch(self): @@ -310,19 +354,21 @@ class AutogenRenderTest(TestBase): autogenerate.render._add_unique_constraint using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - uq = UniqueConstraint(t.c.code, name='uq_test_code') + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + uq = UniqueConstraint(t.c.code, name="uq_test_code") op_obj = ops.AddConstraintOp.from_constraint(uq) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "batch_op.create_unique_constraint('uq_test_code', " - "['code'])" + "['code'])", ) def test_drop_unique_constraint(self): @@ -330,16 +376,18 @@ class AutogenRenderTest(TestBase): autogenerate.render._drop_constraint """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) - uq = UniqueConstraint(t.c.code, name='uq_test_code') + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) + uq = UniqueConstraint(t.c.code, name="uq_test_code") op_obj = ops.DropConstraintOp.from_constraint(uq) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_constraint('uq_test_code', 'test', type_='unique')" + "op.drop_constraint('uq_test_code', 'test', type_='unique')", ) def test_drop_unique_constraint_schema(self): @@ -347,67 +395,72 @@ class AutogenRenderTest(TestBase): autogenerate.render._drop_constraint using schema """ m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) - uq = UniqueConstraint(t.c.code, name='uq_test_code') + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) + uq = UniqueConstraint(t.c.code, name="uq_test_code") op_obj = ops.DropConstraintOp.from_constraint(uq) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.drop_constraint('uq_test_code', 'test', " - "schema='CamelSchema', type_='unique')" + "schema='CamelSchema', type_='unique')", ) def test_drop_unique_constraint_schema_reprobj(self): """ autogenerate.render._drop_constraint using schema """ + class SomeObj(str): def __repr__(self): return "foo.camel_schema" op_obj = ops.DropConstraintOp( - "uq_test_code", "test", type_="unique", - schema=SomeObj("CamelSchema") + "uq_test_code", + "test", + type_="unique", + schema=SomeObj("CamelSchema"), ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.drop_constraint('uq_test_code', 'test', " - "schema=foo.camel_schema, type_='unique')" + "schema=foo.camel_schema, type_='unique')", ) def test_add_fk_constraint(self): m = MetaData() - Table('a', m, Column('id', Integer, primary_key=True)) - b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id'))) - fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id') + Table("a", m, Column("id", Integer, primary_key=True)) + b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id"))) + fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id") b.append_constraint(fk) op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])" + "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])", ) def test_add_fk_constraint_batch(self): m = MetaData() - Table('a', m, Column('id', Integer, primary_key=True)) - b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id'))) - fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id') + Table("a", m, Column("id", Integer, primary_key=True)) + b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id"))) + fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id") b.append_constraint(fk) op_obj = ops.AddConstraintOp.from_constraint(fk) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'])" + "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'])", ) def test_add_fk_constraint_kwarg(self): m = MetaData() - t1 = Table('t', m, Column('c', Integer)) - t2 = Table('t2', m, Column('c_rem', Integer)) + t1 = Table("t", m, Column("c", Integer)) + t2 = Table("t2", m, Column("c_rem", Integer)) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE") @@ -417,11 +470,12 @@ class AutogenRenderTest(TestBase): op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render_op_text(self.autogen_context, op_obj), ), "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], " - "onupdate='CASCADE')" + "onupdate='CASCADE')", ) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], ondelete="CASCADE") @@ -429,54 +483,62 @@ class AutogenRenderTest(TestBase): op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", - autogenerate.render_op_text(self.autogen_context, op_obj) + r"u'", + "'", + autogenerate.render_op_text(self.autogen_context, op_obj), ), "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], " - "ondelete='CASCADE')" + "ondelete='CASCADE')", ) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], deferrable=True) op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", - autogenerate.render_op_text(self.autogen_context, op_obj) + r"u'", + "'", + autogenerate.render_op_text(self.autogen_context, op_obj), ), "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], " - "deferrable=True)" + "deferrable=True)", ) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], initially="XYZ") op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render_op_text(self.autogen_context, op_obj), ), "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], " - "initially='XYZ')" + "initially='XYZ')", ) fk = ForeignKeyConstraint( - [t1.c.c], [t2.c.c_rem], - initially="XYZ", ondelete="CASCADE", deferrable=True) + [t1.c.c], + [t2.c.c_rem], + initially="XYZ", + ondelete="CASCADE", + deferrable=True, + ) op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", - autogenerate.render_op_text(self.autogen_context, op_obj) + r"u'", + "'", + autogenerate.render_op_text(self.autogen_context, op_obj), ), "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], " - "ondelete='CASCADE', initially='XYZ', deferrable=True)" + "ondelete='CASCADE', initially='XYZ', deferrable=True)", ) def test_add_fk_constraint_inline_colkeys(self): m = MetaData() - Table('a', m, Column('id', Integer, key='aid', primary_key=True)) + Table("a", m, Column("id", Integer, key="aid", primary_key=True)) b = Table( - 'b', m, - Column('a_id', Integer, ForeignKey('a.aid'), key='baid')) + "b", m, Column("a_id", Integer, ForeignKey("a.aid"), key="baid") + ) op_obj = ops.CreateTableOp.from_table(b) py_code = autogenerate.render_op_text(self.autogen_context, op_obj) @@ -485,20 +547,21 @@ class AutogenRenderTest(TestBase): py_code, "op.create_table('b'," "sa.Column('a_id', sa.Integer(), nullable=True)," - "sa.ForeignKeyConstraint(['a_id'], ['a.id'], ))" + "sa.ForeignKeyConstraint(['a_id'], ['a.id'], ))", ) context = op_fixture() eval(py_code) context.assert_( "CREATE TABLE b (a_id INTEGER, " - "FOREIGN KEY(a_id) REFERENCES a (id))") + "FOREIGN KEY(a_id) REFERENCES a (id))" + ) def test_add_fk_constraint_separate_colkeys(self): m = MetaData() - Table('a', m, Column('id', Integer, key='aid', primary_key=True)) - b = Table('b', m, Column('a_id', Integer, key='baid')) - fk = ForeignKeyConstraint(['baid'], ['a.aid'], name='fk_a_id') + Table("a", m, Column("id", Integer, key="aid", primary_key=True)) + b = Table("b", m, Column("a_id", Integer, key="baid")) + fk = ForeignKeyConstraint(["baid"], ["a.aid"], name="fk_a_id") b.append_constraint(fk) op_obj = ops.CreateTableOp.from_table(b) @@ -508,14 +571,15 @@ class AutogenRenderTest(TestBase): py_code, "op.create_table('b'," "sa.Column('a_id', sa.Integer(), nullable=True)," - "sa.ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id'))" + "sa.ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id'))", ) context = op_fixture() eval(py_code) context.assert_( "CREATE TABLE b (a_id INTEGER, CONSTRAINT " - "fk_a_id FOREIGN KEY(a_id) REFERENCES a (id))") + "fk_a_id FOREIGN KEY(a_id) REFERENCES a (id))" + ) context = op_fixture() @@ -523,7 +587,7 @@ class AutogenRenderTest(TestBase): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])" + "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])", ) py_code = autogenerate.render_op_text(self.autogen_context, op_obj) @@ -531,124 +595,151 @@ class AutogenRenderTest(TestBase): eval(py_code) context.assert_( "ALTER TABLE b ADD CONSTRAINT fk_a_id " - "FOREIGN KEY(a_id) REFERENCES a (id)") + "FOREIGN KEY(a_id) REFERENCES a (id)" + ) def test_add_fk_constraint_schema(self): m = MetaData() Table( - 'a', m, Column('id', Integer, primary_key=True), - schema="CamelSchemaTwo") + "a", + m, + Column("id", Integer, primary_key=True), + schema="CamelSchemaTwo", + ) b = Table( - 'b', m, Column('a_id', Integer, ForeignKey('a.id')), - schema="CamelSchemaOne") + "b", + m, + Column("a_id", Integer, ForeignKey("a.id")), + schema="CamelSchemaOne", + ) fk = ForeignKeyConstraint( - ["a_id"], - ["CamelSchemaTwo.a.id"], name='fk_a_id') + ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id" + ) b.append_constraint(fk) op_obj = ops.AddConstraintOp.from_constraint(fk) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id']," " source_schema='CamelSchemaOne', " - "referent_schema='CamelSchemaTwo')" + "referent_schema='CamelSchemaTwo')", ) def test_add_fk_constraint_schema_batch(self): m = MetaData() Table( - 'a', m, Column('id', Integer, primary_key=True), - schema="CamelSchemaTwo") + "a", + m, + Column("id", Integer, primary_key=True), + schema="CamelSchemaTwo", + ) b = Table( - 'b', m, Column('a_id', Integer, ForeignKey('a.id')), - schema="CamelSchemaOne") + "b", + m, + Column("a_id", Integer, ForeignKey("a.id")), + schema="CamelSchemaOne", + ) fk = ForeignKeyConstraint( - ["a_id"], - ["CamelSchemaTwo.a.id"], name='fk_a_id') + ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id" + ) b.append_constraint(fk) op_obj = ops.AddConstraintOp.from_constraint(fk) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id']," - " referent_schema='CamelSchemaTwo')" + " referent_schema='CamelSchemaTwo')", ) def test_drop_fk_constraint(self): m = MetaData() - Table('a', m, Column('id', Integer, primary_key=True)) - b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id'))) - fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id') + Table("a", m, Column("id", Integer, primary_key=True)) + b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id"))) + fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id") b.append_constraint(fk) op_obj = ops.DropConstraintOp.from_constraint(fk) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_constraint('fk_a_id', 'b', type_='foreignkey')" + "op.drop_constraint('fk_a_id', 'b', type_='foreignkey')", ) def test_drop_fk_constraint_batch(self): m = MetaData() - Table('a', m, Column('id', Integer, primary_key=True)) - b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id'))) - fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id') + Table("a", m, Column("id", Integer, primary_key=True)) + b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id"))) + fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id") b.append_constraint(fk) op_obj = ops.DropConstraintOp.from_constraint(fk) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "batch_op.drop_constraint('fk_a_id', type_='foreignkey')" + "batch_op.drop_constraint('fk_a_id', type_='foreignkey')", ) def test_drop_fk_constraint_schema(self): m = MetaData() Table( - 'a', m, Column('id', Integer, primary_key=True), - schema="CamelSchemaTwo") + "a", + m, + Column("id", Integer, primary_key=True), + schema="CamelSchemaTwo", + ) b = Table( - 'b', m, Column('a_id', Integer, ForeignKey('a.id')), - schema="CamelSchemaOne") + "b", + m, + Column("a_id", Integer, ForeignKey("a.id")), + schema="CamelSchemaOne", + ) fk = ForeignKeyConstraint( - ["a_id"], - ["CamelSchemaTwo.a.id"], name='fk_a_id') + ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id" + ) b.append_constraint(fk) op_obj = ops.DropConstraintOp.from_constraint(fk) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.drop_constraint('fk_a_id', 'b', schema='CamelSchemaOne', " - "type_='foreignkey')" + "type_='foreignkey')", ) def test_drop_fk_constraint_batch_schema(self): m = MetaData() Table( - 'a', m, Column('id', Integer, primary_key=True), - schema="CamelSchemaTwo") + "a", + m, + Column("id", Integer, primary_key=True), + schema="CamelSchemaTwo", + ) b = Table( - 'b', m, Column('a_id', Integer, ForeignKey('a.id')), - schema="CamelSchemaOne") + "b", + m, + Column("a_id", Integer, ForeignKey("a.id")), + schema="CamelSchemaOne", + ) fk = ForeignKeyConstraint( - ["a_id"], - ["CamelSchemaTwo.a.id"], name='fk_a_id') + ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id" + ) b.append_constraint(fk) op_obj = ops.DropConstraintOp.from_constraint(fk) with self.autogen_context._within_batch(): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "batch_op.drop_constraint('fk_a_id', type_='foreignkey')" + "batch_op.drop_constraint('fk_a_id', type_='foreignkey')", ) def test_render_table_upgrade(self): m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('name', Unicode(255)), - Column("address_id", Integer, ForeignKey("address.id")), - Column("timestamp", DATETIME, server_default="NOW()"), - Column("amount", Numeric(5, 2)), - UniqueConstraint("name", name="uq_name"), - UniqueConstraint("timestamp"), - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("name", Unicode(255)), + Column("address_id", Integer, ForeignKey("address.id")), + Column("timestamp", DATETIME, server_default="NOW()"), + Column("amount", Numeric(5, 2)), + UniqueConstraint("name", name="uq_name"), + UniqueConstraint("timestamp"), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( @@ -666,16 +757,18 @@ class AutogenRenderTest(TestBase): "sa.PrimaryKeyConstraint('id')," "sa.UniqueConstraint('name', name='uq_name')," "sa.UniqueConstraint('timestamp')" - ")" + ")", ) def test_render_table_w_schema(self): m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('address.id')), - schema='foo' - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("q", Integer, ForeignKey("address.id")), + schema="foo", + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), @@ -685,84 +778,89 @@ class AutogenRenderTest(TestBase): "sa.ForeignKeyConstraint(['q'], ['address.id'], )," "sa.PrimaryKeyConstraint('id')," "schema='foo'" - ")" + ")", ) def test_render_table_w_system(self): m = MetaData() - t = Table('sometable', m, - Column('id', Integer, primary_key=True), - Column('xmin', Integer, system=True, nullable=False) - ) + t = Table( + "sometable", + m, + Column("id", Integer, primary_key=True), + Column("xmin", Integer, system=True, nullable=False), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('sometable'," "sa.Column('id', sa.Integer(), nullable=False)," "sa.Column('xmin', sa.Integer(), nullable=False, system=True)," - "sa.PrimaryKeyConstraint('id'))" + "sa.PrimaryKeyConstraint('id'))", ) def test_render_table_w_unicode_name(self): m = MetaData() - t = Table(compat.ue('\u0411\u0435\u0437'), m, - Column('id', Integer, primary_key=True), - ) + t = Table( + compat.ue("\u0411\u0435\u0437"), + m, + Column("id", Integer, primary_key=True), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table(%r," "sa.Column('id', sa.Integer(), nullable=False)," - "sa.PrimaryKeyConstraint('id'))" % compat.ue('\u0411\u0435\u0437') + "sa.PrimaryKeyConstraint('id'))" % compat.ue("\u0411\u0435\u0437"), ) def test_render_table_w_unicode_schema(self): m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - schema=compat.ue('\u0411\u0435\u0437') - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + schema=compat.ue("\u0411\u0435\u0437"), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('test'," "sa.Column('id', sa.Integer(), nullable=False)," "sa.PrimaryKeyConstraint('id')," - "schema=%r)" % compat.ue('\u0411\u0435\u0437') + "schema=%r)" % compat.ue("\u0411\u0435\u0437"), ) def test_render_table_w_unsupported_constraint(self): from sqlalchemy.sql.schema import ColumnCollectionConstraint class SomeCustomConstraint(ColumnCollectionConstraint): - __visit_name__ = 'some_custom' + __visit_name__ = "some_custom" m = MetaData() - t = Table( - 't', m, Column('id', Integer), - SomeCustomConstraint('id'), - ) + t = Table("t", m, Column("id", Integer), SomeCustomConstraint("id")) op_obj = ops.CreateTableOp.from_table(t) with assertions.expect_warnings( - "No renderer is established for object SomeCustomConstraint"): + "No renderer is established for object SomeCustomConstraint" + ): eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t'," "sa.Column('id', sa.Integer(), nullable=True)," "[Unknown Python object " - "SomeCustomConstraint(Column('id', Integer(), table=<t>))])" + "SomeCustomConstraint(Column('id', Integer(), table=<t>))])", ) @patch("alembic.autogenerate.render.MAX_PYTHON_ARGS", 3) def test_render_table_max_cols(self): m = MetaData() t = Table( - 'test', m, - Column('a', Integer), - Column('b', Integer), - Column('c', Integer), - Column('d', Integer), + "test", + m, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), + Column("d", Integer), ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( @@ -771,14 +869,15 @@ class AutogenRenderTest(TestBase): "*[sa.Column('a', sa.Integer(), nullable=True)," "sa.Column('b', sa.Integer(), nullable=True)," "sa.Column('c', sa.Integer(), nullable=True)," - "sa.Column('d', sa.Integer(), nullable=True)])" + "sa.Column('d', sa.Integer(), nullable=True)])", ) t2 = Table( - 'test2', m, - Column('a', Integer), - Column('b', Integer), - Column('c', Integer), + "test2", + m, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), ) op_obj = ops.CreateTableOp.from_table(t2) @@ -787,15 +886,17 @@ class AutogenRenderTest(TestBase): "op.create_table('test2'," "sa.Column('a', sa.Integer(), nullable=True)," "sa.Column('b', sa.Integer(), nullable=True)," - "sa.Column('c', sa.Integer(), nullable=True))" + "sa.Column('c', sa.Integer(), nullable=True))", ) def test_render_table_w_fk_schema(self): m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('foo.address.id')), - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("q", Integer, ForeignKey("foo.address.id")), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), @@ -804,20 +905,23 @@ class AutogenRenderTest(TestBase): "sa.Column('q', sa.Integer(), nullable=True)," "sa.ForeignKeyConstraint(['q'], ['foo.address.id'], )," "sa.PrimaryKeyConstraint('id')" - ")" + ")", ) def test_render_table_w_metadata_schema(self): m = MetaData(schema="foo") - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('address.id')), - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("q", Integer, ForeignKey("address.id")), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( re.sub( - r"u'", "'", - autogenerate.render_op_text(self.autogen_context, op_obj) + r"u'", + "'", + autogenerate.render_op_text(self.autogen_context, op_obj), ), "op.create_table('test'," "sa.Column('id', sa.Integer(), nullable=False)," @@ -825,15 +929,17 @@ class AutogenRenderTest(TestBase): "sa.ForeignKeyConstraint(['q'], ['foo.address.id'], )," "sa.PrimaryKeyConstraint('id')," "schema='foo'" - ")" + ")", ) def test_render_table_w_metadata_schema_override(self): m = MetaData(schema="foo") - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('bar.address.id')), - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("q", Integer, ForeignKey("bar.address.id")), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), @@ -843,16 +949,19 @@ class AutogenRenderTest(TestBase): "sa.ForeignKeyConstraint(['q'], ['bar.address.id'], )," "sa.PrimaryKeyConstraint('id')," "schema='foo'" - ")" + ")", ) def test_render_addtl_args(self): m = MetaData() - t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('bar.address.id')), - sqlite_autoincrement=True, mysql_engine="InnoDB" - ) + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("q", Integer, ForeignKey("bar.address.id")), + sqlite_autoincrement=True, + mysql_engine="InnoDB", + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), @@ -861,60 +970,58 @@ class AutogenRenderTest(TestBase): "sa.Column('q', sa.Integer(), nullable=True)," "sa.ForeignKeyConstraint(['q'], ['bar.address.id'], )," "sa.PrimaryKeyConstraint('id')," - "mysql_engine='InnoDB',sqlite_autoincrement=True)" + "mysql_engine='InnoDB',sqlite_autoincrement=True)", ) def test_render_drop_table(self): - op_obj = ops.DropTableOp.from_table( - Table("sometable", MetaData()) - ) + op_obj = ops.DropTableOp.from_table(Table("sometable", MetaData())) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_table('sometable')" + "op.drop_table('sometable')", ) def test_render_drop_table_w_schema(self): op_obj = ops.DropTableOp.from_table( - Table("sometable", MetaData(), schema='foo') + Table("sometable", MetaData(), schema="foo") ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_table('sometable', schema='foo')" + "op.drop_table('sometable', schema='foo')", ) def test_render_table_no_implicit_check(self): m = MetaData() - t = Table('test', m, Column('x', Boolean())) + t = Table("test", m, Column("x", Boolean())) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('test'," - "sa.Column('x', sa.Boolean(), nullable=True))" + "sa.Column('x', sa.Boolean(), nullable=True))", ) def test_render_pk_with_col_name_vs_col_key(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, key='y', primary_key=True)) + t1 = Table("t1", m, Column("x", Integer, key="y", primary_key=True)) op_obj = ops.CreateTableOp.from_table(t1) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t1'," "sa.Column('x', sa.Integer(), nullable=False)," - "sa.PrimaryKeyConstraint('x'))" + "sa.PrimaryKeyConstraint('x'))", ) def test_render_empty_pk_vs_nonempty_pk(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer)) - t2 = Table('t2', m, Column('x', Integer, primary_key=True)) + t1 = Table("t1", m, Column("x", Integer)) + t2 = Table("t2", m, Column("x", Integer, primary_key=True)) op_obj = ops.CreateTableOp.from_table(t1) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t1'," - "sa.Column('x', sa.Integer(), nullable=True))" + "sa.Column('x', sa.Integer(), nullable=True))", ) op_obj = ops.CreateTableOp.from_table(t2) @@ -922,16 +1029,18 @@ class AutogenRenderTest(TestBase): autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t2'," "sa.Column('x', sa.Integer(), nullable=False)," - "sa.PrimaryKeyConstraint('x'))" + "sa.PrimaryKeyConstraint('x'))", ) @config.requirements.fail_before_sqla_110 def test_render_table_w_autoincrement(self): m = MetaData() t = Table( - 'test', m, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True, autoincrement=True)) + "test", + m, + Column("id1", Integer, primary_key=True), + Column("id2", Integer, primary_key=True, autoincrement=True), + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), @@ -940,111 +1049,109 @@ class AutogenRenderTest(TestBase): "sa.Column('id2', sa.Integer(), autoincrement=True, " "nullable=False)," "sa.PrimaryKeyConstraint('id1', 'id2')" - ")" + ")", ) def test_render_add_column(self): op_obj = ops.AddColumnOp( - "foo", Column("x", Integer, server_default="5")) + "foo", Column("x", Integer, server_default="5") + ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.add_column('foo', sa.Column('x', sa.Integer(), " - "server_default='5', nullable=True))" + "server_default='5', nullable=True))", ) def test_render_add_column_system(self): # this would never actually happen since "system" columns # can't be added in any case. Howver it will render as # part of op.CreateTableOp. - op_obj = ops.AddColumnOp( - "foo", Column("xmin", Integer, system=True)) + op_obj = ops.AddColumnOp("foo", Column("xmin", Integer, system=True)) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.add_column('foo', sa.Column('xmin', sa.Integer(), " - "nullable=True, system=True))" + "nullable=True, system=True))", ) def test_render_add_column_w_schema(self): op_obj = ops.AddColumnOp( - "bar", Column("x", Integer, server_default="5"), - schema="foo") + "bar", Column("x", Integer, server_default="5"), schema="foo" + ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.add_column('bar', sa.Column('x', sa.Integer(), " - "server_default='5', nullable=True), schema='foo')" + "server_default='5', nullable=True), schema='foo')", ) def test_render_drop_column(self): op_obj = ops.DropColumnOp.from_column_and_tablename( - None, "foo", Column("x", Integer, server_default="5")) + None, "foo", Column("x", Integer, server_default="5") + ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_column('foo', 'x')" + "op.drop_column('foo', 'x')", ) def test_render_drop_column_w_schema(self): op_obj = ops.DropColumnOp.from_column_and_tablename( - "foo", "bar", Column("x", Integer, server_default="5")) + "foo", "bar", Column("x", Integer, server_default="5") + ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_column('bar', 'x', schema='foo')" + "op.drop_column('bar', 'x', schema='foo')", ) def test_render_quoted_server_default(self): eq_( autogenerate.render._render_server_default( "nextval('group_to_perm_group_to_perm_id_seq'::regclass)", - self.autogen_context), - '"nextval(\'group_to_perm_group_to_perm_id_seq\'::regclass)"' + self.autogen_context, + ), + "\"nextval('group_to_perm_group_to_perm_id_seq'::regclass)\"", ) def test_render_unicode_server_default(self): default = compat.ue( - '\u0411\u0435\u0437 ' - '\u043d\u0430\u0437\u0432\u0430\u043d\u0438\u044f' + "\u0411\u0435\u0437 " + "\u043d\u0430\u0437\u0432\u0430\u043d\u0438\u044f" ) - c = Column( - 'x', Unicode, - server_default=text(default) - ) + c = Column("x", Unicode, server_default=text(default)) eq_ignore_whitespace( autogenerate.render._render_server_default( c.server_default, self.autogen_context ), - "sa.text(%r)" % default + "sa.text(%r)" % default, ) def test_render_col_with_server_default(self): - c = Column('updated_at', TIMESTAMP(), - server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)', - nullable=False) - result = autogenerate.render._render_column( - c, self.autogen_context + c = Column( + "updated_at", + TIMESTAMP(), + server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)', + nullable=False, ) + result = autogenerate.render._render_column(c, self.autogen_context) eq_ignore_whitespace( result, - 'sa.Column(\'updated_at\', sa.TIMESTAMP(), ' - 'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', ' - 'nullable=False)' + "sa.Column('updated_at', sa.TIMESTAMP(), " + "server_default='TIMEZONE(\"utc\", CURRENT_TIMESTAMP)', " + "nullable=False)", ) def test_render_col_autoinc_false_mysql(self): - c = Column('some_key', Integer, primary_key=True, autoincrement=False) - Table('some_table', MetaData(), c) - result = autogenerate.render._render_column( - c, self.autogen_context - ) + c = Column("some_key", Integer, primary_key=True, autoincrement=False) + Table("some_table", MetaData(), c) + result = autogenerate.render._render_column(c, self.autogen_context) eq_ignore_whitespace( result, - 'sa.Column(\'some_key\', sa.Integer(), ' - 'autoincrement=False, ' - 'nullable=False)' + "sa.Column('some_key', sa.Integer(), " + "autoincrement=False, " + "nullable=False)", ) def test_render_custom(self): - class MySpecialType(Integer): pass @@ -1065,17 +1172,18 @@ class AutogenRenderTest(TestBase): return "render:%s" % type_ self.autogen_context.opts.update( - render_item=render, - alembic_module_prefix='sa.' + render_item=render, alembic_module_prefix="sa." ) - t = Table('t', MetaData(), - Column('x', Integer), - Column('y', Integer), - Column('q', MySpecialType()), - PrimaryKeyConstraint('x'), - ForeignKeyConstraint(['x'], ['y']) - ) + t = Table( + "t", + MetaData(), + Column("x", Integer), + Column("y", Integer), + Column("q", MySpecialType()), + PrimaryKeyConstraint("x"), + ForeignKeyConstraint(["x"], ["y"]), + ) op_obj = ops.CreateTableOp.from_table(t) result = autogenerate.render_op_text(self.autogen_context, op_obj) eq_ignore_whitespace( @@ -1083,89 +1191,97 @@ class AutogenRenderTest(TestBase): "sa.create_table('t'," "col(x)," "sa.Column('q', MySpecialType(), nullable=True)," - "render:primary_key)" + "render:primary_key)", ) eq_( self.autogen_context.imports, - set(['from mypackage import MySpecialType']) + set(["from mypackage import MySpecialType"]), ) def test_render_modify_type(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", - modify_type=CHAR(10), existing_type=CHAR(20) + "sometable", + "somecolumn", + modify_type=CHAR(10), + existing_type=CHAR(20), ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " - "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))" + "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))", ) def test_render_modify_type_w_schema(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", - modify_type=CHAR(10), existing_type=CHAR(20), - schema='foo' + "sometable", + "somecolumn", + modify_type=CHAR(10), + existing_type=CHAR(20), + schema="foo", ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), " - "schema='foo')" + "schema='foo')", ) def test_render_modify_nullable(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", + "sometable", + "somecolumn", existing_type=Integer(), - modify_nullable=True + modify_nullable=True, ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " - "existing_type=sa.Integer(), nullable=True)" + "existing_type=sa.Integer(), nullable=True)", ) def test_render_modify_nullable_no_existing_type(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", - modify_nullable=True + "sometable", "somecolumn", modify_nullable=True ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.alter_column('sometable', 'somecolumn', nullable=True)" + "op.alter_column('sometable', 'somecolumn', nullable=True)", ) def test_render_modify_nullable_w_schema(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", + "sometable", + "somecolumn", existing_type=Integer(), - modify_nullable=True, schema='foo' + modify_nullable=True, + schema="foo", ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " - "existing_type=sa.Integer(), nullable=True, schema='foo')" + "existing_type=sa.Integer(), nullable=True, schema='foo')", ) def test_render_modify_type_w_autoincrement(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", - modify_type=Integer(), existing_type=BigInteger(), - autoincrement=True + "sometable", + "somecolumn", + modify_type=Integer(), + existing_type=BigInteger(), + autoincrement=True, ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.BigInteger(), type_=sa.Integer(), " - "autoincrement=True)" + "autoincrement=True)", ) def test_render_fk_constraint_kwarg(self): m = MetaData() - t1 = Table('t', m, Column('c', Integer)) - t2 = Table('t2', m, Column('c_rem', Integer)) + t1 = Table("t", m, Column("c", Integer)) + t2 = Table("t2", m, Column("c_rem", Integer)) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE") @@ -1174,239 +1290,274 @@ class AutogenRenderTest(TestBase): eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context)), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')" + fk, self.autogen_context + ), + ), + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')", ) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], ondelete="CASCADE") eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context)), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')" + fk, self.autogen_context + ), + ), + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')", ) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], deferrable=True) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context), + fk, self.autogen_context + ), ), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)" + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)", ) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], initially="XYZ") eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context) + fk, self.autogen_context + ), ), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')" + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')", ) fk = ForeignKeyConstraint( - [t1.c.c], [t2.c.c_rem], - initially="XYZ", ondelete="CASCADE", deferrable=True) + [t1.c.c], + [t2.c.c_rem], + initially="XYZ", + ondelete="CASCADE", + deferrable=True, + ) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context) + fk, self.autogen_context + ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], " - "ondelete='CASCADE', initially='XYZ', deferrable=True)" + "ondelete='CASCADE', initially='XYZ', deferrable=True)", ) def test_render_fk_constraint_resolve_key(self): m = MetaData() - t1 = Table('t', m, Column('c', Integer)) - t2 = Table('t2', m, Column('c_rem', Integer, key='c_remkey')) + t1 = Table("t", m, Column("c", Integer)) + t2 = Table("t2", m, Column("c_rem", Integer, key="c_remkey")) - fk = ForeignKeyConstraint(['c'], ['t2.c_remkey']) + fk = ForeignKeyConstraint(["c"], ["t2.c_remkey"]) t1.append_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context)), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )" + fk, self.autogen_context + ), + ), + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )", ) def test_render_fk_constraint_bad_table_resolve(self): m = MetaData() - t1 = Table('t', m, Column('c', Integer)) - t2 = Table('t2', m, Column('c_rem', Integer)) + t1 = Table("t", m, Column("c", Integer)) + t2 = Table("t2", m, Column("c_rem", Integer)) - fk = ForeignKeyConstraint(['c'], ['t2.nonexistent']) + fk = ForeignKeyConstraint(["c"], ["t2.nonexistent"]) t1.append_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context)), - "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )" + fk, self.autogen_context + ), + ), + "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )", ) def test_render_fk_constraint_bad_table_resolve_dont_get_confused(self): m = MetaData() - t1 = Table('t', m, Column('c', Integer)) + t1 = Table("t", m, Column("c", Integer)) t2 = Table( - 't2', m, - Column('c_rem', Integer, key='cr_key'), - Column('c_rem_2', Integer, key='c_rem') - + "t2", + m, + Column("c_rem", Integer, key="cr_key"), + Column("c_rem_2", Integer, key="c_rem"), ) - fk = ForeignKeyConstraint(['c'], ['t2.c_rem'], link_to_name=True) + fk = ForeignKeyConstraint(["c"], ["t2.c_rem"], link_to_name=True) t1.append_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context)), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )" + fk, self.autogen_context + ), + ), + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )", ) def test_render_fk_constraint_link_to_name(self): m = MetaData() - t1 = Table('t', m, Column('c', Integer)) - t2 = Table('t2', m, Column('c_rem', Integer, key='c_remkey')) + t1 = Table("t", m, Column("c", Integer)) + t2 = Table("t2", m, Column("c_rem", Integer, key="c_remkey")) - fk = ForeignKeyConstraint(['c'], ['t2.c_rem'], link_to_name=True) + fk = ForeignKeyConstraint(["c"], ["t2.c_rem"], link_to_name=True) t1.append_constraint(fk) eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context)), - "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )" + fk, self.autogen_context + ), + ), + "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )", ) def test_render_fk_constraint_use_alter(self): m = MetaData() - Table('t', m, Column('c', Integer)) + Table("t", m, Column("c", Integer)) t2 = Table( - 't2', m, + "t2", + m, Column( - 'c_rem', Integer, - ForeignKey('t.c', name="fk1", use_alter=True))) + "c_rem", Integer, ForeignKey("t.c", name="fk1", use_alter=True) + ), + ) const = list(t2.foreign_keys)[0].constraint eq_ignore_whitespace( autogenerate.render._render_constraint( - const, self.autogen_context), + const, self.autogen_context + ), "sa.ForeignKeyConstraint(['c_rem'], ['t.c'], " - "name='fk1', use_alter=True)" + "name='fk1', use_alter=True)", ) def test_render_fk_constraint_w_metadata_schema(self): m = MetaData(schema="foo") - t1 = Table('t', m, Column('c', Integer)) - t2 = Table('t2', m, Column('c_rem', Integer)) + t1 = Table("t", m, Column("c", Integer)) + t2 = Table("t2", m, Column("c_rem", Integer)) fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE") eq_ignore_whitespace( re.sub( - r"u'", "'", + r"u'", + "'", autogenerate.render._render_constraint( - fk, self.autogen_context) + fk, self.autogen_context + ), ), "sa.ForeignKeyConstraint(['c'], ['foo.t2.c_rem'], " - "onupdate='CASCADE')" + "onupdate='CASCADE')", ) def test_render_check_constraint_literal(self): eq_ignore_whitespace( autogenerate.render._render_check_constraint( - CheckConstraint("im a constraint", name='cc1'), - self.autogen_context + CheckConstraint("im a constraint", name="cc1"), + self.autogen_context, ), - "sa.CheckConstraint(!U'im a constraint', name='cc1')" + "sa.CheckConstraint(!U'im a constraint', name='cc1')", ) def test_render_check_constraint_sqlexpr(self): - c = column('c') - five = literal_column('5') - ten = literal_column('10') + c = column("c") + five = literal_column("5") + ten = literal_column("10") eq_ignore_whitespace( autogenerate.render._render_check_constraint( - CheckConstraint(and_(c > five, c < ten)), - self.autogen_context + CheckConstraint(and_(c > five, c < ten)), self.autogen_context ), - "sa.CheckConstraint(!U'c > 5 AND c < 10')" + "sa.CheckConstraint(!U'c > 5 AND c < 10')", ) def test_render_check_constraint_literal_binds(self): - c = column('c') + c = column("c") eq_ignore_whitespace( autogenerate.render._render_check_constraint( - CheckConstraint(and_(c > 5, c < 10)), - self.autogen_context + CheckConstraint(and_(c > 5, c < 10)), self.autogen_context ), - "sa.CheckConstraint(!U'c > 5 AND c < 10')" + "sa.CheckConstraint(!U'c > 5 AND c < 10')", ) def test_render_unique_constraint_opts(self): m = MetaData() - t = Table('t', m, Column('c', Integer)) + t = Table("t", m, Column("c", Integer)) eq_ignore_whitespace( autogenerate.render._render_unique_constraint( - UniqueConstraint(t.c.c, name='uq_1', deferrable='XYZ'), - self.autogen_context + UniqueConstraint(t.c.c, name="uq_1", deferrable="XYZ"), + self.autogen_context, ), - "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')" + "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')", ) def test_add_unique_constraint_unicode_schema(self): m = MetaData() t = Table( - 't', m, Column('c', Integer), - schema=compat.ue('\u0411\u0435\u0437') + "t", + m, + Column("c", Integer), + schema=compat.ue("\u0411\u0435\u0437"), ) op_obj = ops.AddConstraintOp.from_constraint(UniqueConstraint(t.c.c)) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_unique_constraint(None, 't', ['c'], " - "schema=%r)" % compat.ue('\u0411\u0435\u0437') + "schema=%r)" % compat.ue("\u0411\u0435\u0437"), ) def test_render_modify_nullable_w_default(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", + "sometable", + "somecolumn", existing_type=Integer(), existing_server_default="5", - modify_nullable=True + modify_nullable=True, ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True, " - "existing_server_default='5')" + "existing_server_default='5')", ) def test_render_enum(self): eq_ignore_whitespace( autogenerate.render._repr_type( Enum("one", "two", "three", name="myenum"), - self.autogen_context), - "sa.Enum('one', 'two', 'three', name='myenum')" + self.autogen_context, + ), + "sa.Enum('one', 'two', 'three', name='myenum')", ) eq_ignore_whitespace( autogenerate.render._repr_type( - Enum("one", "two", "three"), - self.autogen_context), - "sa.Enum('one', 'two', 'three')" + Enum("one", "two", "three"), self.autogen_context + ), + "sa.Enum('one', 'two', 'three')", ) @config.requirements.sqlalchemy_099 @@ -1414,15 +1565,16 @@ class AutogenRenderTest(TestBase): eq_ignore_whitespace( autogenerate.render._repr_type( Enum("one", "two", "three", native_enum=False), - self.autogen_context), - "sa.Enum('one', 'two', 'three', native_enum=False)" + self.autogen_context, + ), + "sa.Enum('one', 'two', 'three', native_enum=False)", ) def test_repr_plain_sqla_type(self): type_ = Integer() eq_ignore_whitespace( autogenerate.render._repr_type(type_, self.autogen_context), - "sa.Integer()" + "sa.Integer()", ) @config.requirements.sqlalchemy_110 @@ -1430,24 +1582,27 @@ class AutogenRenderTest(TestBase): eq_ignore_whitespace( autogenerate.render._repr_type( - types.ARRAY(Integer), self.autogen_context), - "sa.ARRAY(sa.Integer())" + types.ARRAY(Integer), self.autogen_context + ), + "sa.ARRAY(sa.Integer())", ) eq_ignore_whitespace( autogenerate.render._repr_type( - types.ARRAY(DateTime(timezone=True)), self.autogen_context), - "sa.ARRAY(sa.DateTime(timezone=True))" + types.ARRAY(DateTime(timezone=True)), self.autogen_context + ), + "sa.ARRAY(sa.DateTime(timezone=True))", ) @config.requirements.sqlalchemy_110 def test_render_array_no_context(self): - uo = ops.UpgradeOps(ops=[ - ops.CreateTableOp( - "sometable", - [Column('x', types.ARRAY(Integer))] - ) - ]) + uo = ops.UpgradeOps( + ops=[ + ops.CreateTableOp( + "sometable", [Column("x", types.ARRAY(Integer))] + ) + ] + ) eq_( autogenerate.render_python_code(uo), @@ -1455,11 +1610,11 @@ class AutogenRenderTest(TestBase): " op.create_table('sometable',\n" " sa.Column('x', sa.ARRAY(sa.Integer()), nullable=True)\n" " )\n" - " # ### end Alembic commands ###" + " # ### end Alembic commands ###", ) def test_repr_custom_type_w_sqla_prefix(self): - self.autogen_context.opts['user_module_prefix'] = None + self.autogen_context.opts["user_module_prefix"] = None class MyType(UserDefinedType): pass @@ -1470,283 +1625,280 @@ class AutogenRenderTest(TestBase): eq_ignore_whitespace( autogenerate.render._repr_type(type_, self.autogen_context), - "sqlalchemy_util.types.MyType()" + "sqlalchemy_util.types.MyType()", ) def test_repr_user_type_user_prefix_None(self): class MyType(UserDefinedType): - def get_col_spec(self): return "MYTYPE" type_ = MyType() - self.autogen_context.opts['user_module_prefix'] = None + self.autogen_context.opts["user_module_prefix"] = None eq_ignore_whitespace( autogenerate.render._repr_type(type_, self.autogen_context), - "tests.test_autogen_render.MyType()" + "tests.test_autogen_render.MyType()", ) def test_repr_user_type_user_prefix_present(self): from sqlalchemy.types import UserDefinedType class MyType(UserDefinedType): - def get_col_spec(self): return "MYTYPE" type_ = MyType() - self.autogen_context.opts['user_module_prefix'] = 'user.' + self.autogen_context.opts["user_module_prefix"] = "user." eq_ignore_whitespace( autogenerate.render._repr_type(type_, self.autogen_context), - "user.MyType()" + "user.MyType()", ) def test_repr_dialect_type(self): from sqlalchemy.dialects.mysql import VARCHAR - type_ = VARCHAR(20, charset='utf8', national=True) + type_ = VARCHAR(20, charset="utf8", national=True) - self.autogen_context.opts['user_module_prefix'] = None + self.autogen_context.opts["user_module_prefix"] = None eq_ignore_whitespace( autogenerate.render._repr_type(type_, self.autogen_context), - "mysql.VARCHAR(charset='utf8', national=True, length=20)" + "mysql.VARCHAR(charset='utf8', national=True, length=20)", + ) + eq_( + self.autogen_context.imports, + set(["from sqlalchemy.dialects import mysql"]), ) - eq_(self.autogen_context.imports, - set(['from sqlalchemy.dialects import mysql']) - ) def test_render_server_default_text(self): c = Column( - 'updated_at', TIMESTAMP(), - server_default=text('now()'), - nullable=False) - result = autogenerate.render._render_column( - c, self.autogen_context + "updated_at", + TIMESTAMP(), + server_default=text("now()"), + nullable=False, ) + result = autogenerate.render._render_column(c, self.autogen_context) eq_ignore_whitespace( result, - 'sa.Column(\'updated_at\', sa.TIMESTAMP(), ' - 'server_default=sa.text(!U\'now()\'), ' - 'nullable=False)' + "sa.Column('updated_at', sa.TIMESTAMP(), " + "server_default=sa.text(!U'now()'), " + "nullable=False)", ) def test_render_server_default_non_native_boolean(self): c = Column( - 'updated_at', Boolean(), - server_default=false(), - nullable=False) - - result = autogenerate.render._render_column( - c, self.autogen_context + "updated_at", Boolean(), server_default=false(), nullable=False ) + + result = autogenerate.render._render_column(c, self.autogen_context) eq_ignore_whitespace( result, - 'sa.Column(\'updated_at\', sa.Boolean(), ' - 'server_default=sa.text(!U\'0\'), ' - 'nullable=False)' + "sa.Column('updated_at', sa.Boolean(), " + "server_default=sa.text(!U'0'), " + "nullable=False)", ) def test_render_server_default_func(self): c = Column( - 'updated_at', TIMESTAMP(), + "updated_at", + TIMESTAMP(), server_default=func.now(), - nullable=False) - result = autogenerate.render._render_column( - c, self.autogen_context + nullable=False, ) + result = autogenerate.render._render_column(c, self.autogen_context) eq_ignore_whitespace( result, - 'sa.Column(\'updated_at\', sa.TIMESTAMP(), ' - 'server_default=sa.text(!U\'now()\'), ' - 'nullable=False)' + "sa.Column('updated_at', sa.TIMESTAMP(), " + "server_default=sa.text(!U'now()'), " + "nullable=False)", ) def test_render_server_default_int(self): - c = Column( - 'value', Integer, - server_default="0") - result = autogenerate.render._render_column( - c, self.autogen_context - ) + c = Column("value", Integer, server_default="0") + result = autogenerate.render._render_column(c, self.autogen_context) eq_( result, "sa.Column('value', sa.Integer(), " - "server_default='0', nullable=True)" + "server_default='0', nullable=True)", ) def test_render_modify_reflected_int_server_default(self): op_obj = ops.AlterColumnOp( - "sometable", "somecolumn", + "sometable", + "somecolumn", existing_type=Integer(), existing_server_default=DefaultClause(text("5")), - modify_nullable=True + modify_nullable=True, ) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True, " - "existing_server_default=sa.text(!U'5'))" + "existing_server_default=sa.text(!U'5'))", ) def test_render_executesql_plaintext(self): op_obj = ops.ExecuteSQLOp("drop table foo") eq_( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.execute('drop table foo')" + "op.execute('drop table foo')", ) def test_render_executesql_sqlexpr_notimplemented(self): - sql = table('x', column('q')).insert() + sql = table("x", column("q")).insert() op_obj = ops.ExecuteSQLOp(sql) assert_raises( NotImplementedError, - autogenerate.render_op_text, self.autogen_context, op_obj + autogenerate.render_op_text, + self.autogen_context, + op_obj, ) class RenderNamingConventionTest(TestBase): - __requires__ = ('sqlalchemy_094',) + __requires__ = ("sqlalchemy_094",) def setUp(self): convention = { - "ix": 'ix_%(custom)s_%(column_0_label)s', + "ix": "ix_%(custom)s_%(column_0_label)s", "uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s", "ck": "ck_%(custom)s_%(table_name)s", "fk": "fk_%(custom)s_%(table_name)s_" "%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(custom)s_%(table_name)s", - "custom": lambda const, table: "ct" + "custom": lambda const, table: "ct", } - self.metadata = MetaData( - naming_convention=convention - ) + self.metadata = MetaData(naming_convention=convention) ctx_opts = { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'target_metadata': MetaData() + "sqlalchemy_module_prefix": "sa.", + "alembic_module_prefix": "op.", + "target_metadata": MetaData(), } context = MigrationContext.configure( - dialect_name="postgresql", - opts=ctx_opts + dialect_name="postgresql", opts=ctx_opts ) self.autogen_context = api.AutogenContext(context) def test_schema_type_boolean(self): - t = Table('t', self.metadata, Column('c', Boolean(name='xyz'))) + t = Table("t", self.metadata, Column("c", Boolean(name="xyz"))) op_obj = ops.AddColumnOp.from_column(t.c.c) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.add_column('t', " - "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))" + "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))", ) def test_explicit_unique_constraint(self): - t = Table('t', self.metadata, Column('c', Integer)) + t = Table("t", self.metadata, Column("c", Integer)) eq_ignore_whitespace( autogenerate.render._render_unique_constraint( - UniqueConstraint(t.c.c, deferrable='XYZ'), - self.autogen_context + UniqueConstraint(t.c.c, deferrable="XYZ"), self.autogen_context ), "sa.UniqueConstraint('c', deferrable='XYZ', " - "name=op.f('uq_ct_t_c'))" + "name=op.f('uq_ct_t_c'))", ) def test_explicit_named_unique_constraint(self): - t = Table('t', self.metadata, Column('c', Integer)) + t = Table("t", self.metadata, Column("c", Integer)) eq_ignore_whitespace( autogenerate.render._render_unique_constraint( - UniqueConstraint(t.c.c, name='q'), - self.autogen_context + UniqueConstraint(t.c.c, name="q"), self.autogen_context ), - "sa.UniqueConstraint('c', name='q')" + "sa.UniqueConstraint('c', name='q')", ) def test_render_add_index(self): - t = Table('test', self.metadata, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + t = Table( + "test", + self.metadata, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) idx = Index(None, t.c.active, t.c.code) op_obj = ops.CreateIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index(op.f('ix_ct_test_active'), 'test', " - "['active', 'code'], unique=False)" + "['active', 'code'], unique=False)", ) def test_render_drop_index(self): - t = Table('test', self.metadata, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + t = Table( + "test", + self.metadata, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + ) idx = Index(None, t.c.active, t.c.code) op_obj = ops.DropIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), - "op.drop_index(op.f('ix_ct_test_active'), table_name='test')" + "op.drop_index(op.f('ix_ct_test_active'), table_name='test')", ) def test_render_add_index_schema(self): - t = Table('test', self.metadata, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) + t = Table( + "test", + self.metadata, + Column("id", Integer, primary_key=True), + Column("active", Boolean()), + Column("code", String(255)), + schema="CamelSchema", + ) idx = Index(None, t.c.active, t.c.code) op_obj = ops.CreateIndexOp.from_index(idx) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_index(op.f('ix_ct_CamelSchema_test_active'), 'test', " - "['active', 'code'], unique=False, schema='CamelSchema')" + "['active', 'code'], unique=False, schema='CamelSchema')", ) def test_implicit_unique_constraint(self): - t = Table('t', self.metadata, Column('c', Integer, unique=True)) + t = Table("t", self.metadata, Column("c", Integer, unique=True)) uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0] eq_ignore_whitespace( - autogenerate.render._render_unique_constraint(uq, - self.autogen_context - ), - "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))" + autogenerate.render._render_unique_constraint( + uq, self.autogen_context + ), + "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))", ) def test_inline_pk_constraint(self): - t = Table('t', self.metadata, Column('c', Integer, primary_key=True)) + t = Table("t", self.metadata, Column("c", Integer, primary_key=True)) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t',sa.Column('c', sa.Integer(), nullable=False)," - "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))" + "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))", ) def test_inline_ck_constraint(self): t = Table( - 't', self.metadata, Column('c', Integer), CheckConstraint("c > 5")) + "t", self.metadata, Column("c", Integer), CheckConstraint("c > 5") + ) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t',sa.Column('c', sa.Integer(), nullable=True)," - "sa.CheckConstraint(!U'c > 5', name=op.f('ck_ct_t')))" + "sa.CheckConstraint(!U'c > 5', name=op.f('ck_ct_t')))", ) def test_inline_fk(self): - t = Table('t', self.metadata, Column('c', Integer, ForeignKey('q.id'))) + t = Table("t", self.metadata, Column("c", Integer, ForeignKey("q.id"))) op_obj = ops.CreateTableOp.from_table(t) eq_ignore_whitespace( autogenerate.render_op_text(self.autogen_context, op_obj), "op.create_table('t',sa.Column('c', sa.Integer(), nullable=True)," "sa.ForeignKeyConstraint(['c'], ['q.id'], " - "name=op.f('fk_ct_t_c_q')))" + "name=op.f('fk_ct_t_c_q')))", ) def test_render_check_constraint_renamed(self): @@ -1760,31 +1912,31 @@ class RenderNamingConventionTest(TestBase): used. """ - m1 = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s"}) + m1 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) ck = CheckConstraint("im a constraint", name="cc1") - Table('t', m1, Column('x'), ck) + Table("t", m1, Column("x"), ck) eq_ignore_whitespace( autogenerate.render._render_check_constraint( - ck, - self.autogen_context + ck, self.autogen_context ), - "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))" + "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))", ) def test_create_table_plus_add_index_in_modify(self): - uo = ops.UpgradeOps(ops=[ - ops.CreateTableOp( - "sometable", - [Column('x', Integer), Column('y', Integer)] - ), - ops.ModifyTableOps( - "sometable", ops=[ - ops.CreateIndexOp('ix1', 'sometable', ['x', 'y']) - ] - ) - ]) + uo = ops.UpgradeOps( + ops=[ + ops.CreateTableOp( + "sometable", [Column("x", Integer), Column("y", Integer)] + ), + ops.ModifyTableOps( + "sometable", + ops=[ops.CreateIndexOp("ix1", "sometable", ["x", "y"])], + ), + ] + ) eq_( autogenerate.render_python_code(uo, render_as_batch=True), @@ -1797,5 +1949,5 @@ class RenderNamingConventionTest(TestBase): "as batch_op:\n" " batch_op.create_index(" "'ix1', ['x', 'y'], unique=False)\n\n" - " # ### end Alembic commands ###" + " # ### end Alembic commands ###", ) diff --git a/tests/test_batch.py b/tests/test_batch.py index 99605d0..2a7c52e 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -11,9 +11,22 @@ from alembic.operations.batch import ApplyBatchImpl from alembic.runtime.migration import MigrationContext -from sqlalchemy import Integer, Table, Column, String, MetaData, ForeignKey, \ - UniqueConstraint, ForeignKeyConstraint, Index, Boolean, CheckConstraint, \ - Enum, DateTime, PrimaryKeyConstraint +from sqlalchemy import ( + Integer, + Table, + Column, + String, + MetaData, + ForeignKey, + UniqueConstraint, + ForeignKeyConstraint, + Index, + Boolean, + CheckConstraint, + Enum, + DateTime, + PrimaryKeyConstraint, +) from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql import column, text, select from sqlalchemy.schema import CreateTable, CreateIndex @@ -21,84 +34,91 @@ from sqlalchemy import exc class BatchApplyTest(TestBase): - def setUp(self): self.op = Operations(mock.Mock(opts={})) def _simple_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('x', String(10)), - Column('y', Integer) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("x", String(10)), + Column("y", Integer), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _uq_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('x', String()), - Column('y', Integer), - UniqueConstraint('y', name='uq1') + "tname", + m, + Column("id", Integer, primary_key=True), + Column("x", String()), + Column("y", Integer), + UniqueConstraint("y", name="uq1"), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _ix_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('x', String()), - Column('y', Integer), - Index('ix1', 'y') + "tname", + m, + Column("id", Integer, primary_key=True), + Column("x", String()), + Column("y", Integer), + Index("ix1", "y"), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _pk_fixture(self): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer), - Column('x', String()), - Column('y', Integer), - PrimaryKeyConstraint('id', name="mypk") + "tname", + m, + Column("id", Integer), + Column("x", String()), + Column("y", Integer), + PrimaryKeyConstraint("id", name="mypk"), ) return ApplyBatchImpl(t, (), {}, False) def _literal_ck_fixture( - self, copy_from=None, table_args=(), table_kwargs={}): + self, copy_from=None, table_args=(), table_kwargs={} + ): m = MetaData() if copy_from is not None: t = copy_from else: t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('email', String()), - CheckConstraint("email LIKE '%@%'") + "tname", + m, + Column("id", Integer, primary_key=True), + Column("email", String()), + CheckConstraint("email LIKE '%@%'"), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _sql_ck_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('email', String()) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("email", String()), ) - t.append_constraint(CheckConstraint(t.c.email.like('%@%'))) + t.append_constraint(CheckConstraint(t.c.email.like("%@%"))) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _fk_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('email', String()), - Column('user_id', Integer, ForeignKey('user.id')) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("email", String()), + Column("user_id", Integer, ForeignKey("user.id")), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) @@ -110,93 +130,108 @@ class BatchApplyTest(TestBase): schemaarg = "" t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('email', String()), - Column('user_id_1', Integer, ForeignKey('%suser.id' % schemaarg)), - Column('user_id_2', Integer, ForeignKey('%suser.id' % schemaarg)), - Column('user_id_3', Integer), - Column('user_id_version', Integer), + "tname", + m, + Column("id", Integer, primary_key=True), + Column("email", String()), + Column("user_id_1", Integer, ForeignKey("%suser.id" % schemaarg)), + Column("user_id_2", Integer, ForeignKey("%suser.id" % schemaarg)), + Column("user_id_3", Integer), + Column("user_id_version", Integer), ForeignKeyConstraint( - ['user_id_3', 'user_id_version'], - ['%suser.id' % schemaarg, '%suser.id_version' % schemaarg]), - schema=schema + ["user_id_3", "user_id_version"], + ["%suser.id" % schemaarg, "%suser.id_version" % schemaarg], + ), + schema=schema, ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _named_fk_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('email', String()), - Column('user_id', Integer, ForeignKey('user.id', name='ufk')) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("email", String()), + Column("user_id", Integer, ForeignKey("user.id", name="ufk")), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _selfref_fk_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('tname.id')), - Column('data', String) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("tname.id")), + Column("data", String), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _boolean_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('flag', Boolean) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("flag", Boolean), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _boolean_no_ck_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('flag', Boolean(create_constraint=False)) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("flag", Boolean(create_constraint=False)), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _enum_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('thing', Enum('a', 'b', 'c')) + "tname", + m, + Column("id", Integer, primary_key=True), + Column("thing", Enum("a", "b", "c")), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) def _server_default_fixture(self, table_args=(), table_kwargs={}): m = MetaData() t = Table( - 'tname', m, - Column('id', Integer, primary_key=True), - Column('thing', String(), server_default='') + "tname", + m, + Column("id", Integer, primary_key=True), + Column("thing", String(), server_default=""), ) return ApplyBatchImpl(t, table_args, table_kwargs, False) - def _assert_impl(self, impl, colnames=None, - ddl_contains=None, ddl_not_contains=None, - dialect='default', schema=None): + def _assert_impl( + self, + impl, + colnames=None, + ddl_contains=None, + ddl_not_contains=None, + dialect="default", + schema=None, + ): context = op_fixture(dialect=dialect) impl._create(context.impl) if colnames is None: - colnames = ['id', 'x', 'y'] + colnames = ["id", "x", "y"] eq_(impl.new_table.c.keys(), colnames) pk_cols = [col for col in impl.new_table.c if col.primary_key] eq_(list(impl.new_table.primary_key), pk_cols) create_stmt = str( - CreateTable(impl.new_table).compile(dialect=context.dialect)) - create_stmt = re.sub(r'[\n\t]', '', create_stmt) + CreateTable(impl.new_table).compile(dialect=context.dialect) + ) + create_stmt = re.sub(r"[\n\t]", "", create_stmt) idx_stmt = "" for idx in impl.indexes.values(): @@ -205,17 +240,16 @@ class BatchApplyTest(TestBase): impl.new_table.name = impl.table.name idx_stmt += str(CreateIndex(idx).compile(dialect=context.dialect)) impl.new_table.name = ApplyBatchImpl._calc_temp_name( - impl.table.name) - idx_stmt = re.sub(r'[\n\t]', '', idx_stmt) + impl.table.name + ) + idx_stmt = re.sub(r"[\n\t]", "", idx_stmt) if ddl_contains: assert ddl_contains in create_stmt + idx_stmt if ddl_not_contains: assert ddl_not_contains not in create_stmt + idx_stmt - expected = [ - create_stmt, - ] + expected = [create_stmt] if schema: args = {"schema": "%s." % schema} @@ -224,32 +258,40 @@ class BatchApplyTest(TestBase): args["temp_name"] = impl.new_table.name - args['colnames'] = ", ".join([ - impl.new_table.c[name].name - for name in colnames - if name in impl.table.c]) + args["colnames"] = ", ".join( + [ + impl.new_table.c[name].name + for name in colnames + if name in impl.table.c + ] + ) - args['tname_colnames'] = ", ".join( - "CAST(%(schema)stname.%(name)s AS %(type)s) AS anon_1" % { - 'schema': args['schema'], - 'name': name, - 'type': impl.new_table.c[name].type + args["tname_colnames"] = ", ".join( + "CAST(%(schema)stname.%(name)s AS %(type)s) AS anon_1" + % { + "schema": args["schema"], + "name": name, + "type": impl.new_table.c[name].type, } if ( impl.new_table.c[name].type._type_affinity - is not impl.table.c[name].type._type_affinity) - else "%(schema)stname.%(name)s" % { - 'schema': args['schema'], 'name': name} - for name in colnames if name in impl.table.c - ) - - expected.extend([ - 'INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) ' - 'SELECT %(tname_colnames)s FROM %(schema)stname' % args, - 'DROP TABLE %(schema)stname' % args, - 'ALTER TABLE %(schema)s%(temp_name)s ' - 'RENAME TO %(schema)stname' % args - ]) + is not impl.table.c[name].type._type_affinity + ) + else "%(schema)stname.%(name)s" + % {"schema": args["schema"], "name": name} + for name in colnames + if name in impl.table.c + ) + + expected.extend( + [ + "INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) " + "SELECT %(tname_colnames)s FROM %(schema)stname" % args, + "DROP TABLE %(schema)stname" % args, + "ALTER TABLE %(schema)s%(temp_name)s " + "RENAME TO %(schema)stname" % args, + ] + ) if idx_stmt: expected.append(idx_stmt) context.assert_(*expected) @@ -257,36 +299,42 @@ class BatchApplyTest(TestBase): def test_change_type(self): impl = self._simple_fixture() - impl.alter_column('tname', 'x', type_=String) + impl.alter_column("tname", "x", type_=String) new_table = self._assert_impl(impl) assert new_table.c.x.type._type_affinity is String def test_rename_col(self): impl = self._simple_fixture() - impl.alter_column('tname', 'x', name='q') + impl.alter_column("tname", "x", name="q") new_table = self._assert_impl(impl) - eq_(new_table.c.x.name, 'q') + eq_(new_table.c.x.name, "q") def test_rename_col_boolean(self): impl = self._boolean_fixture() - impl.alter_column('tname', 'flag', name='bflag') + impl.alter_column("tname", "flag", name="bflag") new_table = self._assert_impl( - impl, ddl_contains="CHECK (bflag IN (0, 1)", - colnames=["id", "flag"]) - eq_(new_table.c.flag.name, 'bflag') + impl, + ddl_contains="CHECK (bflag IN (0, 1)", + colnames=["id", "flag"], + ) + eq_(new_table.c.flag.name, "bflag") eq_( - len([ - const for const - in new_table.constraints - if isinstance(const, CheckConstraint)]), - 1) + len( + [ + const + for const in new_table.constraints + if isinstance(const, CheckConstraint) + ] + ), + 1, + ) def test_change_type_schematype_to_non(self): impl = self._boolean_fixture() - impl.alter_column('tname', 'flag', type_=Integer) + impl.alter_column("tname", "flag", type_=Integer) new_table = self._assert_impl( - impl, colnames=['id', 'flag'], - ddl_not_contains="CHECK") + impl, colnames=["id", "flag"], ddl_not_contains="CHECK" + ) assert new_table.c.flag.type._type_affinity is Integer # NOTE: we can't do test_change_type_non_to_schematype @@ -295,254 +343,310 @@ class BatchApplyTest(TestBase): def test_rename_col_boolean_no_ck(self): impl = self._boolean_no_ck_fixture() - impl.alter_column('tname', 'flag', name='bflag') + impl.alter_column("tname", "flag", name="bflag") new_table = self._assert_impl( - impl, ddl_not_contains="CHECK", - colnames=["id", "flag"]) - eq_(new_table.c.flag.name, 'bflag') + impl, ddl_not_contains="CHECK", colnames=["id", "flag"] + ) + eq_(new_table.c.flag.name, "bflag") eq_( - len([ - const for const - in new_table.constraints - if isinstance(const, CheckConstraint)]), - 0) + len( + [ + const + for const in new_table.constraints + if isinstance(const, CheckConstraint) + ] + ), + 0, + ) def test_rename_col_enum(self): impl = self._enum_fixture() - impl.alter_column('tname', 'thing', name='thang') + impl.alter_column("tname", "thing", name="thang") new_table = self._assert_impl( - impl, ddl_contains="CHECK (thang IN ('a', 'b', 'c')", - colnames=["id", "thing"]) - eq_(new_table.c.thing.name, 'thang') + impl, + ddl_contains="CHECK (thang IN ('a', 'b', 'c')", + colnames=["id", "thing"], + ) + eq_(new_table.c.thing.name, "thang") eq_( - len([ - const for const - in new_table.constraints - if isinstance(const, CheckConstraint)]), - 1) + len( + [ + const + for const in new_table.constraints + if isinstance(const, CheckConstraint) + ] + ), + 1, + ) def test_rename_col_literal_ck(self): impl = self._literal_ck_fixture() - impl.alter_column('tname', 'email', name='emol') + impl.alter_column("tname", "email", name="emol") new_table = self._assert_impl( # note this is wrong, we don't dig into the SQL - impl, ddl_contains="CHECK (email LIKE '%@%')", - colnames=["id", "email"]) + impl, + ddl_contains="CHECK (email LIKE '%@%')", + colnames=["id", "email"], + ) eq_( - len([c for c in new_table.constraints - if isinstance(c, CheckConstraint)]), 1) + len( + [ + c + for c in new_table.constraints + if isinstance(c, CheckConstraint) + ] + ), + 1, + ) - eq_(new_table.c.email.name, 'emol') + eq_(new_table.c.email.name, "emol") def test_rename_col_literal_ck_workaround(self): impl = self._literal_ck_fixture( copy_from=Table( - 'tname', MetaData(), - Column('id', Integer, primary_key=True), - Column('email', String), + "tname", + MetaData(), + Column("id", Integer, primary_key=True), + Column("email", String), ), - table_args=[CheckConstraint("emol LIKE '%@%'")]) + table_args=[CheckConstraint("emol LIKE '%@%'")], + ) - impl.alter_column('tname', 'email', name='emol') + impl.alter_column("tname", "email", name="emol") new_table = self._assert_impl( - impl, ddl_contains="CHECK (emol LIKE '%@%')", - colnames=["id", "email"]) + impl, + ddl_contains="CHECK (emol LIKE '%@%')", + colnames=["id", "email"], + ) eq_( - len([c for c in new_table.constraints - if isinstance(c, CheckConstraint)]), 1) - eq_(new_table.c.email.name, 'emol') + len( + [ + c + for c in new_table.constraints + if isinstance(c, CheckConstraint) + ] + ), + 1, + ) + eq_(new_table.c.email.name, "emol") def test_rename_col_sql_ck(self): impl = self._sql_ck_fixture() - impl.alter_column('tname', 'email', name='emol') + impl.alter_column("tname", "email", name="emol") new_table = self._assert_impl( - impl, ddl_contains="CHECK (emol LIKE '%@%')", - colnames=["id", "email"]) + impl, + ddl_contains="CHECK (emol LIKE '%@%')", + colnames=["id", "email"], + ) eq_( - len([c for c in new_table.constraints - if isinstance(c, CheckConstraint)]), 1) + len( + [ + c + for c in new_table.constraints + if isinstance(c, CheckConstraint) + ] + ), + 1, + ) - eq_(new_table.c.email.name, 'emol') + eq_(new_table.c.email.name, "emol") def test_add_col(self): impl = self._simple_fixture() - col = Column('g', Integer) + col = Column("g", Integer) # operations.add_column produces a table - t = self.op.schema_obj.table('tname', col) # noqa - impl.add_column('tname', col) - new_table = self._assert_impl(impl, colnames=['id', 'x', 'y', 'g']) - eq_(new_table.c.g.name, 'g') + t = self.op.schema_obj.table("tname", col) # noqa + impl.add_column("tname", col) + new_table = self._assert_impl(impl, colnames=["id", "x", "y", "g"]) + eq_(new_table.c.g.name, "g") def test_add_server_default(self): impl = self._simple_fixture() - impl.alter_column('tname', 'y', server_default="10") - new_table = self._assert_impl( - impl, ddl_contains="DEFAULT '10'") - eq_( - new_table.c.y.server_default.arg, "10" - ) + impl.alter_column("tname", "y", server_default="10") + new_table = self._assert_impl(impl, ddl_contains="DEFAULT '10'") + eq_(new_table.c.y.server_default.arg, "10") def test_drop_server_default(self): impl = self._server_default_fixture() - impl.alter_column('tname', 'thing', server_default=None) + impl.alter_column("tname", "thing", server_default=None) new_table = self._assert_impl( - impl, colnames=['id', 'thing'], ddl_not_contains="DEFAULT") + impl, colnames=["id", "thing"], ddl_not_contains="DEFAULT" + ) eq_(new_table.c.thing.server_default, None) def test_rename_col_pk(self): impl = self._simple_fixture() - impl.alter_column('tname', 'id', name='foobar') + impl.alter_column("tname", "id", name="foobar") new_table = self._assert_impl( - impl, ddl_contains="PRIMARY KEY (foobar)") - eq_(new_table.c.id.name, 'foobar') + impl, ddl_contains="PRIMARY KEY (foobar)" + ) + eq_(new_table.c.id.name, "foobar") eq_(list(new_table.primary_key), [new_table.c.id]) def test_rename_col_fk(self): impl = self._fk_fixture() - impl.alter_column('tname', 'user_id', name='foobar') + impl.alter_column("tname", "user_id", name="foobar") new_table = self._assert_impl( - impl, colnames=['id', 'email', 'user_id'], - ddl_contains='FOREIGN KEY(foobar) REFERENCES "user" (id)') - eq_(new_table.c.user_id.name, 'foobar') + impl, + colnames=["id", "email", "user_id"], + ddl_contains='FOREIGN KEY(foobar) REFERENCES "user" (id)', + ) + eq_(new_table.c.user_id.name, "foobar") eq_( - list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), - "user.id" + list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id" ) def test_regen_multi_fk(self): impl = self._multi_fk_fixture() self._assert_impl( - impl, colnames=[ - 'id', 'email', 'user_id_1', 'user_id_2', - 'user_id_3', 'user_id_version'], - ddl_contains='FOREIGN KEY(user_id_3, user_id_version) ' - 'REFERENCES "user" (id, id_version)') + impl, + colnames=[ + "id", + "email", + "user_id_1", + "user_id_2", + "user_id_3", + "user_id_version", + ], + ddl_contains="FOREIGN KEY(user_id_3, user_id_version) " + 'REFERENCES "user" (id, id_version)', + ) def test_regen_multi_fk_schema(self): - impl = self._multi_fk_fixture(schema='foo_schema') + impl = self._multi_fk_fixture(schema="foo_schema") self._assert_impl( - impl, colnames=[ - 'id', 'email', 'user_id_1', 'user_id_2', - 'user_id_3', 'user_id_version'], - ddl_contains='FOREIGN KEY(user_id_3, user_id_version) ' + impl, + colnames=[ + "id", + "email", + "user_id_1", + "user_id_2", + "user_id_3", + "user_id_version", + ], + ddl_contains="FOREIGN KEY(user_id_3, user_id_version) " 'REFERENCES foo_schema."user" (id, id_version)', - schema='foo_schema') + schema="foo_schema", + ) def test_drop_col(self): impl = self._simple_fixture() - impl.drop_column('tname', column('x')) - new_table = self._assert_impl(impl, colnames=['id', 'y']) - assert 'y' in new_table.c - assert 'x' not in new_table.c + impl.drop_column("tname", column("x")) + new_table = self._assert_impl(impl, colnames=["id", "y"]) + assert "y" in new_table.c + assert "x" not in new_table.c def test_drop_col_remove_pk(self): impl = self._simple_fixture() - impl.drop_column('tname', column('id')) + impl.drop_column("tname", column("id")) new_table = self._assert_impl( - impl, colnames=['x', 'y'], ddl_not_contains="PRIMARY KEY") - assert 'y' in new_table.c - assert 'id' not in new_table.c + impl, colnames=["x", "y"], ddl_not_contains="PRIMARY KEY" + ) + assert "y" in new_table.c + assert "id" not in new_table.c assert not new_table.primary_key def test_drop_col_remove_fk(self): impl = self._fk_fixture() - impl.drop_column('tname', column('user_id')) + impl.drop_column("tname", column("user_id")) new_table = self._assert_impl( - impl, colnames=['id', 'email'], ddl_not_contains="FOREIGN KEY") - assert 'user_id' not in new_table.c + impl, colnames=["id", "email"], ddl_not_contains="FOREIGN KEY" + ) + assert "user_id" not in new_table.c assert not new_table.foreign_keys def test_drop_col_retain_fk(self): impl = self._fk_fixture() - impl.drop_column('tname', column('email')) + impl.drop_column("tname", column("email")) new_table = self._assert_impl( - impl, colnames=['id', 'user_id'], - ddl_contains='FOREIGN KEY(user_id) REFERENCES "user" (id)') - assert 'email' not in new_table.c + impl, + colnames=["id", "user_id"], + ddl_contains='FOREIGN KEY(user_id) REFERENCES "user" (id)', + ) + assert "email" not in new_table.c assert new_table.c.user_id.foreign_keys def test_drop_col_retain_fk_selfref(self): impl = self._selfref_fk_fixture() - impl.drop_column('tname', column('data')) - new_table = self._assert_impl(impl, colnames=['id', 'parent_id']) - assert 'data' not in new_table.c + impl.drop_column("tname", column("data")) + new_table = self._assert_impl(impl, colnames=["id", "parent_id"]) + assert "data" not in new_table.c assert new_table.c.parent_id.foreign_keys def test_add_fk(self): impl = self._simple_fixture() - impl.add_column('tname', Column('user_id', Integer)) + impl.add_column("tname", Column("user_id", Integer)) fk = self.op.schema_obj.foreign_key_constraint( - 'fk1', 'tname', 'user', - ['user_id'], ['id']) + "fk1", "tname", "user", ["user_id"], ["id"] + ) impl.add_constraint(fk) new_table = self._assert_impl( - impl, colnames=['id', 'x', 'y', 'user_id'], - ddl_contains='CONSTRAINT fk1 FOREIGN KEY(user_id) ' - 'REFERENCES "user" (id)') + impl, + colnames=["id", "x", "y", "user_id"], + ddl_contains="CONSTRAINT fk1 FOREIGN KEY(user_id) " + 'REFERENCES "user" (id)', + ) eq_( - list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), - 'user.id' + list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id" ) def test_drop_fk(self): impl = self._named_fk_fixture() - fk = ForeignKeyConstraint([], [], name='ufk') + fk = ForeignKeyConstraint([], [], name="ufk") impl.drop_constraint(fk) new_table = self._assert_impl( - impl, colnames=['id', 'email', 'user_id'], - ddl_not_contains="CONSTRANT fk1") - eq_( - list(new_table.foreign_keys), - [] + impl, + colnames=["id", "email", "user_id"], + ddl_not_contains="CONSTRANT fk1", ) + eq_(list(new_table.foreign_keys), []) def test_add_uq(self): impl = self._simple_fixture() - uq = self.op.schema_obj.unique_constraint( - 'uq1', 'tname', ['y'] - ) + uq = self.op.schema_obj.unique_constraint("uq1", "tname", ["y"]) impl.add_constraint(uq) self._assert_impl( - impl, colnames=['id', 'x', 'y'], - ddl_contains="CONSTRAINT uq1 UNIQUE") + impl, + colnames=["id", "x", "y"], + ddl_contains="CONSTRAINT uq1 UNIQUE", + ) def test_drop_uq(self): impl = self._uq_fixture() - uq = self.op.schema_obj.unique_constraint( - 'uq1', 'tname', ['y'] - ) + uq = self.op.schema_obj.unique_constraint("uq1", "tname", ["y"]) impl.drop_constraint(uq) self._assert_impl( - impl, colnames=['id', 'x', 'y'], - ddl_not_contains="CONSTRAINT uq1 UNIQUE") + impl, + colnames=["id", "x", "y"], + ddl_not_contains="CONSTRAINT uq1 UNIQUE", + ) def test_create_index(self): impl = self._simple_fixture() - ix = self.op.schema_obj.index('ix1', 'tname', ['y']) + ix = self.op.schema_obj.index("ix1", "tname", ["y"]) impl.create_index(ix) self._assert_impl( - impl, colnames=['id', 'x', 'y'], - ddl_contains="CREATE INDEX ix1") + impl, colnames=["id", "x", "y"], ddl_contains="CREATE INDEX ix1" + ) def test_drop_index(self): impl = self._ix_fixture() - ix = self.op.schema_obj.index('ix1', 'tname', ['y']) + ix = self.op.schema_obj.index("ix1", "tname", ["y"]) impl.drop_index(ix) self._assert_impl( - impl, colnames=['id', 'x', 'y'], - ddl_not_contains="CONSTRAINT uq1 UNIQUE") + impl, + colnames=["id", "x", "y"], + ddl_not_contains="CONSTRAINT uq1 UNIQUE", + ) def test_add_table_opts(self): - impl = self._simple_fixture(table_kwargs={'mysql_engine': 'InnoDB'}) - self._assert_impl( - impl, ddl_contains="ENGINE=InnoDB", - dialect='mysql' - ) + impl = self._simple_fixture(table_kwargs={"mysql_engine": "InnoDB"}) + self._assert_impl(impl, ddl_contains="ENGINE=InnoDB", dialect="mysql") def test_drop_pk(self): impl = self._pk_fixture() @@ -554,14 +658,15 @@ class BatchApplyTest(TestBase): class BatchAPITest(TestBase): - @contextmanager def _fixture(self, schema=None): migration_context = mock.Mock( - opts={}, impl=mock.MagicMock(__dialect__='sqlite')) + opts={}, impl=mock.MagicMock(__dialect__="sqlite") + ) op = Operations(migration_context) batch = op.batch_alter_table( - 'tname', recreate='never', schema=schema).__enter__() + "tname", recreate="never", schema=schema + ).__enter__() mock_schema = mock.MagicMock() with mock.patch("alembic.operations.schemaobj.sa_schema", mock_schema): @@ -571,105 +676,131 @@ class BatchAPITest(TestBase): def test_drop_col(self): with self._fixture() as batch: - batch.drop_column('q') + batch.drop_column("q") eq_( batch.impl.operations.impl.mock_calls, - [mock.call.drop_column( - 'tname', self.mock_schema.Column(), schema=None)] + [ + mock.call.drop_column( + "tname", self.mock_schema.Column(), schema=None + ) + ], ) def test_add_col(self): - column = Column('w', String(50)) + column = Column("w", String(50)) with self._fixture() as batch: batch.add_column(column) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.add_column( - 'tname', column, schema=None)] + [mock.call.add_column("tname", column, schema=None)], ) def test_create_fk(self): with self._fixture() as batch: - batch.create_foreign_key('myfk', 'user', ['x'], ['y']) + batch.create_foreign_key("myfk", "user", ["x"], ["y"]) eq_( self.mock_schema.ForeignKeyConstraint.mock_calls, [ mock.call( - ['x'], ['user.y'], - onupdate=None, ondelete=None, name='myfk', - initially=None, deferrable=None, match=None) - ] + ["x"], + ["user.y"], + onupdate=None, + ondelete=None, + name="myfk", + initially=None, + deferrable=None, + match=None, + ) + ], ) eq_( self.mock_schema.Table.mock_calls, [ mock.call( - 'user', self.mock_schema.MetaData(), + "user", + self.mock_schema.MetaData(), self.mock_schema.Column(), - schema=None + schema=None, ), mock.call( - 'tname', self.mock_schema.MetaData(), + "tname", + self.mock_schema.MetaData(), self.mock_schema.Column(), - schema=None + schema=None, ), mock.call().append_constraint( - self.mock_schema.ForeignKeyConstraint()) - ] + self.mock_schema.ForeignKeyConstraint() + ), + ], ) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.add_constraint( - self.mock_schema.ForeignKeyConstraint())] + [ + mock.call.add_constraint( + self.mock_schema.ForeignKeyConstraint() + ) + ], ) def test_create_fk_schema(self): - with self._fixture(schema='foo') as batch: - batch.create_foreign_key('myfk', 'user', ['x'], ['y']) + with self._fixture(schema="foo") as batch: + batch.create_foreign_key("myfk", "user", ["x"], ["y"]) eq_( self.mock_schema.ForeignKeyConstraint.mock_calls, [ mock.call( - ['x'], ['user.y'], - onupdate=None, ondelete=None, name='myfk', - initially=None, deferrable=None, match=None) - ] + ["x"], + ["user.y"], + onupdate=None, + ondelete=None, + name="myfk", + initially=None, + deferrable=None, + match=None, + ) + ], ) eq_( self.mock_schema.Table.mock_calls, [ mock.call( - 'user', self.mock_schema.MetaData(), + "user", + self.mock_schema.MetaData(), self.mock_schema.Column(), - schema=None + schema=None, ), mock.call( - 'tname', self.mock_schema.MetaData(), + "tname", + self.mock_schema.MetaData(), self.mock_schema.Column(), - schema='foo' + schema="foo", ), mock.call().append_constraint( - self.mock_schema.ForeignKeyConstraint()) - ] + self.mock_schema.ForeignKeyConstraint() + ), + ], ) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.add_constraint( - self.mock_schema.ForeignKeyConstraint())] + [ + mock.call.add_constraint( + self.mock_schema.ForeignKeyConstraint() + ) + ], ) def test_create_uq(self): with self._fixture() as batch: - batch.create_unique_constraint('uq1', ['a', 'b']) + batch.create_unique_constraint("uq1", ["a", "b"]) eq_( self.mock_schema.Table().c.__getitem__.mock_calls, - [mock.call('a'), mock.call('b')] + [mock.call("a"), mock.call("b")], ) eq_( @@ -678,23 +809,22 @@ class BatchAPITest(TestBase): mock.call( self.mock_schema.Table().c.__getitem__(), self.mock_schema.Table().c.__getitem__(), - name='uq1' + name="uq1", ) - ] + ], ) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.add_constraint( - self.mock_schema.UniqueConstraint())] + [mock.call.add_constraint(self.mock_schema.UniqueConstraint())], ) def test_create_pk(self): with self._fixture() as batch: - batch.create_primary_key('pk1', ['a', 'b']) + batch.create_primary_key("pk1", ["a", "b"]) eq_( self.mock_schema.Table().c.__getitem__.mock_calls, - [mock.call('a'), mock.call('b')] + [mock.call("a"), mock.call("b")], ) eq_( @@ -703,60 +833,53 @@ class BatchAPITest(TestBase): mock.call( self.mock_schema.Table().c.__getitem__(), self.mock_schema.Table().c.__getitem__(), - name='pk1' + name="pk1", ) - ] + ], ) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.add_constraint( - self.mock_schema.PrimaryKeyConstraint())] + [ + mock.call.add_constraint( + self.mock_schema.PrimaryKeyConstraint() + ) + ], ) def test_create_check(self): expr = text("a > b") with self._fixture() as batch: - batch.create_check_constraint('ck1', expr) + batch.create_check_constraint("ck1", expr) eq_( self.mock_schema.CheckConstraint.mock_calls, - [ - mock.call( - expr, name="ck1" - ) - ] + [mock.call(expr, name="ck1")], ) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.add_constraint( - self.mock_schema.CheckConstraint())] + [mock.call.add_constraint(self.mock_schema.CheckConstraint())], ) def test_drop_constraint(self): with self._fixture() as batch: - batch.drop_constraint('uq1') + batch.drop_constraint("uq1") - eq_( - self.mock_schema.Constraint.mock_calls, - [ - mock.call(name='uq1') - ] - ) + eq_(self.mock_schema.Constraint.mock_calls, [mock.call(name="uq1")]) eq_( batch.impl.operations.impl.mock_calls, - [mock.call.drop_constraint(self.mock_schema.Constraint())] + [mock.call.drop_constraint(self.mock_schema.Constraint())], ) class CopyFromTest(TestBase): - def _fixture(self): self.metadata = MetaData() self.table = Table( - 'foo', self.metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('x', Integer), + "foo", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("x", Integer), ) context = op_fixture(dialect="sqlite", as_sql=True) @@ -766,148 +889,151 @@ class CopyFromTest(TestBase): def test_change_type(self): context = self._fixture() with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: - batch_op.alter_column('data', type_=Integer) + "foo", copy_from=self.table + ) as batch_op: + batch_op.alter_column("data", type_=Integer) context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data INTEGER, x INTEGER, PRIMARY KEY (id))', - 'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, ' - 'CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo' + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data INTEGER, x INTEGER, PRIMARY KEY (id))", + "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, " + "CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", ) def test_change_type_from_schematype(self): context = self._fixture() self.table.append_column( - Column('y', Boolean( - create_constraint=True, name="ck1"))) + Column("y", Boolean(create_constraint=True, name="ck1")) + ) with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: + "foo", copy_from=self.table + ) as batch_op: batch_op.alter_column( - 'y', type_=Integer, - existing_type=Boolean( - create_constraint=True, name="ck1")) + "y", + type_=Integer, + existing_type=Boolean(create_constraint=True, name="ck1"), + ) context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data VARCHAR(50), x INTEGER, y INTEGER, PRIMARY KEY (id))', - 'INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, ' - 'foo.data, foo.x, CAST(foo.y AS INTEGER) AS anon_1 FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo' + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data VARCHAR(50), x INTEGER, y INTEGER, PRIMARY KEY (id))", + "INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, " + "foo.data, foo.x, CAST(foo.y AS INTEGER) AS anon_1 FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", ) def test_change_type_to_schematype(self): context = self._fixture() - self.table.append_column( - Column('y', Integer)) + self.table.append_column(Column("y", Integer)) with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: + "foo", copy_from=self.table + ) as batch_op: batch_op.alter_column( - 'y', existing_type=Integer, - type_=Boolean( - create_constraint=True, name="ck1")) + "y", + existing_type=Integer, + type_=Boolean(create_constraint=True, name="ck1"), + ) context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data VARCHAR(50), x INTEGER, y BOOLEAN, PRIMARY KEY (id), ' - 'CONSTRAINT ck1 CHECK (y IN (0, 1)))', - 'INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, ' - 'foo.data, foo.x, CAST(foo.y AS BOOLEAN) AS anon_1 FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo' + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data VARCHAR(50), x INTEGER, y BOOLEAN, PRIMARY KEY (id), " + "CONSTRAINT ck1 CHECK (y IN (0, 1)))", + "INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, " + "foo.data, foo.x, CAST(foo.y AS BOOLEAN) AS anon_1 FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", ) def test_create_drop_index_w_always(self): context = self._fixture() with self.op.batch_alter_table( - "foo", copy_from=self.table, recreate='always') as batch_op: - batch_op.create_index( - 'ix_data', ['data'], unique=True) + "foo", copy_from=self.table, recreate="always" + ) as batch_op: + batch_op.create_index("ix_data", ["data"], unique=True) context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data VARCHAR(50), ' - 'x INTEGER, PRIMARY KEY (id))', - 'INSERT INTO _alembic_tmp_foo (id, data, x) ' - 'SELECT foo.id, foo.data, foo.x FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo', - 'CREATE UNIQUE INDEX ix_data ON foo (data)', + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data VARCHAR(50), " + "x INTEGER, PRIMARY KEY (id))", + "INSERT INTO _alembic_tmp_foo (id, data, x) " + "SELECT foo.id, foo.data, foo.x FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", + "CREATE UNIQUE INDEX ix_data ON foo (data)", ) context.clear_assertions() - Index('ix_data', self.table.c.data, unique=True) + Index("ix_data", self.table.c.data, unique=True) with self.op.batch_alter_table( - "foo", copy_from=self.table, recreate='always') as batch_op: - batch_op.drop_index('ix_data') + "foo", copy_from=self.table, recreate="always" + ) as batch_op: + batch_op.drop_index("ix_data") context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data VARCHAR(50), x INTEGER, PRIMARY KEY (id))', - 'INSERT INTO _alembic_tmp_foo (id, data, x) ' - 'SELECT foo.id, foo.data, foo.x FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo' + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data VARCHAR(50), x INTEGER, PRIMARY KEY (id))", + "INSERT INTO _alembic_tmp_foo (id, data, x) " + "SELECT foo.id, foo.data, foo.x FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", ) def test_create_drop_index_wo_always(self): context = self._fixture() with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: - batch_op.create_index( - 'ix_data', ['data'], unique=True) + "foo", copy_from=self.table + ) as batch_op: + batch_op.create_index("ix_data", ["data"], unique=True) - context.assert_( - 'CREATE UNIQUE INDEX ix_data ON foo (data)' - ) + context.assert_("CREATE UNIQUE INDEX ix_data ON foo (data)") context.clear_assertions() - Index('ix_data', self.table.c.data, unique=True) + Index("ix_data", self.table.c.data, unique=True) with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: - batch_op.drop_index('ix_data') + "foo", copy_from=self.table + ) as batch_op: + batch_op.drop_index("ix_data") - context.assert_( - 'DROP INDEX ix_data' - ) + context.assert_("DROP INDEX ix_data") def test_create_drop_index_w_other_ops(self): context = self._fixture() with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: - batch_op.alter_column('data', type_=Integer) - batch_op.create_index( - 'ix_data', ['data'], unique=True) + "foo", copy_from=self.table + ) as batch_op: + batch_op.alter_column("data", type_=Integer) + batch_op.create_index("ix_data", ["data"], unique=True) context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data INTEGER, x INTEGER, PRIMARY KEY (id))', - 'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, ' - 'CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo', - 'CREATE UNIQUE INDEX ix_data ON foo (data)', + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data INTEGER, x INTEGER, PRIMARY KEY (id))", + "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, " + "CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", + "CREATE UNIQUE INDEX ix_data ON foo (data)", ) context.clear_assertions() - Index('ix_data', self.table.c.data, unique=True) + Index("ix_data", self.table.c.data, unique=True) with self.op.batch_alter_table( - "foo", copy_from=self.table) as batch_op: - batch_op.drop_index('ix_data') - batch_op.alter_column('data', type_=String) + "foo", copy_from=self.table + ) as batch_op: + batch_op.drop_index("ix_data") + batch_op.alter_column("data", type_=String) context.assert_( - 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, ' - 'data VARCHAR, x INTEGER, PRIMARY KEY (id))', - 'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, ' - 'foo.data, foo.x FROM foo', - 'DROP TABLE foo', - 'ALTER TABLE _alembic_tmp_foo RENAME TO foo' + "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, " + "data VARCHAR, x INTEGER, PRIMARY KEY (id))", + "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, " + "foo.data, foo.x FROM foo", + "DROP TABLE foo", + "ALTER TABLE _alembic_tmp_foo RENAME TO foo", ) @@ -918,11 +1044,12 @@ class BatchRoundTripTest(TestBase): self.conn = config.db.connect() self.metadata = MetaData() t1 = Table( - 'foo', self.metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('x', Integer), - mysql_engine='InnoDB' + "foo", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("x", Integer), + mysql_engine="InnoDB", ) t1.create(self.conn) @@ -933,8 +1060,8 @@ class BatchRoundTripTest(TestBase): {"id": 2, "data": "22", "x": 6}, {"id": 3, "data": "8.5", "x": 7}, {"id": 4, "data": "9.46", "x": 8}, - {"id": 5, "data": "d5", "x": 9} - ] + {"id": 5, "data": "d5", "x": 9}, + ], ) context = MigrationContext.configure(self.conn) self.op = Operations(context) @@ -949,80 +1076,75 @@ class BatchRoundTripTest(TestBase): def _no_pk_fixture(self): nopk = Table( - 'nopk', self.metadata, - Column('a', Integer), - Column('b', Integer), - Column('c', Integer), - mysql_engine='InnoDB' + "nopk", + self.metadata, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), + mysql_engine="InnoDB", ) nopk.create(self.conn) self.conn.execute( - nopk.insert(), - [ - {"a": 1, "b": 2, "c": 3}, - {"a": 2, "b": 4, "c": 5}, - ] - + nopk.insert(), [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}] ) return nopk def _table_w_index_fixture(self): t = Table( - 't_w_ix', self.metadata, - Column('id', Integer, primary_key=True), - Column('thing', Integer), - Column('data', String(20)), + "t_w_ix", + self.metadata, + Column("id", Integer, primary_key=True), + Column("thing", Integer), + Column("data", String(20)), ) - Index('ix_thing', t.c.thing) + Index("ix_thing", t.c.thing) t.create(self.conn) return t def _boolean_fixture(self): t = Table( - 'hasbool', self.metadata, - Column('x', Boolean(create_constraint=True, name='ck1')), - Column('y', Integer) + "hasbool", + self.metadata, + Column("x", Boolean(create_constraint=True, name="ck1")), + Column("y", Integer), ) t.create(self.conn) def _timestamp_fixture(self): - t = Table( - 'hasts', self.metadata, - Column('x', DateTime()), - ) + t = Table("hasts", self.metadata, Column("x", DateTime())) t.create(self.conn) return t def _int_to_boolean_fixture(self): - t = Table( - 'hasbool', self.metadata, - Column('x', Integer) - ) + t = Table("hasbool", self.metadata, Column("x", Integer)) t.create(self.conn) def test_change_type_boolean_to_int(self): self._boolean_fixture() - with self.op.batch_alter_table( - "hasbool" - ) as batch_op: + with self.op.batch_alter_table("hasbool") as batch_op: batch_op.alter_column( - 'x', type_=Integer, existing_type=Boolean( - create_constraint=True, name='ck1')) + "x", + type_=Integer, + existing_type=Boolean(create_constraint=True, name="ck1"), + ) insp = Inspector.from_engine(config.db) eq_( - [c['type']._type_affinity for c in insp.get_columns('hasbool') - if c['name'] == 'x'], - [Integer] + [ + c["type"]._type_affinity + for c in insp.get_columns("hasbool") + if c["name"] == "x" + ], + [Integer], ) def test_no_net_change_timestamp(self): t = self._timestamp_fixture() import datetime + self.conn.execute( - t.insert(), - {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)} + t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)} ) with self.op.batch_alter_table("hasts") as batch_op: @@ -1030,69 +1152,71 @@ class BatchRoundTripTest(TestBase): eq_( self.conn.execute(select([t.c.x])).fetchall(), - [(datetime.datetime(2012, 5, 18, 15, 32, 5),)] + [(datetime.datetime(2012, 5, 18, 15, 32, 5),)], ) def test_drop_col_schematype(self): self._boolean_fixture() - with self.op.batch_alter_table( - "hasbool" - ) as batch_op: - batch_op.drop_column('x') + with self.op.batch_alter_table("hasbool") as batch_op: + batch_op.drop_column("x") insp = Inspector.from_engine(config.db) - assert 'x' not in (c['name'] for c in insp.get_columns('hasbool')) + assert "x" not in (c["name"] for c in insp.get_columns("hasbool")) def test_change_type_int_to_boolean(self): self._int_to_boolean_fixture() - with self.op.batch_alter_table( - "hasbool" - ) as batch_op: + with self.op.batch_alter_table("hasbool") as batch_op: batch_op.alter_column( - 'x', type_=Boolean(create_constraint=True, name='ck1')) + "x", type_=Boolean(create_constraint=True, name="ck1") + ) insp = Inspector.from_engine(config.db) if exclusions.against(config, "sqlite"): eq_( - [c['type']._type_affinity for - c in insp.get_columns('hasbool') if c['name'] == 'x'], - [Boolean] + [ + c["type"]._type_affinity + for c in insp.get_columns("hasbool") + if c["name"] == "x" + ], + [Boolean], ) elif exclusions.against(config, "mysql"): eq_( - [c['type']._type_affinity for - c in insp.get_columns('hasbool') if c['name'] == 'x'], - [Integer] + [ + c["type"]._type_affinity + for c in insp.get_columns("hasbool") + if c["name"] == "x" + ], + [Integer], ) def tearDown(self): self.metadata.drop_all(self.conn) self.conn.close() - def _assert_data(self, data, tablename='foo'): + def _assert_data(self, data, tablename="foo"): eq_( - [dict(row) for row - in self.conn.execute("select * from %s" % tablename)], - data + [ + dict(row) + for row in self.conn.execute("select * from %s" % tablename) + ], + data, ) def test_ix_existing(self): self._table_w_index_fixture() with self.op.batch_alter_table("t_w_ix") as batch_op: - batch_op.alter_column('data', type_=String(30)) + batch_op.alter_column("data", type_=String(30)) batch_op.create_index("ix_data", ["data"]) insp = Inspector.from_engine(config.db) eq_( set( - (ix['name'], tuple(ix['column_names'])) for ix in - insp.get_indexes('t_w_ix') + (ix["name"], tuple(ix["column_names"])) + for ix in insp.get_indexes("t_w_ix") ), - set([ - ('ix_data', ('data',)), - ('ix_thing', ('thing', )) - ]) + set([("ix_data", ("data",)), ("ix_thing", ("thing",))]), ) def test_fk_points_to_me_auto(self): @@ -1108,31 +1232,39 @@ class BatchRoundTripTest(TestBase): @exclusions.only_on("sqlite") @exclusions.fails( "intentionally asserting that this " - "doesn't work w/ pragma foreign keys") + "doesn't work w/ pragma foreign keys" + ) def test_fk_points_to_me_sqlite_refinteg(self): with self._sqlite_referential_integrity(): self._test_fk_points_to_me("auto") def _test_fk_points_to_me(self, recreate): bar = Table( - 'bar', self.metadata, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - mysql_engine='InnoDB' + "bar", + self.metadata, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id")), + mysql_engine="InnoDB", ) bar.create(self.conn) - self.conn.execute(bar.insert(), {'id': 1, 'foo_id': 3}) + self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3}) with self.op.batch_alter_table("foo", recreate=recreate) as batch_op: batch_op.alter_column( - 'data', new_column_name='newdata', existing_type=String(50)) + "data", new_column_name="newdata", existing_type=String(50) + ) insp = Inspector.from_engine(self.conn) eq_( - [(key['referred_table'], - key['referred_columns'], key['constrained_columns']) - for key in insp.get_foreign_keys('bar')], - [('foo', ['id'], ['foo_id'])] + [ + ( + key["referred_table"], + key["referred_columns"], + key["constrained_columns"], + ) + for key in insp.get_foreign_keys("bar") + ], + [("foo", ["id"], ["foo_id"])], ) def test_selfref_fk_auto(self): @@ -1145,100 +1277,112 @@ class BatchRoundTripTest(TestBase): @exclusions.only_on("sqlite") @exclusions.fails( "intentionally asserting that this " - "doesn't work w/ pragma foreign keys") + "doesn't work w/ pragma foreign keys" + ) def test_selfref_fk_sqlite_refinteg(self): with self._sqlite_referential_integrity(): self._test_selfref_fk("auto") def _test_selfref_fk(self, recreate): bar = Table( - 'bar', self.metadata, - Column('id', Integer, primary_key=True), - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('data', String(50)), - mysql_engine='InnoDB' + "bar", + self.metadata, + Column("id", Integer, primary_key=True), + Column("bar_id", Integer, ForeignKey("bar.id")), + Column("data", String(50)), + mysql_engine="InnoDB", ) bar.create(self.conn) - self.conn.execute(bar.insert(), {'id': 1, 'data': 'x', 'bar_id': None}) - self.conn.execute(bar.insert(), {'id': 2, 'data': 'y', 'bar_id': 1}) + self.conn.execute(bar.insert(), {"id": 1, "data": "x", "bar_id": None}) + self.conn.execute(bar.insert(), {"id": 2, "data": "y", "bar_id": 1}) with self.op.batch_alter_table("bar", recreate=recreate) as batch_op: batch_op.alter_column( - 'data', new_column_name='newdata', existing_type=String(50)) + "data", new_column_name="newdata", existing_type=String(50) + ) insp = Inspector.from_engine(self.conn) insp = Inspector.from_engine(self.conn) eq_( - [(key['referred_table'], - key['referred_columns'], key['constrained_columns']) - for key in insp.get_foreign_keys('bar')], - [('bar', ['id'], ['bar_id'])] + [ + ( + key["referred_table"], + key["referred_columns"], + key["constrained_columns"], + ) + for key in insp.get_foreign_keys("bar") + ], + [("bar", ["id"], ["bar_id"])], ) def test_change_type(self): with self.op.batch_alter_table("foo") as batch_op: - batch_op.alter_column('data', type_=Integer) + batch_op.alter_column("data", type_=Integer) - self._assert_data([ - {"id": 1, "data": 0, "x": 5}, - {"id": 2, "data": 22, "x": 6}, - {"id": 3, "data": 8, "x": 7}, - {"id": 4, "data": 9, "x": 8}, - {"id": 5, "data": 0, "x": 9} - ]) + self._assert_data( + [ + {"id": 1, "data": 0, "x": 5}, + {"id": 2, "data": 22, "x": 6}, + {"id": 3, "data": 8, "x": 7}, + {"id": 4, "data": 9, "x": 8}, + {"id": 5, "data": 0, "x": 9}, + ] + ) def test_drop_column(self): with self.op.batch_alter_table("foo") as batch_op: - batch_op.drop_column('data') + batch_op.drop_column("data") - self._assert_data([ - {"id": 1, "x": 5}, - {"id": 2, "x": 6}, - {"id": 3, "x": 7}, - {"id": 4, "x": 8}, - {"id": 5, "x": 9} - ]) + self._assert_data( + [ + {"id": 1, "x": 5}, + {"id": 2, "x": 6}, + {"id": 3, "x": 7}, + {"id": 4, "x": 8}, + {"id": 5, "x": 9}, + ] + ) def test_drop_pk_col_readd_col(self): # drop a column, add it back without primary_key=True, should no # longer be in the constraint with self.op.batch_alter_table("foo") as batch_op: - batch_op.drop_column('id') - batch_op.add_column(Column('id', Integer)) + batch_op.drop_column("id") + batch_op.add_column(Column("id", Integer)) - pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo') - eq_(pk_const['constrained_columns'], []) + pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo") + eq_(pk_const["constrained_columns"], []) def test_drop_pk_col_readd_pk_col(self): # drop a column, add it back with primary_key=True, should remain with self.op.batch_alter_table("foo") as batch_op: - batch_op.drop_column('id') - batch_op.add_column(Column('id', Integer, primary_key=True)) + batch_op.drop_column("id") + batch_op.add_column(Column("id", Integer, primary_key=True)) - pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo') - eq_(pk_const['constrained_columns'], ['id']) + pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo") + eq_(pk_const["constrained_columns"], ["id"]) def test_drop_pk_col_readd_col_also_pk_const(self): # drop a column, add it back without primary_key=True, but then # also make anew PK constraint that includes it, should remain with self.op.batch_alter_table("foo") as batch_op: - batch_op.drop_column('id') - batch_op.add_column(Column('id', Integer)) - batch_op.create_primary_key('newpk', ['id']) + batch_op.drop_column("id") + batch_op.add_column(Column("id", Integer)) + batch_op.create_primary_key("newpk", ["id"]) - pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo') - eq_(pk_const['constrained_columns'], ['id']) + pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo") + eq_(pk_const["constrained_columns"], ["id"]) def test_add_pk_constraint(self): self._no_pk_fixture() with self.op.batch_alter_table("nopk", recreate="always") as batch_op: - batch_op.create_primary_key('newpk', ['a', 'b']) + batch_op.create_primary_key("newpk", ["a", "b"]) - pk_const = Inspector.from_engine(self.conn).get_pk_constraint('nopk') + pk_const = Inspector.from_engine(self.conn).get_pk_constraint("nopk") with config.requirements.reflects_pk_names.fail_if(): - eq_(pk_const['name'], 'newpk') - eq_(pk_const['constrained_columns'], ['a', 'b']) + eq_(pk_const["name"], "newpk") + eq_(pk_const["constrained_columns"], ["a", "b"]) @config.requirements.check_constraints_w_enforcement def test_add_ck_constraint(self): @@ -1247,203 +1391,219 @@ class BatchRoundTripTest(TestBase): # we dont support reflection of CHECK constraints # so test this by just running invalid data in - foo = self.metadata.tables['foo'] + foo = self.metadata.tables["foo"] assert_raises_message( exc.IntegrityError, "newck", self.conn.execute, - foo.insert(), {"id": 6, "data": 5, "x": -2} + foo.insert(), + {"id": 6, "data": 5, "x": -2}, ) @config.requirements.sqlalchemy_094 @config.requirements.unnamed_constraints def test_drop_foreign_key(self): bar = Table( - 'bar', self.metadata, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - mysql_engine='InnoDB' + "bar", + self.metadata, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id")), + mysql_engine="InnoDB", ) bar.create(self.conn) - self.conn.execute(bar.insert(), {'id': 1, 'foo_id': 3}) + self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3}) naming_convention = { - "fk": - "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s" } with self.op.batch_alter_table( - "bar", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint( - "fk_bar_foo_id_foo", type_="foreignkey") - eq_( - Inspector.from_engine(self.conn).get_foreign_keys('bar'), - [] - ) + "bar", naming_convention=naming_convention + ) as batch_op: + batch_op.drop_constraint("fk_bar_foo_id_foo", type_="foreignkey") + eq_(Inspector.from_engine(self.conn).get_foreign_keys("bar"), []) def test_drop_column_fk_recreate(self): - with self.op.batch_alter_table("foo", recreate='always') as batch_op: - batch_op.drop_column('data') + with self.op.batch_alter_table("foo", recreate="always") as batch_op: + batch_op.drop_column("data") - self._assert_data([ - {"id": 1, "x": 5}, - {"id": 2, "x": 6}, - {"id": 3, "x": 7}, - {"id": 4, "x": 8}, - {"id": 5, "x": 9} - ]) + self._assert_data( + [ + {"id": 1, "x": 5}, + {"id": 2, "x": 6}, + {"id": 3, "x": 7}, + {"id": 4, "x": 8}, + {"id": 5, "x": 9}, + ] + ) def test_rename_column(self): with self.op.batch_alter_table("foo") as batch_op: - batch_op.alter_column('x', new_column_name='y') + batch_op.alter_column("x", new_column_name="y") - self._assert_data([ - {"id": 1, "data": "d1", "y": 5}, - {"id": 2, "data": "22", "y": 6}, - {"id": 3, "data": "8.5", "y": 7}, - {"id": 4, "data": "9.46", "y": 8}, - {"id": 5, "data": "d5", "y": 9} - ]) + self._assert_data( + [ + {"id": 1, "data": "d1", "y": 5}, + {"id": 2, "data": "22", "y": 6}, + {"id": 3, "data": "8.5", "y": 7}, + {"id": 4, "data": "9.46", "y": 8}, + {"id": 5, "data": "d5", "y": 9}, + ] + ) def test_rename_column_boolean(self): bar = Table( - 'bar', self.metadata, - Column('id', Integer, primary_key=True), - Column('flag', Boolean()), - mysql_engine='InnoDB' + "bar", + self.metadata, + Column("id", Integer, primary_key=True), + Column("flag", Boolean()), + mysql_engine="InnoDB", ) bar.create(self.conn) - self.conn.execute(bar.insert(), {'id': 1, 'flag': True}) - self.conn.execute(bar.insert(), {'id': 2, 'flag': False}) + self.conn.execute(bar.insert(), {"id": 1, "flag": True}) + self.conn.execute(bar.insert(), {"id": 2, "flag": False}) - with self.op.batch_alter_table( - "bar" - ) as batch_op: + with self.op.batch_alter_table("bar") as batch_op: batch_op.alter_column( - 'flag', new_column_name='bflag', existing_type=Boolean) + "flag", new_column_name="bflag", existing_type=Boolean + ) - self._assert_data([ - {"id": 1, 'bflag': True}, - {"id": 2, 'bflag': False}, - ], 'bar') + self._assert_data( + [{"id": 1, "bflag": True}, {"id": 2, "bflag": False}], "bar" + ) @config.requirements.non_native_boolean def test_rename_column_non_native_boolean_no_ck(self): bar = Table( - 'bar', self.metadata, - Column('id', Integer, primary_key=True), - Column('flag', Boolean(create_constraint=False)), - mysql_engine='InnoDB' + "bar", + self.metadata, + Column("id", Integer, primary_key=True), + Column("flag", Boolean(create_constraint=False)), + mysql_engine="InnoDB", ) bar.create(self.conn) - self.conn.execute(bar.insert(), {'id': 1, 'flag': True}) - self.conn.execute(bar.insert(), {'id': 2, 'flag': False}) + self.conn.execute(bar.insert(), {"id": 1, "flag": True}) + self.conn.execute(bar.insert(), {"id": 2, "flag": False}) self.conn.execute( # override Boolean type which as of 1.1 coerces numerics # to 1/0 text("insert into bar (id, flag) values (:id, :flag)"), - {'id': 3, 'flag': 5}) + {"id": 3, "flag": 5}, + ) with self.op.batch_alter_table( "bar", - reflect_args=[Column('flag', Boolean(create_constraint=False))] + reflect_args=[Column("flag", Boolean(create_constraint=False))], ) as batch_op: batch_op.alter_column( - 'flag', new_column_name='bflag', existing_type=Boolean) + "flag", new_column_name="bflag", existing_type=Boolean + ) - self._assert_data([ - {"id": 1, 'bflag': True}, - {"id": 2, 'bflag': False}, - {'id': 3, 'bflag': 5} - ], 'bar') + self._assert_data( + [ + {"id": 1, "bflag": True}, + {"id": 2, "bflag": False}, + {"id": 3, "bflag": 5}, + ], + "bar", + ) def test_drop_column_pk(self): with self.op.batch_alter_table("foo") as batch_op: - batch_op.drop_column('id') + batch_op.drop_column("id") - self._assert_data([ - {"data": "d1", "x": 5}, - {"data": "22", "x": 6}, - {"data": "8.5", "x": 7}, - {"data": "9.46", "x": 8}, - {"data": "d5", "x": 9} - ]) + self._assert_data( + [ + {"data": "d1", "x": 5}, + {"data": "22", "x": 6}, + {"data": "8.5", "x": 7}, + {"data": "9.46", "x": 8}, + {"data": "d5", "x": 9}, + ] + ) def test_rename_column_pk(self): with self.op.batch_alter_table("foo") as batch_op: - batch_op.alter_column('id', new_column_name='ident') + batch_op.alter_column("id", new_column_name="ident") - self._assert_data([ - {"ident": 1, "data": "d1", "x": 5}, - {"ident": 2, "data": "22", "x": 6}, - {"ident": 3, "data": "8.5", "x": 7}, - {"ident": 4, "data": "9.46", "x": 8}, - {"ident": 5, "data": "d5", "x": 9} - ]) + self._assert_data( + [ + {"ident": 1, "data": "d1", "x": 5}, + {"ident": 2, "data": "22", "x": 6}, + {"ident": 3, "data": "8.5", "x": 7}, + {"ident": 4, "data": "9.46", "x": 8}, + {"ident": 5, "data": "d5", "x": 9}, + ] + ) def test_add_column_auto(self): # note this uses ALTER with self.op.batch_alter_table("foo") as batch_op: batch_op.add_column( - Column('data2', String(50), server_default='hi')) + Column("data2", String(50), server_default="hi") + ) - self._assert_data([ - {"id": 1, "data": "d1", "x": 5, 'data2': 'hi'}, - {"id": 2, "data": "22", "x": 6, 'data2': 'hi'}, - {"id": 3, "data": "8.5", "x": 7, 'data2': 'hi'}, - {"id": 4, "data": "9.46", "x": 8, 'data2': 'hi'}, - {"id": 5, "data": "d5", "x": 9, 'data2': 'hi'} - ]) + self._assert_data( + [ + {"id": 1, "data": "d1", "x": 5, "data2": "hi"}, + {"id": 2, "data": "22", "x": 6, "data2": "hi"}, + {"id": 3, "data": "8.5", "x": 7, "data2": "hi"}, + {"id": 4, "data": "9.46", "x": 8, "data2": "hi"}, + {"id": 5, "data": "d5", "x": 9, "data2": "hi"}, + ] + ) def test_add_column_recreate(self): - with self.op.batch_alter_table("foo", recreate='always') as batch_op: + with self.op.batch_alter_table("foo", recreate="always") as batch_op: batch_op.add_column( - Column('data2', String(50), server_default='hi')) + Column("data2", String(50), server_default="hi") + ) - self._assert_data([ - {"id": 1, "data": "d1", "x": 5, 'data2': 'hi'}, - {"id": 2, "data": "22", "x": 6, 'data2': 'hi'}, - {"id": 3, "data": "8.5", "x": 7, 'data2': 'hi'}, - {"id": 4, "data": "9.46", "x": 8, 'data2': 'hi'}, - {"id": 5, "data": "d5", "x": 9, 'data2': 'hi'} - ]) + self._assert_data( + [ + {"id": 1, "data": "d1", "x": 5, "data2": "hi"}, + {"id": 2, "data": "22", "x": 6, "data2": "hi"}, + {"id": 3, "data": "8.5", "x": 7, "data2": "hi"}, + {"id": 4, "data": "9.46", "x": 8, "data2": "hi"}, + {"id": 5, "data": "d5", "x": 9, "data2": "hi"}, + ] + ) def test_create_drop_index(self): insp = Inspector.from_engine(config.db) - eq_( - insp.get_indexes('foo'), [] - ) + eq_(insp.get_indexes("foo"), []) - with self.op.batch_alter_table("foo", recreate='always') as batch_op: - batch_op.create_index( - 'ix_data', ['data'], unique=True) + with self.op.batch_alter_table("foo", recreate="always") as batch_op: + batch_op.create_index("ix_data", ["data"], unique=True) - self._assert_data([ - {"id": 1, "data": "d1", "x": 5}, - {"id": 2, "data": "22", "x": 6}, - {"id": 3, "data": "8.5", "x": 7}, - {"id": 4, "data": "9.46", "x": 8}, - {"id": 5, "data": "d5", "x": 9} - ]) + self._assert_data( + [ + {"id": 1, "data": "d1", "x": 5}, + {"id": 2, "data": "22", "x": 6}, + {"id": 3, "data": "8.5", "x": 7}, + {"id": 4, "data": "9.46", "x": 8}, + {"id": 5, "data": "d5", "x": 9}, + ] + ) insp = Inspector.from_engine(config.db) eq_( [ - dict(unique=ix['unique'], - name=ix['name'], - column_names=ix['column_names']) - for ix in insp.get_indexes('foo') + dict( + unique=ix["unique"], + name=ix["name"], + column_names=ix["column_names"], + ) + for ix in insp.get_indexes("foo") ], - [{'unique': True, 'name': 'ix_data', 'column_names': ['data']}] + [{"unique": True, "name": "ix_data", "column_names": ["data"]}], ) - with self.op.batch_alter_table("foo", recreate='always') as batch_op: - batch_op.drop_index('ix_data') + with self.op.batch_alter_table("foo", recreate="always") as batch_op: + batch_op.drop_index("ix_data") insp = Inspector.from_engine(config.db) - eq_( - insp.get_indexes('foo'), [] - ) + eq_(insp.get_indexes("foo"), []) class BatchRoundTripMySQLTest(BatchRoundTripTest): @@ -1496,7 +1656,8 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest): @exclusions.fails() def test_drop_pk_col_readd_pk_col(self): super( - BatchRoundTripPostgresqlTest, self).test_drop_pk_col_readd_pk_col() + BatchRoundTripPostgresqlTest, self + ).test_drop_pk_col_readd_pk_col() @exclusions.fails() def test_drop_pk_col_readd_col_also_pk_const(self): @@ -1513,10 +1674,12 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest): @exclusions.fails() def test_change_type_int_to_boolean(self): - super(BatchRoundTripPostgresqlTest, self).\ - test_change_type_int_to_boolean() + super( + BatchRoundTripPostgresqlTest, self + ).test_change_type_int_to_boolean() @exclusions.fails() def test_change_type_boolean_to_int(self): - super(BatchRoundTripPostgresqlTest, self).\ - test_change_type_boolean_to_int() + super( + BatchRoundTripPostgresqlTest, self + ).test_change_type_boolean_to_int() diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py index 2655630..220719a 100644 --- a/tests/test_bulk_insert.py +++ b/tests/test_bulk_insert.py @@ -13,127 +13,131 @@ from alembic.testing import eq_, assert_raises_message, config class BulkInsertTest(TestBase): def _table_fixture(self, dialect, as_sql): context = op_fixture(dialect, as_sql) - t1 = table("ins_table", - column('id', Integer), - column('v1', String()), - column('v2', String()), - ) + t1 = table( + "ins_table", + column("id", Integer), + column("v1", String()), + column("v2", String()), + ) return context, t1 def _big_t_table_fixture(self, dialect, as_sql): context = op_fixture(dialect, as_sql) - t1 = Table("ins_table", MetaData(), - Column('id', Integer, primary_key=True), - Column('v1', String()), - Column('v2', String()), - ) + t1 = Table( + "ins_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("v1", String()), + Column("v2", String()), + ) return context, t1 def _test_bulk_insert(self, dialect, as_sql): context, t1 = self._table_fixture(dialect, as_sql) - op.bulk_insert(t1, [ - {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, - {'id': 2, 'v1': 'row v2', 'v2': 'row v6'}, - {'id': 3, 'v1': 'row v3', 'v2': 'row v7'}, - {'id': 4, 'v1': 'row v4', 'v2': 'row v8'}, - ]) + op.bulk_insert( + t1, + [ + {"id": 1, "v1": "row v1", "v2": "row v5"}, + {"id": 2, "v1": "row v2", "v2": "row v6"}, + {"id": 3, "v1": "row v3", "v2": "row v7"}, + {"id": 4, "v1": "row v4", "v2": "row v8"}, + ], + ) return context def _test_bulk_insert_single(self, dialect, as_sql): context, t1 = self._table_fixture(dialect, as_sql) - op.bulk_insert(t1, [ - {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, - ]) + op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}]) return context def _test_bulk_insert_single_bigt(self, dialect, as_sql): context, t1 = self._big_t_table_fixture(dialect, as_sql) - op.bulk_insert(t1, [ - {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, - ]) + op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}]) return context def test_bulk_insert(self): - context = self._test_bulk_insert('default', False) + context = self._test_bulk_insert("default", False) context.assert_( - 'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)' + "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)" ) def test_bulk_insert_wrong_cols(self): - context = op_fixture('postgresql') - t1 = table("ins_table", - column('id', Integer), - column('v1', String()), - column('v2', String()), - ) - op.bulk_insert(t1, [ - {'v1': 'row v1', }, - ]) + context = op_fixture("postgresql") + t1 = table( + "ins_table", + column("id", Integer), + column("v1", String()), + column("v2", String()), + ) + op.bulk_insert(t1, [{"v1": "row v1"}]) context.assert_( - 'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)' + "INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)" ) def test_bulk_insert_no_rows(self): - context, t1 = self._table_fixture('default', False) + context, t1 = self._table_fixture("default", False) op.bulk_insert(t1, []) context.assert_() def test_bulk_insert_pg(self): - context = self._test_bulk_insert('postgresql', False) + context = self._test_bulk_insert("postgresql", False) context.assert_( - 'INSERT INTO ins_table (id, v1, v2) ' - 'VALUES (%(id)s, %(v1)s, %(v2)s)' + "INSERT INTO ins_table (id, v1, v2) " + "VALUES (%(id)s, %(v1)s, %(v2)s)" ) def test_bulk_insert_pg_single(self): - context = self._test_bulk_insert_single('postgresql', False) + context = self._test_bulk_insert_single("postgresql", False) context.assert_( - 'INSERT INTO ins_table (id, v1, v2) ' - 'VALUES (%(id)s, %(v1)s, %(v2)s)' + "INSERT INTO ins_table (id, v1, v2) " + "VALUES (%(id)s, %(v1)s, %(v2)s)" ) def test_bulk_insert_pg_single_as_sql(self): - context = self._test_bulk_insert_single('postgresql', True) + context = self._test_bulk_insert_single("postgresql", True) context.assert_( "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')" ) def test_bulk_insert_pg_single_big_t_as_sql(self): - context = self._test_bulk_insert_single_bigt('postgresql', True) + context = self._test_bulk_insert_single_bigt("postgresql", True) context.assert_( "INSERT INTO ins_table (id, v1, v2) " "VALUES (1, 'row v1', 'row v5')" ) def test_bulk_insert_mssql(self): - context = self._test_bulk_insert('mssql', False) + context = self._test_bulk_insert("mssql", False) context.assert_( - 'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)' + "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)" ) def test_bulk_insert_inline_literal_as_sql(self): - context = op_fixture('postgresql', True) + context = op_fixture("postgresql", True) class MyType(TypeEngine): pass - t1 = table('t', column('id', Integer), column('data', MyType())) + t1 = table("t", column("id", Integer), column("data", MyType())) - op.bulk_insert(t1, [ - {'id': 1, 'data': op.inline_literal('d1')}, - {'id': 2, 'data': op.inline_literal('d2')}, - ]) + op.bulk_insert( + t1, + [ + {"id": 1, "data": op.inline_literal("d1")}, + {"id": 2, "data": op.inline_literal("d2")}, + ], + ) context.assert_( "INSERT INTO t (id, data) VALUES (1, 'd1')", - "INSERT INTO t (id, data) VALUES (2, 'd2')" + "INSERT INTO t (id, data) VALUES (2, 'd2')", ) def test_bulk_insert_as_sql(self): - context = self._test_bulk_insert('default', True) + context = self._test_bulk_insert("default", True) context.assert_( "INSERT INTO ins_table (id, v1, v2) " "VALUES (1, 'row v1', 'row v5')", @@ -142,11 +146,11 @@ class BulkInsertTest(TestBase): "INSERT INTO ins_table (id, v1, v2) " "VALUES (3, 'row v3', 'row v7')", "INSERT INTO ins_table (id, v1, v2) " - "VALUES (4, 'row v4', 'row v8')" + "VALUES (4, 'row v4', 'row v8')", ) def test_bulk_insert_as_sql_pg(self): - context = self._test_bulk_insert('postgresql', True) + context = self._test_bulk_insert("postgresql", True) context.assert_( "INSERT INTO ins_table (id, v1, v2) " "VALUES (1, 'row v1', 'row v5')", @@ -155,65 +159,68 @@ class BulkInsertTest(TestBase): "INSERT INTO ins_table (id, v1, v2) " "VALUES (3, 'row v3', 'row v7')", "INSERT INTO ins_table (id, v1, v2) " - "VALUES (4, 'row v4', 'row v8')" + "VALUES (4, 'row v4', 'row v8')", ) def test_bulk_insert_as_sql_mssql(self): - context = self._test_bulk_insert('mssql', True) + context = self._test_bulk_insert("mssql", True) # SQL server requires IDENTITY_INSERT # TODO: figure out if this is safe to enable for a table that # doesn't have an IDENTITY column context.assert_( - 'SET IDENTITY_INSERT ins_table ON', - 'GO', + "SET IDENTITY_INSERT ins_table ON", + "GO", "INSERT INTO ins_table (id, v1, v2) " "VALUES (1, 'row v1', 'row v5')", - 'GO', + "GO", "INSERT INTO ins_table (id, v1, v2) " "VALUES (2, 'row v2', 'row v6')", - 'GO', + "GO", "INSERT INTO ins_table (id, v1, v2) " "VALUES (3, 'row v3', 'row v7')", - 'GO', + "GO", "INSERT INTO ins_table (id, v1, v2) " "VALUES (4, 'row v4', 'row v8')", - 'GO', - 'SET IDENTITY_INSERT ins_table OFF', - 'GO', + "GO", + "SET IDENTITY_INSERT ins_table OFF", + "GO", ) def test_bulk_insert_from_new_table(self): context = op_fixture("postgresql", True) t1 = op.create_table( "ins_table", - Column('id', Integer), - Column('v1', String()), - Column('v2', String()), + Column("id", Integer), + Column("v1", String()), + Column("v2", String()), + ) + op.bulk_insert( + t1, + [ + {"id": 1, "v1": "row v1", "v2": "row v5"}, + {"id": 2, "v1": "row v2", "v2": "row v6"}, + ], ) - op.bulk_insert(t1, [ - {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, - {'id': 2, 'v1': 'row v2', 'v2': 'row v6'}, - ]) context.assert_( - 'CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)', + "CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)", "INSERT INTO ins_table (id, v1, v2) VALUES " "(1, 'row v1', 'row v5')", "INSERT INTO ins_table (id, v1, v2) VALUES " - "(2, 'row v2', 'row v6')" + "(2, 'row v2', 'row v6')", ) def test_invalid_format(self): context, t1 = self._table_fixture("sqlite", False) assert_raises_message( - TypeError, - "List expected", - op.bulk_insert, t1, {"id": 5} + TypeError, "List expected", op.bulk_insert, t1, {"id": 5} ) assert_raises_message( TypeError, "List of dictionaries expected", - op.bulk_insert, t1, [(5, )] + op.bulk_insert, + t1, + [(5,)], ) @@ -223,86 +230,85 @@ class RoundTripTest(TestBase): def setUp(self): from sqlalchemy import create_engine from alembic.migration import MigrationContext + self.conn = config.db.connect() - self.conn.execute(""" + self.conn.execute( + """ create table foo( id integer primary key, data varchar(50), x integer ) - """) + """ + ) context = MigrationContext.configure(self.conn) self.op = op.Operations(context) - self.t1 = table('foo', - column('id'), - column('data'), - column('x') - ) + self.t1 = table("foo", column("id"), column("data"), column("x")) def tearDown(self): self.conn.execute("drop table foo") self.conn.close() def test_single_insert_round_trip(self): - self.op.bulk_insert(self.t1, - [{'data': "d1", "x": "x1"}] - ) + self.op.bulk_insert(self.t1, [{"data": "d1", "x": "x1"}]) eq_( self.conn.execute("select id, data, x from foo").fetchall(), - [ - (1, "d1", "x1"), - ] + [(1, "d1", "x1")], ) def test_bulk_insert_round_trip(self): - self.op.bulk_insert(self.t1, [ - {'data': "d1", "x": "x1"}, - {'data': "d2", "x": "x2"}, - {'data': "d3", "x": "x3"}, - ]) + self.op.bulk_insert( + self.t1, + [ + {"data": "d1", "x": "x1"}, + {"data": "d2", "x": "x2"}, + {"data": "d3", "x": "x3"}, + ], + ) eq_( self.conn.execute("select id, data, x from foo").fetchall(), - [ - (1, "d1", "x1"), - (2, "d2", "x2"), - (3, "d3", "x3") - ] + [(1, "d1", "x1"), (2, "d2", "x2"), (3, "d3", "x3")], ) def test_bulk_insert_inline_literal(self): class MyType(TypeEngine): pass - t1 = table('foo', column('id', Integer), column('data', MyType())) + t1 = table("foo", column("id", Integer), column("data", MyType())) - self.op.bulk_insert(t1, [ - {'id': 1, 'data': self.op.inline_literal('d1')}, - {'id': 2, 'data': self.op.inline_literal('d2')}, - ], multiinsert=False) + self.op.bulk_insert( + t1, + [ + {"id": 1, "data": self.op.inline_literal("d1")}, + {"id": 2, "data": self.op.inline_literal("d2")}, + ], + multiinsert=False, + ) eq_( self.conn.execute("select id, data from foo").fetchall(), - [ - (1, "d1"), - (2, "d2"), - ] + [(1, "d1"), (2, "d2")], ) def test_bulk_insert_from_new_table(self): t1 = self.op.create_table( "ins_table", - Column('id', Integer), - Column('v1', String()), - Column('v2', String()), + Column("id", Integer), + Column("v1", String()), + Column("v2", String()), + ) + self.op.bulk_insert( + t1, + [ + {"id": 1, "v1": "row v1", "v2": "row v5"}, + {"id": 2, "v1": "row v2", "v2": "row v6"}, + ], ) - self.op.bulk_insert(t1, [ - {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, - {'id': 2, 'v1': 'row v2', 'v2': 'row v6'}, - ]) eq_( self.conn.execute( - "select id, v1, v2 from ins_table order by id").fetchall(), - [(1, u'row v1', u'row v5'), (2, u'row v2', u'row v6')] - )
\ No newline at end of file + "select id, v1, v2 from ins_table order by id" + ).fetchall(), + [(1, u"row v1", u"row v5"), (2, u"row v2", u"row v6")], + ) diff --git a/tests/test_command.py b/tests/test_command.py index 3f3daf5..a9f0e5d 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -3,9 +3,16 @@ from io import TextIOWrapper, BytesIO from alembic.script import ScriptDirectory from alembic import config from alembic.testing.fixtures import TestBase, capture_context_buffer -from alembic.testing.env import staging_env, _sqlite_testing_config, \ - three_rev_fixture, clear_staging_env, _no_sql_testing_config, \ - _sqlite_file_db, write_script, env_file_fixture +from alembic.testing.env import ( + staging_env, + _sqlite_testing_config, + three_rev_fixture, + clear_staging_env, + _no_sql_testing_config, + _sqlite_file_db, + write_script, + env_file_fixture, +) from alembic.testing import eq_, assert_raises_message, mock, assert_raises from alembic import util from contextlib import contextmanager @@ -18,13 +25,12 @@ class _BufMixin(object): # try to simulate how sys.stdout looks - we send it u'' # but then it's trying to encode to something. buf = BytesIO() - wrapper = TextIOWrapper(buf, encoding='ascii', line_buffering=True) + wrapper = TextIOWrapper(buf, encoding="ascii", line_buffering=True) wrapper.getvalue = buf.getvalue return wrapper class HistoryTest(_BufMixin, TestBase): - @classmethod def setup_class(cls): cls.env = staging_env() @@ -41,7 +47,8 @@ class HistoryTest(_BufMixin, TestBase): @classmethod def _setup_env_file(self): - env_file_fixture(r""" + env_file_fixture( + r""" from sqlalchemy import MetaData, engine_from_config target_metadata = MetaData() @@ -63,7 +70,8 @@ try: finally: connection.close() -""") +""" + ) def _eq_cmd_output(self, buf, expected, env_token=False, currents=()): script = ScriptDirectory.from_config(self.cfg) @@ -82,9 +90,11 @@ finally: assert_lines.insert(0, "environment included OK") eq_( - buf.getvalue().decode("ascii", 'replace').strip(), - "\n".join(assert_lines). - encode("ascii", "replace").decode("ascii").strip() + buf.getvalue().decode("ascii", "replace").strip(), + "\n".join(assert_lines) + .encode("ascii", "replace") + .decode("ascii") + .strip(), ) def test_history_full(self): @@ -163,11 +173,11 @@ finally: self.cfg.stdout = buf = self._buf_fixture() command.history(self.cfg, indicate_current=True, verbose=True) self._eq_cmd_output( - buf, [self.c, self.b, self.a], currents=(self.b,), env_token=True) + buf, [self.c, self.b, self.a], currents=(self.b,), env_token=True + ) class CurrentTest(_BufMixin, TestBase): - @classmethod def setup_class(cls): cls.env = env = staging_env() @@ -189,11 +199,15 @@ class CurrentTest(_BufMixin, TestBase): yield - lines = set([ - re.match(r'(^.\w)', elem).group(1) - for elem in re.split( - "\n", - buf.getvalue().decode('ascii', 'replace').strip()) if elem]) + lines = set( + [ + re.match(r"(^.\w)", elem).group(1) + for elem in re.split( + "\n", buf.getvalue().decode("ascii", "replace").strip() + ) + if elem + ] + ) eq_(lines, set(revs)) @@ -205,25 +219,25 @@ class CurrentTest(_BufMixin, TestBase): def test_plain_current(self): command.stamp(self.cfg, ()) command.stamp(self.cfg, self.a3.revision) - with self._assert_lines(['a3']): + with self._assert_lines(["a3"]): command.current(self.cfg) def test_two_heads(self): command.stamp(self.cfg, ()) command.stamp(self.cfg, (self.a1.revision, self.b1.revision)) - with self._assert_lines(['a1', 'b1']): + with self._assert_lines(["a1", "b1"]): command.current(self.cfg) def test_heads_one_is_dependent(self): command.stamp(self.cfg, ()) - command.stamp(self.cfg, (self.b2.revision, )) - with self._assert_lines(['a2', 'b2']): + command.stamp(self.cfg, (self.b2.revision,)) + with self._assert_lines(["a2", "b2"]): command.current(self.cfg) def test_heads_upg(self): - command.stamp(self.cfg, (self.b2.revision, )) + command.stamp(self.cfg, (self.b2.revision,)) command.upgrade(self.cfg, (self.b3.revision)) - with self._assert_lines(['a2', 'b3']): + with self._assert_lines(["a2", "b3"]): command.current(self.cfg) @@ -236,7 +250,8 @@ class RevisionTest(TestBase): clear_staging_env() def _env_fixture(self, version_table_pk=True): - env_file_fixture(""" + env_file_fixture( + """ from sqlalchemy import MetaData, engine_from_config target_metadata = MetaData() @@ -258,7 +273,9 @@ try: finally: connection.close() -""" % (version_table_pk, )) +""" + % (version_table_pk,) + ) def test_create_rev_plain_db_not_up_to_date(self): self._env_fixture() @@ -275,7 +292,9 @@ finally: assert_raises_message( util.CommandError, "Target database is not up to date.", - command.revision, self.cfg, autogenerate=True + command.revision, + self.cfg, + autogenerate=True, ) def test_create_rev_autogen_db_not_up_to_date_multi_heads(self): @@ -290,7 +309,9 @@ finally: assert_raises_message( util.CommandError, "Target database is not up to date.", - command.revision, self.cfg, autogenerate=True + command.revision, + self.cfg, + autogenerate=True, ) def test_create_rev_plain_db_not_up_to_date_multi_heads(self): @@ -306,7 +327,8 @@ finally: util.CommandError, "Multiple heads are present; please specify the head revision " "on which the new revision should be based, or perform a merge.", - command.revision, self.cfg + command.revision, + self.cfg, ) def test_create_rev_autogen_need_to_select_head(self): @@ -321,7 +343,9 @@ finally: util.CommandError, "Multiple heads are present; please specify the head revision " "on which the new revision should be based, or perform a merge.", - command.revision, self.cfg, autogenerate=True + command.revision, + self.cfg, + autogenerate=True, ) def test_pk_constraint_normally_prevents_dupe_rows(self): @@ -333,7 +357,7 @@ finally: assert_raises( sqla_exc.IntegrityError, db.execute, - "insert into alembic_version values ('%s')" % r2.revision + "insert into alembic_version values ('%s')" % r2.revision, ) def test_err_correctly_raised_on_dupe_rows_no_pk(self): @@ -347,7 +371,9 @@ finally: util.CommandError, "Online migration expected to match one row when " "updating .* in 'alembic_version'; 2 found", - command.downgrade, self.cfg, "-1" + command.downgrade, + self.cfg, + "-1", ) def test_create_rev_plain_need_to_select_head(self): @@ -362,7 +388,8 @@ finally: util.CommandError, "Multiple heads are present; please specify the head revision " "on which the new revision should be based, or perform a merge.", - command.revision, self.cfg + command.revision, + self.cfg, ) def test_create_rev_plain_post_merge(self): @@ -389,27 +416,20 @@ finally: command.revision(self.cfg) rev2 = command.revision(self.cfg) rev3 = command.revision(self.cfg, depends_on=rev2.revision) - eq_( - rev3._resolved_dependencies, (rev2.revision, ) - ) + eq_(rev3._resolved_dependencies, (rev2.revision,)) rev4 = command.revision( - self.cfg, depends_on=[rev2.revision, rev3.revision]) - eq_( - rev4._resolved_dependencies, (rev2.revision, rev3.revision) + self.cfg, depends_on=[rev2.revision, rev3.revision] ) + eq_(rev4._resolved_dependencies, (rev2.revision, rev3.revision)) def test_create_rev_depends_on_branch_label(self): self._env_fixture() command.revision(self.cfg) - rev2 = command.revision(self.cfg, branch_label='foobar') - rev3 = command.revision(self.cfg, depends_on='foobar') - eq_( - rev3.dependencies, 'foobar' - ) - eq_( - rev3._resolved_dependencies, (rev2.revision, ) - ) + rev2 = command.revision(self.cfg, branch_label="foobar") + rev3 = command.revision(self.cfg, depends_on="foobar") + eq_(rev3.dependencies, "foobar") + eq_(rev3._resolved_dependencies, (rev2.revision,)) def test_create_rev_depends_on_partial_revid(self): self._env_fixture() @@ -417,12 +437,8 @@ finally: rev2 = command.revision(self.cfg) assert len(rev2.revision) > 7 rev3 = command.revision(self.cfg, depends_on=rev2.revision[0:4]) - eq_( - rev3.dependencies, rev2.revision - ) - eq_( - rev3._resolved_dependencies, (rev2.revision, ) - ) + eq_(rev3.dependencies, rev2.revision) + eq_(rev3._resolved_dependencies, (rev2.revision,)) def test_create_rev_invalid_depends_on(self): self._env_fixture() @@ -430,7 +446,9 @@ finally: assert_raises_message( util.CommandError, "Can't locate revision identified by 'invalid'", - command.revision, self.cfg, depends_on='invalid' + command.revision, + self.cfg, + depends_on="invalid", ) def test_create_rev_autogenerate_db_not_up_to_date_post_merge(self): @@ -444,7 +462,9 @@ finally: assert_raises_message( util.CommandError, "Target database is not up to date.", - command.revision, self.cfg, autogenerate=True + command.revision, + self.cfg, + autogenerate=True, ) def test_nonsensical_sql_mode_autogen(self): @@ -452,7 +472,10 @@ finally: assert_raises_message( util.CommandError, "Using --sql with --autogenerate does not make any sense", - command.revision, self.cfg, autogenerate=True, sql=True + command.revision, + self.cfg, + autogenerate=True, + sql=True, ) def test_nonsensical_sql_no_env(self): @@ -461,7 +484,9 @@ finally: util.CommandError, "Using --sql with the revision command when revision_environment " "is not configured does not make any sense", - command.revision, self.cfg, sql=True + command.revision, + self.cfg, + sql=True, ) def test_sensical_sql_w_env(self): @@ -471,12 +496,11 @@ finally: class UpgradeDowngradeStampTest(TestBase): - def setUp(self): self.env = staging_env() self.cfg = cfg = _no_sql_testing_config() - cfg.set_main_option('dialect_name', 'sqlite') - cfg.remove_main_option('url') + cfg.set_main_option("dialect_name", "sqlite") + cfg.remove_main_option("url") self.a, self.b, self.c = three_rev_fixture(cfg) @@ -559,7 +583,7 @@ class UpgradeDowngradeStampTest(TestBase): class LiveStampTest(TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def setUp(self): self.bind = _sqlite_file_db() @@ -569,15 +593,25 @@ class LiveStampTest(TestBase): self.b = b = util.rev_id() script = ScriptDirectory.from_config(self.cfg) script.generate_revision(a, None, refresh=True) - write_script(script, a, """ + write_script( + script, + a, + """ revision = '%s' down_revision = None -""" % a) +""" + % a, + ) script.generate_revision(b, None, refresh=True) - write_script(script, b, """ + write_script( + script, + b, + """ revision = '%s' down_revision = '%s' -""" % (b, a)) +""" + % (b, a), + ) def tearDown(self): clear_staging_env() @@ -585,29 +619,25 @@ down_revision = '%s' def test_stamp_creates_table(self): command.stamp(self.cfg, "head") eq_( - self.bind.scalar("select version_num from alembic_version"), - self.b + self.bind.scalar("select version_num from alembic_version"), self.b ) def test_stamp_existing_upgrade(self): command.stamp(self.cfg, self.a) command.stamp(self.cfg, self.b) eq_( - self.bind.scalar("select version_num from alembic_version"), - self.b + self.bind.scalar("select version_num from alembic_version"), self.b ) def test_stamp_existing_downgrade(self): command.stamp(self.cfg, self.b) command.stamp(self.cfg, self.a) eq_( - self.bind.scalar("select version_num from alembic_version"), - self.a + self.bind.scalar("select version_num from alembic_version"), self.a ) class EditTest(TestBase): - @classmethod def setup_class(cls): cls.env = staging_env() @@ -622,56 +652,61 @@ class EditTest(TestBase): command.stamp(self.cfg, "base") def test_edit_head(self): - expected_call_arg = '%s/scripts/versions/%s_revision_c.py' % ( - EditTest.cfg.config_args['here'], - EditTest.c + expected_call_arg = "%s/scripts/versions/%s_revision_c.py" % ( + EditTest.cfg.config_args["here"], + EditTest.c, ) - with mock.patch('alembic.util.edit') as edit: + with mock.patch("alembic.util.edit") as edit: command.edit(self.cfg, "head") edit.assert_called_with(expected_call_arg) def test_edit_b(self): - expected_call_arg = '%s/scripts/versions/%s_revision_b.py' % ( - EditTest.cfg.config_args['here'], - EditTest.b + expected_call_arg = "%s/scripts/versions/%s_revision_b.py" % ( + EditTest.cfg.config_args["here"], + EditTest.b, ) - with mock.patch('alembic.util.edit') as edit: + with mock.patch("alembic.util.edit") as edit: command.edit(self.cfg, self.b[0:3]) edit.assert_called_with(expected_call_arg) def test_edit_with_missing_editor(self): - with mock.patch('editor.edit') as edit_mock: + with mock.patch("editor.edit") as edit_mock: edit_mock.side_effect = OSError("file not found") assert_raises_message( util.CommandError, - 'file not found', + "file not found", util.edit, - "/not/a/file.txt") + "/not/a/file.txt", + ) def test_edit_no_revs(self): assert_raises_message( util.CommandError, "No revision files indicated by symbol 'base'", command.edit, - self.cfg, "base") + self.cfg, + "base", + ) def test_edit_no_current(self): assert_raises_message( util.CommandError, "No current revisions", command.edit, - self.cfg, "current") + self.cfg, + "current", + ) def test_edit_current(self): - expected_call_arg = '%s/scripts/versions/%s_revision_b.py' % ( - EditTest.cfg.config_args['here'], - EditTest.b + expected_call_arg = "%s/scripts/versions/%s_revision_b.py" % ( + EditTest.cfg.config_args["here"], + EditTest.b, ) command.stamp(self.cfg, self.b) - with mock.patch('alembic.util.edit') as edit: + with mock.patch("alembic.util.edit") as edit: command.edit(self.cfg, "current") edit.assert_called_with(expected_call_arg) @@ -691,16 +726,21 @@ class CommandLineTest(TestBase): # the command function has "process_revision_directives" # however the ArgumentParser does not. ensure things work def revision( - config, message=None, autogenerate=False, sql=False, - head="head", splice=False, branch_label=None, - version_path=None, rev_id=None, depends_on=None, - process_revision_directives=None + config, + message=None, + autogenerate=False, + sql=False, + head="head", + splice=False, + branch_label=None, + version_path=None, + rev_id=None, + depends_on=None, + process_revision_directives=None, ): - canary( - config, message=message - ) + canary(config, message=message) - revision.__module__ = 'alembic.command' + revision.__module__ = "alembic.command" # CommandLine() pulls the function into the ArgumentParser # and needs the full signature, so we can't patch the "revision" @@ -712,7 +752,4 @@ class CommandLineTest(TestBase): commandline.run_cmd(self.cfg, options) finally: config.command.revision = orig_revision - eq_( - canary.mock_calls, - [mock.call(self.cfg, message="foo")] - ) + eq_(canary.mock_calls, [mock.call(self.cfg, message="foo")]) diff --git a/tests/test_config.py b/tests/test_config.py index 50e1b05..b1d1ca1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,30 +11,35 @@ from alembic.testing.mock import Mock, call from alembic.testing import eq_, assert_raises_message from alembic.testing.fixtures import capture_db -from alembic.testing.env import _no_sql_testing_config, clear_staging_env,\ - staging_env, _write_config_file +from alembic.testing.env import ( + _no_sql_testing_config, + clear_staging_env, + staging_env, + _write_config_file, +) class FileConfigTest(TestBase): - def test_config_args(self): - cfg = _write_config_file(""" + cfg = _write_config_file( + """ [alembic] migrations = %(base_path)s/db/migrations -""") +""" + ) test_cfg = config.Config( cfg.config_file_name, config_args=dict(base_path="/home/alembic") ) eq_( test_cfg.get_section_option("alembic", "migrations"), - "/home/alembic/db/migrations") + "/home/alembic/db/migrations", + ) def tearDown(self): clear_staging_env() class ConfigTest(TestBase): - def test_config_no_file_main_option(self): cfg = config.Config() cfg.set_main_option("url", "postgresql://foo/bar") @@ -66,12 +71,12 @@ class ConfigTest(TestBase): cfg = config.Config() cfg.set_section_option("some_section", "foob", "foob_value") - cfg.set_section_option( - "some_section", "bar", "bar with %(foob)s") + cfg.set_section_option("some_section", "bar", "bar with %(foob)s") eq_( cfg.get_section_option("some_section", "bar"), - "bar with foob_value") + "bar with foob_value", + ) def test_standalone_op(self): eng, buf = capture_db() @@ -80,71 +85,58 @@ class ConfigTest(TestBase): op = Operations(env) op.alter_column("t", "c", nullable=True) - eq_(buf, ['ALTER TABLE t ALTER COLUMN c DROP NOT NULL']) + eq_(buf, ["ALTER TABLE t ALTER COLUMN c DROP NOT NULL"]) def test_no_script_error(self): cfg = config.Config() assert_raises_message( util.CommandError, "No 'script_location' key found in configuration.", - ScriptDirectory.from_config, cfg + ScriptDirectory.from_config, + cfg, ) def test_attributes_attr(self): m1 = Mock() cfg = config.Config() - cfg.attributes['connection'] = m1 - eq_( - cfg.attributes['connection'], m1 - ) + cfg.attributes["connection"] = m1 + eq_(cfg.attributes["connection"], m1) def test_attributes_construtor(self): m1 = Mock() m2 = Mock() - cfg = config.Config(attributes={'m1': m1}) - cfg.attributes['connection'] = m2 - eq_( - cfg.attributes, {'m1': m1, 'connection': m2} - ) + cfg = config.Config(attributes={"m1": m1}) + cfg.attributes["connection"] = m2 + eq_(cfg.attributes, {"m1": m1, "connection": m2}) class StdoutOutputEncodingTest(TestBase): - def test_plain(self): - stdout = Mock(encoding='latin-1') + stdout = Mock(encoding="latin-1") cfg = config.Config(stdout=stdout) cfg.print_stdout("test %s %s", "x", "y") - eq_( - stdout.mock_calls, - [call.write('test x y'), call.write('\n')] - ) + eq_(stdout.mock_calls, [call.write("test x y"), call.write("\n")]) def test_utf8_unicode(self): - stdout = Mock(encoding='latin-1') + stdout = Mock(encoding="latin-1") cfg = config.Config(stdout=stdout) cfg.print_stdout(compat.u("méil %s %s"), "x", "y") eq_( stdout.mock_calls, - [call.write(compat.u('méil x y')), call.write('\n')] + [call.write(compat.u("méil x y")), call.write("\n")], ) def test_ascii_unicode(self): stdout = Mock(encoding=None) cfg = config.Config(stdout=stdout) cfg.print_stdout(compat.u("méil %s %s"), "x", "y") - eq_( - stdout.mock_calls, - [call.write('m?il x y'), call.write('\n')] - ) + eq_(stdout.mock_calls, [call.write("m?il x y"), call.write("\n")]) def test_only_formats_output_with_args(self): stdout = Mock(encoding=None) cfg = config.Config(stdout=stdout) cfg.print_stdout(compat.u("test 3%")) - eq_( - stdout.mock_calls, - [call.write('test 3%'), call.write('\n')] - ) + eq_(stdout.mock_calls, [call.write("test 3%"), call.write("\n")]) class TemplateOutputEncodingTest(TestBase): @@ -157,9 +149,9 @@ class TemplateOutputEncodingTest(TestBase): def test_default(self): script = ScriptDirectory.from_config(self.cfg) - eq_(script.output_encoding, 'utf-8') + eq_(script.output_encoding, "utf-8") def test_setting(self): - self.cfg.set_main_option('output_encoding', 'latin-1') + self.cfg.set_main_option("output_encoding", "latin-1") script = ScriptDirectory.from_config(self.cfg) - eq_(script.output_encoding, 'latin-1') + eq_(script.output_encoding, "latin-1") diff --git a/tests/test_environment.py b/tests/test_environment.py index 42ff328..cfa72f6 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -5,15 +5,19 @@ from alembic.environment import EnvironmentContext from alembic.migration import MigrationContext from alembic.testing.fixtures import TestBase from alembic.testing.mock import Mock, call, MagicMock -from alembic.testing.env import _no_sql_testing_config, \ - staging_env, clear_staging_env, write_script, _sqlite_file_db +from alembic.testing.env import ( + _no_sql_testing_config, + staging_env, + clear_staging_env, + write_script, + _sqlite_file_db, +) from alembic.testing.assertions import expect_warnings from alembic.testing import eq_, is_ class EnvironmentTest(TestBase): - def setUp(self): staging_env() self.cfg = _no_sql_testing_config() @@ -23,49 +27,30 @@ class EnvironmentTest(TestBase): def _fixture(self, **kw): script = ScriptDirectory.from_config(self.cfg) - env = EnvironmentContext( - self.cfg, - script, - **kw - ) + env = EnvironmentContext(self.cfg, script, **kw) return env def test_x_arg(self): env = self._fixture() self.cfg.cmd_opts = Mock(x="y=5") - eq_( - env.get_x_argument(), - "y=5" - ) + eq_(env.get_x_argument(), "y=5") def test_x_arg_asdict(self): env = self._fixture() self.cfg.cmd_opts = Mock(x=["y=5"]) - eq_( - env.get_x_argument(as_dictionary=True), - {"y": "5"} - ) + eq_(env.get_x_argument(as_dictionary=True), {"y": "5"}) def test_x_arg_no_opts(self): env = self._fixture() - eq_( - env.get_x_argument(), - [] - ) + eq_(env.get_x_argument(), []) def test_x_arg_no_opts_asdict(self): env = self._fixture() - eq_( - env.get_x_argument(as_dictionary=True), - {} - ) + eq_(env.get_x_argument(as_dictionary=True), {}) def test_tag_arg(self): env = self._fixture(tag="x") - eq_( - env.get_tag_argument(), - "x" - ) + eq_(env.get_tag_argument(), "x") def test_migration_context_has_config(self): env = self._fixture() @@ -81,9 +66,12 @@ class EnvironmentTest(TestBase): engine = _sqlite_file_db() - a_rev = 'arev' + a_rev = "arev" env.script.generate_revision(a_rev, "revision a", refresh=True) - write_script(env.script, a_rev, """\ + write_script( + env.script, + a_rev, + """\ "Rev A" revision = '%s' down_revision = None @@ -98,7 +86,9 @@ def upgrade(): def downgrade(): pass -""" % a_rev) +""" + % a_rev, + ) migration_fn = MagicMock() def upgrade(rev, context): @@ -106,15 +96,13 @@ def downgrade(): return env.script._upgrade_revs(a_rev, rev) with expect_warnings( - r"'connection' argument to configure\(\) is " - r"expected to be a sqlalchemy.engine.Connection "): + r"'connection' argument to configure\(\) is " + r"expected to be a sqlalchemy.engine.Connection " + ): env.configure( - connection=engine, fn=upgrade, - transactional_ddl=False) + connection=engine, fn=upgrade, transactional_ddl=False + ) env.run_migrations() - eq_( - migration_fn.mock_calls, - [call((), env._migration_context)] - ) + eq_(migration_fn.mock_calls, [call((), env._migration_context)]) diff --git a/tests/test_external_dialect.py b/tests/test_external_dialect.py index dc01b75..1c3222d 100644 --- a/tests/test_external_dialect.py +++ b/tests/test_external_dialect.py @@ -15,6 +15,7 @@ from sqlalchemy.engine import default class CustomDialect(default.DefaultDialect): name = "custom_dialect" + try: from sqlalchemy.dialects import registry except ImportError: @@ -24,20 +25,22 @@ else: class CustomDialectImpl(impl.DefaultImpl): - __dialect__ = 'custom_dialect' + __dialect__ = "custom_dialect" transactional_ddl = False def render_type(self, type_, autogen_context): if type_.__module__ == __name__: autogen_context.imports.add( - "from %s import custom_dialect_types" % (__name__, )) + "from %s import custom_dialect_types" % (__name__,) + ) is_external = True else: is_external = False - if is_external and \ - hasattr(self, '_render_%s_type' % type_.__visit_name__): - meth = getattr(self, '_render_%s_type' % type_.__visit_name__) + if is_external and hasattr( + self, "_render_%s_type" % type_.__visit_name__ + ): + meth = getattr(self, "_render_%s_type" % type_.__visit_name__) return meth(type_, autogen_context) if is_external: @@ -47,13 +50,16 @@ class CustomDialectImpl(impl.DefaultImpl): def _render_EXT_ARRAY_type(self, type_, autogen_context): return render._render_type_w_subtype( - type_, autogen_context, 'item_type', r'(.+?\()', - prefix="custom_dialect_types." + type_, + autogen_context, + "item_type", + r"(.+?\()", + prefix="custom_dialect_types.", ) class EXT_ARRAY(sqla_types.TypeEngine): - __visit_name__ = 'EXT_ARRAY' + __visit_name__ = "EXT_ARRAY" def __init__(self, item_type): if isinstance(item_type, type): @@ -63,75 +69,78 @@ class EXT_ARRAY(sqla_types.TypeEngine): class FOOBARTYPE(sqla_types.TypeEngine): - __visit_name__ = 'FOOBARTYPE' + __visit_name__ = "FOOBARTYPE" class ExternalDialectRenderTest(TestBase): - def setUp(self): ctx_opts = { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'target_metadata': MetaData(), - 'user_module_prefix': None + "sqlalchemy_module_prefix": "sa.", + "alembic_module_prefix": "op.", + "target_metadata": MetaData(), + "user_module_prefix": None, } context = MigrationContext.configure( - dialect_name="custom_dialect", - opts=ctx_opts + dialect_name="custom_dialect", opts=ctx_opts ) self.autogen_context = api.AutogenContext(context) def test_render_type(self): eq_ignore_whitespace( - autogenerate.render._repr_type( - FOOBARTYPE(), self.autogen_context), - "custom_dialect_types.FOOBARTYPE()" + autogenerate.render._repr_type(FOOBARTYPE(), self.autogen_context), + "custom_dialect_types.FOOBARTYPE()", ) eq_( self.autogen_context.imports, - set([ - 'from tests.test_external_dialect import custom_dialect_types' - ]) + set( + [ + "from tests.test_external_dialect import custom_dialect_types" + ] + ), ) def test_external_nested_render_sqla_type(self): eq_ignore_whitespace( autogenerate.render._repr_type( - EXT_ARRAY(sqla_types.Integer), self.autogen_context), - "custom_dialect_types.EXT_ARRAY(sa.Integer())" + EXT_ARRAY(sqla_types.Integer), self.autogen_context + ), + "custom_dialect_types.EXT_ARRAY(sa.Integer())", ) eq_ignore_whitespace( autogenerate.render._repr_type( - EXT_ARRAY( - sqla_types.DateTime(timezone=True) - ), - self.autogen_context), - "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))" + EXT_ARRAY(sqla_types.DateTime(timezone=True)), + self.autogen_context, + ), + "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))", ) eq_( self.autogen_context.imports, - set([ - 'from tests.test_external_dialect import custom_dialect_types' - ]) + set( + [ + "from tests.test_external_dialect import custom_dialect_types" + ] + ), ) def test_external_nested_render_external_type(self): eq_ignore_whitespace( autogenerate.render._repr_type( - EXT_ARRAY(FOOBARTYPE), - self.autogen_context), - "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())" + EXT_ARRAY(FOOBARTYPE), self.autogen_context + ), + "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())", ) eq_( self.autogen_context.imports, - set([ - 'from tests.test_external_dialect import custom_dialect_types' - ]) + set( + [ + "from tests.test_external_dialect import custom_dialect_types" + ] + ), ) diff --git a/tests/test_mssql.py b/tests/test_mssql.py index b092dcf..1657ccc 100644 --- a/tests/test_mssql.py +++ b/tests/test_mssql.py @@ -8,13 +8,16 @@ from alembic import op, command, util from alembic.testing import eq_, assert_raises_message from alembic.testing.fixtures import capture_context_buffer, op_fixture -from alembic.testing.env import staging_env, _no_sql_testing_config, \ - three_rev_fixture, clear_staging_env +from alembic.testing.env import ( + staging_env, + _no_sql_testing_config, + three_rev_fixture, + clear_staging_env, +) from alembic.testing import config class FullEnvironmentTests(TestBase): - @classmethod def setup_class(cls): staging_env() @@ -24,8 +27,7 @@ class FullEnvironmentTests(TestBase): directives = "" cls.cfg = cfg = _no_sql_testing_config("mssql", directives) - cls.a, cls.b, cls.c = \ - three_rev_fixture(cfg) + cls.a, cls.b, cls.c = three_rev_fixture(cfg) @classmethod def teardown_class(cls): @@ -39,7 +41,7 @@ class FullEnvironmentTests(TestBase): # ensure ends in COMMIT; GO eq_( [x for x in buf.getvalue().splitlines() if x][-2:], - ['COMMIT;', 'GO'] + ["COMMIT;", "GO"], ) def test_batch_separator_default(self): @@ -54,242 +56,248 @@ class FullEnvironmentTests(TestBase): class OpTest(TestBase): - def test_add_column(self): - context = op_fixture('mssql') - op.add_column('t1', Column('c1', Integer, nullable=False)) + context = op_fixture("mssql") + op.add_column("t1", Column("c1", Integer, nullable=False)) context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL") def test_add_column_with_default(self): context = op_fixture("mssql") op.add_column( - 't1', Column('c1', Integer, nullable=False, server_default="12")) + "t1", Column("c1", Integer, nullable=False, server_default="12") + ) context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'") def test_alter_column_rename_mssql(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", new_column_name="x") - context.assert_( - "EXEC sp_rename 't.c', x, 'COLUMN'" - ) + context.assert_("EXEC sp_rename 't.c', x, 'COLUMN'") def test_alter_column_rename_quoted_mssql(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", new_column_name="SomeFancyName") - context.assert_( - "EXEC sp_rename 't.c', [SomeFancyName], 'COLUMN'" - ) + context.assert_("EXEC sp_rename 't.c', [SomeFancyName], 'COLUMN'") def test_alter_column_new_type(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", type_=Integer) - context.assert_( - 'ALTER TABLE t ALTER COLUMN c INTEGER' - ) + context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER") def test_alter_column_dont_touch_constraints(self): - context = op_fixture('mssql') + context = op_fixture("mssql") from sqlalchemy import Boolean - op.alter_column('tests', 'col', - existing_type=Boolean(), - nullable=False) - context.assert_('ALTER TABLE tests ALTER COLUMN col BIT NOT NULL') + + op.alter_column( + "tests", "col", existing_type=Boolean(), nullable=False + ) + context.assert_("ALTER TABLE tests ALTER COLUMN col BIT NOT NULL") def test_drop_index(self): - context = op_fixture('mssql') - op.drop_index('my_idx', 'my_table') + context = op_fixture("mssql") + op.drop_index("my_idx", "my_table") context.assert_contains("DROP INDEX my_idx ON my_table") def test_drop_column_w_default(self): - context = op_fixture('mssql') - op.drop_column('t1', 'c1', mssql_drop_default=True) - op.drop_column('t1', 'c2', mssql_drop_default=True) + context = op_fixture("mssql") + op.drop_column("t1", "c1", mssql_drop_default=True) + op.drop_column("t1", "c2", mssql_drop_default=True) context.assert_contains( - "exec('alter table t1 drop constraint ' + @const_name)") + "exec('alter table t1 drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") def test_drop_column_w_default_in_batch(self): - context = op_fixture('mssql') - with op.batch_alter_table('t1', schema=None) as batch_op: - batch_op.drop_column('c1', mssql_drop_default=True) - batch_op.drop_column('c2', mssql_drop_default=True) + context = op_fixture("mssql") + with op.batch_alter_table("t1", schema=None) as batch_op: + batch_op.drop_column("c1", mssql_drop_default=True) + batch_op.drop_column("c2", mssql_drop_default=True) context.assert_contains( - "exec('alter table t1 drop constraint ' + @const_name)") + "exec('alter table t1 drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") def test_alter_column_drop_default(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", server_default=None) context.assert_contains( - "exec('alter table t drop constraint ' + @const_name)") + "exec('alter table t drop constraint ' + @const_name)" + ) def test_alter_column_dont_drop_default(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", server_default=False) context.assert_() def test_drop_column_w_schema(self): - context = op_fixture('mssql') - op.drop_column('t1', 'c1', schema='xyz') + context = op_fixture("mssql") + op.drop_column("t1", "c1", schema="xyz") context.assert_contains("ALTER TABLE xyz.t1 DROP COLUMN c1") def test_drop_column_w_check(self): - context = op_fixture('mssql') - op.drop_column('t1', 'c1', mssql_drop_check=True) - op.drop_column('t1', 'c2', mssql_drop_check=True) + context = op_fixture("mssql") + op.drop_column("t1", "c1", mssql_drop_check=True) + op.drop_column("t1", "c2", mssql_drop_check=True) context.assert_contains( - "exec('alter table t1 drop constraint ' + @const_name)") + "exec('alter table t1 drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") def test_drop_column_w_check_in_batch(self): - context = op_fixture('mssql') - with op.batch_alter_table('t1', schema=None) as batch_op: - batch_op.drop_column('c1', mssql_drop_check=True) - batch_op.drop_column('c2', mssql_drop_check=True) + context = op_fixture("mssql") + with op.batch_alter_table("t1", schema=None) as batch_op: + batch_op.drop_column("c1", mssql_drop_check=True) + batch_op.drop_column("c2", mssql_drop_check=True) context.assert_contains( - "exec('alter table t1 drop constraint ' + @const_name)") + "exec('alter table t1 drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") def test_drop_column_w_check_quoting(self): - context = op_fixture('mssql') - op.drop_column('table', 'column', mssql_drop_check=True) + context = op_fixture("mssql") + op.drop_column("table", "column", mssql_drop_check=True) context.assert_contains( - "exec('alter table [table] drop constraint ' + @const_name)") + "exec('alter table [table] drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE [table] DROP COLUMN [column]") def test_alter_column_nullable_w_existing_type(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", nullable=True, existing_type=Integer) - context.assert_( - "ALTER TABLE t ALTER COLUMN c INTEGER NULL" - ) + context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NULL") def test_drop_column_w_fk(self): - context = op_fixture('mssql') - op.drop_column('t1', 'c1', mssql_drop_foreign_key=True) + context = op_fixture("mssql") + op.drop_column("t1", "c1", mssql_drop_foreign_key=True) context.assert_contains( - "exec('alter table t1 drop constraint ' + @const_name)") + "exec('alter table t1 drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") def test_drop_column_w_fk_in_batch(self): - context = op_fixture('mssql') - with op.batch_alter_table('t1', schema=None) as batch_op: - batch_op.drop_column('c1', mssql_drop_foreign_key=True) + context = op_fixture("mssql") + with op.batch_alter_table("t1", schema=None) as batch_op: + batch_op.drop_column("c1", mssql_drop_foreign_key=True) context.assert_contains( - "exec('alter table t1 drop constraint ' + @const_name)") + "exec('alter table t1 drop constraint ' + @const_name)" + ) context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") def test_alter_column_not_nullable_w_existing_type(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", nullable=False, existing_type=Integer) - context.assert_( - "ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL" - ) + context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL") def test_alter_column_nullable_w_new_type(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", nullable=True, type_=Integer) - context.assert_( - "ALTER TABLE t ALTER COLUMN c INTEGER NULL" - ) + context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NULL") def test_alter_column_not_nullable_w_new_type(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", nullable=False, type_=Integer) - context.assert_( - "ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL" - ) + context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL") def test_alter_column_nullable_type_required(self): - context = op_fixture('mssql') + context = op_fixture("mssql") assert_raises_message( util.CommandError, "MS-SQL ALTER COLUMN operations with NULL or " "NOT NULL require the existing_type or a new " "type_ be passed.", - op.alter_column, "t", "c", nullable=False + op.alter_column, + "t", + "c", + nullable=False, ) def test_alter_add_server_default(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", server_default="5") - context.assert_( - "ALTER TABLE t ADD DEFAULT '5' FOR c" - ) + context.assert_("ALTER TABLE t ADD DEFAULT '5' FOR c") def test_alter_replace_server_default(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column( - "t", "c", server_default="5", existing_server_default="6") - context.assert_contains( - "exec('alter table t drop constraint ' + @const_name)") + "t", "c", server_default="5", existing_server_default="6" + ) context.assert_contains( - "ALTER TABLE t ADD DEFAULT '5' FOR c" + "exec('alter table t drop constraint ' + @const_name)" ) + context.assert_contains("ALTER TABLE t ADD DEFAULT '5' FOR c") def test_alter_remove_server_default(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", server_default=None) context.assert_contains( - "exec('alter table t drop constraint ' + @const_name)") + "exec('alter table t drop constraint ' + @const_name)" + ) def test_alter_do_everything(self): - context = op_fixture('mssql') - op.alter_column("t", "c", new_column_name="c2", nullable=True, - type_=Integer, server_default="5") + context = op_fixture("mssql") + op.alter_column( + "t", + "c", + new_column_name="c2", + nullable=True, + type_=Integer, + server_default="5", + ) context.assert_( - 'ALTER TABLE t ALTER COLUMN c INTEGER NULL', + "ALTER TABLE t ALTER COLUMN c INTEGER NULL", "ALTER TABLE t ADD DEFAULT '5' FOR c", - "EXEC sp_rename 't.c', c2, 'COLUMN'" + "EXEC sp_rename 't.c', c2, 'COLUMN'", ) def test_rename_table(self): - context = op_fixture('mssql') - op.rename_table('t1', 't2') + context = op_fixture("mssql") + op.rename_table("t1", "t2") context.assert_contains("EXEC sp_rename 't1', t2") def test_rename_table_schema(self): - context = op_fixture('mssql') - op.rename_table('t1', 't2', schema="foobar") + context = op_fixture("mssql") + op.rename_table("t1", "t2", schema="foobar") context.assert_contains("EXEC sp_rename 'foobar.t1', t2") def test_rename_table_casesens(self): - context = op_fixture('mssql') - op.rename_table('TeeOne', 'TeeTwo') + context = op_fixture("mssql") + op.rename_table("TeeOne", "TeeTwo") # yup, ran this in SQL Server 2014, the two levels of quoting # seems to be understood. Can't do the two levels on the # target name though ! context.assert_contains("EXEC sp_rename '[TeeOne]', [TeeTwo]") def test_rename_table_schema_casesens(self): - context = op_fixture('mssql') - op.rename_table('TeeOne', 'TeeTwo', schema="FooBar") + context = op_fixture("mssql") + op.rename_table("TeeOne", "TeeTwo", schema="FooBar") # yup, ran this in SQL Server 2014, the two levels of quoting # seems to be understood. Can't do the two levels on the # target name though ! context.assert_contains("EXEC sp_rename '[FooBar].[TeeOne]', [TeeTwo]") def test_alter_column_rename_mssql_schema(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.alter_column("t", "c", name="x", schema="y") - context.assert_( - "EXEC sp_rename 'y.t.c', x, 'COLUMN'" - ) + context.assert_("EXEC sp_rename 'y.t.c', x, 'COLUMN'") def test_create_index_mssql_include(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.create_index( - op.f('ix_mytable_a_b'), 'mytable', ['col_a', 'col_b'], - unique=False, mssql_include=['col_c']) + op.f("ix_mytable_a_b"), + "mytable", + ["col_a", "col_b"], + unique=False, + mssql_include=["col_c"], + ) context.assert_contains( "CREATE INDEX ix_mytable_a_b ON mytable " - "(col_a, col_b) INCLUDE (col_c)") + "(col_a, col_b) INCLUDE (col_c)" + ) def test_create_index_mssql_include_is_none(self): - context = op_fixture('mssql') + context = op_fixture("mssql") op.create_index( - op.f('ix_mytable_a_b'), 'mytable', ['col_a', 'col_b'], - unique=False) + op.f("ix_mytable_a_b"), "mytable", ["col_a", "col_b"], unique=False + ) context.assert_contains( - "CREATE INDEX ix_mytable_a_b ON mytable " - "(col_a, col_b)") + "CREATE INDEX ix_mytable_a_b ON mytable " "(col_a, col_b)" + ) diff --git a/tests/test_mysql.py b/tests/test_mysql.py index dd872f7..68746ba 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -7,53 +7,68 @@ from alembic import op, util from alembic.testing import eq_, assert_raises_message from alembic.testing.fixtures import capture_context_buffer, op_fixture -from alembic.testing.env import staging_env, _no_sql_testing_config, \ - three_rev_fixture, clear_staging_env +from alembic.testing.env import ( + staging_env, + _no_sql_testing_config, + three_rev_fixture, + clear_staging_env, +) from alembic.migration import MigrationContext class MySQLOpTest(TestBase): - def test_rename_column(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - 't1', 'c1', new_column_name="c2", existing_type=Integer) - context.assert_( - 'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL' + "t1", "c1", new_column_name="c2", existing_type=Integer ) + context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL") def test_rename_column_quotes_needed_one(self): - context = op_fixture('mysql') - op.alter_column('MyTable', 'ColumnOne', new_column_name="ColumnTwo", - existing_type=Integer) + context = op_fixture("mysql") + op.alter_column( + "MyTable", + "ColumnOne", + new_column_name="ColumnTwo", + existing_type=Integer, + ) context.assert_( - 'ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL' + "ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL" ) def test_rename_column_quotes_needed_two(self): - context = op_fixture('mysql') - op.alter_column('my table', 'column one', new_column_name="column two", - existing_type=Integer) + context = op_fixture("mysql") + op.alter_column( + "my table", + "column one", + new_column_name="column two", + existing_type=Integer, + ) context.assert_( - 'ALTER TABLE `my table` CHANGE `column one` ' - '`column two` INTEGER NULL' + "ALTER TABLE `my table` CHANGE `column one` " + "`column two` INTEGER NULL" ) def test_rename_column_serv_default(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - 't1', 'c1', new_column_name="c2", existing_type=Integer, - existing_server_default="q") - context.assert_( - "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'" + "t1", + "c1", + new_column_name="c2", + existing_type=Integer, + existing_server_default="q", ) + context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'") def test_rename_column_serv_compiled_default(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - 't1', 'c1', existing_type=Integer, - server_default=func.utc_thing(func.current_timestamp())) + "t1", + "c1", + existing_type=Integer, + server_default=func.utc_thing(func.current_timestamp()), + ) # this is not a valid MySQL default but the point is to just # test SQL expression rendering context.assert_( @@ -62,184 +77,183 @@ class MySQLOpTest(TestBase): ) def test_rename_column_autoincrement(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - 't1', 'c1', new_column_name="c2", existing_type=Integer, - existing_autoincrement=True) + "t1", + "c1", + new_column_name="c2", + existing_type=Integer, + existing_autoincrement=True, + ) context.assert_( - 'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT' + "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT" ) def test_col_add_autoincrement(self): - context = op_fixture('mysql') - op.alter_column('t1', 'c1', existing_type=Integer, - autoincrement=True) - context.assert_( - 'ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT' - ) + context = op_fixture("mysql") + op.alter_column("t1", "c1", existing_type=Integer, autoincrement=True) + context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT") def test_col_remove_autoincrement(self): - context = op_fixture('mysql') - op.alter_column('t1', 'c1', existing_type=Integer, - existing_autoincrement=True, - autoincrement=False) - context.assert_( - 'ALTER TABLE t1 MODIFY c1 INTEGER NULL' + context = op_fixture("mysql") + op.alter_column( + "t1", + "c1", + existing_type=Integer, + existing_autoincrement=True, + autoincrement=False, ) + context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL") def test_col_dont_remove_server_default(self): - context = op_fixture('mysql') - op.alter_column('t1', 'c1', existing_type=Integer, - existing_server_default='1', - server_default=False) + context = op_fixture("mysql") + op.alter_column( + "t1", + "c1", + existing_type=Integer, + existing_server_default="1", + server_default=False, + ) context.assert_() def test_alter_column_drop_default(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column("t", "c", existing_type=Integer, server_default=None) - context.assert_( - 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT' - ) + context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT") def test_alter_column_remove_schematype(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - "t", "c", + "t", + "c", type_=Integer, existing_type=Boolean(create_constraint=True, name="ck1"), - server_default=None) - context.assert_( - 'ALTER TABLE t MODIFY c INTEGER NULL' + server_default=None, ) + context.assert_("ALTER TABLE t MODIFY c INTEGER NULL") def test_alter_column_modify_default(self): - context = op_fixture('mysql') + context = op_fixture("mysql") # notice we dont need the existing type on this one... - op.alter_column("t", "c", server_default='1') - context.assert_( - "ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'" - ) + op.alter_column("t", "c", server_default="1") + context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'") def test_col_not_nullable(self): - context = op_fixture('mysql') - op.alter_column('t1', 'c1', nullable=False, existing_type=Integer) - context.assert_( - 'ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL' - ) + context = op_fixture("mysql") + op.alter_column("t1", "c1", nullable=False, existing_type=Integer) + context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL") def test_col_not_nullable_existing_serv_default(self): - context = op_fixture('mysql') - op.alter_column('t1', 'c1', nullable=False, existing_type=Integer, - existing_server_default='5') + context = op_fixture("mysql") + op.alter_column( + "t1", + "c1", + nullable=False, + existing_type=Integer, + existing_server_default="5", + ) context.assert_( "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT '5'" ) def test_col_nullable(self): - context = op_fixture('mysql') - op.alter_column('t1', 'c1', nullable=True, existing_type=Integer) - context.assert_( - 'ALTER TABLE t1 MODIFY c1 INTEGER NULL' - ) + context = op_fixture("mysql") + op.alter_column("t1", "c1", nullable=True, existing_type=Integer) + context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL") def test_col_multi_alter(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - 't1', 'c1', nullable=False, server_default="q", type_=Integer) + "t1", "c1", nullable=False, server_default="q", type_=Integer + ) context.assert_( "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT 'q'" ) def test_alter_column_multi_alter_w_drop_default(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.alter_column( - 't1', 'c1', nullable=False, server_default=None, type_=Integer) - context.assert_( - "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL" + "t1", "c1", nullable=False, server_default=None, type_=Integer ) + context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL") def test_col_alter_type_required(self): - op_fixture('mysql') + op_fixture("mysql") assert_raises_message( util.CommandError, "MySQL CHANGE/MODIFY COLUMN operations require the existing type.", - op.alter_column, 't1', 'c1', nullable=False, server_default="q" + op.alter_column, + "t1", + "c1", + nullable=False, + server_default="q", ) def test_drop_fk(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("f1", "t1", "foreignkey") - context.assert_( - "ALTER TABLE t1 DROP FOREIGN KEY f1" - ) + context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1") def test_drop_fk_quoted(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("MyFk", "MyTable", "foreignkey") - context.assert_( - "ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`" - ) + context.assert_("ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`") def test_drop_constraint_primary(self): - context = op_fixture('mysql') - op.drop_constraint('primary', 't1', type_='primary') - context.assert_( - "ALTER TABLE t1 DROP PRIMARY KEY" - ) + context = op_fixture("mysql") + op.drop_constraint("primary", "t1", type_="primary") + context.assert_("ALTER TABLE t1 DROP PRIMARY KEY") def test_drop_unique(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("f1", "t1", "unique") - context.assert_( - "ALTER TABLE t1 DROP INDEX f1" - ) + context.assert_("ALTER TABLE t1 DROP INDEX f1") def test_drop_unique_quoted(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("MyUnique", "MyTable", "unique") - context.assert_( - "ALTER TABLE `MyTable` DROP INDEX `MyUnique`" - ) + context.assert_("ALTER TABLE `MyTable` DROP INDEX `MyUnique`") def test_drop_check(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("f1", "t1", "check") - context.assert_( - "ALTER TABLE t1 DROP CONSTRAINT f1" - ) + context.assert_("ALTER TABLE t1 DROP CONSTRAINT f1") def test_drop_check_quoted(self): - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("MyCheck", "MyTable", "check") - context.assert_( - "ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`" - ) + context.assert_("ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`") def test_drop_unknown(self): - op_fixture('mysql') + op_fixture("mysql") assert_raises_message( TypeError, "'type' can be one of 'check', 'foreignkey', " "'primary', 'unique', None", - op.drop_constraint, "f1", "t1", "typo" + op.drop_constraint, + "f1", + "t1", + "typo", ) def test_drop_generic_constraint(self): - op_fixture('mysql') + op_fixture("mysql") assert_raises_message( NotImplementedError, "No generic 'DROP CONSTRAINT' in MySQL - please " "specify constraint type", - op.drop_constraint, "f1", "t1" + op.drop_constraint, + "f1", + "t1", ) class MySQLDefaultCompareTest(TestBase): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True - __requires__ = 'mysql_timestamp_reflection', + __requires__ = ("mysql_timestamp_reflection",) @classmethod def setup_class(cls): @@ -247,17 +261,14 @@ class MySQLDefaultCompareTest(TestBase): staging_env() context = MigrationContext.configure( connection=cls.bind.connect(), - opts={ - 'compare_type': True, - 'compare_server_default': True - } + opts={"compare_type": True, "compare_server_default": True}, ) connection = context.bind cls.autogen_context = { - 'imports': set(), - 'connection': connection, - 'dialect': connection.dialect, - 'context': context + "imports": set(), + "connection": connection, + "dialect": connection.dialect, + "context": context, } @classmethod @@ -277,64 +288,46 @@ class MySQLDefaultCompareTest(TestBase): alternate = txt expected = False t = Table( - "test", self.metadata, + "test", + self.metadata, Column( - "somecol", type_, - server_default=text(txt) if txt else None - ) + "somecol", type_, server_default=text(txt) if txt else None + ), + ) + t2 = Table( + "test", + MetaData(), + Column("somecol", type_, server_default=text(alternate)), ) - t2 = Table("test", MetaData(), - Column("somecol", type_, server_default=text(alternate)) - ) - assert self._compare_default( - t, t2, t2.c.somecol, alternate - ) is expected - - def _compare_default( - self, - t1, t2, col, - rendered - ): + assert ( + self._compare_default(t, t2, t2.c.somecol, alternate) is expected + ) + + def _compare_default(self, t1, t2, col, rendered): t1.create(self.bind) insp = Inspector.from_engine(self.bind) cols = insp.get_columns(t1.name) refl = Table(t1.name, MetaData()) insp.reflecttable(refl, None) - ctx = self.autogen_context['context'] + ctx = self.autogen_context["context"] return ctx.impl.compare_server_default( - refl.c[cols[0]['name']], - col, - rendered, - cols[0]['default']) + refl.c[cols[0]["name"]], col, rendered, cols[0]["default"] + ) def test_compare_timestamp_current_timestamp(self): - self._compare_default_roundtrip( - TIMESTAMP(), - "CURRENT_TIMESTAMP", - ) + self._compare_default_roundtrip(TIMESTAMP(), "CURRENT_TIMESTAMP") def test_compare_timestamp_current_timestamp_diff(self): - self._compare_default_roundtrip( - TIMESTAMP(), - None, "CURRENT_TIMESTAMP", - ) + self._compare_default_roundtrip(TIMESTAMP(), None, "CURRENT_TIMESTAMP") def test_compare_integer_same(self): - self._compare_default_roundtrip( - Integer(), "5" - ) + self._compare_default_roundtrip(Integer(), "5") def test_compare_integer_diff(self): - self._compare_default_roundtrip( - Integer(), "5", "7" - ) + self._compare_default_roundtrip(Integer(), "5", "7") def test_compare_boolean_same(self): - self._compare_default_roundtrip( - Boolean(), "1" - ) + self._compare_default_roundtrip(Boolean(), "1") def test_compare_boolean_diff(self): - self._compare_default_roundtrip( - Boolean(), "1", "0" - ) + self._compare_default_roundtrip(Boolean(), "1", "0") diff --git a/tests/test_offline_environment.py b/tests/test_offline_environment.py index 3920690..fbbbec3 100644 --- a/tests/test_offline_environment.py +++ b/tests/test_offline_environment.py @@ -3,16 +3,20 @@ from alembic.testing.fixtures import TestBase, capture_context_buffer from alembic import command, util from alembic.testing import assert_raises_message -from alembic.testing.env import staging_env, _no_sql_testing_config, \ - three_rev_fixture, clear_staging_env, env_file_fixture, \ - multi_heads_fixture +from alembic.testing.env import ( + staging_env, + _no_sql_testing_config, + three_rev_fixture, + clear_staging_env, + env_file_fixture, + multi_heads_fixture, +) import re a = b = c = None class OfflineEnvironmentTest(TestBase): - def setUp(self): staging_env() self.cfg = _no_sql_testing_config() @@ -24,92 +28,122 @@ class OfflineEnvironmentTest(TestBase): clear_staging_env() def test_not_requires_connection(self): - env_file_fixture(""" + env_file_fixture( + """ assert not context.requires_connection() -""") +""" + ) command.upgrade(self.cfg, a, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) def test_requires_connection(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.requires_connection() -""") +""" + ) command.upgrade(self.cfg, a) command.downgrade(self.cfg, a) def test_starting_rev_post_context(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite', starting_rev='x') assert context.get_starting_revision_argument() == 'x' -""") +""" + ) command.upgrade(self.cfg, a, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) command.current(self.cfg) command.stamp(self.cfg, a) def test_starting_rev_pre_context(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_starting_revision_argument() == 'x' -""") +""" + ) command.upgrade(self.cfg, "x:y", sql=True) command.downgrade(self.cfg, "x:y", sql=True) def test_starting_rev_pre_context_cmd_w_no_startrev(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_starting_revision_argument() == 'x' -""") +""" + ) assert_raises_message( util.CommandError, "No starting revision argument is available.", - command.current, self.cfg) + command.current, + self.cfg, + ) def test_starting_rev_current_pre_context(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_starting_revision_argument() is None -""") +""" + ) assert_raises_message( util.CommandError, "No starting revision argument is available.", - command.current, self.cfg + command.current, + self.cfg, ) def test_destination_rev_pre_context(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_revision_argument() == '%s' -""" % b) +""" + % b + ) command.upgrade(self.cfg, b, sql=True) command.stamp(self.cfg, b, sql=True) command.downgrade(self.cfg, "%s:%s" % (c, b), sql=True) def test_destination_rev_pre_context_multihead(self): d, e, f = multi_heads_fixture(self.cfg, a, b, c) - env_file_fixture(""" + env_file_fixture( + """ assert set(context.get_revision_argument()) == set(('%s', '%s', '%s', )) -""" % (f, e, c)) - command.upgrade(self.cfg, 'heads', sql=True) +""" + % (f, e, c) + ) + command.upgrade(self.cfg, "heads", sql=True) def test_destination_rev_post_context(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') assert context.get_revision_argument() == '%s' -""" % b) +""" + % b + ) command.upgrade(self.cfg, b, sql=True) command.downgrade(self.cfg, "%s:%s" % (c, b), sql=True) command.stamp(self.cfg, b, sql=True) def test_destination_rev_post_context_multihead(self): d, e, f = multi_heads_fixture(self.cfg, a, b, c) - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') assert set(context.get_revision_argument()) == set(('%s', '%s', '%s', )) -""" % (f, e, c)) - command.upgrade(self.cfg, 'heads', sql=True) +""" + % (f, e, c) + ) + command.upgrade(self.cfg, "heads", sql=True) def test_head_rev_pre_context(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_head_revision() == '%s' assert context.get_head_revisions() == ('%s', ) -""" % (c, c)) +""" + % (c, c) + ) command.upgrade(self.cfg, b, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) command.stamp(self.cfg, b, sql=True) @@ -117,20 +151,26 @@ assert context.get_head_revisions() == ('%s', ) def test_head_rev_pre_context_multihead(self): d, e, f = multi_heads_fixture(self.cfg, a, b, c) - env_file_fixture(""" + env_file_fixture( + """ assert set(context.get_head_revisions()) == set(('%s', '%s', '%s', )) -""" % (e, f, c)) +""" + % (e, f, c) + ) command.upgrade(self.cfg, e, sql=True) command.downgrade(self.cfg, "%s:%s" % (e, b), sql=True) command.stamp(self.cfg, c, sql=True) command.current(self.cfg) def test_head_rev_post_context(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') assert context.get_head_revision() == '%s' assert context.get_head_revisions() == ('%s', ) -""" % (c, c)) +""" + % (c, c) + ) command.upgrade(self.cfg, b, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) command.stamp(self.cfg, b, sql=True) @@ -138,70 +178,89 @@ assert context.get_head_revisions() == ('%s', ) def test_head_rev_post_context_multihead(self): d, e, f = multi_heads_fixture(self.cfg, a, b, c) - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') assert set(context.get_head_revisions()) == set(('%s', '%s', '%s', )) -""" % (e, f, c)) +""" + % (e, f, c) + ) command.upgrade(self.cfg, e, sql=True) command.downgrade(self.cfg, "%s:%s" % (e, b), sql=True) command.stamp(self.cfg, c, sql=True) command.current(self.cfg) def test_tag_pre_context(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_tag_argument() == 'hi' -""") - command.upgrade(self.cfg, b, sql=True, tag='hi') - command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi') +""" + ) + command.upgrade(self.cfg, b, sql=True, tag="hi") + command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi") def test_tag_pre_context_None(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_tag_argument() is None -""") +""" + ) command.upgrade(self.cfg, b, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) def test_tag_cmd_arg(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') assert context.get_tag_argument() == 'hi' -""") - command.upgrade(self.cfg, b, sql=True, tag='hi') - command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi') +""" + ) + command.upgrade(self.cfg, b, sql=True, tag="hi") + command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi") def test_tag_cfg_arg(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite', tag='there') assert context.get_tag_argument() == 'there' -""") - command.upgrade(self.cfg, b, sql=True, tag='hi') - command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi') +""" + ) + command.upgrade(self.cfg, b, sql=True, tag="hi") + command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi") def test_tag_None(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') assert context.get_tag_argument() is None -""") +""" + ) command.upgrade(self.cfg, b, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) def test_downgrade_wo_colon(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') -""") +""" + ) assert_raises_message( util.CommandError, "downgrade with --sql requires <fromrev>:<torev>", command.downgrade, - self.cfg, b, sql=True + self.cfg, + b, + sql=True, ) def test_upgrade_with_output_encoding(self): - env_file_fixture(""" + env_file_fixture( + """ url = config.get_main_option('sqlalchemy.url') context.configure(url=url, output_encoding='utf-8') assert not context.requires_connection() -""") +""" + ) command.upgrade(self.cfg, a, sql=True) command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True) @@ -213,37 +272,49 @@ assert not context.requires_connection() with capture_context_buffer(transactional_ddl=True) as buf: command.upgrade(self.cfg, "%s:%s" % (a, d.revision), sql=True) - assert not re.match(r".*-- .*and multiline", buf.getvalue(), re.S | re.M) + assert not re.match( + r".*-- .*and multiline", buf.getvalue(), re.S | re.M + ) def test_starting_rev_pre_context_abbreviated(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_starting_revision_argument() == '%s' -""" % b[0:4]) +""" + % b[0:4] + ) command.upgrade(self.cfg, "%s:%s" % (b[0:4], c), sql=True) command.stamp(self.cfg, "%s:%s" % (b[0:4], c), sql=True) command.downgrade(self.cfg, "%s:%s" % (b[0:4], a), sql=True) def test_destination_rev_pre_context_abbreviated(self): - env_file_fixture(""" + env_file_fixture( + """ assert context.get_revision_argument() == '%s' -""" % b[0:4]) +""" + % b[0:4] + ) command.upgrade(self.cfg, "%s:%s" % (a, b[0:4]), sql=True) command.stamp(self.cfg, b[0:4], sql=True) command.downgrade(self.cfg, "%s:%s" % (c, b[0:4]), sql=True) def test_starting_rev_context_runs_abbreviated(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') context.run_migrations() -""") +""" + ) command.upgrade(self.cfg, "%s:%s" % (b[0:4], c), sql=True) command.downgrade(self.cfg, "%s:%s" % (b[0:4], a), sql=True) def test_destination_rev_context_runs_abbreviated(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite') context.run_migrations() -""") +""" + ) command.upgrade(self.cfg, "%s:%s" % (a, b[0:4]), sql=True) command.stamp(self.cfg, b[0:4], sql=True) command.downgrade(self.cfg, "%s:%s" % (c, b[0:4]), sql=True) diff --git a/tests/test_op.py b/tests/test_op.py index f9a6c51..fb2db5f 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -1,7 +1,6 @@ """Test against the builders in the op.* module.""" -from sqlalchemy import Integer, Column, ForeignKey, \ - Table, String, Boolean +from sqlalchemy import Integer, Column, ForeignKey, Table, String, Boolean from sqlalchemy.sql import column, func, text from sqlalchemy import event @@ -17,19 +16,18 @@ from alembic.operations import schemaobj, ops @event.listens_for(Table, "after_parent_attach") def _add_cols(table, metadata): if table.name == "tbl_with_auto_appended_column": - table.append_column(Column('bat', Integer)) + table.append_column(Column("bat", Integer)) class OpTest(TestBase): - def test_rename_table(self): context = op_fixture() - op.rename_table('t1', 't2') + op.rename_table("t1", "t2") context.assert_("ALTER TABLE t1 RENAME TO t2") def test_rename_table_schema(self): context = op_fixture() - op.rename_table('t1', 't2', schema="foo") + op.rename_table("t1", "t2", schema="foo") context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2") def test_create_index_no_expr_allowed(self): @@ -37,15 +35,21 @@ class OpTest(TestBase): assert_raises_message( ValueError, r"String or text\(\) construct expected", - op.create_index, 'name', 'tname', [func.foo(column('x'))] + op.create_index, + "name", + "tname", + [func.foo(column("x"))], ) def test_add_column_schema_hard_quoting(self): from sqlalchemy.sql.schema import quoted_name + context = op_fixture("postgresql") op.add_column( - "somename", Column("colname", String), - schema=quoted_name("some.schema", quote=True)) + "somename", + Column("colname", String), + schema=quoted_name("some.schema", quote=True), + ) context.assert_( 'ALTER TABLE "some.schema".somename ADD COLUMN colname VARCHAR' @@ -53,68 +57,67 @@ class OpTest(TestBase): def test_rename_table_schema_hard_quoting(self): from sqlalchemy.sql.schema import quoted_name + context = op_fixture("postgresql") op.rename_table( - 't1', 't2', - schema=quoted_name("some.schema", quote=True)) - - context.assert_( - 'ALTER TABLE "some.schema".t1 RENAME TO t2' + "t1", "t2", schema=quoted_name("some.schema", quote=True) ) + context.assert_('ALTER TABLE "some.schema".t1 RENAME TO t2') + def test_add_constraint_schema_hard_quoting(self): from sqlalchemy.sql.schema import quoted_name + context = op_fixture("postgresql") op.create_check_constraint( "ck_user_name_len", "user_table", - func.len(column('name')) > 5, - schema=quoted_name("some.schema", quote=True) + func.len(column("name")) > 5, + schema=quoted_name("some.schema", quote=True), ) context.assert_( 'ALTER TABLE "some.schema".user_table ADD ' - 'CONSTRAINT ck_user_name_len CHECK (len(name) > 5)' + "CONSTRAINT ck_user_name_len CHECK (len(name) > 5)" ) def test_create_index_quoting(self): context = op_fixture("postgresql") - op.create_index( - 'geocoded', - 'locations', - ["IShouldBeQuoted"]) + op.create_index("geocoded", "locations", ["IShouldBeQuoted"]) context.assert_( - 'CREATE INDEX geocoded ON locations ("IShouldBeQuoted")') + 'CREATE INDEX geocoded ON locations ("IShouldBeQuoted")' + ) def test_create_index_expressions(self): context = op_fixture() - op.create_index( - 'geocoded', - 'locations', - [text('lower(coordinates)')]) + op.create_index("geocoded", "locations", [text("lower(coordinates)")]) context.assert_( - "CREATE INDEX geocoded ON locations (lower(coordinates))") + "CREATE INDEX geocoded ON locations (lower(coordinates))" + ) def test_add_column(self): context = op_fixture() - op.add_column('t1', Column('c1', Integer, nullable=False)) + op.add_column("t1", Column("c1", Integer, nullable=False)) context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL") def test_add_column_schema(self): context = op_fixture() - op.add_column('t1', Column('c1', Integer, nullable=False), schema="foo") + op.add_column( + "t1", Column("c1", Integer, nullable=False), schema="foo" + ) context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL") def test_add_column_with_default(self): context = op_fixture() op.add_column( - 't1', Column('c1', Integer, nullable=False, server_default="12")) + "t1", Column("c1", Integer, nullable=False, server_default="12") + ) context.assert_( - "ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL") + "ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL" + ) def test_add_column_with_index(self): context = op_fixture() - op.add_column( - 't1', Column('c1', Integer, nullable=False, index=True)) + op.add_column("t1", Column("c1", Integer, nullable=False, index=True)) context.assert_( "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL", "CREATE INDEX ix_t1_c1 ON t1 (c1)", @@ -122,107 +125,117 @@ class OpTest(TestBase): def test_add_column_schema_with_default(self): context = op_fixture() - op.add_column('t1', - Column('c1', Integer, nullable=False, server_default="12"), - schema='foo') + op.add_column( + "t1", + Column("c1", Integer, nullable=False, server_default="12"), + schema="foo", + ) context.assert_( - "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL") + "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL" + ) def test_add_column_fk(self): context = op_fixture() op.add_column( - 't1', Column('c1', Integer, ForeignKey('c2.id'), nullable=False)) + "t1", Column("c1", Integer, ForeignKey("c2.id"), nullable=False) + ) context.assert_( "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL", - "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)" + "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)", ) def test_add_column_schema_fk(self): context = op_fixture() - op.add_column('t1', - Column('c1', Integer, ForeignKey('c2.id'), nullable=False), - schema='foo') + op.add_column( + "t1", + Column("c1", Integer, ForeignKey("c2.id"), nullable=False), + schema="foo", + ) context.assert_( "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL", - "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)" + "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)", ) def test_add_column_schema_type(self): """Test that a schema type generates its constraints....""" context = op_fixture() - op.add_column('t1', Column('c1', Boolean, nullable=False)) + op.add_column("t1", Column("c1", Boolean, nullable=False)) context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', - 'ALTER TABLE t1 ADD CHECK (c1 IN (0, 1))' + "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL", + "ALTER TABLE t1 ADD CHECK (c1 IN (0, 1))", ) def test_add_column_schema_schema_type(self): """Test that a schema type generates its constraints....""" context = op_fixture() - op.add_column('t1', Column('c1', Boolean, nullable=False), schema='foo') + op.add_column( + "t1", Column("c1", Boolean, nullable=False), schema="foo" + ) context.assert_( - 'ALTER TABLE foo.t1 ADD COLUMN c1 BOOLEAN NOT NULL', - 'ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))' + "ALTER TABLE foo.t1 ADD COLUMN c1 BOOLEAN NOT NULL", + "ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))", ) def test_add_column_schema_type_checks_rule(self): """Test that a schema type doesn't generate a constraint based on check rule.""" - context = op_fixture('postgresql') - op.add_column('t1', Column('c1', Boolean, nullable=False)) - context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', - ) + context = op_fixture("postgresql") + op.add_column("t1", Column("c1", Boolean, nullable=False)) + context.assert_("ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL") def test_add_column_fk_self_referential(self): context = op_fixture() op.add_column( - 't1', Column('c1', Integer, ForeignKey('t1.c2'), nullable=False)) + "t1", Column("c1", Integer, ForeignKey("t1.c2"), nullable=False) + ) context.assert_( "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL", - "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)" + "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)", ) def test_add_column_schema_fk_self_referential(self): context = op_fixture() op.add_column( - 't1', - Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False), - schema='foo') + "t1", + Column("c1", Integer, ForeignKey("foo.t1.c2"), nullable=False), + schema="foo", + ) context.assert_( "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL", - "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)" + "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)", ) def test_add_column_fk_schema(self): context = op_fixture() op.add_column( - 't1', - Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False)) + "t1", + Column("c1", Integer, ForeignKey("remote.t2.c2"), nullable=False), + ) context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL', - 'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)' + "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL", + "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)", ) def test_add_column_schema_fk_schema(self): context = op_fixture() op.add_column( - 't1', - Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False), - schema='foo') + "t1", + Column("c1", Integer, ForeignKey("remote.t2.c2"), nullable=False), + schema="foo", + ) context.assert_( - 'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL', - 'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)' + "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL", + "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)", ) def test_drop_column(self): context = op_fixture() - op.drop_column('t1', 'c1') + op.drop_column("t1", "c1") context.assert_("ALTER TABLE t1 DROP COLUMN c1") def test_drop_column_schema(self): context = op_fixture() - op.drop_column('t1', 'c1', schema='foo') + op.drop_column("t1", "c1", schema="foo") context.assert_("ALTER TABLE foo.t1 DROP COLUMN c1") def test_alter_column_nullable(self): @@ -236,7 +249,7 @@ class OpTest(TestBase): def test_alter_column_schema_nullable(self): context = op_fixture() - op.alter_column("t", "c", nullable=True, schema='foo') + op.alter_column("t", "c", nullable=True, schema="foo") context.assert_( # TODO: not sure if this is PG only or standard # SQL @@ -254,7 +267,7 @@ class OpTest(TestBase): def test_alter_column_schema_not_nullable(self): context = op_fixture() - op.alter_column("t", "c", nullable=False, schema='foo') + op.alter_column("t", "c", nullable=False, schema="foo") context.assert_( # TODO: not sure if this is PG only or standard # SQL @@ -264,58 +277,50 @@ class OpTest(TestBase): def test_alter_column_rename(self): context = op_fixture() op.alter_column("t", "c", new_column_name="x") - context.assert_( - "ALTER TABLE t RENAME c TO x" - ) + context.assert_("ALTER TABLE t RENAME c TO x") def test_alter_column_schema_rename(self): context = op_fixture() - op.alter_column("t", "c", new_column_name="x", schema='foo') - context.assert_( - "ALTER TABLE foo.t RENAME c TO x" - ) + op.alter_column("t", "c", new_column_name="x", schema="foo") + context.assert_("ALTER TABLE foo.t RENAME c TO x") def test_alter_column_type(self): context = op_fixture() op.alter_column("t", "c", type_=String(50)) - context.assert_( - 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)' - ) + context.assert_("ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)") def test_alter_column_schema_type(self): context = op_fixture() - op.alter_column("t", "c", type_=String(50), schema='foo') - context.assert_( - 'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)' - ) + op.alter_column("t", "c", type_=String(50), schema="foo") + context.assert_("ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)") def test_alter_column_set_default(self): context = op_fixture() op.alter_column("t", "c", server_default="q") - context.assert_( - "ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'" - ) + context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'") def test_alter_column_schema_set_default(self): context = op_fixture() - op.alter_column("t", "c", server_default="q", schema='foo') - context.assert_( - "ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'" - ) + op.alter_column("t", "c", server_default="q", schema="foo") + context.assert_("ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'") def test_alter_column_set_compiled_default(self): context = op_fixture() - op.alter_column("t", "c", - server_default=func.utc_thing(func.current_timestamp())) + op.alter_column( + "t", "c", server_default=func.utc_thing(func.current_timestamp()) + ) context.assert_( "ALTER TABLE t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)" ) def test_alter_column_schema_set_compiled_default(self): context = op_fixture() - op.alter_column("t", "c", - server_default=func.utc_thing(func.current_timestamp()), - schema='foo') + op.alter_column( + "t", + "c", + server_default=func.utc_thing(func.current_timestamp()), + schema="foo", + ) context.assert_( "ALTER TABLE foo.t ALTER COLUMN c " "SET DEFAULT utc_thing(CURRENT_TIMESTAMP)" @@ -324,101 +329,98 @@ class OpTest(TestBase): def test_alter_column_drop_default(self): context = op_fixture() op.alter_column("t", "c", server_default=None) - context.assert_( - 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT' - ) + context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT") def test_alter_column_schema_drop_default(self): context = op_fixture() - op.alter_column("t", "c", server_default=None, schema='foo') - context.assert_( - 'ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT' - ) + op.alter_column("t", "c", server_default=None, schema="foo") + context.assert_("ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT") def test_alter_column_schema_type_unnamed(self): - context = op_fixture('mssql', native_boolean=False) + context = op_fixture("mssql", native_boolean=False) op.alter_column("t", "c", type_=Boolean()) context.assert_( - 'ALTER TABLE t ALTER COLUMN c BIT', - 'ALTER TABLE t ADD CHECK (c IN (0, 1))' + "ALTER TABLE t ALTER COLUMN c BIT", + "ALTER TABLE t ADD CHECK (c IN (0, 1))", ) def test_alter_column_schema_schema_type_unnamed(self): - context = op_fixture('mssql', native_boolean=False) - op.alter_column("t", "c", type_=Boolean(), schema='foo') + context = op_fixture("mssql", native_boolean=False) + op.alter_column("t", "c", type_=Boolean(), schema="foo") context.assert_( - 'ALTER TABLE foo.t ALTER COLUMN c BIT', - 'ALTER TABLE foo.t ADD CHECK (c IN (0, 1))' + "ALTER TABLE foo.t ALTER COLUMN c BIT", + "ALTER TABLE foo.t ADD CHECK (c IN (0, 1))", ) def test_alter_column_schema_type_named(self): - context = op_fixture('mssql', native_boolean=False) + context = op_fixture("mssql", native_boolean=False) op.alter_column("t", "c", type_=Boolean(name="xyz")) context.assert_( - 'ALTER TABLE t ALTER COLUMN c BIT', - 'ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))' + "ALTER TABLE t ALTER COLUMN c BIT", + "ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))", ) def test_alter_column_schema_schema_type_named(self): - context = op_fixture('mssql', native_boolean=False) - op.alter_column("t", "c", type_=Boolean(name="xyz"), schema='foo') + context = op_fixture("mssql", native_boolean=False) + op.alter_column("t", "c", type_=Boolean(name="xyz"), schema="foo") context.assert_( - 'ALTER TABLE foo.t ALTER COLUMN c BIT', - 'ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))' + "ALTER TABLE foo.t ALTER COLUMN c BIT", + "ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))", ) def test_alter_column_schema_type_existing_type(self): - context = op_fixture('mssql', native_boolean=False) + context = op_fixture("mssql", native_boolean=False) op.alter_column( - "t", "c", type_=String(10), existing_type=Boolean(name="xyz")) + "t", "c", type_=String(10), existing_type=Boolean(name="xyz") + ) context.assert_( - 'ALTER TABLE t DROP CONSTRAINT xyz', - 'ALTER TABLE t ALTER COLUMN c VARCHAR(10)' + "ALTER TABLE t DROP CONSTRAINT xyz", + "ALTER TABLE t ALTER COLUMN c VARCHAR(10)", ) def test_alter_column_schema_schema_type_existing_type(self): - context = op_fixture('mssql', native_boolean=False) - op.alter_column("t", "c", type_=String(10), - existing_type=Boolean(name="xyz"), schema='foo') + context = op_fixture("mssql", native_boolean=False) + op.alter_column( + "t", + "c", + type_=String(10), + existing_type=Boolean(name="xyz"), + schema="foo", + ) context.assert_( - 'ALTER TABLE foo.t DROP CONSTRAINT xyz', - 'ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)' + "ALTER TABLE foo.t DROP CONSTRAINT xyz", + "ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)", ) def test_alter_column_schema_type_existing_type_no_const(self): - context = op_fixture('postgresql') + context = op_fixture("postgresql") op.alter_column("t", "c", type_=String(10), existing_type=Boolean()) - context.assert_( - 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)' - ) + context.assert_("ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)") def test_alter_column_schema_schema_type_existing_type_no_const(self): - context = op_fixture('postgresql') - op.alter_column("t", "c", type_=String(10), existing_type=Boolean(), - schema='foo') - context.assert_( - 'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)' + context = op_fixture("postgresql") + op.alter_column( + "t", "c", type_=String(10), existing_type=Boolean(), schema="foo" ) + context.assert_("ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)") def test_alter_column_schema_type_existing_type_no_new_type(self): - context = op_fixture('postgresql') + context = op_fixture("postgresql") op.alter_column("t", "c", nullable=False, existing_type=Boolean()) - context.assert_( - 'ALTER TABLE t ALTER COLUMN c SET NOT NULL' - ) + context.assert_("ALTER TABLE t ALTER COLUMN c SET NOT NULL") def test_alter_column_schema_schema_type_existing_type_no_new_type(self): - context = op_fixture('postgresql') - op.alter_column("t", "c", nullable=False, existing_type=Boolean(), - schema='foo') - context.assert_( - 'ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL' + context = op_fixture("postgresql") + op.alter_column( + "t", "c", nullable=False, existing_type=Boolean(), schema="foo" ) + context.assert_("ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL") def test_add_foreign_key(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho']) + op.create_foreign_key( + "fk_test", "t1", "t2", ["foo", "bar"], ["bat", "hoho"] + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES t2 (bat, hoho)" @@ -426,9 +428,15 @@ class OpTest(TestBase): def test_add_foreign_key_schema(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - source_schema='foo2', referent_schema='bar2') + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + source_schema="foo2", + referent_schema="bar2", + ) context.assert_( "ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES bar2.t2 (bat, hoho)" @@ -436,9 +444,15 @@ class OpTest(TestBase): def test_add_foreign_key_schema_same_tablename(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't1', - ['foo', 'bar'], ['bat', 'hoho'], - source_schema='foo2', referent_schema='bar2') + op.create_foreign_key( + "fk_test", + "t1", + "t1", + ["foo", "bar"], + ["bat", "hoho"], + source_schema="foo2", + referent_schema="bar2", + ) context.assert_( "ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES bar2.t1 (bat, hoho)" @@ -446,9 +460,14 @@ class OpTest(TestBase): def test_add_foreign_key_onupdate(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - onupdate='CASCADE') + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + onupdate="CASCADE", + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES t2 (bat, hoho) ON UPDATE CASCADE" @@ -456,9 +475,14 @@ class OpTest(TestBase): def test_add_foreign_key_ondelete(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - ondelete='CASCADE') + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + ondelete="CASCADE", + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES t2 (bat, hoho) ON DELETE CASCADE" @@ -466,9 +490,14 @@ class OpTest(TestBase): def test_add_foreign_key_deferrable(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - deferrable=True) + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + deferrable=True, + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES t2 (bat, hoho) DEFERRABLE" @@ -476,9 +505,14 @@ class OpTest(TestBase): def test_add_foreign_key_initially(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - initially='INITIAL') + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + initially="INITIAL", + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES t2 (bat, hoho) INITIALLY INITIAL" @@ -487,9 +521,14 @@ class OpTest(TestBase): @config.requirements.foreign_key_match def test_add_foreign_key_match(self): context = op_fixture() - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - match='SIMPLE') + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + match="SIMPLE", + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " "REFERENCES t2 (bat, hoho) MATCH SIMPLE" @@ -497,24 +536,44 @@ class OpTest(TestBase): def test_add_foreign_key_dialect_kw(self): op_fixture() - with mock.patch( - "sqlalchemy.schema.ForeignKeyConstraint" - ) as fkc: - op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - foobar_arg='xyz') + with mock.patch("sqlalchemy.schema.ForeignKeyConstraint") as fkc: + op.create_foreign_key( + "fk_test", + "t1", + "t2", + ["foo", "bar"], + ["bat", "hoho"], + foobar_arg="xyz", + ) if config.requirements.foreign_key_match.enabled: - eq_(fkc.mock_calls[0], - mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'], - onupdate=None, ondelete=None, name='fk_test', - foobar_arg='xyz', - deferrable=None, initially=None, match=None)) + eq_( + fkc.mock_calls[0], + mock.call( + ["foo", "bar"], + ["t2.bat", "t2.hoho"], + onupdate=None, + ondelete=None, + name="fk_test", + foobar_arg="xyz", + deferrable=None, + initially=None, + match=None, + ), + ) else: - eq_(fkc.mock_calls[0], - mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'], - onupdate=None, ondelete=None, name='fk_test', - foobar_arg='xyz', - deferrable=None, initially=None)) + eq_( + fkc.mock_calls[0], + mock.call( + ["foo", "bar"], + ["t2.bat", "t2.hoho"], + onupdate=None, + ondelete=None, + name="fk_test", + foobar_arg="xyz", + deferrable=None, + initially=None, + ), + ) def test_add_foreign_key_self_referential(self): context = op_fixture() @@ -541,9 +600,7 @@ class OpTest(TestBase): def test_add_check_constraint(self): context = op_fixture() op.create_check_constraint( - "ck_user_name_len", - "user_table", - func.len(column('name')) > 5 + "ck_user_name_len", "user_table", func.len(column("name")) > 5 ) context.assert_( "ALTER TABLE user_table ADD CONSTRAINT ck_user_name_len " @@ -555,8 +612,8 @@ class OpTest(TestBase): op.create_check_constraint( "ck_user_name_len", "user_table", - func.len(column('name')) > 5, - schema='foo' + func.len(column("name")) > 5, + schema="foo", ) context.assert_( "ALTER TABLE foo.user_table ADD CONSTRAINT ck_user_name_len " @@ -565,7 +622,7 @@ class OpTest(TestBase): def test_add_unique_constraint(self): context = op_fixture() - op.create_unique_constraint('uk_test', 't1', ['foo', 'bar']) + op.create_unique_constraint("uk_test", "t1", ["foo", "bar"]) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)" ) @@ -574,12 +631,12 @@ class OpTest(TestBase): context = op_fixture() op.create_foreign_key( - name='some_fk', - source='some_table', - referent='referred_table', - local_cols=['a', 'b'], - remote_cols=['c', 'd'], - ondelete='CASCADE' + name="some_fk", + source="some_table", + referent="referred_table", + local_cols=["a", "b"], + remote_cols=["c", "d"], + ondelete="CASCADE", ) context.assert_( "ALTER TABLE some_table ADD CONSTRAINT some_fk " @@ -590,27 +647,26 @@ class OpTest(TestBase): def test_add_unique_constraint_legacy_kwarg(self): context = op_fixture() op.create_unique_constraint( - name='uk_test', - source='t1', - local_cols=['foo', 'bar']) + name="uk_test", source="t1", local_cols=["foo", "bar"] + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)" ) def test_drop_constraint_legacy_kwarg(self): context = op_fixture() - op.drop_constraint(name='pk_name', - table_name='sometable', - type_='primary') - context.assert_( - "ALTER TABLE sometable DROP CONSTRAINT pk_name" + op.drop_constraint( + name="pk_name", table_name="sometable", type_="primary" ) + context.assert_("ALTER TABLE sometable DROP CONSTRAINT pk_name") def test_create_pk_legacy_kwarg(self): context = op_fixture() - op.create_primary_key(name=None, - table_name='sometable', - cols=['router_id', 'l3_agent_id']) + op.create_primary_key( + name=None, + table_name="sometable", + cols=["router_id", "l3_agent_id"], + ) context.assert_( "ALTER TABLE sometable ADD PRIMARY KEY (router_id, l3_agent_id)" ) @@ -623,57 +679,50 @@ class OpTest(TestBase): "missing required positional argument: columns", op.create_primary_key, name=None, - table_name='sometable', - wrong_cols=['router_id', 'l3_agent_id'] + table_name="sometable", + wrong_cols=["router_id", "l3_agent_id"], ) def test_add_unique_constraint_schema(self): context = op_fixture() op.create_unique_constraint( - 'uk_test', 't1', ['foo', 'bar'], schema='foo') + "uk_test", "t1", ["foo", "bar"], schema="foo" + ) context.assert_( "ALTER TABLE foo.t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)" ) def test_drop_constraint(self): context = op_fixture() - op.drop_constraint('foo_bar_bat', 't1') - context.assert_( - "ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat" - ) + op.drop_constraint("foo_bar_bat", "t1") + context.assert_("ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat") def test_drop_constraint_schema(self): context = op_fixture() - op.drop_constraint('foo_bar_bat', 't1', schema='foo') - context.assert_( - "ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat" - ) + op.drop_constraint("foo_bar_bat", "t1", schema="foo") + context.assert_("ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat") def test_create_index(self): context = op_fixture() - op.create_index('ik_test', 't1', ['foo', 'bar']) - context.assert_( - "CREATE INDEX ik_test ON t1 (foo, bar)" - ) + op.create_index("ik_test", "t1", ["foo", "bar"]) + context.assert_("CREATE INDEX ik_test ON t1 (foo, bar)") def test_create_unique_index(self): context = op_fixture() - op.create_index('ik_test', 't1', ['foo', 'bar'], unique=True) - context.assert_( - "CREATE UNIQUE INDEX ik_test ON t1 (foo, bar)" - ) + op.create_index("ik_test", "t1", ["foo", "bar"], unique=True) + context.assert_("CREATE UNIQUE INDEX ik_test ON t1 (foo, bar)") def test_create_index_quote_flag(self): context = op_fixture() - op.create_index('ik_test', 't1', ['foo', 'bar'], quote=True) - context.assert_( - 'CREATE INDEX "ik_test" ON t1 (foo, bar)' - ) + op.create_index("ik_test", "t1", ["foo", "bar"], quote=True) + context.assert_('CREATE INDEX "ik_test" ON t1 (foo, bar)') def test_create_index_table_col_event(self): context = op_fixture() - op.create_index('ik_test', 'tbl_with_auto_appended_column', ['foo', 'bar']) + op.create_index( + "ik_test", "tbl_with_auto_appended_column", ["foo", "bar"] + ) context.assert_( "CREATE INDEX ik_test ON tbl_with_auto_appended_column (foo, bar)" ) @@ -681,8 +730,8 @@ class OpTest(TestBase): def test_add_unique_constraint_col_event(self): context = op_fixture() op.create_unique_constraint( - 'ik_test', - 'tbl_with_auto_appended_column', ['foo', 'bar']) + "ik_test", "tbl_with_auto_appended_column", ["foo", "bar"] + ) context.assert_( "ALTER TABLE tbl_with_auto_appended_column " "ADD CONSTRAINT ik_test UNIQUE (foo, bar)" @@ -690,45 +739,35 @@ class OpTest(TestBase): def test_create_index_schema(self): context = op_fixture() - op.create_index('ik_test', 't1', ['foo', 'bar'], schema='foo') - context.assert_( - "CREATE INDEX ik_test ON foo.t1 (foo, bar)" - ) + op.create_index("ik_test", "t1", ["foo", "bar"], schema="foo") + context.assert_("CREATE INDEX ik_test ON foo.t1 (foo, bar)") def test_drop_index(self): context = op_fixture() - op.drop_index('ik_test') - context.assert_( - "DROP INDEX ik_test" - ) + op.drop_index("ik_test") + context.assert_("DROP INDEX ik_test") def test_drop_index_schema(self): context = op_fixture() - op.drop_index('ik_test', schema='foo') - context.assert_( - "DROP INDEX foo.ik_test" - ) + op.drop_index("ik_test", schema="foo") + context.assert_("DROP INDEX foo.ik_test") def test_drop_table(self): context = op_fixture() - op.drop_table('tb_test') - context.assert_( - "DROP TABLE tb_test" - ) + op.drop_table("tb_test") + context.assert_("DROP TABLE tb_test") def test_drop_table_schema(self): context = op_fixture() - op.drop_table('tb_test', schema='foo') - context.assert_( - "DROP TABLE foo.tb_test" - ) + op.drop_table("tb_test", schema="foo") + context.assert_("DROP TABLE foo.tb_test") def test_create_table_selfref(self): context = op_fixture() op.create_table( "some_table", - Column('id', Integer, primary_key=True), - Column('st_id', Integer, ForeignKey('some_table.id')) + Column("id", Integer, primary_key=True), + Column("st_id", Integer, ForeignKey("some_table.id")), ) context.assert_( "CREATE TABLE some_table (" @@ -742,9 +781,9 @@ class OpTest(TestBase): context = op_fixture() t1 = op.create_table( "some_table", - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - schema='schema' + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id")), + schema="schema", ) context.assert_( "CREATE TABLE schema.some_table (" @@ -760,9 +799,9 @@ class OpTest(TestBase): context = op_fixture() t1 = op.create_table( "some_table", - Column('x', Integer), - Column('y', Integer), - Column('z', Integer), + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), ) context.assert_( "CREATE TABLE some_table (x INTEGER, y INTEGER, z INTEGER)" @@ -773,9 +812,9 @@ class OpTest(TestBase): context = op_fixture() op.create_table( "some_table", - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('foo_bar', Integer, ForeignKey('foo.bar')), + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id")), + Column("foo_bar", Integer, ForeignKey("foo.bar")), ) context.assert_( "CREATE TABLE some_table (" @@ -792,27 +831,26 @@ class OpTest(TestBase): from sqlalchemy.sql import table, column from sqlalchemy import String, Integer - account = table('account', - column('name', String), - column('id', Integer) - ) + account = table( + "account", column("name", String), column("id", Integer) + ) op.execute( - account.update(). - where(account.c.name == op.inline_literal('account 1')). - values({'name': op.inline_literal('account 2')}) + account.update() + .where(account.c.name == op.inline_literal("account 1")) + .values({"name": op.inline_literal("account 2")}) ) op.execute( - account.update(). - where(account.c.id == op.inline_literal(1)). - values({'id': op.inline_literal(2)}) + account.update() + .where(account.c.id == op.inline_literal(1)) + .values({"id": op.inline_literal(2)}) ) context.assert_( "UPDATE account SET name='account 2' WHERE account.name = 'account 1'", - "UPDATE account SET id=2 WHERE account.id = 1" + "UPDATE account SET id=2 WHERE account.id = 1", ) def test_cant_op(self): - if hasattr(op, '_proxy'): + if hasattr(op, "_proxy"): del op._proxy assert_raises_message( NameError, @@ -820,7 +858,8 @@ class OpTest(TestBase): "proxy object has not yet been established " "for the Alembic 'Operations' class. " "Try placing this code inside a callable.", - op.inline_literal, "asdf" + op.inline_literal, + "asdf", ) def test_naming_changes(self): @@ -832,17 +871,17 @@ class OpTest(TestBase): op.alter_column("t", "c", new_column_name="x") context.assert_("ALTER TABLE t RENAME c TO x") - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("f1", "t1", type="foreignkey") context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1") - context = op_fixture('mysql') + context = op_fixture("mysql") op.drop_constraint("f1", "t1", type_="foreignkey") context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1") def test_naming_changes_drop_idx(self): - context = op_fixture('mssql') - op.drop_index('ik_test', tablename='t1') + context = op_fixture("mssql") + op.drop_index("ik_test", tablename="t1") context.assert_("DROP INDEX ik_test ON t1") @@ -852,20 +891,19 @@ class SQLModeOpTest(TestBase): from sqlalchemy.sql import table, column from sqlalchemy import String, Integer - account = table('account', - column('name', String), - column('id', Integer) - ) + account = table( + "account", column("name", String), column("id", Integer) + ) op.execute( - account.update(). - where(account.c.name == op.inline_literal('account 1')). - values({'name': op.inline_literal('account 2')}) + account.update() + .where(account.c.name == op.inline_literal("account 1")) + .values({"name": op.inline_literal("account 2")}) ) - op.execute(text("update table set foo=:bar").bindparams(bar='bat')) + op.execute(text("update table set foo=:bar").bindparams(bar="bat")) context.assert_( "UPDATE account SET name='account 2' " "WHERE account.name = 'account 1'", - "update table set foo='bat'" + "update table set foo='bat'", ) def test_create_table_literal_binds(self): @@ -873,8 +911,8 @@ class SQLModeOpTest(TestBase): op.create_table( "some_table", - Column('id', Integer, primary_key=True), - Column('st_id', Integer, ForeignKey('some_table.id')) + Column("id", Integer, primary_key=True), + Column("st_id", Integer, ForeignKey("some_table.id")), ) context.assert_( @@ -907,7 +945,7 @@ class CustomOpTest(TestBase): operations.execute("CREATE SEQUENCE %s" % operation.sequence_name) context = op_fixture() - op.create_sequence('foob') + op.create_sequence("foob") context.assert_("CREATE SEQUENCE foob") @@ -923,48 +961,36 @@ class EnsureOrigObjectFromToTest(TestBase): def test_drop_index(self): schema_obj = schemaobj.SchemaObjects() - idx = schema_obj.index('x', 'y', ['z']) + idx = schema_obj.index("x", "y", ["z"]) op = ops.DropIndexOp.from_index(idx) - is_( - op.to_index(), idx - ) + is_(op.to_index(), idx) def test_create_index(self): schema_obj = schemaobj.SchemaObjects() - idx = schema_obj.index('x', 'y', ['z']) + idx = schema_obj.index("x", "y", ["z"]) op = ops.CreateIndexOp.from_index(idx) - is_( - op.to_index(), idx - ) + is_(op.to_index(), idx) def test_drop_table(self): schema_obj = schemaobj.SchemaObjects() - table = schema_obj.table('x', Column('q', Integer)) + table = schema_obj.table("x", Column("q", Integer)) op = ops.DropTableOp.from_table(table) - is_( - op.to_table(), table - ) + is_(op.to_table(), table) def test_create_table(self): schema_obj = schemaobj.SchemaObjects() - table = schema_obj.table('x', Column('q', Integer)) + table = schema_obj.table("x", Column("q", Integer)) op = ops.CreateTableOp.from_table(table) - is_( - op.to_table(), table - ) + is_(op.to_table(), table) def test_drop_unique_constraint(self): schema_obj = schemaobj.SchemaObjects() - const = schema_obj.unique_constraint('x', 'foobar', ['a']) + const = schema_obj.unique_constraint("x", "foobar", ["a"]) op = ops.DropConstraintOp.from_constraint(const) - is_( - op.to_constraint(), const - ) + is_(op.to_constraint(), const) def test_drop_constraint_not_available(self): - op = ops.DropConstraintOp('x', 'y', type_='unique') + op = ops.DropConstraintOp("x", "y", type_="unique") assert_raises_message( - ValueError, - "constraint cannot be produced", - op.to_constraint + ValueError, "constraint cannot be produced", op.to_constraint ) diff --git a/tests/test_op_naming_convention.py b/tests/test_op_naming_convention.py index fd70faa..fbcd181 100644 --- a/tests/test_op_naming_convention.py +++ b/tests/test_op_naming_convention.py @@ -1,5 +1,11 @@ -from sqlalchemy import Integer, Column, \ - Table, Boolean, MetaData, CheckConstraint +from sqlalchemy import ( + Integer, + Column, + Table, + Boolean, + MetaData, + CheckConstraint, +) from sqlalchemy.sql import column, func from alembic import op @@ -9,16 +15,14 @@ from alembic.testing.fixtures import TestBase class AutoNamingConventionTest(TestBase): - __requires__ = ('sqlalchemy_094', ) + __requires__ = ("sqlalchemy_094",) def test_add_check_constraint(self): - context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) + context = op_fixture( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) op.create_check_constraint( - "foo", - "user_table", - func.len(column('name')) > 5 + "foo", "user_table", func.len(column("name")) > 5 ) context.assert_( "ALTER TABLE user_table ADD CONSTRAINT ck_user_table_foo " @@ -26,13 +30,9 @@ class AutoNamingConventionTest(TestBase): ) def test_add_check_constraint_name_is_none(self): - context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_foo" - }) + context = op_fixture(naming_convention={"ck": "ck_%(table_name)s_foo"}) op.create_check_constraint( - None, - "user_table", - func.len(column('name')) > 5 + None, "user_table", func.len(column("name")) > 5 ) context.assert_( "ALTER TABLE user_table ADD CONSTRAINT ck_user_table_foo " @@ -40,44 +40,29 @@ class AutoNamingConventionTest(TestBase): ) def test_add_unique_constraint_name_is_none(self): - context = op_fixture(naming_convention={ - "uq": "uq_%(table_name)s_foo" - }) - op.create_unique_constraint( - None, - "user_table", - 'x' - ) + context = op_fixture(naming_convention={"uq": "uq_%(table_name)s_foo"}) + op.create_unique_constraint(None, "user_table", "x") context.assert_( "ALTER TABLE user_table ADD CONSTRAINT uq_user_table_foo UNIQUE (x)" ) def test_add_index_name_is_none(self): - context = op_fixture(naming_convention={ - "ix": "ix_%(table_name)s_foo" - }) - op.create_index( - None, - "user_table", - 'x' - ) - context.assert_( - "CREATE INDEX ix_user_table_foo ON user_table (x)" - ) + context = op_fixture(naming_convention={"ix": "ix_%(table_name)s_foo"}) + op.create_index(None, "user_table", "x") + context.assert_("CREATE INDEX ix_user_table_foo ON user_table (x)") def test_add_check_constraint_already_named_from_schema(self): m1 = MetaData( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) ck = CheckConstraint("im a constraint", name="cc1") - Table('t', m1, Column('x'), ck) + Table("t", m1, Column("x"), ck) context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) - - op.create_table( - "some_table", - Column('x', Integer, ck), + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} ) + + op.create_table("some_table", Column("x", Integer, ck)) context.assert_( "CREATE TABLE some_table " "(x INTEGER CONSTRAINT ck_t_cc1 CHECK (im a constraint))" @@ -85,11 +70,12 @@ class AutoNamingConventionTest(TestBase): def test_add_check_constraint_inline_on_table(self): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) op.create_table( "some_table", - Column('x', Integer), - CheckConstraint("im a constraint", name="cc1") + Column("x", Integer), + CheckConstraint("im a constraint", name="cc1"), ) context.assert_( "CREATE TABLE some_table " @@ -98,11 +84,12 @@ class AutoNamingConventionTest(TestBase): def test_add_check_constraint_inline_on_table_w_f(self): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) op.create_table( "some_table", - Column('x', Integer), - CheckConstraint("im a constraint", name=op.f("ck_some_table_cc1")) + Column("x", Integer), + CheckConstraint("im a constraint", name=op.f("ck_some_table_cc1")), ) context.assert_( "CREATE TABLE some_table " @@ -111,10 +98,13 @@ class AutoNamingConventionTest(TestBase): def test_add_check_constraint_inline_on_column(self): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) op.create_table( "some_table", - Column('x', Integer, CheckConstraint("im a constraint", name="cc1")) + Column( + "x", Integer, CheckConstraint("im a constraint", name="cc1") + ), ) context.assert_( "CREATE TABLE some_table " @@ -123,12 +113,15 @@ class AutoNamingConventionTest(TestBase): def test_add_check_constraint_inline_on_column_w_f(self): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) op.create_table( "some_table", Column( - 'x', Integer, - CheckConstraint("im a constraint", name=op.f("ck_q_cc1"))) + "x", + Integer, + CheckConstraint("im a constraint", name=op.f("ck_q_cc1")), + ), ) context.assert_( "CREATE TABLE some_table " @@ -136,22 +129,23 @@ class AutoNamingConventionTest(TestBase): ) def test_add_column_schema_type(self): - context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - op.add_column('t1', Column('c1', Boolean(name='foo'), nullable=False)) + context = op_fixture( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + op.add_column("t1", Column("c1", Boolean(name="foo"), nullable=False)) context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', - 'ALTER TABLE t1 ADD CONSTRAINT ck_t1_foo CHECK (c1 IN (0, 1))' + "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL", + "ALTER TABLE t1 ADD CONSTRAINT ck_t1_foo CHECK (c1 IN (0, 1))", ) def test_add_column_schema_type_w_f(self): - context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) + context = op_fixture( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) op.add_column( - 't1', Column('c1', Boolean(name=op.f('foo')), nullable=False)) + "t1", Column("c1", Boolean(name=op.f("foo")), nullable=False) + ) context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', - 'ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))' + "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL", + "ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))", ) diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 8b9c9e5..86e0ece 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -1,23 +1,24 @@ - from sqlalchemy import Integer, Column from alembic import op, command from alembic.testing.fixtures import TestBase from alembic.testing.fixtures import op_fixture, capture_context_buffer -from alembic.testing.env import _no_sql_testing_config, staging_env, \ - three_rev_fixture, clear_staging_env +from alembic.testing.env import ( + _no_sql_testing_config, + staging_env, + three_rev_fixture, + clear_staging_env, +) class FullEnvironmentTests(TestBase): - @classmethod def setup_class(cls): staging_env() cls.cfg = cfg = _no_sql_testing_config("oracle") - cls.a, cls.b, cls.c = \ - three_rev_fixture(cfg) + cls.a, cls.b, cls.c = three_rev_fixture(cfg) @classmethod def teardown_class(cls): @@ -42,113 +43,99 @@ class FullEnvironmentTests(TestBase): class OpTest(TestBase): - def test_add_column(self): - context = op_fixture('oracle') - op.add_column('t1', Column('c1', Integer, nullable=False)) + context = op_fixture("oracle") + op.add_column("t1", Column("c1", Integer, nullable=False)) context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL") def test_add_column_with_default(self): context = op_fixture("oracle") op.add_column( - 't1', Column('c1', Integer, nullable=False, server_default="12")) + "t1", Column("c1", Integer, nullable=False, server_default="12") + ) context.assert_("ALTER TABLE t1 ADD c1 INTEGER DEFAULT '12' NOT NULL") def test_alter_column_rename_oracle(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", name="x") - context.assert_( - "ALTER TABLE t RENAME COLUMN c TO x" - ) + context.assert_("ALTER TABLE t RENAME COLUMN c TO x") def test_alter_column_new_type(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", type_=Integer) - context.assert_( - 'ALTER TABLE t MODIFY c INTEGER' - ) + context.assert_("ALTER TABLE t MODIFY c INTEGER") def test_drop_index(self): - context = op_fixture('oracle') - op.drop_index('my_idx', 'my_table') + context = op_fixture("oracle") + op.drop_index("my_idx", "my_table") context.assert_contains("DROP INDEX my_idx") def test_drop_column_w_default(self): - context = op_fixture('oracle') - op.drop_column('t1', 'c1') - context.assert_( - "ALTER TABLE t1 DROP COLUMN c1" - ) + context = op_fixture("oracle") + op.drop_column("t1", "c1") + context.assert_("ALTER TABLE t1 DROP COLUMN c1") def test_drop_column_w_check(self): - context = op_fixture('oracle') - op.drop_column('t1', 'c1') - context.assert_( - "ALTER TABLE t1 DROP COLUMN c1" - ) + context = op_fixture("oracle") + op.drop_column("t1", "c1") + context.assert_("ALTER TABLE t1 DROP COLUMN c1") def test_alter_column_nullable_w_existing_type(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", nullable=True, existing_type=Integer) - context.assert_( - "ALTER TABLE t MODIFY c NULL" - ) + context.assert_("ALTER TABLE t MODIFY c NULL") def test_alter_column_not_nullable_w_existing_type(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", nullable=False, existing_type=Integer) - context.assert_( - "ALTER TABLE t MODIFY c NOT NULL" - ) + context.assert_("ALTER TABLE t MODIFY c NOT NULL") def test_alter_column_nullable_w_new_type(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", nullable=True, type_=Integer) context.assert_( - "ALTER TABLE t MODIFY c NULL", - 'ALTER TABLE t MODIFY c INTEGER' + "ALTER TABLE t MODIFY c NULL", "ALTER TABLE t MODIFY c INTEGER" ) def test_alter_column_not_nullable_w_new_type(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", nullable=False, type_=Integer) context.assert_( - "ALTER TABLE t MODIFY c NOT NULL", - "ALTER TABLE t MODIFY c INTEGER" + "ALTER TABLE t MODIFY c NOT NULL", "ALTER TABLE t MODIFY c INTEGER" ) def test_alter_add_server_default(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", server_default="5") - context.assert_( - "ALTER TABLE t MODIFY c DEFAULT '5'" - ) + context.assert_("ALTER TABLE t MODIFY c DEFAULT '5'") def test_alter_replace_server_default(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column( - "t", "c", server_default="5", existing_server_default="6") - context.assert_( - "ALTER TABLE t MODIFY c DEFAULT '5'" + "t", "c", server_default="5", existing_server_default="6" ) + context.assert_("ALTER TABLE t MODIFY c DEFAULT '5'") def test_alter_remove_server_default(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column("t", "c", server_default=None) - context.assert_( - "ALTER TABLE t MODIFY c DEFAULT NULL" - ) + context.assert_("ALTER TABLE t MODIFY c DEFAULT NULL") def test_alter_do_everything(self): - context = op_fixture('oracle') + context = op_fixture("oracle") op.alter_column( - "t", "c", name="c2", nullable=True, - type_=Integer, server_default="5") + "t", + "c", + name="c2", + nullable=True, + type_=Integer, + server_default="5", + ) context.assert_( - 'ALTER TABLE t MODIFY c NULL', + "ALTER TABLE t MODIFY c NULL", "ALTER TABLE t MODIFY c DEFAULT '5'", - 'ALTER TABLE t MODIFY c INTEGER', - 'ALTER TABLE t RENAME COLUMN c TO c2' + "ALTER TABLE t MODIFY c INTEGER", + "ALTER TABLE t RENAME COLUMN c TO c2", ) # TODO: when we add schema support diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 23ec49c..61ba2d1 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,13 +1,28 @@ - -from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \ - String, Interval, Sequence, Numeric, BigInteger, Float, Numeric +from sqlalchemy import ( + DateTime, + MetaData, + Table, + Column, + text, + Integer, + String, + Interval, + Sequence, + Numeric, + BigInteger, + Float, + Numeric, +) from sqlalchemy.dialects.postgresql import ARRAY, UUID, BYTEA from sqlalchemy.engine.reflection import Inspector from sqlalchemy import types from alembic.operations import Operations from sqlalchemy.sql import table, column -from alembic.autogenerate.compare import \ - _compare_server_default, _compare_tables, _render_server_default_for_compare +from alembic.autogenerate.compare import ( + _compare_server_default, + _compare_tables, + _render_server_default_for_compare, +) from alembic.operations import ops from alembic import command, util @@ -16,8 +31,12 @@ from alembic.script import ScriptDirectory from alembic.autogenerate import api from alembic.testing import eq_, provide_metadata -from alembic.testing.env import staging_env, clear_staging_env, \ - _no_sql_testing_config, write_script +from alembic.testing.env import ( + staging_env, + clear_staging_env, + _no_sql_testing_config, + write_script, +) from alembic.testing.fixtures import capture_context_buffer from alembic.testing.fixtures import TestBase from alembic.testing.fixtures import op_fixture @@ -36,79 +55,79 @@ if util.sqla_09: class PostgresqlOpTest(TestBase): - def test_rename_table_postgresql(self): context = op_fixture("postgresql") - op.rename_table('t1', 't2') + op.rename_table("t1", "t2") context.assert_("ALTER TABLE t1 RENAME TO t2") def test_rename_table_schema_postgresql(self): context = op_fixture("postgresql") - op.rename_table('t1', 't2', schema="foo") + op.rename_table("t1", "t2", schema="foo") context.assert_("ALTER TABLE foo.t1 RENAME TO t2") def test_create_index_postgresql_expressions(self): context = op_fixture("postgresql") op.create_index( - 'geocoded', - 'locations', - [text('lower(coordinates)')], - postgresql_where=text("locations.coordinates != Null")) + "geocoded", + "locations", + [text("lower(coordinates)")], + postgresql_where=text("locations.coordinates != Null"), + ) context.assert_( "CREATE INDEX geocoded ON locations (lower(coordinates)) " - "WHERE locations.coordinates != Null") + "WHERE locations.coordinates != Null" + ) def test_create_index_postgresql_where(self): context = op_fixture("postgresql") op.create_index( - 'geocoded', - 'locations', - ['coordinates'], - postgresql_where=text("locations.coordinates != Null")) + "geocoded", + "locations", + ["coordinates"], + postgresql_where=text("locations.coordinates != Null"), + ) context.assert_( "CREATE INDEX geocoded ON locations (coordinates) " - "WHERE locations.coordinates != Null") + "WHERE locations.coordinates != Null" + ) @config.requirements.fail_before_sqla_099 def test_create_index_postgresql_concurrently(self): context = op_fixture("postgresql") op.create_index( - 'geocoded', - 'locations', - ['coordinates'], - postgresql_concurrently=True) + "geocoded", + "locations", + ["coordinates"], + postgresql_concurrently=True, + ) context.assert_( - "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)") + "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)" + ) @config.requirements.fail_before_sqla_110 def test_drop_index_postgresql_concurrently(self): context = op_fixture("postgresql") - op.drop_index( - 'geocoded', - 'locations', - postgresql_concurrently=True) - context.assert_( - "DROP INDEX CONCURRENTLY geocoded") + op.drop_index("geocoded", "locations", postgresql_concurrently=True) + context.assert_("DROP INDEX CONCURRENTLY geocoded") def test_alter_column_type_using(self): - context = op_fixture('postgresql') - op.alter_column("t", "c", type_=Integer, postgresql_using='c::integer') + context = op_fixture("postgresql") + op.alter_column("t", "c", type_=Integer, postgresql_using="c::integer") context.assert_( - 'ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer' + "ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer" ) def test_col_w_pk_is_serial(self): context = op_fixture("postgresql") - op.add_column("some_table", Column('q', Integer, primary_key=True)) - context.assert_( - 'ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL' - ) + op.add_column("some_table", Column("q", Integer, primary_key=True)) + context.assert_("ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL") @config.requirements.fail_before_sqla_100 def test_create_exclude_constraint(self): context = op_fixture("postgresql") op.create_exclude_constraint( - "ex1", "t1", ('x', '>'), where='x > 5', using="gist") + "ex1", "t1", ("x", ">"), where="x > 5", using="gist" + ) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT ex1 EXCLUDE USING gist (x WITH >) " "WHERE (x > 5)" @@ -118,8 +137,12 @@ class PostgresqlOpTest(TestBase): def test_create_exclude_constraint_quoted_literal(self): context = op_fixture("postgresql") op.create_exclude_constraint( - "ex1", "SomeTable", ('"SomeColumn"', '>'), - where='"SomeColumn" > 5', using="gist") + "ex1", + "SomeTable", + ('"SomeColumn"', ">"), + where='"SomeColumn" > 5', + using="gist", + ) context.assert_( 'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE USING gist ' '("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)' @@ -129,8 +152,12 @@ class PostgresqlOpTest(TestBase): def test_create_exclude_constraint_quoted_column(self): context = op_fixture("postgresql") op.create_exclude_constraint( - "ex1", "SomeTable", (column("SomeColumn"), '>'), - where=column("SomeColumn") > 5, using="gist") + "ex1", + "SomeTable", + (column("SomeColumn"), ">"), + where=column("SomeColumn") > 5, + using="gist", + ) context.assert_( 'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE ' 'USING gist ("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)' @@ -138,7 +165,6 @@ class PostgresqlOpTest(TestBase): class PGOfflineEnumTest(TestBase): - def setUp(self): staging_env() self.cfg = cfg = _no_sql_testing_config() @@ -152,7 +178,10 @@ class PGOfflineEnumTest(TestBase): clear_staging_env() def _inline_enum_script(self): - write_script(self.script, self.rid, """ + write_script( + self.script, + self.rid, + """ revision = '%s' down_revision = None @@ -169,10 +198,15 @@ def upgrade(): def downgrade(): op.drop_table("sometable") -""" % self.rid) +""" + % self.rid, + ) def _distinct_enum_script(self): - write_script(self.script, self.rid, """ + write_script( + self.script, + self.rid, + """ revision = '%s' down_revision = None @@ -193,14 +227,18 @@ def downgrade(): op.drop_table("sometable") ENUM(name="pgenum").drop(op.get_bind(), checkfirst=False) -""" % self.rid) +""" + % self.rid, + ) def test_offline_inline_enum_create(self): self._inline_enum_script() with capture_context_buffer() as buf: command.upgrade(self.cfg, self.rid, sql=True) - assert "CREATE TYPE pgenum AS "\ + assert ( + "CREATE TYPE pgenum AS " "ENUM ('one', 'two', 'three')" in buf.getvalue() + ) assert "CREATE TABLE sometable (\n data pgenum\n)" in buf.getvalue() def test_offline_inline_enum_drop(self): @@ -215,8 +253,10 @@ def downgrade(): self._distinct_enum_script() with capture_context_buffer() as buf: command.upgrade(self.cfg, self.rid, sql=True) - assert "CREATE TYPE pgenum AS ENUM "\ + assert ( + "CREATE TYPE pgenum AS ENUM " "('one', 'two', 'three')" in buf.getvalue() + ) assert "CREATE TABLE sometable (\n data pgenum\n)" in buf.getvalue() def test_offline_distinct_enum_drop(self): @@ -228,23 +268,27 @@ def downgrade(): class PostgresqlInlineLiteralTest(TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @classmethod def setup_class(cls): cls.bind = config.db - cls.bind.execute(""" + cls.bind.execute( + """ create table tab ( col varchar(50) ) - """) - cls.bind.execute(""" + """ + ) + cls.bind.execute( + """ insert into tab (col) values ('old data 1'), ('old data 2.1'), ('old data 3') - """) + """ + ) @classmethod def teardown_class(cls): @@ -260,35 +304,32 @@ class PostgresqlInlineLiteralTest(TestBase): def test_inline_percent(self): # TODO: here's the issue, you need to escape this. - tab = table('tab', column('col')) + tab = table("tab", column("col")) self.op.execute( - tab.update().where( - tab.c.col.like(self.op.inline_literal('%.%')) - ).values(col=self.op.inline_literal('new data')), - execution_options={'no_parameters': True} + tab.update() + .where(tab.c.col.like(self.op.inline_literal("%.%"))) + .values(col=self.op.inline_literal("new data")), + execution_options={"no_parameters": True}, ) eq_( self.conn.execute( - "select count(*) from tab where col='new data'").scalar(), + "select count(*) from tab where col='new data'" + ).scalar(), 1, ) class PostgresqlDefaultCompareTest(TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True - @classmethod def setup_class(cls): cls.bind = config.db staging_env() cls.migration_context = MigrationContext.configure( connection=cls.bind.connect(), - opts={ - 'compare_type': True, - 'compare_server_default': True - } + opts={"compare_type": True, "compare_server_default": True}, ) def setUp(self): @@ -303,216 +344,166 @@ class PostgresqlDefaultCompareTest(TestBase): self.metadata.drop_all() def _compare_default_roundtrip( - self, type_, orig_default, alternate=None, diff_expected=None): - diff_expected = diff_expected \ - if diff_expected is not None \ + self, type_, orig_default, alternate=None, diff_expected=None + ): + diff_expected = ( + diff_expected + if diff_expected is not None else alternate is not None + ) if alternate is None: alternate = orig_default - t1 = Table("test", self.metadata, - Column("somecol", type_, server_default=orig_default)) - t2 = Table("test", MetaData(), - Column("somecol", type_, server_default=alternate)) + t1 = Table( + "test", + self.metadata, + Column("somecol", type_, server_default=orig_default), + ) + t2 = Table( + "test", + MetaData(), + Column("somecol", type_, server_default=alternate), + ) t1.create(self.bind) insp = Inspector.from_engine(self.bind) cols = insp.get_columns(t1.name) - insp_col = Column("somecol", cols[0]['type'], - server_default=text(cols[0]['default'])) + insp_col = Column( + "somecol", cols[0]["type"], server_default=text(cols[0]["default"]) + ) op = ops.AlterColumnOp("test", "somecol") _compare_server_default( - self.autogen_context, op, - None, "test", "somecol", insp_col, t2.c.somecol) + self.autogen_context, + op, + None, + "test", + "somecol", + insp_col, + t2.c.somecol, + ) diffs = op.to_diff_tuple() eq_(bool(diffs), diff_expected) - def _compare_default( - self, - t1, t2, col, - rendered - ): + def _compare_default(self, t1, t2, col, rendered): t1.create(self.bind, checkfirst=True) insp = Inspector.from_engine(self.bind) cols = insp.get_columns(t1.name) ctx = self.autogen_context.migration_context return ctx.impl.compare_server_default( - None, - col, - rendered, - cols[0]['default']) + None, col, rendered, cols[0]["default"] + ) def test_compare_interval_str(self): # this form shouldn't be used but testing here # for compatibility - self._compare_default_roundtrip( - Interval, - "14 days" - ) + self._compare_default_roundtrip(Interval, "14 days") @config.requirements.postgresql_uuid_ossp def test_compare_uuid_text(self): - self._compare_default_roundtrip( - UUID, - text("uuid_generate_v4()") - ) + self._compare_default_roundtrip(UUID, text("uuid_generate_v4()")) def test_compare_interval_text(self): - self._compare_default_roundtrip( - Interval, - text("'14 days'") - ) + self._compare_default_roundtrip(Interval, text("'14 days'")) def test_compare_array_of_integer_text(self): self._compare_default_roundtrip( - ARRAY(Integer), - text("(ARRAY[]::integer[])") + ARRAY(Integer), text("(ARRAY[]::integer[])") ) def test_compare_current_timestamp_text(self): self._compare_default_roundtrip( - DateTime(), - text("TIMEZONE('utc', CURRENT_TIMESTAMP)"), + DateTime(), text("TIMEZONE('utc', CURRENT_TIMESTAMP)") ) def test_compare_integer_str(self): - self._compare_default_roundtrip( - Integer(), - "5", - ) + self._compare_default_roundtrip(Integer(), "5") def test_compare_integer_text(self): - self._compare_default_roundtrip( - Integer(), - text("5"), - ) + self._compare_default_roundtrip(Integer(), text("5")) def test_compare_integer_text_diff(self): - self._compare_default_roundtrip( - Integer(), - text("5"), "7" - ) + self._compare_default_roundtrip(Integer(), text("5"), "7") def test_compare_float_str(self): - self._compare_default_roundtrip( - Float(), - "5.2", - ) + self._compare_default_roundtrip(Float(), "5.2") def test_compare_float_text(self): - self._compare_default_roundtrip( - Float(), - text("5.2"), - ) + self._compare_default_roundtrip(Float(), text("5.2")) def test_compare_float_no_diff1(self): self._compare_default_roundtrip( - Float(), - text("5.2"), "5.2", - diff_expected=False + Float(), text("5.2"), "5.2", diff_expected=False ) def test_compare_float_no_diff2(self): self._compare_default_roundtrip( - Float(), - "5.2", text("5.2"), - diff_expected=False + Float(), "5.2", text("5.2"), diff_expected=False ) def test_compare_float_no_diff3(self): self._compare_default_roundtrip( - Float(), - text("5"), text("5.0"), - diff_expected=False + Float(), text("5"), text("5.0"), diff_expected=False ) def test_compare_float_no_diff4(self): self._compare_default_roundtrip( - Float(), - "5", "5.0", - diff_expected=False + Float(), "5", "5.0", diff_expected=False ) def test_compare_float_no_diff5(self): self._compare_default_roundtrip( - Float(), - text("5"), "5.0", - diff_expected=False + Float(), text("5"), "5.0", diff_expected=False ) def test_compare_float_no_diff6(self): self._compare_default_roundtrip( - Float(), - "5", text("5.0"), - diff_expected=False + Float(), "5", text("5.0"), diff_expected=False ) def test_compare_numeric_no_diff(self): self._compare_default_roundtrip( - Numeric(), - text("5"), "5.0", - diff_expected=False + Numeric(), text("5"), "5.0", diff_expected=False ) def test_compare_unicode_literal(self): - self._compare_default_roundtrip( - String(), - u'im a default' - ) + self._compare_default_roundtrip(String(), u"im a default") # TOOD: will need to actually eval() the repr() and # spend more effort figuring out exactly the kind of expression # to use def _TODO_test_compare_character_str_w_singlequote(self): - self._compare_default_roundtrip( - String(), - "hel''lo", - ) + self._compare_default_roundtrip(String(), "hel''lo") def test_compare_character_str(self): - self._compare_default_roundtrip( - String(), - "hello", - ) + self._compare_default_roundtrip(String(), "hello") def test_compare_character_text(self): - self._compare_default_roundtrip( - String(), - text("'hello'"), - ) + self._compare_default_roundtrip(String(), text("'hello'")) def test_compare_character_str_diff(self): - self._compare_default_roundtrip( - String(), - "hello", - "there" - ) + self._compare_default_roundtrip(String(), "hello", "there") def test_compare_character_text_diff(self): self._compare_default_roundtrip( - String(), - text("'hello'"), - text("'there'") + String(), text("'hello'"), text("'there'") ) def test_primary_key_skip(self): """Test that SERIAL cols are just skipped""" - t1 = Table("sometable", self.metadata, - Column("id", Integer, primary_key=True) - ) - t2 = Table("sometable", MetaData(), - Column("id", Integer, primary_key=True) - ) - assert not self._compare_default( - t1, t2, t2.c.id, "" + t1 = Table( + "sometable", self.metadata, Column("id", Integer, primary_key=True) + ) + t2 = Table( + "sometable", MetaData(), Column("id", Integer, primary_key=True) ) + assert not self._compare_default(t1, t2, t2.c.id, "") class PostgresqlDetectSerialTest(TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @classmethod @@ -522,10 +513,7 @@ class PostgresqlDetectSerialTest(TestBase): staging_env() cls.migration_context = MigrationContext.configure( connection=cls.conn, - opts={ - 'compare_type': True, - 'compare_server_default': True - } + opts={"compare_type": True, "compare_server_default": True}, ) def setUp(self): @@ -538,7 +526,7 @@ class PostgresqlDetectSerialTest(TestBase): @provide_metadata def _expect_default(self, c_expected, col, seq=None): - Table('t', self.metadata, col) + Table("t", self.metadata, col) self.autogen_context.metadata = self.metadata @@ -550,43 +538,50 @@ class PostgresqlDetectSerialTest(TestBase): uo = ops.UpgradeOps(ops=[]) _compare_tables( - set([(None, 't')]), set([]), - insp, uo, self.autogen_context) + set([(None, "t")]), set([]), insp, uo, self.autogen_context + ) diffs = uo.as_diffs() tab = diffs[0][1] - eq_(_render_server_default_for_compare( - tab.c.x.server_default, tab.c.x, self.autogen_context), - c_expected) + eq_( + _render_server_default_for_compare( + tab.c.x.server_default, tab.c.x, self.autogen_context + ), + c_expected, + ) insp = Inspector.from_engine(config.db) uo = ops.UpgradeOps(ops=[]) m2 = MetaData() - Table('t', m2, Column('x', BigInteger())) + Table("t", m2, Column("x", BigInteger())) self.autogen_context.metadata = m2 _compare_tables( - set([(None, 't')]), set([(None, 't')]), - insp, uo, self.autogen_context) + set([(None, "t")]), + set([(None, "t")]), + insp, + uo, + self.autogen_context, + ) diffs = uo.as_diffs() - server_default = diffs[0][0][4]['existing_server_default'] - eq_(_render_server_default_for_compare( - server_default, tab.c.x, self.autogen_context), - c_expected) + server_default = diffs[0][0][4]["existing_server_default"] + eq_( + _render_server_default_for_compare( + server_default, tab.c.x, self.autogen_context + ), + c_expected, + ) def test_serial(self): - self._expect_default( - None, - Column('x', Integer, primary_key=True) - ) + self._expect_default(None, Column("x", Integer, primary_key=True)) def test_separate_seq(self): seq = Sequence("x_id_seq") self._expect_default( "nextval('x_id_seq'::regclass)", Column( - 'x', Integer, - server_default=seq.next_value(), primary_key=True), - seq + "x", Integer, server_default=seq.next_value(), primary_key=True + ), + seq, ) def test_numeric(self): @@ -594,29 +589,29 @@ class PostgresqlDetectSerialTest(TestBase): self._expect_default( "nextval('x_id_seq'::regclass)", Column( - 'x', Numeric(8, 2), server_default=seq.next_value(), - primary_key=True), - seq + "x", + Numeric(8, 2), + server_default=seq.next_value(), + primary_key=True, + ), + seq, ) def test_no_default(self): self._expect_default( - None, - Column('x', Integer, autoincrement=False, primary_key=True) + None, Column("x", Integer, autoincrement=False, primary_key=True) ) class PostgresqlAutogenRenderTest(TestBase): - def setUp(self): ctx_opts = { - 'sqlalchemy_module_prefix': 'sa.', - 'alembic_module_prefix': 'op.', - 'target_metadata': MetaData() + "sqlalchemy_module_prefix": "sa.", + "alembic_module_prefix": "op.", + "target_metadata": MetaData(), } context = MigrationContext.configure( - dialect_name="postgresql", - opts=ctx_opts + dialect_name="postgresql", opts=ctx_opts ) self.autogen_context = api.AutogenContext(context) @@ -625,13 +620,11 @@ class PostgresqlAutogenRenderTest(TestBase): autogen_context = self.autogen_context m = MetaData() - t = Table('t', m, - Column('x', String), - Column('y', String) - ) + t = Table("t", m, Column("x", String), Column("y", String)) - idx = Index('foo_idx', t.c.x, t.c.y, - postgresql_where=(t.c.y == 'something')) + idx = Index( + "foo_idx", t.c.x, t.c.y, postgresql_where=(t.c.y == "something") + ) op_obj = ops.CreateIndexOp.from_index(idx) @@ -639,114 +632,124 @@ class PostgresqlAutogenRenderTest(TestBase): autogenerate.render_op_text(autogen_context, op_obj), """op.create_index('foo_idx', 't', \ ['x', 'y'], unique=False, """ - """postgresql_where=sa.text(!U"y = 'something'"))""" + """postgresql_where=sa.text(!U"y = 'something'"))""", ) def test_render_server_default_native_boolean(self): c = Column( - 'updated_at', Boolean(), - server_default=false(), - nullable=False) - result = autogenerate.render._render_column( - c, self.autogen_context, + "updated_at", Boolean(), server_default=false(), nullable=False ) + result = autogenerate.render._render_column(c, self.autogen_context) eq_ignore_whitespace( result, - 'sa.Column(\'updated_at\', sa.Boolean(), ' - 'server_default=sa.text(!U\'false\'), ' - 'nullable=False)' + "sa.Column('updated_at', sa.Boolean(), " + "server_default=sa.text(!U'false'), " + "nullable=False)", ) def test_postgresql_array_type(self): eq_ignore_whitespace( autogenerate.render._repr_type( - ARRAY(Integer), self.autogen_context), - "postgresql.ARRAY(sa.Integer())" + ARRAY(Integer), self.autogen_context + ), + "postgresql.ARRAY(sa.Integer())", ) eq_ignore_whitespace( autogenerate.render._repr_type( - ARRAY(DateTime(timezone=True)), self.autogen_context), - "postgresql.ARRAY(sa.DateTime(timezone=True))" + ARRAY(DateTime(timezone=True)), self.autogen_context + ), + "postgresql.ARRAY(sa.DateTime(timezone=True))", ) eq_ignore_whitespace( autogenerate.render._repr_type( - ARRAY(BYTEA, as_tuple=True, dimensions=2), - self.autogen_context), - "postgresql.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)" + ARRAY(BYTEA, as_tuple=True, dimensions=2), self.autogen_context + ), + "postgresql.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)", ) - assert 'from sqlalchemy.dialects import postgresql' in \ - self.autogen_context.imports + assert ( + "from sqlalchemy.dialects import postgresql" + in self.autogen_context.imports + ) @config.requirements.sqlalchemy_110 def test_postgresql_hstore_subtypes(self): eq_ignore_whitespace( - autogenerate.render._repr_type( - HSTORE(), self.autogen_context), - "postgresql.HSTORE(text_type=sa.Text())" + autogenerate.render._repr_type(HSTORE(), self.autogen_context), + "postgresql.HSTORE(text_type=sa.Text())", ) eq_ignore_whitespace( autogenerate.render._repr_type( - HSTORE(text_type=String()), self.autogen_context), - "postgresql.HSTORE(text_type=sa.String())" + HSTORE(text_type=String()), self.autogen_context + ), + "postgresql.HSTORE(text_type=sa.String())", ) eq_ignore_whitespace( autogenerate.render._repr_type( - HSTORE(text_type=BYTEA()), self.autogen_context), - "postgresql.HSTORE(text_type=postgresql.BYTEA())" + HSTORE(text_type=BYTEA()), self.autogen_context + ), + "postgresql.HSTORE(text_type=postgresql.BYTEA())", ) - assert 'from sqlalchemy.dialects import postgresql' in \ - self.autogen_context.imports + assert ( + "from sqlalchemy.dialects import postgresql" + in self.autogen_context.imports + ) @config.requirements.sqlalchemy_110 def test_generic_array_type(self): eq_ignore_whitespace( autogenerate.render._repr_type( - types.ARRAY(Integer), self.autogen_context), - "sa.ARRAY(sa.Integer())" + types.ARRAY(Integer), self.autogen_context + ), + "sa.ARRAY(sa.Integer())", ) eq_ignore_whitespace( autogenerate.render._repr_type( - types.ARRAY(DateTime(timezone=True)), self.autogen_context), - "sa.ARRAY(sa.DateTime(timezone=True))" + types.ARRAY(DateTime(timezone=True)), self.autogen_context + ), + "sa.ARRAY(sa.DateTime(timezone=True))", ) - assert 'from sqlalchemy.dialects import postgresql' not in \ - self.autogen_context.imports + assert ( + "from sqlalchemy.dialects import postgresql" + not in self.autogen_context.imports + ) eq_ignore_whitespace( autogenerate.render._repr_type( types.ARRAY(BYTEA, as_tuple=True, dimensions=2), - self.autogen_context), - "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)" + self.autogen_context, + ), + "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)", ) - assert 'from sqlalchemy.dialects import postgresql' in \ - self.autogen_context.imports + assert ( + "from sqlalchemy.dialects import postgresql" + in self.autogen_context.imports + ) def test_array_type_user_defined_inner(self): def repr_type(typestring, object_, autogen_context): - if typestring == 'type' and isinstance(object_, String): + if typestring == "type" and isinstance(object_, String): return "foobar.MYVARCHAR" else: return False - self.autogen_context.opts.update( - render_item=repr_type - ) + self.autogen_context.opts.update(render_item=repr_type) eq_ignore_whitespace( autogenerate.render._repr_type( - ARRAY(String), self.autogen_context), - "postgresql.ARRAY(foobar.MYVARCHAR)" + ARRAY(String), self.autogen_context + ), + "postgresql.ARRAY(foobar.MYVARCHAR)", ) @config.requirements.fail_before_sqla_100 @@ -756,22 +759,18 @@ class PostgresqlAutogenRenderTest(TestBase): autogen_context = self.autogen_context m = MetaData() - t = Table('t', m, - Column('x', String), - Column('y', String) - ) - - op_obj = ops.AddConstraintOp.from_constraint(ExcludeConstraint( - (t.c.x, ">"), - where=t.c.x != 2, - using="gist", - name="t_excl_x" - )) + t = Table("t", m, Column("x", String), Column("y", String)) + + op_obj = ops.AddConstraintOp.from_constraint( + ExcludeConstraint( + (t.c.x, ">"), where=t.c.x != 2, using="gist", name="t_excl_x" + ) + ) eq_ignore_whitespace( autogenerate.render_op_text(autogen_context, op_obj), "op.create_exclude_constraint('t_excl_x', 't', (sa.column('x'), '>'), " - "where=sa.text(!U'x != 2'), using='gist')" + "where=sa.text(!U'x != 2'), using='gist')", ) @config.requirements.fail_before_sqla_100 @@ -781,25 +780,25 @@ class PostgresqlAutogenRenderTest(TestBase): autogen_context = self.autogen_context m = MetaData() - t = Table('TTAble', m, - Column('XColumn', String), - Column('YColumn', String) - ) + t = Table( + "TTAble", m, Column("XColumn", String), Column("YColumn", String) + ) - op_obj = ops.AddConstraintOp.from_constraint(ExcludeConstraint( - (t.c.XColumn, ">"), - where=t.c.XColumn != 2, - using="gist", - name="t_excl_x" - )) + op_obj = ops.AddConstraintOp.from_constraint( + ExcludeConstraint( + (t.c.XColumn, ">"), + where=t.c.XColumn != 2, + using="gist", + name="t_excl_x", + ) + ) eq_ignore_whitespace( autogenerate.render_op_text(autogen_context, op_obj), "op.create_exclude_constraint('t_excl_x', 'TTAble', (sa.column('XColumn'), '>'), " - "where=sa.text(!U'\"XColumn\" != 2'), using='gist')" + "where=sa.text(!U'\"XColumn\" != 2'), using='gist')", ) - @config.requirements.fail_before_sqla_100 def test_inline_exclude_constraint(self): from sqlalchemy.dialects.postgresql import ExcludeConstraint @@ -808,15 +807,13 @@ class PostgresqlAutogenRenderTest(TestBase): m = MetaData() t = Table( - 't', m, - Column('x', String), - Column('y', String), + "t", + m, + Column("x", String), + Column("y", String), ExcludeConstraint( - ('x', ">"), - using="gist", - where='x != 2', - name="t_excl_x" - ) + ("x", ">"), using="gist", where="x != 2", name="t_excl_x" + ), ) op_obj = ops.CreateTableOp.from_table(t) @@ -827,7 +824,7 @@ class PostgresqlAutogenRenderTest(TestBase): "sa.Column('y', sa.String(), nullable=True)," "postgresql.ExcludeConstraint((!U'x', '>'), " "where=sa.text(!U'x != 2'), using='gist', name='t_excl_x')" - ")" + ")", ) @config.requirements.fail_before_sqla_100 @@ -838,15 +835,13 @@ class PostgresqlAutogenRenderTest(TestBase): m = MetaData() t = Table( - 'TTable', m, - Column('XColumn', String), - Column('YColumn', String), + "TTable", m, Column("XColumn", String), Column("YColumn", String) ) ExcludeConstraint( (t.c.XColumn, ">"), using="gist", where='"XColumn" != 2', - name="TExclX" + name="TExclX", ) op_obj = ops.CreateTableOp.from_table(t) @@ -858,33 +853,29 @@ class PostgresqlAutogenRenderTest(TestBase): "sa.Column('YColumn', sa.String(), nullable=True)," "postgresql.ExcludeConstraint((sa.column('XColumn'), '>'), " "where=sa.text(!U'\"XColumn\" != 2'), using='gist', " - "name='TExclX'))" + "name='TExclX'))", ) def test_json_type(self): if config.requirements.sqlalchemy_110.enabled: eq_ignore_whitespace( - autogenerate.render._repr_type( - JSON(), self.autogen_context), - "postgresql.JSON(astext_type=sa.Text())" + autogenerate.render._repr_type(JSON(), self.autogen_context), + "postgresql.JSON(astext_type=sa.Text())", ) else: eq_ignore_whitespace( - autogenerate.render._repr_type( - JSON(), self.autogen_context), - "postgresql.JSON()" + autogenerate.render._repr_type(JSON(), self.autogen_context), + "postgresql.JSON()", ) def test_jsonb_type(self): if config.requirements.sqlalchemy_110.enabled: eq_ignore_whitespace( - autogenerate.render._repr_type( - JSONB(), self.autogen_context), - "postgresql.JSONB(astext_type=sa.Text())" + autogenerate.render._repr_type(JSONB(), self.autogen_context), + "postgresql.JSONB(astext_type=sa.Text())", ) else: eq_ignore_whitespace( - autogenerate.render._repr_type( - JSONB(), self.autogen_context), - "postgresql.JSONB()" + autogenerate.render._repr_type(JSONB(), self.autogen_context), + "postgresql.JSONB()", ) diff --git a/tests/test_revision.py b/tests/test_revision.py index 1f1c342..41a713e 100644 --- a/tests/test_revision.py +++ b/tests/test_revision.py @@ -1,7 +1,11 @@ from alembic.testing.fixtures import TestBase from alembic.testing import eq_, assert_raises_message -from alembic.script.revision import RevisionMap, Revision, MultipleHeads, \ - RevisionError +from alembic.script.revision import ( + RevisionMap, + Revision, + MultipleHeads, + RevisionError, +) from . import _large_map @@ -9,136 +13,144 @@ class APITest(TestBase): def test_add_revision_one_head(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ("b",)), ] ) - eq_(map_.heads, ('c', )) + eq_(map_.heads, ("c",)) - map_.add_revision(Revision('d', ('c', ))) - eq_(map_.heads, ('d', )) + map_.add_revision(Revision("d", ("c",))) + eq_(map_.heads, ("d",)) def test_add_revision_two_head(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c1', ('b',)), - Revision('c2', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c1", ("b",)), + Revision("c2", ("b",)), ] ) - eq_(map_.heads, ('c1', 'c2')) + eq_(map_.heads, ("c1", "c2")) - map_.add_revision(Revision('d1', ('c1', ))) - eq_(map_.heads, ('c2', 'd1')) + map_.add_revision(Revision("d1", ("c1",))) + eq_(map_.heads, ("c2", "d1")) def test_get_revision_head_single(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ("b",)), ] ) - eq_(map_.get_revision('head'), map_._revision_map['c']) + eq_(map_.get_revision("head"), map_._revision_map["c"]) def test_get_revision_base_single(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ("b",)), ] ) - eq_(map_.get_revision('base'), None) + eq_(map_.get_revision("base"), None) def test_get_revision_head_multiple(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c1', ('b',)), - Revision('c2', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c1", ("b",)), + Revision("c2", ("b",)), ] ) assert_raises_message( MultipleHeads, "Multiple heads are present", - map_.get_revision, 'head' + map_.get_revision, + "head", ) def test_get_revision_heads_multiple(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c1', ('b',)), - Revision('c2', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c1", ("b",)), + Revision("c2", ("b",)), ] ) assert_raises_message( MultipleHeads, "Multiple heads are present", - map_.get_revision, "heads" + map_.get_revision, + "heads", ) def test_get_revision_base_multiple(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', ()), - Revision('d', ('c',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ()), + Revision("d", ("c",)), ] ) - eq_(map_.get_revision('base'), None) + eq_(map_.get_revision("base"), None) def test_iterate_tolerates_dupe_targets(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', ('b',)), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ("b",)), ] ) eq_( - [ - r.revision for r in - map_._iterate_revisions(('c', 'c'), 'a') - ], - ['c', 'b', 'a'] + [r.revision for r in map_._iterate_revisions(("c", "c"), "a")], + ["c", "b", "a"], ) def test_repr_revs(self): map_ = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', (), dependencies=('a', 'b')), + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", (), dependencies=("a", "b")), ] ) - c = map_._revision_map['c'] + c = map_._revision_map["c"] eq_(repr(c), "Revision('c', None, dependencies=('a', 'b'))") class DownIterateTest(TestBase): def _assert_iteration( - self, upper, lower, assertion, inclusive=True, map_=None, - implicit_base=False, select_for_downgrade=False): + self, + upper, + lower, + assertion, + inclusive=True, + map_=None, + implicit_base=False, + select_for_downgrade=False, + ): if map_ is None: map_ = self.map eq_( [ - rev.revision for rev in - map_.iterate_revisions( - upper, lower, - inclusive=inclusive, implicit_base=implicit_base, - select_for_downgrade=select_for_downgrade + rev.revision + for rev in map_.iterate_revisions( + upper, + lower, + inclusive=inclusive, + implicit_base=implicit_base, + select_for_downgrade=select_for_downgrade, ) ], - assertion + assertion, ) @@ -146,173 +158,141 @@ class DiamondTest(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b1', ('a',)), - Revision('b2', ('a',)), - Revision('c', ('b1', 'b2')), - Revision('d', ('c',)), + Revision("a", ()), + Revision("b1", ("a",)), + Revision("b2", ("a",)), + Revision("c", ("b1", "b2")), + Revision("d", ("c",)), ] ) def test_iterate_simple_diamond(self): - self._assert_iteration( - "d", "a", - ["d", "c", "b1", "b2", "a"] - ) + self._assert_iteration("d", "a", ["d", "c", "b1", "b2", "a"]) class EmptyMapTest(DownIterateTest): # see issue #258 def setUp(self): - self.map = RevisionMap( - lambda: [] - ) + self.map = RevisionMap(lambda: []) def test_iterate(self): - self._assert_iteration( - "head", "base", - [] - ) + self._assert_iteration("head", "base", []) class LabeledBranchTest(DownIterateTest): def test_dupe_branch_collection(self): fn = lambda: [ - Revision('a', ()), - Revision('b', ('a',)), - Revision('c', ('b',), branch_labels=['xy1']), - Revision('d', ()), - Revision('e', ('d',), branch_labels=['xy1']), - Revision('f', ('e',)) + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ("b",), branch_labels=["xy1"]), + Revision("d", ()), + Revision("e", ("d",), branch_labels=["xy1"]), + Revision("f", ("e",)), ] assert_raises_message( RevisionError, r"Branch name 'xy1' in revision (?:e|c) already " "used by revision (?:e|c)", - getattr, RevisionMap(fn), "_revision_map" + getattr, + RevisionMap(fn), + "_revision_map", ) def test_filter_for_lineage_labeled_head_across_merge(self): fn = lambda: [ - Revision('a', ()), - Revision('b', ('a', )), - Revision('c1', ('b', ), branch_labels='c1branch'), - Revision('c2', ('b', )), - Revision('d', ('c1', 'c2')), - + Revision("a", ()), + Revision("b", ("a",)), + Revision("c1", ("b",), branch_labels="c1branch"), + Revision("c2", ("b",)), + Revision("d", ("c1", "c2")), ] map_ = RevisionMap(fn) - c1 = map_.get_revision('c1') - c2 = map_.get_revision('c2') - d = map_.get_revision('d') - eq_( - map_.filter_for_lineage([c1, c2, d], "c1branch@head"), - [c1, c2, d] - ) + c1 = map_.get_revision("c1") + c2 = map_.get_revision("c2") + d = map_.get_revision("d") + eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), [c1, c2, d]) def test_filter_for_lineage_heads(self): eq_( - self.map.filter_for_lineage( - [self.map.get_revision("f")], - "heads" - ), - [self.map.get_revision("f")] + self.map.filter_for_lineage([self.map.get_revision("f")], "heads"), + [self.map.get_revision("f")], ) def setUp(self): - self.map = RevisionMap(lambda: [ - Revision('a', (), branch_labels='abranch'), - Revision('b', ('a',)), - Revision('somelongername', ('b',)), - Revision('c', ('somelongername',)), - Revision('d', ()), - Revision('e', ('d',), branch_labels=['ebranch']), - Revision('someothername', ('e',)), - Revision('f', ('someothername',)), - ]) + self.map = RevisionMap( + lambda: [ + Revision("a", (), branch_labels="abranch"), + Revision("b", ("a",)), + Revision("somelongername", ("b",)), + Revision("c", ("somelongername",)), + Revision("d", ()), + Revision("e", ("d",), branch_labels=["ebranch"]), + Revision("someothername", ("e",)), + Revision("f", ("someothername",)), + ] + ) def test_get_base_revisions_labeled(self): - eq_( - self.map._get_base_revisions("somelongername@base"), - ['a'] - ) + eq_(self.map._get_base_revisions("somelongername@base"), ["a"]) def test_get_current_named_rev(self): - eq_( - self.map.get_revision("ebranch@head"), - self.map.get_revision("f") - ) + eq_(self.map.get_revision("ebranch@head"), self.map.get_revision("f")) def test_get_base_revisions(self): - eq_( - self.map._get_base_revisions("base"), - ['a', 'd'] - ) + eq_(self.map._get_base_revisions("base"), ["a", "d"]) def test_iterate_head_to_named_base(self): self._assert_iteration( - "heads", "ebranch@base", - ['f', 'someothername', 'e', 'd'] + "heads", "ebranch@base", ["f", "someothername", "e", "d"] ) self._assert_iteration( - "heads", "abranch@base", - ['c', 'somelongername', 'b', 'a'] + "heads", "abranch@base", ["c", "somelongername", "b", "a"] ) def test_iterate_named_head_to_base(self): self._assert_iteration( - "ebranch@head", "base", - ['f', 'someothername', 'e', 'd'] + "ebranch@head", "base", ["f", "someothername", "e", "d"] ) self._assert_iteration( - "abranch@head", "base", - ['c', 'somelongername', 'b', 'a'] + "abranch@head", "base", ["c", "somelongername", "b", "a"] ) def test_iterate_named_head_to_heads(self): - self._assert_iteration( - "heads", "ebranch@head", - ['f'], - inclusive=True - ) + self._assert_iteration("heads", "ebranch@head", ["f"], inclusive=True) def test_iterate_named_rev_to_heads(self): self._assert_iteration( - "heads", "ebranch@d", - ['f', 'someothername', 'e', 'd'], - inclusive=True + "heads", + "ebranch@d", + ["f", "someothername", "e", "d"], + inclusive=True, ) def test_iterate_head_to_version_specific_base(self): self._assert_iteration( - "heads", "e@base", - ['f', 'someothername', 'e', 'd'] + "heads", "e@base", ["f", "someothername", "e", "d"] ) self._assert_iteration( - "heads", "c@base", - ['c', 'somelongername', 'b', 'a'] + "heads", "c@base", ["c", "somelongername", "b", "a"] ) def test_iterate_to_branch_at_rev(self): self._assert_iteration( - "heads", "ebranch@d", - ['f', 'someothername', 'e', 'd'] + "heads", "ebranch@d", ["f", "someothername", "e", "d"] ) def test_branch_w_down_relative(self): self._assert_iteration( - "heads", "ebranch@-2", - ['f', 'someothername', 'e'] + "heads", "ebranch@-2", ["f", "someothername", "e"] ) def test_branch_w_up_relative(self): self._assert_iteration( - "ebranch@+2", "base", - ['someothername', 'e', 'd'] + "ebranch@+2", "base", ["someothername", "e", "d"] ) def test_partial_id_resolve(self): @@ -320,43 +300,43 @@ class LabeledBranchTest(DownIterateTest): eq_(self.map.get_revision("abranch@some").revision, "somelongername") def test_branch_at_heads(self): - eq_( - self.map.get_revision("abranch@heads").revision, - "c" - ) + eq_(self.map.get_revision("abranch@heads").revision, "c") def test_branch_at_syntax(self): - eq_(self.map.get_revision("abranch@head").revision, 'c') + eq_(self.map.get_revision("abranch@head").revision, "c") eq_(self.map.get_revision("abranch@base"), None) - eq_(self.map.get_revision("ebranch@head").revision, 'f') + eq_(self.map.get_revision("ebranch@head").revision, "f") eq_(self.map.get_revision("abranch@base"), None) - eq_(self.map.get_revision("ebranch@d").revision, 'd') + eq_(self.map.get_revision("ebranch@d").revision, "d") def test_branch_at_self(self): - eq_(self.map.get_revision("ebranch@ebranch").revision, 'e') + eq_(self.map.get_revision("ebranch@ebranch").revision, "e") def test_retrieve_branch_revision(self): - eq_(self.map.get_revision("abranch").revision, 'a') - eq_(self.map.get_revision("ebranch").revision, 'e') + eq_(self.map.get_revision("abranch").revision, "a") + eq_(self.map.get_revision("ebranch").revision, "e") def test_rev_not_in_branch(self): assert_raises_message( RevisionError, "Revision b is not a member of branch 'ebranch'", - self.map.get_revision, "ebranch@b" + self.map.get_revision, + "ebranch@b", ) assert_raises_message( RevisionError, "Revision d is not a member of branch 'abranch'", - self.map.get_revision, "abranch@d" + self.map.get_revision, + "abranch@d", ) def test_no_revision_exists(self): assert_raises_message( RevisionError, "No such revision or branch 'q'", - self.map.get_revision, "abranch@q" + self.map.get_revision, + "abranch@q", ) def test_not_actually_a_branch(self): @@ -367,9 +347,7 @@ class LabeledBranchTest(DownIterateTest): def test_no_such_branch(self): assert_raises_message( - RevisionError, - "No such branch: 'x'", - self.map.get_revision, "x@d" + RevisionError, "No such branch: 'x'", self.map.get_revision, "x@d" ) @@ -377,19 +355,18 @@ class LongShortBranchTest(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b1', ('a',)), - Revision('b2', ('a',)), - Revision('c1', ('b1',)), - Revision('d11', ('c1',)), - Revision('d12', ('c1',)), + Revision("a", ()), + Revision("b1", ("a",)), + Revision("b2", ("a",)), + Revision("c1", ("b1",)), + Revision("d11", ("c1",)), + Revision("d12", ("c1",)), ] ) def test_iterate_full(self): self._assert_iteration( - "heads", "base", - ['b2', 'd11', 'd12', 'c1', 'b1', 'a'] + "heads", "base", ["b2", "d11", "d12", "c1", "b1", "a"] ) @@ -397,56 +374,46 @@ class MultipleBranchTest(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('a', ()), - Revision('b1', ('a',)), - Revision('b2', ('a',)), - Revision('cb1', ('b1',)), - Revision('cb2', ('b2',)), - Revision('d1cb1', ('cb1',)), # head - Revision('d2cb1', ('cb1',)), # head - Revision('d1cb2', ('cb2',)), - Revision('d2cb2', ('cb2',)), - Revision('d3cb2', ('cb2',)), # head - Revision('d1d2cb2', ('d1cb2', 'd2cb2')) # head + merge point + Revision("a", ()), + Revision("b1", ("a",)), + Revision("b2", ("a",)), + Revision("cb1", ("b1",)), + Revision("cb2", ("b2",)), + Revision("d1cb1", ("cb1",)), # head + Revision("d2cb1", ("cb1",)), # head + Revision("d1cb2", ("cb2",)), + Revision("d2cb2", ("cb2",)), + Revision("d3cb2", ("cb2",)), # head + Revision("d1d2cb2", ("d1cb2", "d2cb2")), # head + merge point ] ) def test_iterate_from_merge_point(self): self._assert_iteration( - "d1d2cb2", "a", - ['d1d2cb2', 'd1cb2', 'd2cb2', 'cb2', 'b2', 'a'] + "d1d2cb2", "a", ["d1d2cb2", "d1cb2", "d2cb2", "cb2", "b2", "a"] ) def test_iterate_multiple_heads(self): self._assert_iteration( - ["d2cb2", "d3cb2"], "a", - ['d2cb2', 'd3cb2', 'cb2', 'b2', 'a'] + ["d2cb2", "d3cb2"], "a", ["d2cb2", "d3cb2", "cb2", "b2", "a"] ) def test_iterate_single_branch(self): - self._assert_iteration( - "d3cb2", "a", - ['d3cb2', 'cb2', 'b2', 'a'] - ) + self._assert_iteration("d3cb2", "a", ["d3cb2", "cb2", "b2", "a"]) def test_iterate_single_branch_to_base(self): - self._assert_iteration( - "d3cb2", "base", - ['d3cb2', 'cb2', 'b2', 'a'] - ) + self._assert_iteration("d3cb2", "base", ["d3cb2", "cb2", "b2", "a"]) def test_iterate_multiple_branch_to_base(self): self._assert_iteration( - ["d3cb2", "cb1"], "base", - ['d3cb2', 'cb2', 'b2', 'cb1', 'b1', 'a'] + ["d3cb2", "cb1"], "base", ["d3cb2", "cb2", "b2", "cb1", "b1", "a"] ) def test_iterate_multiple_heads_single_base(self): # head d1cb1 is omitted as it is not # a descendant of b2 self._assert_iteration( - ["d1cb1", "d2cb2", "d3cb2"], "b2", - ["d2cb2", 'd3cb2', 'cb2', 'b2'] + ["d1cb1", "d2cb2", "d3cb2"], "b2", ["d2cb2", "d3cb2", "cb2", "b2"] ) def test_same_branch_wrong_direction(self): @@ -456,7 +423,7 @@ class MultipleBranchTest(DownIterateTest): RevisionError, r"Revision d1cb1 is not an ancestor of revision b1", list, - self.map._iterate_revisions('b1', 'd1cb1') + self.map._iterate_revisions("b1", "d1cb1"), ) def test_distinct_branches(self): @@ -465,7 +432,7 @@ class MultipleBranchTest(DownIterateTest): RevisionError, r"Revision b1 is not an ancestor of revision d2cb2", list, - self.map._iterate_revisions('d2cb2', 'b1') + self.map._iterate_revisions("d2cb2", "b1"), ) def test_wrong_direction_to_base_as_none(self): @@ -475,7 +442,7 @@ class MultipleBranchTest(DownIterateTest): RevisionError, r"Revision d1cb1 is not an ancestor of revision base", list, - self.map._iterate_revisions(None, 'd1cb1') + self.map._iterate_revisions(None, "d1cb1"), ) def test_wrong_direction_to_base_as_empty(self): @@ -485,7 +452,7 @@ class MultipleBranchTest(DownIterateTest): RevisionError, r"Revision d1cb1 is not an ancestor of revision base", list, - self.map._iterate_revisions((), 'd1cb1') + self.map._iterate_revisions((), "d1cb1"), ) @@ -501,22 +468,20 @@ class BranchTravellingTest(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('a1', ()), - Revision('a2', ('a1',)), - Revision('a3', ('a2',)), - Revision('b1', ('a3',)), - Revision('b2', ('a3',)), - Revision('cb1', ('b1',)), - Revision('cb2', ('b2',)), - Revision('db1', ('cb1',)), - Revision('db2', ('cb2',)), - - Revision('e1b1', ('db1',)), - Revision('fe1b1', ('e1b1',)), - - Revision('e2b1', ('db1',)), - Revision('e2b2', ('db2',)), - Revision("merge", ('e2b1', 'e2b2')) + Revision("a1", ()), + Revision("a2", ("a1",)), + Revision("a3", ("a2",)), + Revision("b1", ("a3",)), + Revision("b2", ("a3",)), + Revision("cb1", ("b1",)), + Revision("cb2", ("b2",)), + Revision("db1", ("cb1",)), + Revision("db2", ("cb2",)), + Revision("e1b1", ("db1",)), + Revision("fe1b1", ("e1b1",)), + Revision("e2b1", ("db1",)), + Revision("e2b2", ("db2",)), + Revision("merge", ("e2b1", "e2b2")), ] ) @@ -524,19 +489,31 @@ class BranchTravellingTest(DownIterateTest): # test that when we hit a merge point, implicit base will # ensure all branches that supply the merge point are filled in self._assert_iteration( - "merge", "db1", - ['merge', - 'e2b1', 'db1', - 'e2b2', 'db2', 'cb2', 'b2'], - implicit_base=True + "merge", + "db1", + ["merge", "e2b1", "db1", "e2b2", "db2", "cb2", "b2"], + implicit_base=True, ) def test_three_branches_end_in_single_branch(self): self._assert_iteration( - ["merge", "fe1b1"], "a3", - ['merge', 'e2b1', 'e2b2', 'db2', 'cb2', 'b2', - 'fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3'] + ["merge", "fe1b1"], + "a3", + [ + "merge", + "e2b1", + "e2b2", + "db2", + "cb2", + "b2", + "fe1b1", + "e1b1", + "db1", + "cb1", + "b1", + "a3", + ], ) def test_two_branches_to_root(self): @@ -544,80 +521,103 @@ class BranchTravellingTest(DownIterateTest): # here we want 'a3' as a "stop" branch point, but *not* # 'db1', as we don't have multiple traversals on db1 self._assert_iteration( - "merge", "a1", - ['merge', - 'e2b1', 'db1', 'cb1', 'b1', # e2b1 branch - 'e2b2', 'db2', 'cb2', 'b2', # e2b2 branch - 'a3', # both terminate at a3 - 'a2', 'a1' # finish out - ] # noqa + "merge", + "a1", + [ + "merge", + "e2b1", + "db1", + "cb1", + "b1", # e2b1 branch + "e2b2", + "db2", + "cb2", + "b2", # e2b2 branch + "a3", # both terminate at a3 + "a2", + "a1", # finish out + ], # noqa ) def test_two_branches_end_in_branch(self): self._assert_iteration( - "merge", "b1", + "merge", + "b1", # 'b1' is local to 'e2b1' # branch so that is all we get - ['merge', 'e2b1', 'db1', 'cb1', 'b1', - - ] # noqa + ["merge", "e2b1", "db1", "cb1", "b1"], # noqa ) def test_two_branches_end_behind_branch(self): self._assert_iteration( - "merge", "a2", - ['merge', - 'e2b1', 'db1', 'cb1', 'b1', # e2b1 branch - 'e2b2', 'db2', 'cb2', 'b2', # e2b2 branch - 'a3', # both terminate at a3 - 'a2' - ] # noqa + "merge", + "a2", + [ + "merge", + "e2b1", + "db1", + "cb1", + "b1", # e2b1 branch + "e2b2", + "db2", + "cb2", + "b2", # e2b2 branch + "a3", # both terminate at a3 + "a2", + ], # noqa ) def test_three_branches_to_root(self): # in this case, both "a3" and "db1" are stop points self._assert_iteration( - ["merge", "fe1b1"], "a1", - ['merge', - 'e2b1', # e2b1 branch - 'e2b2', 'db2', 'cb2', 'b2', # e2b2 branch - 'fe1b1', 'e1b1', # fe1b1 branch - 'db1', # fe1b1 and e2b1 branches terminate at db1 - 'cb1', 'b1', # e2b1 branch continued....might be nicer - # if this was before the e2b2 branch... - 'a3', # e2b1 and e2b2 branches terminate at a3 - 'a2', 'a1' # finish out - ] # noqa + ["merge", "fe1b1"], + "a1", + [ + "merge", + "e2b1", # e2b1 branch + "e2b2", + "db2", + "cb2", + "b2", # e2b2 branch + "fe1b1", + "e1b1", # fe1b1 branch + "db1", # fe1b1 and e2b1 branches terminate at db1 + "cb1", + "b1", # e2b1 branch continued....might be nicer + # if this was before the e2b2 branch... + "a3", # e2b1 and e2b2 branches terminate at a3 + "a2", + "a1", # finish out + ], # noqa ) def test_three_branches_end_multiple_bases(self): # in this case, both "a3" and "db1" are stop points self._assert_iteration( - ["merge", "fe1b1"], ["cb1", "cb2"], + ["merge", "fe1b1"], + ["cb1", "cb2"], [ - 'merge', - 'e2b1', - 'e2b2', 'db2', 'cb2', - 'fe1b1', 'e1b1', - 'db1', - 'cb1' - ] + "merge", + "e2b1", + "e2b2", + "db2", + "cb2", + "fe1b1", + "e1b1", + "db1", + "cb1", + ], ) def test_three_branches_end_multiple_bases_exclusive(self): self._assert_iteration( - ["merge", "fe1b1"], ["cb1", "cb2"], - [ - 'merge', - 'e2b1', - 'e2b2', 'db2', - 'fe1b1', 'e1b1', - 'db1', - ], - inclusive=False + ["merge", "fe1b1"], + ["cb1", "cb2"], + ["merge", "e2b1", "e2b2", "db2", "fe1b1", "e1b1", "db1"], + inclusive=False, ) def test_detect_invalid_head_selection(self): @@ -627,26 +627,34 @@ class BranchTravellingTest(DownIterateTest): "Requested revision fe1b1 overlaps " "with other requested revisions", list, - self.map._iterate_revisions(["db1", "b2", "fe1b1"], ()) + self.map._iterate_revisions(["db1", "b2", "fe1b1"], ()), ) def test_three_branches_end_multiple_bases_exclusive_blank(self): self._assert_iteration( - ["e2b1", "b2", "fe1b1"], (), + ["e2b1", "b2", "fe1b1"], + (), [ - 'e2b1', - 'b2', - 'fe1b1', 'e1b1', - 'db1', 'cb1', 'b1', 'a3', 'a2', 'a1' + "e2b1", + "b2", + "fe1b1", + "e1b1", + "db1", + "cb1", + "b1", + "a3", + "a2", + "a1", ], - inclusive=False + inclusive=False, ) def test_iterate_to_symbolic_base(self): self._assert_iteration( - ["fe1b1"], "base", - ['fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3', 'a2', 'a1'], - inclusive=False + ["fe1b1"], + "base", + ["fe1b1", "e1b1", "db1", "cb1", "b1", "a3", "a2", "a1"], + inclusive=False, ) def test_ancestor_nodes(self): @@ -656,8 +664,22 @@ class BranchTravellingTest(DownIterateTest): rev.revision for rev in self.map._get_ancestor_nodes([merge], check=True) ), - set(['a1', 'e2b2', 'e2b1', 'cb2', 'merge', - 'a3', 'a2', 'b1', 'b2', 'db1', 'db2', 'cb1']) + set( + [ + "a1", + "e2b2", + "e2b1", + "cb2", + "merge", + "a3", + "a2", + "b1", + "b2", + "db1", + "db2", + "cb1", + ] + ), ) @@ -665,125 +687,153 @@ class MultipleBaseTest(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('base1', ()), - Revision('base2', ()), - Revision('base3', ()), - - Revision('a1a', ('base1',)), - Revision('a1b', ('base1',)), - Revision('a2', ('base2',)), - Revision('a3', ('base3',)), - - Revision('b1a', ('a1a',)), - Revision('b1b', ('a1b',)), - Revision('b2', ('a2',)), - Revision('b3', ('a3',)), - - Revision('c2', ('b2',)), - Revision('d2', ('c2',)), - - Revision('mergeb3d2', ('b3', 'd2')) + Revision("base1", ()), + Revision("base2", ()), + Revision("base3", ()), + Revision("a1a", ("base1",)), + Revision("a1b", ("base1",)), + Revision("a2", ("base2",)), + Revision("a3", ("base3",)), + Revision("b1a", ("a1a",)), + Revision("b1b", ("a1b",)), + Revision("b2", ("a2",)), + Revision("b3", ("a3",)), + Revision("c2", ("b2",)), + Revision("d2", ("c2",)), + Revision("mergeb3d2", ("b3", "d2")), ] ) def test_heads_to_base(self): self._assert_iteration( - "heads", "base", + "heads", + "base", [ - 'b1a', 'a1a', - 'b1b', 'a1b', - 'mergeb3d2', - 'b3', 'a3', 'base3', - 'd2', 'c2', 'b2', 'a2', 'base2', - 'base1' - ] + "b1a", + "a1a", + "b1b", + "a1b", + "mergeb3d2", + "b3", + "a3", + "base3", + "d2", + "c2", + "b2", + "a2", + "base2", + "base1", + ], ) def test_heads_to_base_exclusive(self): self._assert_iteration( - "heads", "base", + "heads", + "base", [ - 'b1a', 'a1a', - 'b1b', 'a1b', - 'mergeb3d2', - 'b3', 'a3', 'base3', - 'd2', 'c2', 'b2', 'a2', 'base2', - 'base1', + "b1a", + "a1a", + "b1b", + "a1b", + "mergeb3d2", + "b3", + "a3", + "base3", + "d2", + "c2", + "b2", + "a2", + "base2", + "base1", ], - inclusive=False + inclusive=False, ) def test_heads_to_blank(self): self._assert_iteration( - "heads", None, + "heads", + None, [ - 'b1a', 'a1a', - 'b1b', 'a1b', - 'mergeb3d2', - 'b3', 'a3', 'base3', - 'd2', 'c2', 'b2', 'a2', 'base2', - 'base1' - ] + "b1a", + "a1a", + "b1b", + "a1b", + "mergeb3d2", + "b3", + "a3", + "base3", + "d2", + "c2", + "b2", + "a2", + "base2", + "base1", + ], ) def test_detect_invalid_base_selection(self): assert_raises_message( RevisionError, - "Requested revision a2 overlaps with " - "other requested revisions", + "Requested revision a2 overlaps with " "other requested revisions", list, - self.map._iterate_revisions(["c2"], ["a2", "b2"]) + self.map._iterate_revisions(["c2"], ["a2", "b2"]), ) def test_heads_to_revs_plus_implicit_base_exclusive(self): self._assert_iteration( - "heads", ["c2"], + "heads", + ["c2"], [ - 'b1a', 'a1a', - 'b1b', 'a1b', - 'mergeb3d2', - 'b3', 'a3', 'base3', - 'd2', - 'base1' + "b1a", + "a1a", + "b1b", + "a1b", + "mergeb3d2", + "b3", + "a3", + "base3", + "d2", + "base1", ], inclusive=False, - implicit_base=True + implicit_base=True, ) def test_heads_to_revs_base_exclusive(self): self._assert_iteration( - "heads", ["c2"], - [ - 'mergeb3d2', 'd2' - ], - inclusive=False + "heads", ["c2"], ["mergeb3d2", "d2"], inclusive=False ) def test_heads_to_revs_plus_implicit_base_inclusive(self): self._assert_iteration( - "heads", ["c2"], + "heads", + ["c2"], [ - 'b1a', 'a1a', - 'b1b', 'a1b', - 'mergeb3d2', - 'b3', 'a3', 'base3', - 'd2', 'c2', - 'base1' + "b1a", + "a1a", + "b1b", + "a1b", + "mergeb3d2", + "b3", + "a3", + "base3", + "d2", + "c2", + "base1", ], - implicit_base=True + implicit_base=True, ) def test_specific_path_one(self): - self._assert_iteration( - "b3", "base3", - ['b3', 'a3', 'base3'] - ) + self._assert_iteration("b3", "base3", ["b3", "a3", "base3"]) def test_specific_path_two_implicit_base(self): self._assert_iteration( - ["b3", "b2"], "base3", - ['b3', 'a3', 'b2', 'a2', 'base2'], - inclusive=False, implicit_base=True + ["b3", "b2"], + "base3", + ["b3", "a3", "b2", "a2", "base2"], + inclusive=False, + implicit_base=True, ) @@ -808,21 +858,19 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest): """ self.map = RevisionMap( lambda: [ - Revision('base1', (), branch_labels='b_1'), - Revision('a1a', ('base1',)), - Revision('a1b', ('base1',)), - Revision('b1a', ('a1a',)), - Revision('b1b', ('a1b', ), dependencies='a3'), - - Revision('base2', (), branch_labels='b_2'), - Revision('a2', ('base2',)), - Revision('b2', ('a2',)), - Revision('c2', ('b2', ), dependencies='a3'), - Revision('d2', ('c2',)), - - Revision('base3', (), branch_labels='b_3'), - Revision('a3', ('base3',)), - Revision('b3', ('a3',)), + Revision("base1", (), branch_labels="b_1"), + Revision("a1a", ("base1",)), + Revision("a1b", ("base1",)), + Revision("b1a", ("a1a",)), + Revision("b1b", ("a1b",), dependencies="a3"), + Revision("base2", (), branch_labels="b_2"), + Revision("a2", ("base2",)), + Revision("b2", ("a2",)), + Revision("c2", ("b2",), dependencies="a3"), + Revision("d2", ("c2",)), + Revision("base3", (), branch_labels="b_3"), + Revision("a3", ("base3",)), + Revision("b3", ("a3",)), ] ) @@ -831,25 +879,45 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest): def test_heads_to_base(self): self._assert_iteration( - "heads", "base", + "heads", + "base", [ - - 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2', - 'b3', 'a3', 'base3', - 'base1' - ] + "b1a", + "a1a", + "b1b", + "a1b", + "d2", + "c2", + "b2", + "a2", + "base2", + "b3", + "a3", + "base3", + "base1", + ], ) def test_heads_to_base_downgrade(self): self._assert_iteration( - "heads", "base", + "heads", + "base", [ - - 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2', - 'b3', 'a3', 'base3', - 'base1' + "b1a", + "a1a", + "b1b", + "a1b", + "d2", + "c2", + "b2", + "a2", + "base2", + "b3", + "a3", + "base3", + "base1", ], - select_for_downgrade=True + select_for_downgrade=True, ) def test_same_branch_wrong_direction(self): @@ -857,83 +925,79 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest): RevisionError, r"Revision d2 is not an ancestor of revision b2", list, - self.map._iterate_revisions('b2', 'd2') + self.map._iterate_revisions("b2", "d2"), ) def test_different_branch_not_wrong_direction(self): - self._assert_iteration( - "b3", "d2", - [] - ) + self._assert_iteration("b3", "d2", []) def test_we_need_head2_upgrade(self): # the 2 branch relies on the 3 branch self._assert_iteration( - "b_2@head", "base", - ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3'] + "b_2@head", + "base", + ["d2", "c2", "b2", "a2", "base2", "a3", "base3"], ) def test_we_need_head2_downgrade(self): # the 2 branch relies on the 3 branch, but # on the downgrade side, don't need to touch the 3 branch self._assert_iteration( - "b_2@head", "b_2@base", - ['d2', 'c2', 'b2', 'a2', 'base2'], - select_for_downgrade=True + "b_2@head", + "b_2@base", + ["d2", "c2", "b2", "a2", "base2"], + select_for_downgrade=True, ) def test_we_need_head3_upgrade(self): # the 3 branch can be upgraded alone. - self._assert_iteration( - "b_3@head", "base", - ['b3', 'a3', 'base3'] - ) + self._assert_iteration("b_3@head", "base", ["b3", "a3", "base3"]) def test_we_need_head3_downgrade(self): # the 3 branch can be upgraded alone. self._assert_iteration( - "b_3@head", "base", - ['b3', 'a3', 'base3'], - select_for_downgrade=True + "b_3@head", + "base", + ["b3", "a3", "base3"], + select_for_downgrade=True, ) def test_we_need_head1_upgrade(self): # the 1 branch relies on the 3 branch self._assert_iteration( - "b1b@head", "base", - ['b1b', 'a1b', 'base1', 'a3', 'base3'] + "b1b@head", "base", ["b1b", "a1b", "base1", "a3", "base3"] ) def test_we_need_head1_downgrade(self): # going down we don't need a3-> base3, as long # as we are limiting the base target self._assert_iteration( - "b1b@head", "b1b@base", - ['b1b', 'a1b', 'base1'], - select_for_downgrade=True + "b1b@head", + "b1b@base", + ["b1b", "a1b", "base1"], + select_for_downgrade=True, ) def test_we_need_base2_upgrade(self): # consider a downgrade to b_2@base - we # want to run through all the "2"s alone, and we're done. self._assert_iteration( - "heads", "b_2@base", - ['d2', 'c2', 'b2', 'a2', 'base2'] + "heads", "b_2@base", ["d2", "c2", "b2", "a2", "base2"] ) def test_we_need_base2_downgrade(self): # consider a downgrade to b_2@base - we # want to run through all the "2"s alone, and we're done. self._assert_iteration( - "heads", "b_2@base", - ['d2', 'c2', 'b2', 'a2', 'base2'], - select_for_downgrade=True + "heads", + "b_2@base", + ["d2", "c2", "b2", "a2", "base2"], + select_for_downgrade=True, ) def test_we_need_base3_upgrade(self): self._assert_iteration( - "heads", "b_3@base", - ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3'] + "heads", "b_3@base", ["b1b", "d2", "c2", "b3", "a3", "base3"] ) def test_we_need_base3_downgrade(self): @@ -942,9 +1006,10 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest): # as well, which means b1b and c2. Then we can downgrade # the 3s. self._assert_iteration( - "heads", "b_3@base", - ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3'], - select_for_downgrade=True + "heads", + "b_3@base", + ["b1b", "d2", "c2", "b3", "a3", "base3"], + select_for_downgrade=True, ) @@ -952,22 +1017,20 @@ class MultipleBaseCrossDependencyTestTwo(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('base1', (), branch_labels='b_1'), - Revision('a1', 'base1'), - Revision('b1', 'a1'), - Revision('c1', 'b1'), - - Revision('base2', (), dependencies='b_1', branch_labels='b_2'), - Revision('a2', 'base2'), - Revision('b2', 'a2'), - Revision('c2', 'b2'), - Revision('d2', 'c2'), - - Revision('base3', (), branch_labels='b_3'), - Revision('a3', 'base3'), - Revision('b3', 'a3'), - Revision('c3', 'b3', dependencies='b2'), - Revision('d3', 'c3'), + Revision("base1", (), branch_labels="b_1"), + Revision("a1", "base1"), + Revision("b1", "a1"), + Revision("c1", "b1"), + Revision("base2", (), dependencies="b_1", branch_labels="b_2"), + Revision("a2", "base2"), + Revision("b2", "a2"), + Revision("c2", "b2"), + Revision("d2", "c2"), + Revision("base3", (), branch_labels="b_3"), + Revision("a3", "base3"), + Revision("b3", "a3"), + Revision("c3", "b3", dependencies="b2"), + Revision("d3", "c3"), ] ) @@ -976,55 +1039,68 @@ class MultipleBaseCrossDependencyTestTwo(DownIterateTest): def test_heads_to_base(self): self._assert_iteration( - "heads", "base", + "heads", + "base", [ - 'c1', 'b1', 'a1', - 'd2', 'c2', - 'd3', 'c3', 'b3', 'a3', 'base3', - 'b2', 'a2', 'base2', - 'base1' - ] + "c1", + "b1", + "a1", + "d2", + "c2", + "d3", + "c3", + "b3", + "a3", + "base3", + "b2", + "a2", + "base2", + "base1", + ], ) def test_we_need_head2(self): self._assert_iteration( - "b_2@head", "base", - ['d2', 'c2', 'b2', 'a2', 'base2', 'base1'] + "b_2@head", "base", ["d2", "c2", "b2", "a2", "base2", "base1"] ) def test_we_need_head3(self): self._assert_iteration( - "b_3@head", "base", - ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1'] + "b_3@head", + "base", + ["d3", "c3", "b3", "a3", "base3", "b2", "a2", "base2", "base1"], ) def test_we_need_head1(self): - self._assert_iteration( - "b_1@head", "base", - ['c1', 'b1', 'a1', 'base1'] - ) + self._assert_iteration("b_1@head", "base", ["c1", "b1", "a1", "base1"]) def test_we_need_base1(self): self._assert_iteration( - "heads", "b_1@base", + "heads", + "b_1@base", [ - 'c1', 'b1', 'a1', - 'd2', 'c2', - 'd3', 'c3', 'b2', 'a2', 'base2', - 'base1' - ] + "c1", + "b1", + "a1", + "d2", + "c2", + "d3", + "c3", + "b2", + "a2", + "base2", + "base1", + ], ) def test_we_need_base2(self): self._assert_iteration( - "heads", "b_2@base", - ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2'] + "heads", "b_2@base", ["d2", "c2", "d3", "c3", "b2", "a2", "base2"] ) def test_we_need_base3(self): self._assert_iteration( - "heads", "b_3@base", - ['d3', 'c3', 'b3', 'a3', 'base3'] + "heads", "b_3@base", ["d3", "c3", "b3", "a3", "base3"] ) @@ -1035,24 +1111,21 @@ class LargeMapTest(DownIterateTest): def test_all(self): raw = [r for r in self.map._revision_map.values() if r is not None] - revs = [ - rev for rev in - self.map.iterate_revisions( - "heads", "base" - ) - ] + revs = [rev for rev in self.map.iterate_revisions("heads", "base")] eq_(set(raw), set(revs)) for idx, rev in enumerate(revs): - ancestors = set( - self.map._get_ancestor_nodes([rev])).difference([rev]) + ancestors = set(self.map._get_ancestor_nodes([rev])).difference( + [rev] + ) descendants = set( - self.map._get_descendant_nodes([rev])).difference([rev]) + self.map._get_descendant_nodes([rev]) + ).difference([rev]) assert not ancestors.intersection(descendants) - remaining = set(revs[idx + 1:]) + remaining = set(revs[idx + 1 :]) if remaining: assert remaining.intersection(ancestors) @@ -1061,22 +1134,20 @@ class DepResolutionFailedTest(DownIterateTest): def setUp(self): self.map = RevisionMap( lambda: [ - Revision('base1', ()), - Revision('a1', 'base1'), - Revision('a2', 'base1'), - Revision('b1', 'a1'), - Revision('c1', 'b1'), + Revision("base1", ()), + Revision("a1", "base1"), + Revision("a2", "base1"), + Revision("b1", "a1"), + Revision("c1", "b1"), ] ) # intentionally make a broken map - self.map._revision_map['fake'] = self.map._revision_map['a2'] - self.map._revision_map['b1'].dependencies = 'fake' - self.map._revision_map['b1']._resolved_dependencies = ('fake', ) + self.map._revision_map["fake"] = self.map._revision_map["a2"] + self.map._revision_map["b1"].dependencies = "fake" + self.map._revision_map["b1"]._resolved_dependencies = ("fake",) def test_failure_message(self): iter_ = self.map.iterate_revisions("c1", "base1") assert_raises_message( - RevisionError, - "Dependency resolution failed;", - list, iter_ + RevisionError, "Dependency resolution failed;", list, iter_ ) diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py index b394784..749b173 100644 --- a/tests/test_script_consumption.py +++ b/tests/test_script_consumption.py @@ -7,9 +7,16 @@ import textwrap from alembic import command, util from alembic.util import compat from alembic.script import ScriptDirectory, Script -from alembic.testing.env import clear_staging_env, staging_env, \ - _sqlite_testing_config, write_script, _sqlite_file_db, \ - three_rev_fixture, _no_sql_testing_config, env_file_fixture +from alembic.testing.env import ( + clear_staging_env, + staging_env, + _sqlite_testing_config, + write_script, + _sqlite_file_db, + three_rev_fixture, + _no_sql_testing_config, + env_file_fixture, +) from alembic.testing import eq_, assert_raises_message from alembic.testing.fixtures import TestBase, capture_context_buffer from alembic.environment import EnvironmentContext @@ -18,7 +25,7 @@ from alembic.testing import mock class ApplyVersionsFunctionalTest(TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" sourceless = False @@ -46,7 +53,10 @@ class ApplyVersionsFunctionalTest(TestBase): script = ScriptDirectory.from_config(self.cfg) script.generate_revision(a, None, refresh=True) - write_script(script, a, """ + write_script( + script, + a, + """ revision = '%s' down_revision = None @@ -60,10 +70,16 @@ class ApplyVersionsFunctionalTest(TestBase): def downgrade(): op.execute("DROP TABLE foo") - """ % a, sourceless=self.sourceless) + """ + % a, + sourceless=self.sourceless, + ) script.generate_revision(b, None, refresh=True) - write_script(script, b, """ + write_script( + script, + b, + """ revision = '%s' down_revision = '%s' @@ -77,10 +93,16 @@ class ApplyVersionsFunctionalTest(TestBase): def downgrade(): op.execute("DROP TABLE bar") - """ % (b, a), sourceless=self.sourceless) + """ + % (b, a), + sourceless=self.sourceless, + ) script.generate_revision(c, None, refresh=True) - write_script(script, c, """ + write_script( + script, + c, + """ revision = '%s' down_revision = '%s' @@ -94,49 +116,52 @@ class ApplyVersionsFunctionalTest(TestBase): def downgrade(): op.execute("DROP TABLE bat") - """ % (c, b), sourceless=self.sourceless) + """ + % (c, b), + sourceless=self.sourceless, + ) def _test_002_upgrade(self): command.upgrade(self.cfg, self.c) db = self.bind - assert db.dialect.has_table(db.connect(), 'foo') - assert db.dialect.has_table(db.connect(), 'bar') - assert db.dialect.has_table(db.connect(), 'bat') + assert db.dialect.has_table(db.connect(), "foo") + assert db.dialect.has_table(db.connect(), "bar") + assert db.dialect.has_table(db.connect(), "bat") def _test_003_downgrade(self): command.downgrade(self.cfg, self.a) db = self.bind - assert db.dialect.has_table(db.connect(), 'foo') - assert not db.dialect.has_table(db.connect(), 'bar') - assert not db.dialect.has_table(db.connect(), 'bat') + assert db.dialect.has_table(db.connect(), "foo") + assert not db.dialect.has_table(db.connect(), "bar") + assert not db.dialect.has_table(db.connect(), "bat") def _test_004_downgrade(self): - command.downgrade(self.cfg, 'base') + command.downgrade(self.cfg, "base") db = self.bind - assert not db.dialect.has_table(db.connect(), 'foo') - assert not db.dialect.has_table(db.connect(), 'bar') - assert not db.dialect.has_table(db.connect(), 'bat') + assert not db.dialect.has_table(db.connect(), "foo") + assert not db.dialect.has_table(db.connect(), "bar") + assert not db.dialect.has_table(db.connect(), "bat") def _test_005_upgrade(self): command.upgrade(self.cfg, self.b) db = self.bind - assert db.dialect.has_table(db.connect(), 'foo') - assert db.dialect.has_table(db.connect(), 'bar') - assert not db.dialect.has_table(db.connect(), 'bat') + assert db.dialect.has_table(db.connect(), "foo") + assert db.dialect.has_table(db.connect(), "bar") + assert not db.dialect.has_table(db.connect(), "bat") def _test_006_upgrade_again(self): command.upgrade(self.cfg, self.b) db = self.bind - assert db.dialect.has_table(db.connect(), 'foo') - assert db.dialect.has_table(db.connect(), 'bar') - assert not db.dialect.has_table(db.connect(), 'bat') + assert db.dialect.has_table(db.connect(), "foo") + assert db.dialect.has_table(db.connect(), "bar") + assert not db.dialect.has_table(db.connect(), "bat") def _test_007_stamp_upgrade(self): command.stamp(self.cfg, self.c) db = self.bind - assert db.dialect.has_table(db.connect(), 'foo') - assert db.dialect.has_table(db.connect(), 'bar') - assert not db.dialect.has_table(db.connect(), 'bat') + assert db.dialect.has_table(db.connect(), "foo") + assert db.dialect.has_table(db.connect(), "bar") + assert not db.dialect.has_table(db.connect(), "bat") class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest): @@ -144,25 +169,29 @@ class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest): class NewFangledSourcelessEnvOnlyApplyVersionsTest( - ApplyVersionsFunctionalTest): + ApplyVersionsFunctionalTest +): sourceless = "pep3147_envonly" - __requires__ = "pep3147", + __requires__ = ("pep3147",) class NewFangledSourcelessEverythingApplyVersionsTest( - ApplyVersionsFunctionalTest): + ApplyVersionsFunctionalTest +): sourceless = "pep3147_everything" - __requires__ = "pep3147", + __requires__ = ("pep3147",) class CallbackEnvironmentTest(ApplyVersionsFunctionalTest): - exp_kwargs = frozenset(('ctx', 'heads', 'run_args', 'step')) + exp_kwargs = frozenset(("ctx", "heads", "run_args", "step")) @staticmethod def _env_file_fixture(): - env_file_fixture(textwrap.dedent("""\ + env_file_fixture( + textwrap.dedent( + """\ import alembic from alembic import context from sqlalchemy import engine_from_config, pool @@ -199,13 +228,16 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest): run_migrations_offline() else: run_migrations_online() - """)) + """ + ) + ) def test_steps(self): import alembic + alembic.mock_event_listener = None self._env_file_fixture() - with mock.patch('alembic.mock_event_listener', mock.Mock()) as mymock: + with mock.patch("alembic.mock_event_listener", mock.Mock()) as mymock: super(CallbackEnvironmentTest, self).test_steps() calls = mymock.call_args_list assert calls @@ -213,27 +245,27 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest): args, kw = call assert not args assert set(kw.keys()) >= self.exp_kwargs - assert kw['run_args'] == {} - assert hasattr(kw['ctx'], 'get_current_revision') + assert kw["run_args"] == {} + assert hasattr(kw["ctx"], "get_current_revision") - step = kw['step'] + step = kw["step"] assert isinstance(step.is_upgrade, bool) assert isinstance(step.is_stamp, bool) assert isinstance(step.is_migration, bool) assert isinstance(step.up_revision_id, compat.string_types) assert isinstance(step.up_revision, Script) - for revtype in 'up', 'down', 'source', 'destination': - revs = getattr(step, '%s_revisions' % revtype) + for revtype in "up", "down", "source", "destination": + revs = getattr(step, "%s_revisions" % revtype) assert isinstance(revs, tuple) for rev in revs: assert isinstance(rev, Script) - revids = getattr(step, '%s_revision_ids' % revtype) + revids = getattr(step, "%s_revision_ids" % revtype) for revid in revids: assert isinstance(revid, compat.string_types) - heads = kw['heads'] - assert hasattr(heads, '__iter__') + heads = kw["heads"] + assert hasattr(heads, "__iter__") for h in heads: assert h is None or isinstance(h, compat.string_types) @@ -242,8 +274,8 @@ class OfflineTransactionalDDLTest(TestBase): def setUp(self): self.env = staging_env() self.cfg = cfg = _no_sql_testing_config() - cfg.set_main_option('dialect_name', 'sqlite') - cfg.remove_main_option('url') + cfg.set_main_option("dialect_name", "sqlite") + cfg.remove_main_option("url") self.a, self.b, self.c = three_rev_fixture(cfg) @@ -254,11 +286,12 @@ class OfflineTransactionalDDLTest(TestBase): with capture_context_buffer(transactional_ddl=True) as buf: command.upgrade(self.cfg, self.c, sql=True) assert re.match( - (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % self.a) + - (r".*%s" % self.b) + - (r".*%s.*?COMMIT;.*$" % self.c), - - buf.getvalue(), re.S) + (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % self.a) + + (r".*%s" % self.b) + + (r".*%s.*?COMMIT;.*$" % self.c), + buf.getvalue(), + re.S, + ) def test_begin_commit_nontransactional_ddl(self): with capture_context_buffer(transactional_ddl=False) as buf: @@ -270,11 +303,12 @@ class OfflineTransactionalDDLTest(TestBase): with capture_context_buffer(transaction_per_migration=True) as buf: command.upgrade(self.cfg, self.c, sql=True) assert re.match( - (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % self.a) + - (r"BEGIN;.*?%s.*?COMMIT;.*" % self.b) + - (r"BEGIN;.*?%s.*?COMMIT;.*$" % self.c), - - buf.getvalue(), re.S) + (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % self.a) + + (r"BEGIN;.*?%s.*?COMMIT;.*" % self.b) + + (r"BEGIN;.*?%s.*?COMMIT;.*$" % self.c), + buf.getvalue(), + re.S, + ) class OnlineTransactionalDDLTest(TestBase): @@ -290,7 +324,10 @@ class OnlineTransactionalDDLTest(TestBase): b = util.rev_id() c = util.rev_id() script.generate_revision(a, "revision a", refresh=True) - write_script(script, a, """ + write_script( + script, + a, + """ "rev a" revision = '%s' @@ -302,9 +339,14 @@ def upgrade(): def downgrade(): pass -""" % (a, )) +""" + % (a,), + ) script.generate_revision(b, "revision b", refresh=True) - write_script(script, b, """ + write_script( + script, + b, + """ "rev b" revision = '%s' down_revision = '%s' @@ -320,9 +362,14 @@ def upgrade(): def downgrade(): pass -""" % (b, a)) +""" + % (b, a), + ) script.generate_revision(c, "revision c", refresh=True) - write_script(script, c, """ + write_script( + script, + c, + """ "rev c" revision = '%s' down_revision = '%s' @@ -337,7 +384,9 @@ def upgrade(): def downgrade(): pass -""" % (c, b)) +""" + % (c, b), + ) return a, b, c @contextmanager @@ -347,7 +396,8 @@ def downgrade(): def configure(*arg, **opt): opt.update( transactional_ddl=transactional_ddl, - transaction_per_migration=transaction_per_migration) + transaction_per_migration=transaction_per_migration, + ) return conf(*arg, **opt) with mock.patch.object(EnvironmentContext, "configure", configure): @@ -357,39 +407,47 @@ def downgrade(): a, b, c = self._opened_transaction_fixture() with self._patch_environment( - transactional_ddl=False, transaction_per_migration=False): + transactional_ddl=False, transaction_per_migration=False + ): assert_raises_message( util.CommandError, r'Migration "upgrade .*, rev b" has left an uncommitted ' - r'transaction opened; transactional_ddl is False so Alembic ' - r'is not committing transactions', - command.upgrade, self.cfg, c + r"transaction opened; transactional_ddl is False so Alembic " + r"is not committing transactions", + command.upgrade, + self.cfg, + c, ) def test_raise_when_rev_leaves_open_transaction_tpm(self): a, b, c = self._opened_transaction_fixture() with self._patch_environment( - transactional_ddl=False, transaction_per_migration=True): + transactional_ddl=False, transaction_per_migration=True + ): assert_raises_message( util.CommandError, r'Migration "upgrade .*, rev b" has left an uncommitted ' - r'transaction opened; transactional_ddl is False so Alembic ' - r'is not committing transactions', - command.upgrade, self.cfg, c + r"transaction opened; transactional_ddl is False so Alembic " + r"is not committing transactions", + command.upgrade, + self.cfg, + c, ) def test_noerr_rev_leaves_open_transaction_transactional_ddl(self): a, b, c = self._opened_transaction_fixture() with self._patch_environment( - transactional_ddl=True, transaction_per_migration=False): + transactional_ddl=True, transaction_per_migration=False + ): command.upgrade(self.cfg, c) def test_noerr_transaction_opened_externally(self): a, b, c = self._opened_transaction_fixture() - env_file_fixture(""" + env_file_fixture( + """ from sqlalchemy import engine_from_config, pool def run_migrations_online(): @@ -411,22 +469,27 @@ def run_migrations_online(): run_migrations_online() -""") +""" + ) command.stamp(self.cfg, c) class EncodingTest(TestBase): - def setUp(self): self.env = staging_env() self.cfg = cfg = _no_sql_testing_config() - cfg.set_main_option('dialect_name', 'sqlite') - cfg.remove_main_option('url') + cfg.set_main_option("dialect_name", "sqlite") + cfg.remove_main_option("url") self.a = util.rev_id() script = ScriptDirectory.from_config(cfg) script.generate_revision(self.a, "revision a", refresh=True) - write_script(script, self.a, (compat.u("""# coding: utf-8 + write_script( + script, + self.a, + ( + compat.u( + """# coding: utf-8 from __future__ import unicode_literals revision = '%s' down_revision = None @@ -439,22 +502,25 @@ def upgrade(): def downgrade(): op.execute("drôle de petite voix m’a réveillé") -""") % self.a), encoding='utf-8') +""" + ) + % self.a + ), + encoding="utf-8", + ) def tearDown(self): clear_staging_env() def test_encode(self): with capture_context_buffer( - bytes_io=True, - output_encoding='utf-8' + bytes_io=True, output_encoding="utf-8" ) as buf: command.upgrade(self.cfg, self.a, sql=True) assert compat.u("« S’il vous plaît…").encode("utf-8") in buf.getvalue() class VersionNameTemplateTest(TestBase): - def setUp(self): self.env = staging_env() self.cfg = _sqlite_testing_config() @@ -467,7 +533,10 @@ class VersionNameTemplateTest(TestBase): script = ScriptDirectory.from_config(self.cfg) a = util.rev_id() script.generate_revision(a, "some message", refresh=True) - write_script(script, a, """ + write_script( + script, + a, + """ revision = '%s' down_revision = None @@ -481,7 +550,9 @@ class VersionNameTemplateTest(TestBase): def downgrade(): op.execute("DROP TABLE foo") - """ % a) + """ + % a, + ) script = ScriptDirectory.from_config(self.cfg) rev = script.get_revision(a) @@ -493,7 +564,10 @@ class VersionNameTemplateTest(TestBase): script = ScriptDirectory.from_config(self.cfg) a = util.rev_id() script.generate_revision(a, None, refresh=True) - write_script(script, a, """ + write_script( + script, + a, + """ down_revision = None from alembic import op @@ -506,7 +580,8 @@ class VersionNameTemplateTest(TestBase): def downgrade(): op.execute("DROP TABLE foo") - """) + """, + ) script = ScriptDirectory.from_config(self.cfg) rev = script.get_revision(a) @@ -520,8 +595,9 @@ class VersionNameTemplateTest(TestBase): script.generate_revision(a, "foobar", refresh=True) path = script.get_revision(a).path - with open(path, 'w') as fp: - fp.write(""" + with open(path, "w") as fp: + fp.write( + """ down_revision = None from alembic import op @@ -534,7 +610,8 @@ def upgrade(): def downgrade(): op.execute("DROP TABLE foo") -""") +""" + ) pyc_path = util.pyc_file_from_path(path) if pyc_path is not None and os.access(pyc_path, os.F_OK): os.unlink(pyc_path) @@ -544,7 +621,10 @@ def downgrade(): "Could not determine revision id from filename foobar_%s.py. " "Be sure the 'revision' variable is declared " "inside the script." % a, - Script._from_path, script, path) + Script._from_path, + script, + path, + ) class IgnoreFilesTest(TestBase): @@ -563,13 +643,11 @@ class IgnoreFilesTest(TestBase): command.revision(self.cfg, message="some rev") script = ScriptDirectory.from_config(self.cfg) path = os.path.join(script.versions, fname) - with open(path, 'w') as f: - f.write( - "crap, crap -> crap" - ) + with open(path, "w") as f: + f.write("crap, crap -> crap") command.revision(self.cfg, message="another rev") - script.get_revision('head') + script.get_revision("head") def _test_ignore_init_py(self, ext): """test that __init__.py is ignored.""" @@ -613,17 +691,16 @@ class SimpleSourcelessIgnoreFilesTest(IgnoreFilesTest): class NewFangledEnvOnlySourcelessIgnoreFilesTest(IgnoreFilesTest): sourceless = "pep3147_envonly" - __requires__ = "pep3147", + __requires__ = ("pep3147",) class NewFangledEverythingSourcelessIgnoreFilesTest(IgnoreFilesTest): sourceless = "pep3147_everything" - __requires__ = "pep3147", + __requires__ = ("pep3147",) class SourcelessNeedsFlagTest(TestBase): - def setUp(self): self.env = staging_env(sourceless=False) self.cfg = _sqlite_testing_config() @@ -636,7 +713,10 @@ class SourcelessNeedsFlagTest(TestBase): script = ScriptDirectory.from_config(self.cfg) script.generate_revision(a, None, refresh=True) - write_script(script, a, """ + write_script( + script, + a, + """ revision = '%s' down_revision = None @@ -650,7 +730,10 @@ class SourcelessNeedsFlagTest(TestBase): def downgrade(): op.execute("DROP TABLE foo") - """ % a, sourceless=True) + """ + % a, + sourceless=True, + ) script = ScriptDirectory.from_config(self.cfg) eq_(script.get_heads(), []) diff --git a/tests/test_script_production.py b/tests/test_script_production.py index af01a38..f7837d9 100644 --- a/tests/test_script_production.py +++ b/tests/test_script_production.py @@ -1,10 +1,20 @@ from alembic.testing.fixtures import TestBase from alembic.testing import eq_, ne_, assert_raises_message, is_, assertions -from alembic.testing.env import clear_staging_env, staging_env, \ - _get_staging_directory, _no_sql_testing_config, env_file_fixture, \ - script_file_fixture, _testing_config, _sqlite_testing_config, \ - three_rev_fixture, _multi_dir_testing_config, write_script,\ - _sqlite_file_db, _multidb_testing_config +from alembic.testing.env import ( + clear_staging_env, + staging_env, + _get_staging_directory, + _no_sql_testing_config, + env_file_fixture, + script_file_fixture, + _testing_config, + _sqlite_testing_config, + three_rev_fixture, + _multi_dir_testing_config, + write_script, + _sqlite_file_db, + _multidb_testing_config, +) from alembic import command from alembic.script import ScriptDirectory from alembic.environment import EnvironmentContext @@ -24,7 +34,6 @@ env, abc, def_ = None, None, None class GeneralOrderedTests(TestBase): - def setUp(self): global env env = staging_env() @@ -43,11 +52,8 @@ class GeneralOrderedTests(TestBase): self._test_008_long_name_configurable() def _test_001_environment(self): - assert_set = set(['env.py', 'script.py.mako', 'README']) - eq_( - assert_set.intersection(os.listdir(env.dir)), - assert_set - ) + assert_set = set(["env.py", "script.py.mako", "README"]) + eq_(assert_set.intersection(os.listdir(env.dir)), assert_set) def _test_002_rev_ids(self): global abc, def_ @@ -66,19 +72,23 @@ class GeneralOrderedTests(TestBase): eq_(script.revision, abc) eq_(script.down_revision, None) assert os.access( - os.path.join(env.dir, 'versions', - '%s_this_is_a_message.py' % abc), os.F_OK) + os.path.join(env.dir, "versions", "%s_this_is_a_message.py" % abc), + os.F_OK, + ) assert callable(script.module.upgrade) eq_(env.get_heads(), [abc]) eq_(env.get_base(), abc) def _test_005_nextrev(self): script = env.generate_revision( - def_, "this is the next rev", refresh=True) + def_, "this is the next rev", refresh=True + ) assert os.access( os.path.join( - env.dir, 'versions', - '%s_this_is_the_next_rev.py' % def_), os.F_OK) + env.dir, "versions", "%s_this_is_the_next_rev.py" % def_ + ), + os.F_OK, + ) eq_(script.revision, def_) eq_(script.down_revision, abc) eq_(env.get_revision(abc).nextrev, set([def_])) @@ -103,32 +113,42 @@ class GeneralOrderedTests(TestBase): def _test_007_long_name(self): rid = util.rev_id() - env.generate_revision(rid, - "this is a really long name with " - "lots of characters and also " - "I'd like it to\nhave\nnewlines") + env.generate_revision( + rid, + "this is a really long name with " + "lots of characters and also " + "I'd like it to\nhave\nnewlines", + ) assert os.access( os.path.join( - env.dir, 'versions', - '%s_this_is_a_really_long_name_with_lots_of_.py' % rid), - os.F_OK) + env.dir, + "versions", + "%s_this_is_a_really_long_name_with_lots_of_.py" % rid, + ), + os.F_OK, + ) def _test_008_long_name_configurable(self): env.truncate_slug_length = 60 rid = util.rev_id() - env.generate_revision(rid, - "this is a really long name with " - "lots of characters and also " - "I'd like it to\nhave\nnewlines") + env.generate_revision( + rid, + "this is a really long name with " + "lots of characters and also " + "I'd like it to\nhave\nnewlines", + ) assert os.access( - os.path.join(env.dir, 'versions', - '%s_this_is_a_really_long_name_with_lots_' - 'of_characters_and_also_.py' % rid), - os.F_OK) + os.path.join( + env.dir, + "versions", + "%s_this_is_a_really_long_name_with_lots_" + "of_characters_and_also_.py" % rid, + ), + os.F_OK, + ) class ScriptNamingTest(TestBase): - @classmethod def setup_class(cls): _testing_config() @@ -143,15 +163,17 @@ class ScriptNamingTest(TestBase): file_template="%(rev)s_%(slug)s_" "%(year)s_%(month)s_" "%(day)s_%(hour)s_" - "%(minute)s_%(second)s" + "%(minute)s_%(second)s", ) create_date = datetime.datetime(2012, 7, 25, 15, 8, 5) eq_( script._rev_path( - script.versions, "12345", "this is a message", create_date), + script.versions, "12345", "this is a message", create_date + ), os.path.abspath( "%s/versions/12345_this_is_a_" - "message_2012_7_25_15_8_5.py" % _get_staging_directory()) + "message_2012_7_25_15_8_5.py" % _get_staging_directory() + ), ) def _test_tz(self, timezone_arg, given, expected): @@ -161,61 +183,57 @@ class ScriptNamingTest(TestBase): "%(year)s_%(month)s_" "%(day)s_%(hour)s_" "%(minute)s_%(second)s", - timezone=timezone_arg + timezone=timezone_arg, ) with mock.patch( - "alembic.script.base.datetime", - mock.Mock( - datetime=mock.Mock( - utcnow=lambda: given, - now=lambda: given - ) - ) + "alembic.script.base.datetime", + mock.Mock( + datetime=mock.Mock(utcnow=lambda: given, now=lambda: given) + ), ): create_date = script._generate_create_date() - eq_( - create_date, - expected - ) + eq_(create_date, expected) def test_custom_tz(self): self._test_tz( - 'EST5EDT', + "EST5EDT", datetime.datetime(2012, 7, 25, 15, 8, 5), datetime.datetime( - 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz('EST5EDT')) + 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT") + ), ) def test_custom_tz_lowercase(self): self._test_tz( - 'est5edt', + "est5edt", datetime.datetime(2012, 7, 25, 15, 8, 5), datetime.datetime( - 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz('EST5EDT')) + 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT") + ), ) def test_custom_tz_utc(self): self._test_tz( - 'utc', + "utc", datetime.datetime(2012, 7, 25, 15, 8, 5), - datetime.datetime( - 2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz('UTC')) + datetime.datetime(2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz("UTC")), ) def test_custom_tzdata_tz(self): self._test_tz( - 'Europe/Berlin', + "Europe/Berlin", datetime.datetime(2012, 7, 25, 15, 8, 5), datetime.datetime( - 2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz('Europe/Berlin')) + 2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz("Europe/Berlin") + ), ) def test_default_tz(self): self._test_tz( None, datetime.datetime(2012, 7, 25, 15, 8, 5), - datetime.datetime(2012, 7, 25, 15, 8, 5) + datetime.datetime(2012, 7, 25, 15, 8, 5), ) def test_tz_cant_locate(self): @@ -225,7 +243,7 @@ class ScriptNamingTest(TestBase): self._test_tz, "fake", datetime.datetime(2012, 7, 25, 15, 8, 5), - datetime.datetime(2012, 7, 25, 15, 8, 5) + datetime.datetime(2012, 7, 25, 15, 8, 5), ) @@ -247,7 +265,8 @@ class RevisionCommandTest(TestBase): def test_create_script_splice(self): rev = command.revision( - self.cfg, message="some message", head=self.b, splice=True) + self.cfg, message="some message", head=self.b, splice=True + ) script = ScriptDirectory.from_config(self.cfg) rev = script.get_revision(rev.revision) eq_(rev.down_revision, self.b) @@ -260,7 +279,9 @@ class RevisionCommandTest(TestBase): "Revision %s is not a head revision; please specify --splice " "to create a new branch from this revision" % self.b, command.revision, - self.cfg, message="some message", head=self.b + self.cfg, + message="some message", + head=self.b, ) def test_illegal_revision_chars(self): @@ -269,19 +290,23 @@ class RevisionCommandTest(TestBase): r"Character\(s\) '-' not allowed in " "revision identifier 'no-dashes'", command.revision, - self.cfg, message="some message", rev_id="no-dashes" + self.cfg, + message="some message", + rev_id="no-dashes", ) assert not os.path.exists( - os.path.join( - self.env.dir, "versions", "no-dashes_some_message.py")) + os.path.join(self.env.dir, "versions", "no-dashes_some_message.py") + ) assert_raises_message( util.CommandError, r"Character\(s\) '@' not allowed in " "revision identifier 'no@atsigns'", command.revision, - self.cfg, message="some message", rev_id="no@atsigns" + self.cfg, + message="some message", + rev_id="no@atsigns", ) assert_raises_message( @@ -289,7 +314,9 @@ class RevisionCommandTest(TestBase): r"Character\(s\) '-, @' not allowed in revision " "identifier 'no@atsigns-ordashes'", command.revision, - self.cfg, message="some message", rev_id="no@atsigns-ordashes" + self.cfg, + message="some message", + rev_id="no@atsigns-ordashes", ) assert_raises_message( @@ -297,12 +324,15 @@ class RevisionCommandTest(TestBase): r"Character\(s\) '\+' not allowed in revision " r"identifier 'no\+plussignseither'", command.revision, - self.cfg, message="some message", rev_id="no+plussignseither" + self.cfg, + message="some message", + rev_id="no+plussignseither", ) def test_create_script_branches(self): rev = command.revision( - self.cfg, message="some message", branch_label="foobar") + self.cfg, message="some message", branch_label="foobar" + ) script = ScriptDirectory.from_config(self.cfg) rev = script.get_revision(rev.revision) eq_(script.get_revision("foobar"), rev) @@ -330,7 +360,9 @@ class RevisionCommandTest(TestBase): "upgraded your script.py.mako to include the 'branch_labels' " r"section\?", command.revision, - self.cfg, message="some message", branch_label="foobar" + self.cfg, + message="some message", + branch_label="foobar", ) @@ -350,11 +382,17 @@ class CustomizeRevisionTest(TestBase): (self.model3, "model3"), ]: script.generate_revision( - model, name, refresh=True, + model, + name, + refresh=True, version_path=os.path.join(_get_staging_directory(), name), - head="base") + head="base", + ) - write_script(script, model, """\ + write_script( + script, + model, + """\ "%s" revision = '%s' down_revision = None @@ -370,7 +408,9 @@ def upgrade(): def downgrade(): pass -""" % (name, model, name)) +""" + % (name, model, name), + ) def tearDown(self): clear_staging_env() @@ -385,13 +425,13 @@ def downgrade(): context.configure( connection=connection, target_metadata=target_metadata, - process_revision_directives=fn) + process_revision_directives=fn, + ) with context.begin_transaction(): context.run_migrations() return mock.patch( - "alembic.script.base.ScriptDirectory.run_env", - run_env + "alembic.script.base.ScriptDirectory.run_env", run_env ) def test_new_locations_no_autogen(self): @@ -404,24 +444,27 @@ def downgrade(): ops.UpgradeOps(), ops.DowngradeOps(), version_path=os.path.join( - _get_staging_directory(), "model1"), - head="model1@head" + _get_staging_directory(), "model1" + ), + head="model1@head", ), ops.MigrationScript( util.rev_id(), ops.UpgradeOps(), ops.DowngradeOps(), version_path=os.path.join( - _get_staging_directory(), "model2"), - head="model2@head" + _get_staging_directory(), "model2" + ), + head="model2@head", ), ops.MigrationScript( util.rev_id(), ops.UpgradeOps(), ops.DowngradeOps(), version_path=os.path.join( - _get_staging_directory(), "model3"), - head="model3@head" + _get_staging_directory(), "model3" + ), + head="model3@head", ), ] @@ -438,10 +481,13 @@ def downgrade(): rev_script = script.get_revision(rev.revision) eq_( rev_script.path, - os.path.abspath(os.path.join( - _get_staging_directory(), model, - "%s_.py" % (rev_script.revision, ) - )) + os.path.abspath( + os.path.join( + _get_staging_directory(), + model, + "%s_.py" % (rev_script.revision,), + ) + ), ) assert os.path.exists(rev_script.path) @@ -455,19 +501,23 @@ def downgrade(): with self._env_fixture(process_revision_directives, m): rev = command.revision( - self.cfg, message="some message", head="model1@head", sql=True) + self.cfg, message="some message", head="model1@head", sql=True + ) with mock.patch.object(rev.module, "op") as op_mock: rev.module.upgrade() eq_( op_mock.mock_calls, - [mock.call.create_index( - 'some_index', 'some_table', ['a', 'b'], unique=False)] + [ + mock.call.create_index( + "some_index", "some_table", ["a", "b"], unique=False + ) + ], ) def test_autogen(self): m = sa.MetaData() - sa.Table('t', m, sa.Column('x', sa.Integer)) + sa.Table("t", m, sa.Column("x", sa.Integer)) def process_revision_directives(context, rev, generate_revisions): existing_upgrades = generate_revisions[0].upgrade_ops @@ -483,17 +533,19 @@ def downgrade(): existing_upgrades, ops.DowngradeOps(), version_path=os.path.join( - _get_staging_directory(), "model1"), - head="model1@head" + _get_staging_directory(), "model1" + ), + head="model1@head", ), ops.MigrationScript( util.rev_id(), ops.UpgradeOps(ops=existing_downgrades.ops), ops.DowngradeOps(), version_path=os.path.join( - _get_staging_directory(), "model2"), - head="model2@head" - ) + _get_staging_directory(), "model2" + ), + head="model2@head", + ), ] with self._env_fixture(process_revision_directives, m): @@ -501,57 +553,57 @@ def downgrade(): eq_( Inspector.from_engine(self.engine).get_table_names(), - ["alembic_version"] + ["alembic_version"], ) command.revision( - self.cfg, message="some message", - autogenerate=True) + self.cfg, message="some message", autogenerate=True + ) command.upgrade(self.cfg, "model1@head") eq_( Inspector.from_engine(self.engine).get_table_names(), - ["alembic_version", "t"] + ["alembic_version", "t"], ) command.upgrade(self.cfg, "model2@head") eq_( Inspector.from_engine(self.engine).get_table_names(), - ["alembic_version"] + ["alembic_version"], ) def test_programmatic_command_option(self): - def process_revision_directives(context, rev, generate_revisions): generate_revisions[0].message = "test programatic" generate_revisions[0].upgrade_ops = ops.UpgradeOps( ops=[ ops.CreateTableOp( - 'test_table', + "test_table", [ - sa.Column('id', sa.Integer(), primary_key=True), - sa.Column('name', sa.String(50), nullable=False) - ] - ), + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("name", sa.String(50), nullable=False), + ], + ) ] ) generate_revisions[0].downgrade_ops = ops.DowngradeOps( - ops=[ - ops.DropTableOp('test_table') - ] + ops=[ops.DropTableOp("test_table")] ) with self._env_fixture(None, None): rev = command.revision( self.cfg, head="model1@head", - process_revision_directives=process_revision_directives) + process_revision_directives=process_revision_directives, + ) with open(rev.path) as handle: result = handle.read() - assert (""" + assert ( + ( + """ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('test_table', @@ -560,22 +612,19 @@ def upgrade(): sa.PrimaryKeyConstraint('id') ) # ### end Alembic commands ### -""") in result +""" + ) + in result + ) class ScriptAccessorTest(TestBase): def test_upgrade_downgrade_ops_list_accessors(self): u1 = ops.UpgradeOps(ops=[]) d1 = ops.DowngradeOps(ops=[]) - m1 = ops.MigrationScript( - "somerev", u1, d1 - ) - is_( - m1.upgrade_ops, u1 - ) - is_( - m1.downgrade_ops, d1 - ) + m1 = ops.MigrationScript("somerev", u1, d1) + is_(m1.upgrade_ops, u1) + is_(m1.downgrade_ops, d1) u2 = ops.UpgradeOps(ops=[]) d2 = ops.DowngradeOps(ops=[]) m1._upgrade_ops.append(u2) @@ -585,13 +634,17 @@ class ScriptAccessorTest(TestBase): ValueError, "This MigrationScript instance has a multiple-entry list for " "UpgradeOps; please use the upgrade_ops_list attribute.", - getattr, m1, "upgrade_ops" + getattr, + m1, + "upgrade_ops", ) assert_raises_message( ValueError, "This MigrationScript instance has a multiple-entry list for " "DowngradeOps; please use the downgrade_ops_list attribute.", - getattr, m1, "downgrade_ops" + getattr, + m1, + "downgrade_ops", ) eq_(m1.upgrade_ops_list, [u1, u2]) eq_(m1.downgrade_ops_list, [d1, d2]) @@ -615,39 +668,36 @@ class ImportsTest(TestBase): context.configure( connection=connection, target_metadata=target_metadata, - **kw) + **kw + ) with context.begin_transaction(): context.run_migrations() return mock.patch( - "alembic.script.base.ScriptDirectory.run_env", - run_env + "alembic.script.base.ScriptDirectory.run_env", run_env ) def test_imports_in_script(self): from sqlalchemy import MetaData, Table, Column from sqlalchemy.dialects.mysql import VARCHAR - type_ = VARCHAR(20, charset='utf8', national=True) + type_ = VARCHAR(20, charset="utf8", national=True) m = MetaData() - Table( - 't', m, - Column('x', type_) - ) + Table("t", m, Column("x", type_)) def process_revision_directives(context, rev, generate_revisions): generate_revisions[0].imports.add( - "from sqlalchemy.dialects.mysql import TINYINT") + "from sqlalchemy.dialects.mysql import TINYINT" + ) with self._env_fixture( - m, - process_revision_directives=process_revision_directives + m, process_revision_directives=process_revision_directives ): rev = command.revision( - self.cfg, message="some message", - autogenerate=True) + self.cfg, message="some message", autogenerate=True + ) with open(rev.path) as file_: contents = file_.read() @@ -659,24 +709,24 @@ class MultiContextTest(TestBase): """test the multidb template for autogenerate front-to-back""" def setUp(self): - self.engine1 = _sqlite_file_db(tempname='eng1.db') - self.engine2 = _sqlite_file_db(tempname='eng2.db') - self.engine3 = _sqlite_file_db(tempname='eng3.db') + self.engine1 = _sqlite_file_db(tempname="eng1.db") + self.engine2 = _sqlite_file_db(tempname="eng2.db") + self.engine3 = _sqlite_file_db(tempname="eng3.db") self.env = staging_env(template="multidb") - self.cfg = _multidb_testing_config({ - "engine1": self.engine1, - "engine2": self.engine2, - "engine3": self.engine3 - }) + self.cfg = _multidb_testing_config( + { + "engine1": self.engine1, + "engine2": self.engine2, + "engine3": self.engine3, + } + ) def _write_metadata(self, meta): - path = os.path.join(_get_staging_directory(), 'scripts', 'env.py') + path = os.path.join(_get_staging_directory(), "scripts", "env.py") with open(path) as env_: existing_env = env_.read() - existing_env = existing_env.replace( - "target_metadata = {}", - meta) + existing_env = existing_env.replace("target_metadata = {}", meta) with open(path, "w") as env_: env_.write(existing_env) @@ -701,40 +751,30 @@ sa.Table('e3t1', m3, sa.Column('z', sa.Integer)) ) rev = command.revision( - self.cfg, message="some message", - autogenerate=True + self.cfg, message="some message", autogenerate=True ) with mock.patch.object(rev.module, "op") as op_mock: rev.module.upgrade_engine1() eq_( op_mock.mock_calls[-1], - mock.call.create_table('e1t1', mock.ANY) + mock.call.create_table("e1t1", mock.ANY), ) rev.module.upgrade_engine2() eq_( op_mock.mock_calls[-1], - mock.call.create_table('e2t1', mock.ANY) + mock.call.create_table("e2t1", mock.ANY), ) rev.module.upgrade_engine3() eq_( op_mock.mock_calls[-1], - mock.call.create_table('e3t1', mock.ANY) + mock.call.create_table("e3t1", mock.ANY), ) rev.module.downgrade_engine1() - eq_( - op_mock.mock_calls[-1], - mock.call.drop_table('e1t1') - ) + eq_(op_mock.mock_calls[-1], mock.call.drop_table("e1t1")) rev.module.downgrade_engine2() - eq_( - op_mock.mock_calls[-1], - mock.call.drop_table('e2t1') - ) + eq_(op_mock.mock_calls[-1], mock.call.drop_table("e2t1")) rev.module.downgrade_engine3() - eq_( - op_mock.mock_calls[-1], - mock.call.drop_table('e3t1') - ) + eq_(op_mock.mock_calls[-1], mock.call.drop_table("e3t1")) class RewriterTest(TestBase): @@ -744,20 +784,13 @@ class RewriterTest(TestBase): mocker = mock.Mock(side_effect=lambda context, revision, op: op) writer.rewrites(ops.MigrateOperation)(mocker) - addcolop = ops.AddColumnOp( - 't1', sa.Column('x', sa.Integer()) - ) + addcolop = ops.AddColumnOp("t1", sa.Column("x", sa.Integer())) directives = [ ops.MigrationScript( util.rev_id(), - ops.UpgradeOps(ops=[ - ops.ModifyTableOps('t1', ops=[ - addcolop - ]) - ]), - ops.DowngradeOps(ops=[ - ]), + ops.UpgradeOps(ops=[ops.ModifyTableOps("t1", ops=[addcolop])]), + ops.DowngradeOps(ops=[]), ) ] @@ -771,7 +804,7 @@ class RewriterTest(TestBase): mock.call(ctx, rev, directives[0].upgrade_ops.ops[0]), mock.call(ctx, rev, addcolop), mock.call(ctx, rev, directives[0].downgrade_ops), - ] + ], ) def test_double_migrate_table(self): @@ -783,28 +816,33 @@ class RewriterTest(TestBase): def second_table(context, revision, op): return [ op, - ops.ModifyTableOps('t2', ops=[ - ops.AddColumnOp('t2', sa.Column('x', sa.Integer())) - ]) + ops.ModifyTableOps( + "t2", + ops=[ops.AddColumnOp("t2", sa.Column("x", sa.Integer()))], + ), ] @writer.rewrites(ops.AddColumnOp) def add_column(context, revision, op): - idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name]) + idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name]) idx_ops.append(idx_op) - return [ - op, - idx_op - ] + return [op, idx_op] directives = [ ops.MigrationScript( util.rev_id(), - ops.UpgradeOps(ops=[ - ops.ModifyTableOps('t1', ops=[ - ops.AddColumnOp('t1', sa.Column('x', sa.Integer())) - ]) - ]), + ops.UpgradeOps( + ops=[ + ops.ModifyTableOps( + "t1", + ops=[ + ops.AddColumnOp( + "t1", sa.Column("x", sa.Integer()) + ) + ], + ) + ] + ), ops.DowngradeOps(ops=[]), ) ] @@ -812,17 +850,10 @@ class RewriterTest(TestBase): ctx, rev = mock.Mock(), mock.Mock() writer(ctx, rev, directives) eq_( - [d.table_name for d in directives[0].upgrade_ops.ops], - ['t1', 't2'] - ) - is_( - directives[0].upgrade_ops.ops[0].ops[1], - idx_ops[0] - ) - is_( - directives[0].upgrade_ops.ops[1].ops[1], - idx_ops[1] + [d.table_name for d in directives[0].upgrade_ops.ops], ["t1", "t2"] ) + is_(directives[0].upgrade_ops.ops[0].ops[1], idx_ops[0]) + is_(directives[0].upgrade_ops.ops[1].ops[1], idx_ops[1]) def test_chained_ops(self): writer1 = autogenerate.Rewriter() @@ -841,26 +872,32 @@ class RewriterTest(TestBase): op.column.name, modify_nullable=False, existing_type=op.column.type, - ) + ), ] @writer2.rewrites(ops.AddColumnOp) def add_column_idx(context, revision, op): - idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name]) - return [ - op, - idx_op - ] + idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name]) + return [op, idx_op] directives = [ ops.MigrationScript( util.rev_id(), - ops.UpgradeOps(ops=[ - ops.ModifyTableOps('t1', ops=[ - ops.AddColumnOp( - 't1', sa.Column('x', sa.Integer(), nullable=False)) - ]) - ]), + ops.UpgradeOps( + ops=[ + ops.ModifyTableOps( + "t1", + ops=[ + ops.AddColumnOp( + "t1", + sa.Column( + "x", sa.Integer(), nullable=False + ), + ) + ], + ) + ] + ), ops.DowngradeOps(ops=[]), ) ] @@ -877,7 +914,7 @@ class RewriterTest(TestBase): " op.alter_column('t1', 'x',\n" " existing_type=sa.Integer(),\n" " nullable=False)\n" - " # ### end Alembic commands ###" + " # ### end Alembic commands ###", ) @@ -894,7 +931,9 @@ class MultiDirRevisionCommandTest(TestBase): util.CommandError, "Multiple version locations present, please specify " "--version-path", - command.revision, self.cfg, message="some message" + command.revision, + self.cfg, + message="some message", ) def test_multiple_dir_no_bases_invalid_version_path(self): @@ -902,40 +941,46 @@ class MultiDirRevisionCommandTest(TestBase): util.CommandError, "Path foo/bar/ is not represented in current version locations", command.revision, - self.cfg, message="x", - version_path=os.path.join("foo/bar/") + self.cfg, + message="x", + version_path=os.path.join("foo/bar/"), ) def test_multiple_dir_no_bases_version_path(self): script = command.revision( - self.cfg, message="x", - version_path=os.path.join(_get_staging_directory(), "model1")) + self.cfg, + message="x", + version_path=os.path.join(_get_staging_directory(), "model1"), + ) assert os.access(script.path, os.F_OK) def test_multiple_dir_chooses_base(self): command.revision( - self.cfg, message="x", + self.cfg, + message="x", head="base", - version_path=os.path.join(_get_staging_directory(), "model1")) + version_path=os.path.join(_get_staging_directory(), "model1"), + ) script2 = command.revision( - self.cfg, message="y", + self.cfg, + message="y", head="base", - version_path=os.path.join(_get_staging_directory(), "model2")) + version_path=os.path.join(_get_staging_directory(), "model2"), + ) script3 = command.revision( - self.cfg, message="y2", - head=script2.revision) + self.cfg, message="y2", head=script2.revision + ) eq_( os.path.dirname(script3.path), - os.path.abspath(os.path.join(_get_staging_directory(), "model2")) + os.path.abspath(os.path.join(_get_staging_directory(), "model2")), ) assert os.access(script3.path, os.F_OK) class TemplateArgsTest(TestBase): - def setUp(self): staging_env() self.cfg = _no_sql_testing_config( @@ -949,43 +994,45 @@ class TemplateArgsTest(TestBase): config = _no_sql_testing_config() script = ScriptDirectory.from_config(config) template_args = {"x": "x1", "y": "y1", "z": "z1"} - env = EnvironmentContext( - config, - script, - template_args=template_args - ) - env.configure(dialect_name="sqlite", - template_args={"y": "y2", "q": "q1"}) - eq_( - template_args, - {"x": "x1", "y": "y2", "z": "z1", "q": "q1"} + env = EnvironmentContext(config, script, template_args=template_args) + env.configure( + dialect_name="sqlite", template_args={"y": "y2", "q": "q1"} ) + eq_(template_args, {"x": "x1", "y": "y2", "z": "z1", "q": "q1"}) def test_tmpl_args_revision(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"}) -""") - script_file_fixture(""" +""" + ) + script_file_fixture( + """ # somearg: ${somearg} revision = ${repr(up_revision)} down_revision = ${repr(down_revision)} -""") +""" + ) command.revision(self.cfg, message="some rev") script = ScriptDirectory.from_config(self.cfg) - rev = script.get_revision('head') + rev = script.get_revision("head") with open(rev.path) as f: text = f.read() assert "somearg: somevalue" in text def test_bad_render(self): - env_file_fixture(""" + env_file_fixture( + """ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"}) -""") - script_file_fixture(""" +""" + ) + script_file_fixture( + """ <% z = x + y %> -""") +""" + ) try: command.revision(self.cfg, message="some rev") @@ -993,7 +1040,7 @@ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"}) m = re.match( r"^Template rendering failed; see (.+?) " "for a template-oriented", - str(ce) + str(ce), ) assert m, "Command error did not produce a file" with open(m.group(1)) as handle: @@ -1003,13 +1050,12 @@ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"}) class DuplicateVersionLocationsTest(TestBase): - def setUp(self): self.env = staging_env() self.cfg = _multi_dir_testing_config( # this is a duplicate of one of the paths # already present in this fixture - extra_version_location='%(here)s/model1' + extra_version_location="%(here)s/model1" ) script = ScriptDirectory.from_config(self.cfg) @@ -1022,10 +1068,16 @@ class DuplicateVersionLocationsTest(TestBase): (self.model3, "model3"), ]: script.generate_revision( - model, name, refresh=True, + model, + name, + refresh=True, version_path=os.path.join(_get_staging_directory(), name), - head="base") - write_script(script, model, """\ + head="base", + ) + write_script( + script, + model, + """\ "%s" revision = '%s' down_revision = None @@ -1041,7 +1093,9 @@ def upgrade(): def downgrade(): pass -""" % (name, model, name)) +""" + % (name, model, name), + ) def tearDown(self): clear_staging_env() @@ -1049,16 +1103,20 @@ def downgrade(): def test_env_emits_warning(self): with assertions.expect_warnings( "File %s loaded twice! ignoring. " - "Please ensure version_locations is unique" % ( - os.path.realpath(os.path.join( - _get_staging_directory(), - "model1", - "%s_model1.py" % self.model1 - ))) + "Please ensure version_locations is unique" + % ( + os.path.realpath( + os.path.join( + _get_staging_directory(), + "model1", + "%s_model1.py" % self.model1, + ) + ) + ) ): script = ScriptDirectory.from_config(self.cfg) script.revision_map.heads eq_( [rev.revision for rev in script.walk_revisions()], - [self.model1, self.model2, self.model3] + [self.model1, self.model2, self.model3], ) diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 75972d4..6718be9 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -7,34 +7,29 @@ from alembic.testing.fixtures import TestBase class SQLiteTest(TestBase): - def test_add_column(self): - context = op_fixture('sqlite') - op.add_column('t1', Column('c1', Integer)) - context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 INTEGER' - ) + context = op_fixture("sqlite") + op.add_column("t1", Column("c1", Integer)) + context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER") def test_add_column_implicit_constraint(self): - context = op_fixture('sqlite') - op.add_column('t1', Column('c1', Boolean)) - context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN' - ) + context = op_fixture("sqlite") + op.add_column("t1", Column("c1", Boolean)) + context.assert_("ALTER TABLE t1 ADD COLUMN c1 BOOLEAN") def test_add_explicit_constraint(self): - op_fixture('sqlite') + op_fixture("sqlite") assert_raises_message( NotImplementedError, "No support for ALTER of constraints in SQLite dialect", op.create_check_constraint, "foo", "sometable", - column('name') > 5 + column("name") > 5, ) def test_drop_explicit_constraint(self): - op_fixture('sqlite') + op_fixture("sqlite") assert_raises_message( NotImplementedError, "No support for ALTER of constraints in SQLite dialect", diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 29530c0..0a545cf 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -8,24 +8,22 @@ from alembic import migration from alembic.util import CommandError -version_table = Table('version_table', MetaData(), - Column('version_num', String(32), nullable=False)) +version_table = Table( + "version_table", + MetaData(), + Column("version_num", String(32), nullable=False), +) def _up(from_, to_, branch_presence_changed=False): - return migration.StampStep( - from_, to_, True, branch_presence_changed - ) + return migration.StampStep(from_, to_, True, branch_presence_changed) def _down(from_, to_, branch_presence_changed=False): - return migration.StampStep( - from_, to_, False, branch_presence_changed - ) + return migration.StampStep(from_, to_, False, branch_presence_changed) class TestMigrationContext(TestBase): - @classmethod def setup_class(cls): cls.bind = config.db @@ -48,112 +46,126 @@ class TestMigrationContext(TestBase): if len(rows) == 0: return None eq_(len(rows), 1) - return rows[0]['version_num'] + return rows[0]["version_num"] def test_config_default_version_table_name(self): - context = self.make_one(dialect_name='sqlite') - eq_(context._version.name, 'alembic_version') + context = self.make_one(dialect_name="sqlite") + eq_(context._version.name, "alembic_version") def test_config_explicit_version_table_name(self): - context = self.make_one(dialect_name='sqlite', - opts={'version_table': 'explicit'}) - eq_(context._version.name, 'explicit') - eq_(context._version.primary_key.name, 'explicit_pkc') + context = self.make_one( + dialect_name="sqlite", opts={"version_table": "explicit"} + ) + eq_(context._version.name, "explicit") + eq_(context._version.primary_key.name, "explicit_pkc") def test_config_explicit_version_table_schema(self): - context = self.make_one(dialect_name='sqlite', - opts={'version_table_schema': 'explicit'}) - eq_(context._version.schema, 'explicit') + context = self.make_one( + dialect_name="sqlite", opts={"version_table_schema": "explicit"} + ) + eq_(context._version.schema, "explicit") def test_config_explicit_no_pk(self): - context = self.make_one(dialect_name='sqlite', - opts={'version_table_pk': False}) + context = self.make_one( + dialect_name="sqlite", opts={"version_table_pk": False} + ) eq_(len(context._version.primary_key), 0) def test_config_explicit_w_pk(self): - context = self.make_one(dialect_name='sqlite', - opts={'version_table_pk': True}) + context = self.make_one( + dialect_name="sqlite", opts={"version_table_pk": True} + ) eq_(len(context._version.primary_key), 1) eq_(context._version.primary_key.name, "alembic_version_pkc") def test_get_current_revision_doesnt_create_version_table(self): - context = self.make_one(connection=self.connection, - opts={'version_table': 'version_table'}) + context = self.make_one( + connection=self.connection, opts={"version_table": "version_table"} + ) eq_(context.get_current_revision(), None) insp = Inspector(self.connection) - assert ('version_table' not in insp.get_table_names()) + assert "version_table" not in insp.get_table_names() def test_get_current_revision(self): - context = self.make_one(connection=self.connection, - opts={'version_table': 'version_table'}) + context = self.make_one( + connection=self.connection, opts={"version_table": "version_table"} + ) version_table.create(self.connection) eq_(context.get_current_revision(), None) self.connection.execute( - version_table.insert().values(version_num='revid')) - eq_(context.get_current_revision(), 'revid') + version_table.insert().values(version_num="revid") + ) + eq_(context.get_current_revision(), "revid") def test_get_current_revision_error_if_starting_rev_given_online(self): - context = self.make_one(connection=self.connection, - opts={'starting_rev': 'boo'}) - assert_raises( - CommandError, - context.get_current_revision + context = self.make_one( + connection=self.connection, opts={"starting_rev": "boo"} ) + assert_raises(CommandError, context.get_current_revision) def test_get_current_revision_offline(self): - context = self.make_one(dialect_name='sqlite', - opts={'starting_rev': 'startrev', - 'as_sql': True}) - eq_(context.get_current_revision(), 'startrev') + context = self.make_one( + dialect_name="sqlite", + opts={"starting_rev": "startrev", "as_sql": True}, + ) + eq_(context.get_current_revision(), "startrev") def test_get_current_revision_multiple_heads(self): version_table.create(self.connection) - context = self.make_one(connection=self.connection, - opts={'version_table': 'version_table'}) + context = self.make_one( + connection=self.connection, opts={"version_table": "version_table"} + ) updater = migration.HeadMaintainer(context, ()) - updater.update_to_step(_up(None, 'a', True)) - updater.update_to_step(_up(None, 'b', True)) + updater.update_to_step(_up(None, "a", True)) + updater.update_to_step(_up(None, "b", True)) assert_raises_message( CommandError, "Version table 'version_table' has more than one head present; " "please use get_current_heads()", - context.get_current_revision + context.get_current_revision, ) def test_get_heads(self): version_table.create(self.connection) - context = self.make_one(connection=self.connection, - opts={'version_table': 'version_table'}) + context = self.make_one( + connection=self.connection, opts={"version_table": "version_table"} + ) updater = migration.HeadMaintainer(context, ()) - updater.update_to_step(_up(None, 'a', True)) - updater.update_to_step(_up(None, 'b', True)) - eq_(context.get_current_heads(), ('a', 'b')) + updater.update_to_step(_up(None, "a", True)) + updater.update_to_step(_up(None, "b", True)) + eq_(context.get_current_heads(), ("a", "b")) def test_get_heads_offline(self): version_table.create(self.connection) - context = self.make_one(connection=self.connection, - opts={ - 'starting_rev': 'q', - 'version_table': 'version_table', - 'as_sql': True}) - eq_(context.get_current_heads(), ('q', )) + context = self.make_one( + connection=self.connection, + opts={ + "starting_rev": "q", + "version_table": "version_table", + "as_sql": True, + }, + ) + eq_(context.get_current_heads(), ("q",)) def test_stamp_api_creates_table(self): context = self.make_one(connection=self.connection) assert ( - 'alembic_version' - not in Inspector(self.connection).get_table_names()) + "alembic_version" + not in Inspector(self.connection).get_table_names() + ) - script = mock.Mock(_stamp_revs=lambda revision, heads: [ - _up(None, 'a', True), - _up(None, 'b', True) - ]) + script = mock.Mock( + _stamp_revs=lambda revision, heads: [ + _up(None, "a", True), + _up(None, "b", True), + ] + ) - context.stamp(script, 'b') - eq_(context.get_current_heads(), ('a', 'b')) + context.stamp(script, "b") + eq_(context.get_current_heads(), ("a", "b")) assert ( - 'alembic_version' - in Inspector(self.connection).get_table_names()) + "alembic_version" in Inspector(self.connection).get_table_names() + ) class UpdateRevTest(TestBase): @@ -166,8 +178,8 @@ class UpdateRevTest(TestBase): def setUp(self): self.connection = self.bind.connect() self.context = migration.MigrationContext.configure( - connection=self.connection, - opts={"version_table": "version_table"}) + connection=self.connection, opts={"version_table": "version_table"} + ) version_table.create(self.connection) self.updater = migration.HeadMaintainer(self.context, ()) @@ -180,105 +192,108 @@ class UpdateRevTest(TestBase): eq_(self.updater.heads, set(heads)) def test_update_none_to_single(self): - self.updater.update_to_step(_up(None, 'a', True)) - self._assert_heads(('a',)) + self.updater.update_to_step(_up(None, "a", True)) + self._assert_heads(("a",)) def test_update_single_to_single(self): - self.updater.update_to_step(_up(None, 'a', True)) - self.updater.update_to_step(_up('a', 'b')) - self._assert_heads(('b',)) + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_up("a", "b")) + self._assert_heads(("b",)) def test_update_single_to_none(self): - self.updater.update_to_step(_up(None, 'a', True)) - self.updater.update_to_step(_down('a', None, True)) + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_down("a", None, True)) self._assert_heads(()) def test_add_branches(self): - self.updater.update_to_step(_up(None, 'a', True)) - self.updater.update_to_step(_up('a', 'b')) - self.updater.update_to_step(_up(None, 'c', True)) - self._assert_heads(('b', 'c')) - self.updater.update_to_step(_up('c', 'd')) - self.updater.update_to_step(_up('d', 'e1')) - self.updater.update_to_step(_up('d', 'e2', True)) - self._assert_heads(('b', 'e1', 'e2')) + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_up("a", "b")) + self.updater.update_to_step(_up(None, "c", True)) + self._assert_heads(("b", "c")) + self.updater.update_to_step(_up("c", "d")) + self.updater.update_to_step(_up("d", "e1")) + self.updater.update_to_step(_up("d", "e2", True)) + self._assert_heads(("b", "e1", "e2")) def test_teardown_branches(self): - self.updater.update_to_step(_up(None, 'd1', True)) - self.updater.update_to_step(_up(None, 'd2', True)) - self._assert_heads(('d1', 'd2')) + self.updater.update_to_step(_up(None, "d1", True)) + self.updater.update_to_step(_up(None, "d2", True)) + self._assert_heads(("d1", "d2")) - self.updater.update_to_step(_down('d1', 'c')) - self._assert_heads(('c', 'd2')) + self.updater.update_to_step(_down("d1", "c")) + self._assert_heads(("c", "d2")) - self.updater.update_to_step(_down('d2', 'c', True)) + self.updater.update_to_step(_down("d2", "c", True)) - self._assert_heads(('c',)) - self.updater.update_to_step(_down('c', 'b')) - self._assert_heads(('b',)) + self._assert_heads(("c",)) + self.updater.update_to_step(_down("c", "b")) + self._assert_heads(("b",)) def test_resolve_merges(self): - self.updater.update_to_step(_up(None, 'a', True)) - self.updater.update_to_step(_up('a', 'b')) - self.updater.update_to_step(_up('b', 'c1')) - self.updater.update_to_step(_up('b', 'c2', True)) - self.updater.update_to_step(_up('c1', 'd1')) - self.updater.update_to_step(_up('c2', 'd2')) - self._assert_heads(('d1', 'd2')) - self.updater.update_to_step(_up(('d1', 'd2'), 'e')) - self._assert_heads(('e',)) + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_up("a", "b")) + self.updater.update_to_step(_up("b", "c1")) + self.updater.update_to_step(_up("b", "c2", True)) + self.updater.update_to_step(_up("c1", "d1")) + self.updater.update_to_step(_up("c2", "d2")) + self._assert_heads(("d1", "d2")) + self.updater.update_to_step(_up(("d1", "d2"), "e")) + self._assert_heads(("e",)) def test_unresolve_merges(self): - self.updater.update_to_step(_up(None, 'e', True)) + self.updater.update_to_step(_up(None, "e", True)) - self.updater.update_to_step(_down('e', ('d1', 'd2'))) - self._assert_heads(('d2', 'd1')) + self.updater.update_to_step(_down("e", ("d1", "d2"))) + self._assert_heads(("d2", "d1")) - self.updater.update_to_step(_down('d2', 'c2')) - self._assert_heads(('c2', 'd1')) + self.updater.update_to_step(_down("d2", "c2")) + self._assert_heads(("c2", "d1")) def test_update_no_match(self): - self.updater.update_to_step(_up(None, 'a', True)) - self.updater.heads.add('x') + self.updater.update_to_step(_up(None, "a", True)) + self.updater.heads.add("x") assert_raises_message( CommandError, "Online migration expected to match one row when updating " "'x' to 'b' in 'version_table'; 0 found", - self.updater.update_to_step, _up('x', 'b') + self.updater.update_to_step, + _up("x", "b"), ) def test_update_multi_match(self): - self.connection.execute(version_table.insert(), version_num='a') - self.connection.execute(version_table.insert(), version_num='a') + self.connection.execute(version_table.insert(), version_num="a") + self.connection.execute(version_table.insert(), version_num="a") - self.updater.heads.add('a') + self.updater.heads.add("a") assert_raises_message( CommandError, "Online migration expected to match one row when updating " "'a' to 'b' in 'version_table'; 2 found", - self.updater.update_to_step, _up('a', 'b') + self.updater.update_to_step, + _up("a", "b"), ) def test_delete_no_match(self): - self.updater.update_to_step(_up(None, 'a', True)) + self.updater.update_to_step(_up(None, "a", True)) - self.updater.heads.add('x') + self.updater.heads.add("x") assert_raises_message( CommandError, "Online migration expected to match one row when " "deleting 'x' in 'version_table'; 0 found", - self.updater.update_to_step, _down('x', None, True) + self.updater.update_to_step, + _down("x", None, True), ) def test_delete_multi_match(self): - self.connection.execute(version_table.insert(), version_num='a') - self.connection.execute(version_table.insert(), version_num='a') + self.connection.execute(version_table.insert(), version_num="a") + self.connection.execute(version_table.insert(), version_num="a") - self.updater.heads.add('a') + self.updater.heads.add("a") assert_raises_message( CommandError, "Online migration expected to match one row when " "deleting 'a' in 'version_table'; 2 found", - self.updater.update_to_step, _down('a', None, True) + self.updater.update_to_step, + _down("a", None, True), ) - diff --git a/tests/test_version_traversal.py b/tests/test_version_traversal.py index f69a9bd..c32c0c8 100644 --- a/tests/test_version_traversal.py +++ b/tests/test_version_traversal.py @@ -7,20 +7,15 @@ from alembic.migration import MigrationStep, HeadMaintainer class MigrationTest(TestBase): - def up_(self, rev): - return MigrationStep.upgrade_from_script( - self.env.revision_map, rev) + return MigrationStep.upgrade_from_script(self.env.revision_map, rev) def down_(self, rev): - return MigrationStep.downgrade_from_script( - self.env.revision_map, rev) + return MigrationStep.downgrade_from_script(self.env.revision_map, rev) def _assert_downgrade(self, destination, source, expected, expected_heads): revs = self.env._downgrade_revs(destination, source) - eq_( - revs, expected - ) + eq_(revs, expected) heads = set(util.to_tuple(source, default=())) head = HeadMaintainer(mock.Mock(), heads) for rev in revs: @@ -29,9 +24,7 @@ class MigrationTest(TestBase): def _assert_upgrade(self, destination, source, expected, expected_heads): revs = self.env._upgrade_revs(destination, source) - eq_( - revs, expected - ) + eq_(revs, expected) heads = set(util.to_tuple(source, default=())) head = HeadMaintainer(mock.Mock(), heads) for rev in revs: @@ -40,15 +33,14 @@ class MigrationTest(TestBase): class RevisionPathTest(MigrationTest): - @classmethod def setup_class(cls): cls.env = env = staging_env() - cls.a = env.generate_revision(util.rev_id(), '->a') - cls.b = env.generate_revision(util.rev_id(), 'a->b') - cls.c = env.generate_revision(util.rev_id(), 'b->c') - cls.d = env.generate_revision(util.rev_id(), 'c->d') - cls.e = env.generate_revision(util.rev_id(), 'd->e') + cls.a = env.generate_revision(util.rev_id(), "->a") + cls.b = env.generate_revision(util.rev_id(), "a->b") + cls.c = env.generate_revision(util.rev_id(), "b->c") + cls.d = env.generate_revision(util.rev_id(), "c->d") + cls.e = env.generate_revision(util.rev_id(), "d->e") @classmethod def teardown_class(cls): @@ -56,58 +48,50 @@ class RevisionPathTest(MigrationTest): def test_upgrade_path(self): self._assert_upgrade( - self.e.revision, self.c.revision, - [ - self.up_(self.d), - self.up_(self.e) - ], - set([self.e.revision]) + self.e.revision, + self.c.revision, + [self.up_(self.d), self.up_(self.e)], + set([self.e.revision]), ) self._assert_upgrade( - self.c.revision, None, - [ - self.up_(self.a), - self.up_(self.b), - self.up_(self.c), - ], - set([self.c.revision]) + self.c.revision, + None, + [self.up_(self.a), self.up_(self.b), self.up_(self.c)], + set([self.c.revision]), ) def test_relative_upgrade_path(self): self._assert_upgrade( - "+2", self.a.revision, - [ - self.up_(self.b), - self.up_(self.c), - ], - set([self.c.revision]) + "+2", + self.a.revision, + [self.up_(self.b), self.up_(self.c)], + set([self.c.revision]), ) self._assert_upgrade( - "+1", self.a.revision, - [ - self.up_(self.b) - ], - set([self.b.revision]) + "+1", self.a.revision, [self.up_(self.b)], set([self.b.revision]) ) self._assert_upgrade( - "+3", self.b.revision, + "+3", + self.b.revision, [self.up_(self.c), self.up_(self.d), self.up_(self.e)], - set([self.e.revision]) + set([self.e.revision]), ) self._assert_upgrade( - "%s+2" % self.b.revision, self.a.revision, + "%s+2" % self.b.revision, + self.a.revision, [self.up_(self.b), self.up_(self.c), self.up_(self.d)], - set([self.d.revision]) + set([self.d.revision]), ) self._assert_upgrade( - "%s-2" % self.d.revision, self.a.revision, + "%s-2" % self.d.revision, + self.a.revision, [self.up_(self.b)], - set([self.b.revision]) + set([self.b.revision]), ) def test_invalid_relative_upgrade_path(self): @@ -115,53 +99,60 @@ class RevisionPathTest(MigrationTest): assert_raises_message( util.CommandError, "Relative revision -2 didn't produce 2 migrations", - self.env._upgrade_revs, "-2", self.b.revision + self.env._upgrade_revs, + "-2", + self.b.revision, ) assert_raises_message( util.CommandError, r"Relative revision \+5 didn't produce 5 migrations", - self.env._upgrade_revs, "+5", self.b.revision + self.env._upgrade_revs, + "+5", + self.b.revision, ) def test_downgrade_path(self): self._assert_downgrade( - self.c.revision, self.e.revision, + self.c.revision, + self.e.revision, [self.down_(self.e), self.down_(self.d)], - set([self.c.revision]) + set([self.c.revision]), ) self._assert_downgrade( - None, self.c.revision, + None, + self.c.revision, [self.down_(self.c), self.down_(self.b), self.down_(self.a)], - set() + set(), ) def test_relative_downgrade_path(self): self._assert_downgrade( - "-1", self.c.revision, - [self.down_(self.c)], - set([self.b.revision]) + "-1", self.c.revision, [self.down_(self.c)], set([self.b.revision]) ) self._assert_downgrade( - "-3", self.e.revision, + "-3", + self.e.revision, [self.down_(self.e), self.down_(self.d), self.down_(self.c)], - set([self.b.revision]) + set([self.b.revision]), ) self._assert_downgrade( - "%s+2" % self.a.revision, self.d.revision, + "%s+2" % self.a.revision, + self.d.revision, [self.down_(self.d)], - set([self.c.revision]) + set([self.c.revision]), ) self._assert_downgrade( - "%s-2" % self.c.revision, self.d.revision, + "%s-2" % self.c.revision, + self.d.revision, [self.down_(self.d), self.down_(self.c), self.down_(self.b)], - set([self.a.revision]) + set([self.a.revision]), ) def test_invalid_relative_downgrade_path(self): @@ -169,13 +160,17 @@ class RevisionPathTest(MigrationTest): assert_raises_message( util.CommandError, "Relative revision -5 didn't produce 5 migrations", - self.env._downgrade_revs, "-5", self.b.revision + self.env._downgrade_revs, + "-5", + self.b.revision, ) assert_raises_message( util.CommandError, r"Relative revision \+2 didn't produce 2 migrations", - self.env._downgrade_revs, "+2", self.b.revision + self.env._downgrade_revs, + "+2", + self.b.revision, ) def test_invalid_move_rev_to_none(self): @@ -184,7 +179,9 @@ class RevisionPathTest(MigrationTest): util.CommandError, r"Destination %s is not a valid downgrade " r"target from current head\(s\)" % self.b.revision[0:3], - self.env._downgrade_revs, self.b.revision[0:3], None + self.env._downgrade_revs, + self.b.revision[0:3], + None, ) def test_invalid_move_higher_to_lower(self): @@ -193,7 +190,9 @@ class RevisionPathTest(MigrationTest): util.CommandError, r"Destination %s is not a valid downgrade " r"target from current head\(s\)" % self.c.revision[0:4], - self.env._downgrade_revs, self.c.revision[0:4], self.b.revision + self.env._downgrade_revs, + self.c.revision[0:4], + self.b.revision, ) def test_stamp_to_base(self): @@ -204,26 +203,27 @@ class RevisionPathTest(MigrationTest): class BranchedPathTest(MigrationTest): - @classmethod def setup_class(cls): cls.env = env = staging_env() - cls.a = env.generate_revision(util.rev_id(), '->a') - cls.b = env.generate_revision(util.rev_id(), 'a->b') + cls.a = env.generate_revision(util.rev_id(), "->a") + cls.b = env.generate_revision(util.rev_id(), "a->b") cls.c1 = env.generate_revision( - util.rev_id(), 'b->c1', - branch_labels='c1branch', - refresh=True) - cls.d1 = env.generate_revision(util.rev_id(), 'c1->d1') + util.rev_id(), "b->c1", branch_labels="c1branch", refresh=True + ) + cls.d1 = env.generate_revision(util.rev_id(), "c1->d1") cls.c2 = env.generate_revision( - util.rev_id(), 'b->c2', - branch_labels='c2branch', - head=cls.b.revision, splice=True) + util.rev_id(), + "b->c2", + branch_labels="c2branch", + head=cls.b.revision, + splice=True, + ) cls.d2 = env.generate_revision( - util.rev_id(), 'c2->d2', - head=cls.c2.revision) + util.rev_id(), "c2->d2", head=cls.c2.revision + ) @classmethod def teardown_class(cls): @@ -231,73 +231,87 @@ class BranchedPathTest(MigrationTest): def test_stamp_down_across_multiple_branch_to_branchpoint(self): heads = [self.d1.revision, self.c2.revision] - revs = self.env._stamp_revs( - self.b.revision, heads) + revs = self.env._stamp_revs(self.b.revision, heads) eq_(len(revs), 1) eq_( revs[0].merge_branch_idents(heads), # DELETE d1 revision, UPDATE c2 to b - ([self.d1.revision], self.c2.revision, self.b.revision) + ([self.d1.revision], self.c2.revision, self.b.revision), ) def test_stamp_to_labeled_base_multiple_heads(self): revs = self.env._stamp_revs( - "c1branch@base", [self.d1.revision, self.c2.revision]) + "c1branch@base", [self.d1.revision, self.c2.revision] + ) eq_(len(revs), 1) assert revs[0].should_delete_branch eq_(revs[0].delete_version_num, self.d1.revision) def test_stamp_to_labeled_head_multiple_heads(self): heads = [self.d1.revision, self.c2.revision] - revs = self.env._stamp_revs( - "c2branch@head", heads) + revs = self.env._stamp_revs("c2branch@head", heads) eq_(len(revs), 1) eq_( revs[0].merge_branch_idents(heads), # the c1branch remains unchanged - ([], self.c2.revision, self.d2.revision) + ([], self.c2.revision, self.d2.revision), ) def test_upgrade_single_branch(self): self._assert_upgrade( - self.d1.revision, self.b.revision, + self.d1.revision, + self.b.revision, [self.up_(self.c1), self.up_(self.d1)], - set([self.d1.revision]) + set([self.d1.revision]), ) def test_upgrade_multiple_branch(self): # move from a single head to multiple heads self._assert_upgrade( - (self.d1.revision, self.d2.revision), self.a.revision, - [self.up_(self.b), self.up_(self.c2), self.up_(self.d2), - self.up_(self.c1), self.up_(self.d1)], - set([self.d1.revision, self.d2.revision]) + (self.d1.revision, self.d2.revision), + self.a.revision, + [ + self.up_(self.b), + self.up_(self.c2), + self.up_(self.d2), + self.up_(self.c1), + self.up_(self.d1), + ], + set([self.d1.revision, self.d2.revision]), ) def test_downgrade_multiple_branch(self): self._assert_downgrade( - self.a.revision, (self.d1.revision, self.d2.revision), - [self.down_(self.d1), self.down_(self.c1), self.down_(self.d2), - self.down_(self.c2), self.down_(self.b)], - set([self.a.revision]) + self.a.revision, + (self.d1.revision, self.d2.revision), + [ + self.down_(self.d1), + self.down_(self.c1), + self.down_(self.d2), + self.down_(self.c2), + self.down_(self.b), + ], + set([self.a.revision]), ) def test_relative_upgrade(self): self._assert_upgrade( - "c2branch@head-1", self.b.revision, + "c2branch@head-1", + self.b.revision, [self.up_(self.c2)], - set([self.c2.revision]) + set([self.c2.revision]), ) def test_relative_downgrade(self): self._assert_downgrade( - "c2branch@base+2", [self.d2.revision, self.d1.revision], + "c2branch@base+2", + [self.d2.revision, self.d1.revision], [self.down_(self.d2), self.down_(self.c2), self.down_(self.d1)], - set([self.c1.revision]) + set([self.c1.revision]), ) @@ -311,43 +325,54 @@ class BranchFromMergepointTest(MigrationTest): @classmethod def setup_class(cls): cls.env = env = staging_env() - cls.a1 = env.generate_revision(util.rev_id(), '->a1') - cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1') - cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1') + cls.a1 = env.generate_revision(util.rev_id(), "->a1") + cls.b1 = env.generate_revision(util.rev_id(), "a1->b1") + cls.c1 = env.generate_revision(util.rev_id(), "b1->c1") cls.a2 = env.generate_revision( - util.rev_id(), '->a2', head=(), - refresh=True) + util.rev_id(), "->a2", head=(), refresh=True + ) cls.b2 = env.generate_revision( - util.rev_id(), 'a2->b2', head=cls.a2.revision) + util.rev_id(), "a2->b2", head=cls.a2.revision + ) cls.c2 = env.generate_revision( - util.rev_id(), 'b2->c2', head=cls.b2.revision) + util.rev_id(), "b2->c2", head=cls.b2.revision + ) # mergepoint between c1, c2 # d1 dependent on c2 cls.d1 = env.generate_revision( - util.rev_id(), 'd1', head=(cls.c1.revision, cls.c2.revision), - refresh=True) + util.rev_id(), + "d1", + head=(cls.c1.revision, cls.c2.revision), + refresh=True, + ) # but then c2 keeps going into d2 cls.d2 = env.generate_revision( - util.rev_id(), 'd2', head=cls.c2.revision, - refresh=True, splice=True) + util.rev_id(), + "d2", + head=cls.c2.revision, + refresh=True, + splice=True, + ) def test_mergepoint_to_only_one_side_upgrade(self): self._assert_upgrade( - self.d1.revision, (self.d2.revision, self.b1.revision), + self.d1.revision, + (self.d2.revision, self.b1.revision), [self.up_(self.c1), self.up_(self.d1)], - set([self.d2.revision, self.d1.revision]) + set([self.d2.revision, self.d1.revision]), ) def test_mergepoint_to_only_one_side_downgrade(self): self._assert_downgrade( - self.b1.revision, (self.d2.revision, self.d1.revision), + self.b1.revision, + (self.d2.revision, self.d1.revision), [self.down_(self.d1), self.down_(self.c1)], - set([self.d2.revision, self.b1.revision]) + set([self.d2.revision, self.b1.revision]), ) @@ -361,42 +386,56 @@ class BranchFrom3WayMergepointTest(MigrationTest): @classmethod def setup_class(cls): cls.env = env = staging_env() - cls.a1 = env.generate_revision(util.rev_id(), '->a1') - cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1') - cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1') + cls.a1 = env.generate_revision(util.rev_id(), "->a1") + cls.b1 = env.generate_revision(util.rev_id(), "a1->b1") + cls.c1 = env.generate_revision(util.rev_id(), "b1->c1") cls.a2 = env.generate_revision( - util.rev_id(), '->a2', head=(), - refresh=True) + util.rev_id(), "->a2", head=(), refresh=True + ) cls.b2 = env.generate_revision( - util.rev_id(), 'a2->b2', head=cls.a2.revision) + util.rev_id(), "a2->b2", head=cls.a2.revision + ) cls.c2 = env.generate_revision( - util.rev_id(), 'b2->c2', head=cls.b2.revision) + util.rev_id(), "b2->c2", head=cls.b2.revision + ) cls.a3 = env.generate_revision( - util.rev_id(), '->a3', head=(), - refresh=True) + util.rev_id(), "->a3", head=(), refresh=True + ) cls.b3 = env.generate_revision( - util.rev_id(), 'a3->b3', head=cls.a3.revision) + util.rev_id(), "a3->b3", head=cls.a3.revision + ) cls.c3 = env.generate_revision( - util.rev_id(), 'b3->c3', head=cls.b3.revision) + util.rev_id(), "b3->c3", head=cls.b3.revision + ) # mergepoint between c1, c2, c3 # d1 dependent on c2, c3 cls.d1 = env.generate_revision( - util.rev_id(), 'd1', head=( - cls.c1.revision, cls.c2.revision, cls.c3.revision), - refresh=True) + util.rev_id(), + "d1", + head=(cls.c1.revision, cls.c2.revision, cls.c3.revision), + refresh=True, + ) # but then c2 keeps going into d2 cls.d2 = env.generate_revision( - util.rev_id(), 'd2', head=cls.c2.revision, - refresh=True, splice=True) + util.rev_id(), + "d2", + head=cls.c2.revision, + refresh=True, + splice=True, + ) # c3 keeps going into d3 cls.d3 = env.generate_revision( - util.rev_id(), 'd3', head=cls.c3.revision, - refresh=True, splice=True) + util.rev_id(), + "d3", + head=cls.c3.revision, + refresh=True, + splice=True, + ) def test_mergepoint_to_only_one_side_upgrade(self): @@ -404,7 +443,7 @@ class BranchFrom3WayMergepointTest(MigrationTest): self.d1.revision, (self.d3.revision, self.d2.revision, self.b1.revision), [self.up_(self.c1), self.up_(self.d1)], - set([self.d3.revision, self.d2.revision, self.d1.revision]) + set([self.d3.revision, self.d2.revision, self.d1.revision]), ) def test_mergepoint_to_only_one_side_downgrade(self): @@ -412,7 +451,7 @@ class BranchFrom3WayMergepointTest(MigrationTest): self.b1.revision, (self.d3.revision, self.d2.revision, self.d1.revision), [self.down_(self.d1), self.down_(self.c1)], - set([self.d3.revision, self.d2.revision, self.b1.revision]) + set([self.d3.revision, self.d2.revision, self.b1.revision]), ) def test_mergepoint_to_two_sides_upgrade(self): @@ -422,14 +461,15 @@ class BranchFrom3WayMergepointTest(MigrationTest): (self.d3.revision, self.b2.revision, self.b1.revision), [self.up_(self.c2), self.up_(self.c1), self.up_(self.d1)], # this will merge b2 and b1 into d1 - set([self.d3.revision, self.d1.revision]) + set([self.d3.revision, self.d1.revision]), ) # but then! b2 will break out again if we keep going with it self._assert_upgrade( - self.d2.revision, (self.d3.revision, self.d1.revision), + self.d2.revision, + (self.d3.revision, self.d1.revision), [self.up_(self.d2)], - set([self.d3.revision, self.d2.revision, self.d1.revision]) + set([self.d3.revision, self.d2.revision, self.d1.revision]), ) @@ -438,6 +478,7 @@ class TwinMergeTest(MigrationTest): originating branches. """ + @classmethod def setup_class(cls): """ @@ -463,44 +504,43 @@ class TwinMergeTest(MigrationTest): """ cls.env = env = staging_env() - cls.a = env.generate_revision( - 'a', 'a' + cls.a = env.generate_revision("a", "a") + cls.b1 = env.generate_revision("b1", "b1", head=cls.a.revision) + cls.b2 = env.generate_revision( + "b2", "b2", splice=True, head=cls.a.revision + ) + cls.b3 = env.generate_revision( + "b3", "b3", splice=True, head=cls.a.revision ) - cls.b1 = env.generate_revision('b1', 'b1', - head=cls.a.revision) - cls.b2 = env.generate_revision('b2', 'b2', - splice=True, - head=cls.a.revision) - cls.b3 = env.generate_revision('b3', 'b3', - splice=True, - head=cls.a.revision) cls.c1 = env.generate_revision( - 'c1', 'c1', - head=(cls.b1.revision, cls.b2.revision, cls.b3.revision)) + "c1", + "c1", + head=(cls.b1.revision, cls.b2.revision, cls.b3.revision), + ) cls.c2 = env.generate_revision( - 'c2', 'c2', + "c2", + "c2", splice=True, - head=(cls.b1.revision, cls.b2.revision, cls.b3.revision)) + head=(cls.b1.revision, cls.b2.revision, cls.b3.revision), + ) - cls.d1 = env.generate_revision( - 'd1', 'd1', head=cls.c1.revision) + cls.d1 = env.generate_revision("d1", "d1", head=cls.c1.revision) - cls.d2 = env.generate_revision( - 'd2', 'd2', head=cls.c2.revision) + cls.d2 = env.generate_revision("d2", "d2", head=cls.c2.revision) def test_upgrade(self): head = HeadMaintainer(mock.Mock(), [self.a.revision]) steps = [ - (self.up_(self.b3), ('b3',)), - (self.up_(self.b1), ('b1', 'b3',)), - (self.up_(self.b2), ('b1', 'b2', 'b3',)), - (self.up_(self.c2), ('c2',)), - (self.up_(self.d2), ('d2',)), - (self.up_(self.c1), ('c1', 'd2')), - (self.up_(self.d1), ('d1', 'd2')), + (self.up_(self.b3), ("b3",)), + (self.up_(self.b1), ("b1", "b3")), + (self.up_(self.b2), ("b1", "b2", "b3")), + (self.up_(self.c2), ("c2",)), + (self.up_(self.d2), ("d2",)), + (self.up_(self.c1), ("c1", "d2")), + (self.up_(self.d1), ("d1", "d2")), ] for step, assert_ in steps: head.update_to_step(step) @@ -511,6 +551,7 @@ class NotQuiteTwinMergeTest(MigrationTest): """Test a variant of #297. """ + @classmethod def setup_class(cls): """ @@ -527,32 +568,26 @@ class NotQuiteTwinMergeTest(MigrationTest): """ cls.env = env = staging_env() - cls.a = env.generate_revision( - 'a', 'a' + cls.a = env.generate_revision("a", "a") + cls.b1 = env.generate_revision("b1", "b1", head=cls.a.revision) + cls.b2 = env.generate_revision( + "b2", "b2", splice=True, head=cls.a.revision + ) + cls.b3 = env.generate_revision( + "b3", "b3", splice=True, head=cls.a.revision ) - cls.b1 = env.generate_revision('b1', 'b1', - head=cls.a.revision) - cls.b2 = env.generate_revision('b2', 'b2', - splice=True, - head=cls.a.revision) - cls.b3 = env.generate_revision('b3', 'b3', - splice=True, - head=cls.a.revision) cls.c1 = env.generate_revision( - 'c1', 'c1', - head=(cls.b1.revision, cls.b2.revision)) + "c1", "c1", head=(cls.b1.revision, cls.b2.revision) + ) cls.c2 = env.generate_revision( - 'c2', 'c2', - splice=True, - head=(cls.b2.revision, cls.b3.revision)) + "c2", "c2", splice=True, head=(cls.b2.revision, cls.b3.revision) + ) - cls.d1 = env.generate_revision( - 'd1', 'd1', head=cls.c1.revision) + cls.d1 = env.generate_revision("d1", "d1", head=cls.c1.revision) - cls.d2 = env.generate_revision( - 'd2', 'd2', head=cls.c2.revision) + cls.d2 = env.generate_revision("d2", "d2", head=cls.c2.revision) def test_upgrade(self): head = HeadMaintainer(mock.Mock(), [self.a.revision]) @@ -568,14 +603,13 @@ class NotQuiteTwinMergeTest(MigrationTest): """ steps = [ - (self.up_(self.b2), ('b2',)), - (self.up_(self.b3), ('b2', 'b3',)), - (self.up_(self.c2), ('c2',)), - (self.up_(self.d2), ('d2',)), - - (self.up_(self.b1), ('b1', 'd2',)), - (self.up_(self.c1), ('c1', 'd2')), - (self.up_(self.d1), ('d1', 'd2')), + (self.up_(self.b2), ("b2",)), + (self.up_(self.b3), ("b2", "b3")), + (self.up_(self.c2), ("c2",)), + (self.up_(self.d2), ("d2",)), + (self.up_(self.b1), ("b1", "d2")), + (self.up_(self.c1), ("c1", "d2")), + (self.up_(self.d1), ("d1", "d2")), ] for step, assert_ in steps: head.update_to_step(step) @@ -583,32 +617,35 @@ class NotQuiteTwinMergeTest(MigrationTest): class DependsOnBranchTestOne(MigrationTest): - @classmethod def setup_class(cls): cls.env = env = staging_env() cls.a1 = env.generate_revision( - util.rev_id(), '->a1', - branch_labels=['lib1']) - cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1') - cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1') + util.rev_id(), "->a1", branch_labels=["lib1"] + ) + cls.b1 = env.generate_revision(util.rev_id(), "a1->b1") + cls.c1 = env.generate_revision(util.rev_id(), "b1->c1") - cls.a2 = env.generate_revision(util.rev_id(), '->a2', head=()) + cls.a2 = env.generate_revision(util.rev_id(), "->a2", head=()) cls.b2 = env.generate_revision( - util.rev_id(), 'a2->b2', head=cls.a2.revision) + util.rev_id(), "a2->b2", head=cls.a2.revision + ) cls.c2 = env.generate_revision( - util.rev_id(), 'b2->c2', head=cls.b2.revision, - depends_on=cls.c1.revision) + util.rev_id(), + "b2->c2", + head=cls.b2.revision, + depends_on=cls.c1.revision, + ) cls.d1 = env.generate_revision( - util.rev_id(), 'c1->d1', - head=cls.c1.revision) + util.rev_id(), "c1->d1", head=cls.c1.revision + ) cls.e1 = env.generate_revision( - util.rev_id(), 'd1->e1', - head=cls.d1.revision) + util.rev_id(), "d1->e1", head=cls.d1.revision + ) cls.f1 = env.generate_revision( - util.rev_id(), 'e1->f1', - head=cls.e1.revision) + util.rev_id(), "e1->f1", head=cls.e1.revision + ) def test_downgrade_to_dependency(self): heads = [self.c2.revision, self.d1.revision] @@ -625,7 +662,6 @@ class DependsOnBranchTestOne(MigrationTest): class DependsOnBranchTestTwo(MigrationTest): - @classmethod def setup_class(cls): """ @@ -656,32 +692,36 @@ class DependsOnBranchTestTwo(MigrationTest): """ cls.env = env = staging_env() - cls.a1 = env.generate_revision("a1", '->a1', head='base') - cls.a2 = env.generate_revision("a2", '->a2', head='base') - cls.a3 = env.generate_revision("a3", '->a3', head='base') - cls.amerge = env.generate_revision("amerge", 'amerge', head=[ - cls.a1.revision, cls.a2.revision, cls.a3.revision - ]) - - cls.b1 = env.generate_revision("b1", '->b1', head='base') - cls.b2 = env.generate_revision("b2", '->b2', head='base') - cls.bmerge = env.generate_revision("bmerge", 'bmerge', head=[ - cls.b1.revision, cls.b2.revision - ]) - - cls.c1 = env.generate_revision("c1", '->c1', head='base') - cls.c2 = env.generate_revision("c2", '->c2', head='base') - cls.c3 = env.generate_revision("c3", '->c3', head='base') - cls.cmerge = env.generate_revision("cmerge", 'cmerge', head=[ - cls.c1.revision, cls.c2.revision, cls.c3.revision - ]) + cls.a1 = env.generate_revision("a1", "->a1", head="base") + cls.a2 = env.generate_revision("a2", "->a2", head="base") + cls.a3 = env.generate_revision("a3", "->a3", head="base") + cls.amerge = env.generate_revision( + "amerge", + "amerge", + head=[cls.a1.revision, cls.a2.revision, cls.a3.revision], + ) + + cls.b1 = env.generate_revision("b1", "->b1", head="base") + cls.b2 = env.generate_revision("b2", "->b2", head="base") + cls.bmerge = env.generate_revision( + "bmerge", "bmerge", head=[cls.b1.revision, cls.b2.revision] + ) + + cls.c1 = env.generate_revision("c1", "->c1", head="base") + cls.c2 = env.generate_revision("c2", "->c2", head="base") + cls.c3 = env.generate_revision("c3", "->c3", head="base") + cls.cmerge = env.generate_revision( + "cmerge", + "cmerge", + head=[cls.c1.revision, cls.c2.revision, cls.c3.revision], + ) cls.d1 = env.generate_revision( - "d1", 'o', + "d1", + "o", head="base", - depends_on=[ - cls.a3.revision, cls.b2.revision, cls.c1.revision - ]) + depends_on=[cls.a3.revision, cls.b2.revision, cls.c1.revision], + ) def test_kaboom(self): # here's the upgrade path: @@ -690,55 +730,77 @@ class DependsOnBranchTestTwo(MigrationTest): heads = [ self.amerge.revision, - self.bmerge.revision, self.cmerge.revision, - self.d1.revision + self.bmerge.revision, + self.cmerge.revision, + self.d1.revision, ] self._assert_downgrade( - self.b2.revision, heads, + self.b2.revision, + heads, [self.down_(self.bmerge)], - set([ - self.amerge.revision, - self.b1.revision, self.cmerge.revision, self.d1.revision]) + set( + [ + self.amerge.revision, + self.b1.revision, + self.cmerge.revision, + self.d1.revision, + ] + ), ) # start with those heads.. heads = [ - self.amerge.revision, self.d1.revision, - self.b1.revision, self.cmerge.revision] + self.amerge.revision, + self.d1.revision, + self.b1.revision, + self.cmerge.revision, + ] # downgrade d1... self._assert_downgrade( - "d1@base", heads, + "d1@base", + heads, [self.down_(self.d1)], - # b2 has to be INSERTed, because it was implied by d1 - set([ - self.amerge.revision, self.b1.revision, - self.b2.revision, self.cmerge.revision]) + set( + [ + self.amerge.revision, + self.b1.revision, + self.b2.revision, + self.cmerge.revision, + ] + ), ) # start with those heads ... heads = [ - self.amerge.revision, self.b1.revision, - self.b2.revision, self.cmerge.revision + self.amerge.revision, + self.b1.revision, + self.b2.revision, + self.cmerge.revision, ] self._assert_downgrade( - "base", heads, + "base", + heads, [ - self.down_(self.amerge), self.down_(self.a1), - self.down_(self.a2), self.down_(self.a3), - self.down_(self.b1), self.down_(self.b2), - self.down_(self.cmerge), self.down_(self.c1), - self.down_(self.c2), self.down_(self.c3) + self.down_(self.amerge), + self.down_(self.a1), + self.down_(self.a2), + self.down_(self.a3), + self.down_(self.b1), + self.down_(self.b2), + self.down_(self.cmerge), + self.down_(self.c1), + self.down_(self.c2), + self.down_(self.c3), ], - set([]) + set([]), ) class DependsOnBranchTestThree(MigrationTest): - @classmethod def setup_class(cls): """ @@ -755,14 +817,18 @@ class DependsOnBranchTestThree(MigrationTest): """ cls.env = env = staging_env() - cls.a1 = env.generate_revision("a1", '->a1', head='base') - cls.a2 = env.generate_revision("a2", '->a2') + cls.a1 = env.generate_revision("a1", "->a1", head="base") + cls.a2 = env.generate_revision("a2", "->a2") - cls.b1 = env.generate_revision("b1", '->b1', head='base') - cls.b2 = env.generate_revision("b2", '->b2', depends_on='a2', head='b1') - cls.b3 = env.generate_revision("b3", '->b3', head='b2') + cls.b1 = env.generate_revision("b1", "->b1", head="base") + cls.b2 = env.generate_revision( + "b2", "->b2", depends_on="a2", head="b1" + ) + cls.b3 = env.generate_revision("b3", "->b3", head="b2") - cls.a3 = env.generate_revision("a3", '->a3', head='a2', depends_on='b1') + cls.a3 = env.generate_revision( + "a3", "->a3", head="a2", depends_on="b1" + ) def test_downgrade_over_crisscross(self): # this state was not possible prior to @@ -772,9 +838,10 @@ class DependsOnBranchTestThree(MigrationTest): # b2 because a2 is dependent on it, hence we add the ability # to remove half of a merge point. self._assert_downgrade( - 'b1', ['a3', 'b2'], + "b1", + ["a3", "b2"], [self.down_(self.b2)], - set(['a3']) # we have b1 also, which is implied by a3 + set(["a3"]), # we have b1 also, which is implied by a3 ) @@ -783,33 +850,35 @@ class DependsOnBranchLabelTest(MigrationTest): def setup_class(cls): cls.env = env = staging_env() cls.a1 = env.generate_revision( - util.rev_id(), '->a1', - branch_labels=['lib1']) - cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1') + util.rev_id(), "->a1", branch_labels=["lib1"] + ) + cls.b1 = env.generate_revision(util.rev_id(), "a1->b1") cls.c1 = env.generate_revision( - util.rev_id(), 'b1->c1', - branch_labels=['c1lib']) + util.rev_id(), "b1->c1", branch_labels=["c1lib"] + ) - cls.a2 = env.generate_revision(util.rev_id(), '->a2', head=()) + cls.a2 = env.generate_revision(util.rev_id(), "->a2", head=()) cls.b2 = env.generate_revision( - util.rev_id(), 'a2->b2', head=cls.a2.revision) + util.rev_id(), "a2->b2", head=cls.a2.revision + ) cls.c2 = env.generate_revision( - util.rev_id(), 'b2->c2', head=cls.b2.revision, - depends_on=['c1lib']) + util.rev_id(), "b2->c2", head=cls.b2.revision, depends_on=["c1lib"] + ) cls.d1 = env.generate_revision( - util.rev_id(), 'c1->d1', - head=cls.c1.revision) + util.rev_id(), "c1->d1", head=cls.c1.revision + ) cls.e1 = env.generate_revision( - util.rev_id(), 'd1->e1', - head=cls.d1.revision) + util.rev_id(), "d1->e1", head=cls.d1.revision + ) cls.f1 = env.generate_revision( - util.rev_id(), 'e1->f1', - head=cls.e1.revision) + util.rev_id(), "e1->f1", head=cls.e1.revision + ) def test_upgrade_path(self): self._assert_upgrade( - self.c2.revision, self.a2.revision, + self.c2.revision, + self.a2.revision, [ self.up_(self.a1), self.up_(self.b1), @@ -817,23 +886,23 @@ class DependsOnBranchLabelTest(MigrationTest): self.up_(self.b2), self.up_(self.c2), ], - set([self.c2.revision]) + set([self.c2.revision]), ) class ForestTest(MigrationTest): - @classmethod def setup_class(cls): cls.env = env = staging_env() - cls.a1 = env.generate_revision(util.rev_id(), '->a1') - cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1') + cls.a1 = env.generate_revision(util.rev_id(), "->a1") + cls.b1 = env.generate_revision(util.rev_id(), "a1->b1") cls.a2 = env.generate_revision( - util.rev_id(), '->a2', head=(), - refresh=True) + util.rev_id(), "->a2", head=(), refresh=True + ) cls.b2 = env.generate_revision( - util.rev_id(), 'a2->b2', head=cls.a2.revision) + util.rev_id(), "a2->b2", head=cls.a2.revision + ) @classmethod def teardown_class(cls): @@ -842,8 +911,12 @@ class ForestTest(MigrationTest): def test_base_to_heads(self): eq_( self.env._upgrade_revs("heads", "base"), - [self.up_(self.a2), self.up_(self.b2), - self.up_(self.a1), self.up_(self.b1)] + [ + self.up_(self.a2), + self.up_(self.b2), + self.up_(self.a1), + self.up_(self.b1), + ], ) def test_stamp_to_heads(self): @@ -851,40 +924,44 @@ class ForestTest(MigrationTest): eq_(len(revs), 2) eq_( set(r.to_revisions for r in revs), - set([(self.b1.revision,), (self.b2.revision,)]) + set([(self.b1.revision,), (self.b2.revision,)]), ) def test_stamp_to_heads_no_moves_needed(self): revs = self.env._stamp_revs( - "heads", (self.b1.revision, self.b2.revision)) + "heads", (self.b1.revision, self.b2.revision) + ) eq_(len(revs), 0) class MergedPathTest(MigrationTest): - @classmethod def setup_class(cls): cls.env = env = staging_env() - cls.a = env.generate_revision(util.rev_id(), '->a') - cls.b = env.generate_revision(util.rev_id(), 'a->b') + cls.a = env.generate_revision(util.rev_id(), "->a") + cls.b = env.generate_revision(util.rev_id(), "a->b") - cls.c1 = env.generate_revision(util.rev_id(), 'b->c1') - cls.d1 = env.generate_revision(util.rev_id(), 'c1->d1') + cls.c1 = env.generate_revision(util.rev_id(), "b->c1") + cls.d1 = env.generate_revision(util.rev_id(), "c1->d1") cls.c2 = env.generate_revision( - util.rev_id(), 'b->c2', - branch_labels='c2branch', - head=cls.b.revision, splice=True) + util.rev_id(), + "b->c2", + branch_labels="c2branch", + head=cls.b.revision, + splice=True, + ) cls.d2 = env.generate_revision( - util.rev_id(), 'c2->d2', - head=cls.c2.revision) + util.rev_id(), "c2->d2", head=cls.c2.revision + ) cls.e = env.generate_revision( - util.rev_id(), 'merge d1 and d2', - head=(cls.d1.revision, cls.d2.revision) + util.rev_id(), + "merge d1 and d2", + head=(cls.d1.revision, cls.d2.revision), ) - cls.f = env.generate_revision(util.rev_id(), 'e->f') + cls.f = env.generate_revision(util.rev_id(), "e->f") @classmethod def teardown_class(cls): @@ -897,7 +974,7 @@ class MergedPathTest(MigrationTest): eq_( revs[0].merge_branch_idents(heads), # no deletes, UPDATE e to c2 - ([], self.e.revision, self.c2.revision) + ([], self.e.revision, self.c2.revision), ) def test_stamp_down_across_merge_prior_branching(self): @@ -907,7 +984,7 @@ class MergedPathTest(MigrationTest): eq_( revs[0].merge_branch_idents(heads), # no deletes, UPDATE e to c2 - ([], self.e.revision, self.a.revision) + ([], self.e.revision, self.a.revision), ) def test_stamp_up_across_merge_from_single_branch(self): @@ -916,7 +993,7 @@ class MergedPathTest(MigrationTest): eq_( revs[0].merge_branch_idents([self.c2.revision]), # no deletes, UPDATE e to c2 - ([], self.c2.revision, self.e.revision) + ([], self.c2.revision, self.e.revision), ) def test_stamp_labled_head_across_merge_from_multiple_branch(self): @@ -924,23 +1001,23 @@ class MergedPathTest(MigrationTest): # d1 both in terms of "c2branch" as well as that the "head" # revision "f" is the head of both d1 and d2 revs = self.env._stamp_revs( - "c2branch@head", [self.d1.revision, self.c2.revision]) + "c2branch@head", [self.d1.revision, self.c2.revision] + ) eq_(len(revs), 1) eq_( revs[0].merge_branch_idents([self.d1.revision, self.c2.revision]), # DELETE d1 revision, UPDATE c2 to e - ([self.d1.revision], self.c2.revision, self.f.revision) + ([self.d1.revision], self.c2.revision, self.f.revision), ) def test_stamp_up_across_merge_from_multiple_branch(self): heads = [self.d1.revision, self.c2.revision] - revs = self.env._stamp_revs( - self.e.revision, heads) + revs = self.env._stamp_revs(self.e.revision, heads) eq_(len(revs), 1) eq_( revs[0].merge_branch_idents(heads), # DELETE d1 revision, UPDATE c2 to e - ([self.d1.revision], self.c2.revision, self.e.revision) + ([self.d1.revision], self.c2.revision, self.e.revision), ) def test_stamp_up_across_merge_prior_branching(self): @@ -950,7 +1027,7 @@ class MergedPathTest(MigrationTest): eq_( revs[0].merge_branch_idents(heads), # no deletes, UPDATE e to c2 - ([], self.b.revision, self.e.revision) + ([], self.b.revision, self.e.revision), ) def test_upgrade_across_merge_point(self): @@ -963,9 +1040,9 @@ class MergedPathTest(MigrationTest): self.up_(self.c1), # b->c1, create new branch self.up_(self.d1), self.up_(self.e), # d1/d2 -> e, merge branches - # (DELETE d2, UPDATE d1->e) - self.up_(self.f) - ] + # (DELETE d2, UPDATE d1->e) + self.up_(self.f), + ], ) def test_downgrade_across_merge_point(self): @@ -975,10 +1052,10 @@ class MergedPathTest(MigrationTest): [ self.down_(self.f), self.down_(self.e), # e -> d1 and d2, unmerge branches - # (UPDATE e->d1, INSERT d2) + # (UPDATE e->d1, INSERT d2) self.down_(self.d1), self.down_(self.c1), self.down_(self.d2), self.down_(self.c2), # c2->b, delete branch - ] + ], ) @@ -55,16 +55,14 @@ commands= {oracle}: python reap_oracle_dbs.py oracle_idents.txt +# thanks to https://julien.danjou.info/the-best-flake8-extensions/ [testenv:pep8] -deps=flake8 -commands = python -m flake8 {posargs} - - -[flake8] - -show-source = True -ignore = E711,E712,E721,D,N -# F841,F811,F401 -exclude=.venv,.git,.tox,dist,doc,*egg,build - - +deps= + flake8 + flake8-import-order + flake8-builtins + flake8-docstrings + flake8-rst-docstrings + # used by flake8-rst-docstrings + pygments +commands = flake8 ./alembic/ ./tests/ setup.py |