summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/__init__.py7
-rw-r--r--alembic/autogenerate/__init__.py14
-rw-r--r--alembic/autogenerate/api.py136
-rw-r--r--alembic/autogenerate/compare.py615
-rw-r--r--alembic/autogenerate/render.py355
-rw-r--r--alembic/autogenerate/rewriter.py12
-rw-r--r--alembic/command.py205
-rw-r--r--alembic/config.py233
-rw-r--r--alembic/ddl/base.py63
-rw-r--r--alembic/ddl/impl.py264
-rw-r--r--alembic/ddl/mssql.py145
-rw-r--r--alembic/ddl/mysql.py266
-rw-r--r--alembic/ddl/oracle.py43
-rw-r--r--alembic/ddl/postgresql.py296
-rw-r--r--alembic/ddl/sqlite.py45
-rw-r--r--alembic/op.py1
-rw-r--r--alembic/operations/__init__.py2
-rw-r--r--alembic/operations/base.py90
-rw-r--r--alembic/operations/batch.py155
-rw-r--r--alembic/operations/ops.py637
-rw-r--r--alembic/operations/schemaobj.py136
-rw-r--r--alembic/operations/toimpl.py40
-rw-r--r--alembic/runtime/environment.py126
-rw-r--r--alembic/runtime/migration.py333
-rw-r--r--alembic/script/__init__.py2
-rw-r--r--alembic/script/base.py347
-rw-r--r--alembic/script/revision.py373
-rw-r--r--alembic/templates/generic/env.py12
-rw-r--r--alembic/templates/multidb/env.py49
-rw-r--r--alembic/templates/pylons/env.py21
-rw-r--r--alembic/testing/__init__.py11
-rw-r--r--alembic/testing/assertions.py60
-rw-r--r--alembic/testing/compat.py9
-rw-r--r--alembic/testing/config.py4
-rw-r--r--alembic/testing/engines.py1
-rw-r--r--alembic/testing/env.py165
-rw-r--r--alembic/testing/exclusions.py102
-rw-r--r--alembic/testing/fixtures.py49
-rw-r--r--alembic/testing/mock.py3
-rw-r--r--alembic/testing/plugin/bootstrap.py7
-rw-r--r--alembic/testing/plugin/noseplugin.py17
-rw-r--r--alembic/testing/plugin/plugin_base.py365
-rw-r--r--alembic/testing/plugin/pytestplugin.py101
-rw-r--r--alembic/testing/provision.py42
-rw-r--r--alembic/testing/requirements.py30
-rw-r--r--alembic/testing/runner.py2
-rw-r--r--alembic/testing/util.py2
-rw-r--r--alembic/testing/warnings.py11
-rw-r--r--alembic/util/__init__.py46
-rw-r--r--alembic/util/compat.py94
-rw-r--r--alembic/util/langhelpers.py116
-rw-r--r--alembic/util/messaging.py14
-rw-r--r--alembic/util/pyfiles.py30
-rw-r--r--alembic/util/sqla_compat.py42
-rw-r--r--setup.cfg15
-rw-r--r--setup.py86
-rw-r--r--tests/_autogen_fixtures.py256
-rw-r--r--tests/_large_map.py290
-rwxr-xr-xtests/conftest.py10
-rw-r--r--tests/requirements.py83
-rw-r--r--tests/test_autogen_composition.py157
-rw-r--r--tests/test_autogen_diffs.py1016
-rw-r--r--tests/test_autogen_fks.py1135
-rw-r--r--tests/test_autogen_indexes.py1121
-rw-r--r--tests/test_autogen_render.py1324
-rw-r--r--tests/test_batch.py1453
-rw-r--r--tests/test_bulk_insert.py246
-rw-r--r--tests/test_command.py237
-rw-r--r--tests/test_config.py72
-rw-r--r--tests/test_environment.py66
-rw-r--r--tests/test_external_dialect.py85
-rw-r--r--tests/test_mssql.py242
-rw-r--r--tests/test_mysql.py319
-rw-r--r--tests/test_offline_environment.py203
-rw-r--r--tests/test_op.py664
-rw-r--r--tests/test_op_naming_convention.py122
-rw-r--r--tests/test_oracle.py111
-rw-r--r--tests/test_postgresql.py613
-rw-r--r--tests/test_revision.py1027
-rw-r--r--tests/test_script_consumption.py285
-rw-r--r--tests/test_script_production.py572
-rw-r--r--tests/test_sqlite.py23
-rw-r--r--tests/test_version_table.py255
-rw-r--r--tests/test_version_traversal.py703
-rw-r--r--tox.ini22
85 files changed, 10842 insertions, 8317 deletions
diff --git a/alembic/__init__.py b/alembic/__init__.py
index 3432a88..a7a2845 100644
--- a/alembic/__init__.py
+++ b/alembic/__init__.py
@@ -1,6 +1,6 @@
from os import path
-__version__ = '1.0.6'
+__version__ = "1.0.6"
package_dir = path.abspath(path.dirname(__file__))
@@ -11,5 +11,6 @@ from . import context # noqa
import sys
from .runtime import environment
from .runtime import migration
-sys.modules['alembic.migration'] = migration
-sys.modules['alembic.environment'] = environment
+
+sys.modules["alembic.migration"] = migration
+sys.modules["alembic.environment"] = environment
diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py
index 142f55d..ad3e6e1 100644
--- a/alembic/autogenerate/__init__.py
+++ b/alembic/autogenerate/__init__.py
@@ -1,8 +1,10 @@
-from .api import ( # noqa
- compare_metadata, _render_migration_diffs,
- produce_migrations, render_python_code,
- RevisionContext
- )
+from .api import ( # noqa
+ compare_metadata,
+ _render_migration_diffs,
+ produce_migrations,
+ render_python_code,
+ RevisionContext,
+)
from .compare import _produce_net_changes, comparators # noqa
from .render import render_op_text, renderers # noqa
-from .rewriter import Rewriter # noqa \ No newline at end of file
+from .rewriter import Rewriter # noqa
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 15b5b6b..cfd6e86 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -136,8 +136,8 @@ def produce_migrations(context, metadata):
def render_python_code(
up_or_down_op,
- sqlalchemy_module_prefix='sa.',
- alembic_module_prefix='op.',
+ sqlalchemy_module_prefix="sa.",
+ alembic_module_prefix="op.",
render_as_batch=False,
imports=(),
render_item=None,
@@ -150,16 +150,17 @@ def render_python_code(
"""
opts = {
- 'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
- 'alembic_module_prefix': alembic_module_prefix,
- 'render_item': render_item,
- 'render_as_batch': render_as_batch,
+ "sqlalchemy_module_prefix": sqlalchemy_module_prefix,
+ "alembic_module_prefix": alembic_module_prefix,
+ "render_item": render_item,
+ "render_as_batch": render_as_batch,
}
autogen_context = AutogenContext(None, opts=opts)
autogen_context.imports = set(imports)
- return render._indent(render._render_cmd_body(
- up_or_down_op, autogen_context))
+ return render._indent(
+ render._render_cmd_body(up_or_down_op, autogen_context)
+ )
def _render_migration_diffs(context, template_args):
@@ -240,42 +241,53 @@ class AutogenContext(object):
"""The :class:`.MigrationContext` established by the ``env.py`` script."""
def __init__(
- self, migration_context, metadata=None,
- opts=None, autogenerate=True):
-
- if autogenerate and \
- migration_context is not None and migration_context.as_sql:
+ self, migration_context, metadata=None, opts=None, autogenerate=True
+ ):
+
+ if (
+ autogenerate
+ and migration_context is not None
+ and migration_context.as_sql
+ ):
raise util.CommandError(
"autogenerate can't use as_sql=True as it prevents querying "
- "the database for schema information")
+ "the database for schema information"
+ )
if opts is None:
opts = migration_context.opts
- self.metadata = metadata = opts.get('target_metadata', None) \
- if metadata is None else metadata
+ self.metadata = metadata = (
+ opts.get("target_metadata", None) if metadata is None else metadata
+ )
- if autogenerate and metadata is None and \
- migration_context is not None and \
- migration_context.script is not None:
+ if (
+ autogenerate
+ and metadata is None
+ and migration_context is not None
+ and migration_context.script is not None
+ ):
raise util.CommandError(
"Can't proceed with --autogenerate option; environment "
"script %s does not provide "
- "a MetaData object or sequence of objects to the context." % (
- migration_context.script.env_py_location
- ))
+ "a MetaData object or sequence of objects to the context."
+ % (migration_context.script.env_py_location)
+ )
- include_symbol = opts.get('include_symbol', None)
- include_object = opts.get('include_object', None)
+ include_symbol = opts.get("include_symbol", None)
+ include_object = opts.get("include_object", None)
object_filters = []
if include_symbol:
+
def include_symbol_filter(
- object, name, type_, reflected, compare_to):
+ object, name, type_, reflected, compare_to
+ ):
if type_ == "table":
return include_symbol(name, object.schema)
else:
return True
+
object_filters.append(include_symbol_filter)
if include_object:
object_filters.append(include_object)
@@ -357,8 +369,8 @@ class AutogenContext(object):
if intersect:
raise ValueError(
"Duplicate table keys across multiple "
- "MetaData objects: %s" %
- (", ".join('"%s"' % key for key in sorted(intersect)))
+ "MetaData objects: %s"
+ % (", ".join('"%s"' % key for key in sorted(intersect)))
)
result.update(m.tables)
@@ -369,26 +381,29 @@ class RevisionContext(object):
"""Maintains configuration and state that's specific to a revision
file generation operation."""
- def __init__(self, config, script_directory, command_args,
- process_revision_directives=None):
+ def __init__(
+ self,
+ config,
+ script_directory,
+ command_args,
+ process_revision_directives=None,
+ ):
self.config = config
self.script_directory = script_directory
self.command_args = command_args
self.process_revision_directives = process_revision_directives
self.template_args = {
- 'config': config # Let templates use config for
- # e.g. multiple databases
+ "config": config # Let templates use config for
+ # e.g. multiple databases
}
- self.generated_revisions = [
- self._default_revision()
- ]
+ self.generated_revisions = [self._default_revision()]
def _to_script(self, migration_script):
template_args = {}
for k, v in self.template_args.items():
template_args.setdefault(k, v)
- if getattr(migration_script, '_needs_render', False):
+ if getattr(migration_script, "_needs_render", False):
autogen_context = self._last_autogen_context
# clear out existing imports if we are doing multiple
@@ -409,7 +424,8 @@ class RevisionContext(object):
branch_labels=migration_script.branch_label,
version_path=migration_script.version_path,
depends_on=migration_script.depends_on,
- **template_args)
+ **template_args
+ )
def run_autogenerate(self, rev, migration_context):
self._run_environment(rev, migration_context, True)
@@ -419,21 +435,24 @@ class RevisionContext(object):
def _run_environment(self, rev, migration_context, autogenerate):
if autogenerate:
- if self.command_args['sql']:
+ if self.command_args["sql"]:
raise util.CommandError(
- "Using --sql with --autogenerate does not make any sense")
- if set(self.script_directory.get_revisions(rev)) != \
- set(self.script_directory.get_revisions("heads")):
+ "Using --sql with --autogenerate does not make any sense"
+ )
+ if set(self.script_directory.get_revisions(rev)) != set(
+ self.script_directory.get_revisions("heads")
+ ):
raise util.CommandError("Target database is not up to date.")
- upgrade_token = migration_context.opts['upgrade_token']
- downgrade_token = migration_context.opts['downgrade_token']
+ upgrade_token = migration_context.opts["upgrade_token"]
+ downgrade_token = migration_context.opts["downgrade_token"]
migration_script = self.generated_revisions[-1]
- if not getattr(migration_script, '_needs_render', False):
+ if not getattr(migration_script, "_needs_render", False):
migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
- migration_script.downgrade_ops_list[-1].downgrade_token = \
- downgrade_token
+ migration_script.downgrade_ops_list[
+ -1
+ ].downgrade_token = downgrade_token
migration_script._needs_render = True
else:
migration_script._upgrade_ops.append(
@@ -443,18 +462,21 @@ class RevisionContext(object):
ops.DowngradeOps([], downgrade_token=downgrade_token)
)
- self._last_autogen_context = autogen_context = \
- AutogenContext(migration_context, autogenerate=autogenerate)
+ self._last_autogen_context = autogen_context = AutogenContext(
+ migration_context, autogenerate=autogenerate
+ )
if autogenerate:
compare._populate_migration_script(
- autogen_context, migration_script)
+ autogen_context, migration_script
+ )
if self.process_revision_directives:
self.process_revision_directives(
- migration_context, rev, self.generated_revisions)
+ migration_context, rev, self.generated_revisions
+ )
- hook = migration_context.opts['process_revision_directives']
+ hook = migration_context.opts["process_revision_directives"]
if hook:
hook(migration_context, rev, self.generated_revisions)
@@ -463,15 +485,15 @@ class RevisionContext(object):
def _default_revision(self):
op = ops.MigrationScript(
- rev_id=self.command_args['rev_id'] or util.rev_id(),
- message=self.command_args['message'],
+ rev_id=self.command_args["rev_id"] or util.rev_id(),
+ message=self.command_args["message"],
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
- head=self.command_args['head'],
- splice=self.command_args['splice'],
- branch_label=self.command_args['branch_label'],
- version_path=self.command_args['version_path'],
- depends_on=self.command_args['depends_on']
+ head=self.command_args["head"],
+ splice=self.command_args["splice"],
+ branch_label=self.command_args["branch_label"],
+ version_path=self.command_args["version_path"],
+ depends_on=self.command_args["depends_on"],
)
return op
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 8b41647..7ff8be6 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -29,7 +29,7 @@ comparators = util.Dispatcher(uselist=True)
def _produce_net_changes(autogen_context, upgrade_ops):
connection = autogen_context.connection
- include_schemas = autogen_context.opts.get('include_schemas', False)
+ include_schemas = autogen_context.opts.get("include_schemas", False)
inspector = Inspector.from_engine(connection)
@@ -55,8 +55,9 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
conn_table_names = set()
- version_table_schema = \
+ version_table_schema = (
autogen_context.migration_context.version_table_schema
+ )
version_table = autogen_context.migration_context.version_table
for s in schemas:
@@ -71,12 +72,22 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
[(table.schema, table.name) for table in autogen_context.sorted_tables]
).difference([(version_table_schema, version_table)])
- _compare_tables(conn_table_names, metadata_table_names,
- inspector, upgrade_ops, autogen_context)
+ _compare_tables(
+ conn_table_names,
+ metadata_table_names,
+ inspector,
+ upgrade_ops,
+ autogen_context,
+ )
-def _compare_tables(conn_table_names, metadata_table_names,
- inspector, upgrade_ops, autogen_context):
+def _compare_tables(
+ conn_table_names,
+ metadata_table_names,
+ inspector,
+ upgrade_ops,
+ autogen_context,
+):
default_schema = inspector.bind.dialect.default_schema_name
@@ -85,10 +96,12 @@ def _compare_tables(conn_table_names, metadata_table_names,
# of table names from local metadata that also have "None" if schema
# == default_schema_name. Most setups will be like this anyway but
# some are not (see #170)
- metadata_table_names_no_dflt_schema = OrderedSet([
- (schema if schema != default_schema else None, tname)
- for schema, tname in metadata_table_names
- ])
+ metadata_table_names_no_dflt_schema = OrderedSet(
+ [
+ (schema if schema != default_schema else None, tname)
+ for schema, tname in metadata_table_names
+ ]
+ )
# to adjust for the MetaData collection storing the tables either
# as "schemaname.tablename" or just "tablename", create a new lookup
@@ -97,27 +110,34 @@ def _compare_tables(conn_table_names, metadata_table_names,
(
no_dflt_schema,
autogen_context.table_key_to_table[
- sa_schema._get_table_key(tname, schema)]
+ sa_schema._get_table_key(tname, schema)
+ ],
)
for no_dflt_schema, (schema, tname) in zip(
- metadata_table_names_no_dflt_schema,
- metadata_table_names)
+ metadata_table_names_no_dflt_schema, metadata_table_names
+ )
)
metadata_table_names = metadata_table_names_no_dflt_schema
for s, tname in metadata_table_names.difference(conn_table_names):
- name = '%s.%s' % (s, tname) if s else tname
+ name = "%s.%s" % (s, tname) if s else tname
metadata_table = tname_to_table[(s, tname)]
if autogen_context.run_filters(
- metadata_table, tname, "table", False, None):
+ metadata_table, tname, "table", False, None
+ ):
upgrade_ops.ops.append(
- ops.CreateTableOp.from_table(metadata_table))
+ ops.CreateTableOp.from_table(metadata_table)
+ )
log.info("Detected added table %r", name)
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
comparators.dispatch("table")(
- autogen_context, modify_table_ops,
- s, tname, None, metadata_table
+ autogen_context,
+ modify_table_ops,
+ s,
+ tname,
+ None,
+ metadata_table,
)
if not modify_table_ops.is_empty():
upgrade_ops.ops.append(modify_table_ops)
@@ -132,23 +152,22 @@ def _compare_tables(conn_table_names, metadata_table_names,
event.listen(
t,
"column_reflect",
- autogen_context.migration_context.impl.
- _compat_autogen_column_reflect(inspector))
+ autogen_context.migration_context.impl._compat_autogen_column_reflect(
+ inspector
+ ),
+ )
inspector.reflecttable(t, None)
if autogen_context.run_filters(t, tname, "table", True, None):
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
comparators.dispatch("table")(
- autogen_context, modify_table_ops,
- s, tname, t, None
+ autogen_context, modify_table_ops, s, tname, t, None
)
if not modify_table_ops.is_empty():
upgrade_ops.ops.append(modify_table_ops)
- upgrade_ops.ops.append(
- ops.DropTableOp.from_table(t)
- )
+ upgrade_ops.ops.append(ops.DropTableOp.from_table(t))
log.info("Detected removed table %r", name)
existing_tables = conn_table_names.intersection(metadata_table_names)
@@ -163,31 +182,41 @@ def _compare_tables(conn_table_names, metadata_table_names,
event.listen(
t,
"column_reflect",
- autogen_context.migration_context.impl.
- _compat_autogen_column_reflect(inspector))
+ autogen_context.migration_context.impl._compat_autogen_column_reflect(
+ inspector
+ ),
+ )
inspector.reflecttable(t, None)
conn_column_info[(s, tname)] = t
- for s, tname in sorted(existing_tables, key=lambda x: (x[0] or '', x[1])):
+ for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
s = s or None
- name = '%s.%s' % (s, tname) if s else tname
+ name = "%s.%s" % (s, tname) if s else tname
metadata_table = tname_to_table[(s, tname)]
conn_table = existing_metadata.tables[name]
if autogen_context.run_filters(
- metadata_table, tname, "table", False,
- conn_table):
+ metadata_table, tname, "table", False, conn_table
+ ):
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
with _compare_columns(
- s, tname,
+ s,
+ tname,
conn_table,
metadata_table,
- modify_table_ops, autogen_context, inspector):
+ modify_table_ops,
+ autogen_context,
+ inspector,
+ ):
comparators.dispatch("table")(
- autogen_context, modify_table_ops,
- s, tname, conn_table, metadata_table
+ autogen_context,
+ modify_table_ops,
+ s,
+ tname,
+ conn_table,
+ metadata_table,
)
if not modify_table_ops.is_empty():
@@ -196,41 +225,41 @@ def _compare_tables(conn_table_names, metadata_table_names,
def _make_index(params, conn_table):
ix = sa_schema.Index(
- params['name'],
- *[conn_table.c[cname] for cname in params['column_names']],
- unique=params['unique']
+ params["name"],
+ *[conn_table.c[cname] for cname in params["column_names"]],
+ unique=params["unique"]
)
- if 'duplicates_constraint' in params:
- ix.info['duplicates_constraint'] = params['duplicates_constraint']
+ if "duplicates_constraint" in params:
+ ix.info["duplicates_constraint"] = params["duplicates_constraint"]
return ix
def _make_unique_constraint(params, conn_table):
uq = sa_schema.UniqueConstraint(
- *[conn_table.c[cname] for cname in params['column_names']],
- name=params['name']
+ *[conn_table.c[cname] for cname in params["column_names"]],
+ name=params["name"]
)
- if 'duplicates_index' in params:
- uq.info['duplicates_index'] = params['duplicates_index']
+ if "duplicates_index" in params:
+ uq.info["duplicates_index"] = params["duplicates_index"]
return uq
def _make_foreign_key(params, conn_table):
- tname = params['referred_table']
- if params['referred_schema']:
- tname = "%s.%s" % (params['referred_schema'], tname)
+ tname = params["referred_table"]
+ if params["referred_schema"]:
+ tname = "%s.%s" % (params["referred_schema"], tname)
- options = params.get('options', {})
+ options = params.get("options", {})
const = sa_schema.ForeignKeyConstraint(
- [conn_table.c[cname] for cname in params['constrained_columns']],
- ["%s.%s" % (tname, n) for n in params['referred_columns']],
- onupdate=options.get('onupdate'),
- ondelete=options.get('ondelete'),
- deferrable=options.get('deferrable'),
- initially=options.get('initially'),
- name=params['name']
+ [conn_table.c[cname] for cname in params["constrained_columns"]],
+ ["%s.%s" % (tname, n) for n in params["referred_columns"]],
+ onupdate=options.get("onupdate"),
+ ondelete=options.get("ondelete"),
+ deferrable=options.get("deferrable"),
+ initially=options.get("initially"),
+ name=params["name"],
)
# needed by 0.7
conn_table.append_constraint(const)
@@ -238,21 +267,30 @@ def _make_foreign_key(params, conn_table):
@contextlib.contextmanager
-def _compare_columns(schema, tname, conn_table, metadata_table,
- modify_table_ops, autogen_context, inspector):
- name = '%s.%s' % (schema, tname) if schema else tname
+def _compare_columns(
+ schema,
+ tname,
+ conn_table,
+ metadata_table,
+ modify_table_ops,
+ autogen_context,
+ inspector,
+):
+ name = "%s.%s" % (schema, tname) if schema else tname
metadata_cols_by_name = dict(
- (c.name, c) for c in metadata_table.c if not c.system)
+ (c.name, c) for c in metadata_table.c if not c.system
+ )
conn_col_names = dict((c.name, c) for c in conn_table.c)
metadata_col_names = OrderedSet(sorted(metadata_cols_by_name))
for cname in metadata_col_names.difference(conn_col_names):
if autogen_context.run_filters(
- metadata_cols_by_name[cname], cname,
- "column", False, None):
+ metadata_cols_by_name[cname], cname, "column", False, None
+ ):
modify_table_ops.ops.append(
ops.AddColumnOp.from_column_and_tablename(
- schema, tname, metadata_cols_by_name[cname])
+ schema, tname, metadata_cols_by_name[cname]
+ )
)
log.info("Detected added column '%s.%s'", name, cname)
@@ -260,15 +298,19 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
metadata_col = metadata_cols_by_name[colname]
conn_col = conn_table.c[colname]
if not autogen_context.run_filters(
- metadata_col, colname, "column", False,
- conn_col):
+ metadata_col, colname, "column", False, conn_col
+ ):
continue
- alter_column_op = ops.AlterColumnOp(
- tname, colname, schema=schema)
+ alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema)
comparators.dispatch("column")(
- autogen_context, alter_column_op,
- schema, tname, colname, conn_col, metadata_col
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ colname,
+ conn_col,
+ metadata_col,
)
if alter_column_op.has_changes():
@@ -278,8 +320,8 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
for cname in set(conn_col_names).difference(metadata_col_names):
if autogen_context.run_filters(
- conn_table.c[cname], cname,
- "column", True, None):
+ conn_table.c[cname], cname, "column", True, None
+ ):
modify_table_ops.ops.append(
ops.DropColumnOp.from_column_and_tablename(
schema, tname, conn_table.c[cname]
@@ -289,7 +331,6 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
class _constraint_sig(object):
-
def md_name_to_sql_name(self, context):
return self.name
@@ -340,36 +381,47 @@ class _fk_constraint_sig(_constraint_sig):
self.name = const.name
(
- self.source_schema, self.source_table,
- self.source_columns, self.target_schema, self.target_table,
+ self.source_schema,
+ self.source_table,
+ self.source_columns,
+ self.target_schema,
+ self.target_table,
self.target_columns,
- onupdate, ondelete,
- deferrable, initially) = _fk_spec(const)
+ onupdate,
+ ondelete,
+ deferrable,
+ initially,
+ ) = _fk_spec(const)
self.sig = (
- self.source_schema, self.source_table, tuple(self.source_columns),
- self.target_schema, self.target_table, tuple(self.target_columns)
+ self.source_schema,
+ self.source_table,
+ tuple(self.source_columns),
+ self.target_schema,
+ self.target_table,
+ tuple(self.target_columns),
)
if include_options:
self.sig += (
- (None if onupdate.lower() == 'no action'
- else onupdate.lower())
- if onupdate else None,
- (None if ondelete.lower() == 'no action'
- else ondelete.lower())
- if ondelete else None,
+ (None if onupdate.lower() == "no action" else onupdate.lower())
+ if onupdate
+ else None,
+ (None if ondelete.lower() == "no action" else ondelete.lower())
+ if ondelete
+ else None,
# convert initially + deferrable into one three-state value
"initially_deferrable"
if initially and initially.lower() == "deferred"
- else "deferrable" if deferrable
- else "not deferrable"
+ else "deferrable"
+ if deferrable
+ else "not deferrable",
)
@comparators.dispatch_for("table")
def _compare_indexes_and_uniques(
- autogen_context, modify_ops, schema, tname, conn_table,
- metadata_table):
+ autogen_context, modify_ops, schema, tname, conn_table, metadata_table
+):
inspector = autogen_context.inspector
is_create_table = conn_table is None
@@ -378,7 +430,8 @@ def _compare_indexes_and_uniques(
# 1a. get raw indexes and unique constraints from metadata ...
if metadata_table is not None:
metadata_unique_constraints = set(
- uq for uq in metadata_table.constraints
+ uq
+ for uq in metadata_table.constraints
if isinstance(uq, sa_schema.UniqueConstraint)
)
metadata_indexes = set(metadata_table.indexes)
@@ -397,7 +450,8 @@ def _compare_indexes_and_uniques(
if hasattr(inspector, "get_unique_constraints"):
try:
conn_uniques = inspector.get_unique_constraints(
- tname, schema=schema)
+ tname, schema=schema
+ )
supports_unique_constraints = True
except NotImplementedError:
pass
@@ -408,7 +462,7 @@ def _compare_indexes_and_uniques(
pass
else:
for uq in conn_uniques:
- if uq.get('duplicates_index'):
+ if uq.get("duplicates_index"):
unique_constraints_duplicate_unique_indexes = True
try:
conn_indexes = inspector.get_indexes(tname, schema=schema)
@@ -421,8 +475,10 @@ def _compare_indexes_and_uniques(
# for DROP TABLE uniques are inline, don't need them
conn_uniques = set()
else:
- conn_uniques = set(_make_unique_constraint(uq_def, conn_table)
- for uq_def in conn_uniques)
+ conn_uniques = set(
+ _make_unique_constraint(uq_def, conn_table)
+ for uq_def in conn_uniques
+ )
conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes)
@@ -431,64 +487,71 @@ def _compare_indexes_and_uniques(
if unique_constraints_duplicate_unique_indexes:
_correct_for_uq_duplicates_uix(
- conn_uniques, conn_indexes,
+ conn_uniques,
+ conn_indexes,
metadata_unique_constraints,
- metadata_indexes
+ metadata_indexes,
)
# 3. give the dialect a chance to omit indexes and constraints that
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
- autogen_context.migration_context.impl.\
- correct_for_autogen_constraints(
- conn_uniques, conn_indexes,
- metadata_unique_constraints,
- metadata_indexes)
+ autogen_context.migration_context.impl.correct_for_autogen_constraints(
+ conn_uniques,
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+ )
# 4. organize the constraints into "signature" collections, the
# _constraint_sig() objects provide a consistent facade over both
# Index and UniqueConstraint so we can easily work with them
# interchangeably
- metadata_unique_constraints = set(_uq_constraint_sig(uq)
- for uq in metadata_unique_constraints
- )
+ metadata_unique_constraints = set(
+ _uq_constraint_sig(uq) for uq in metadata_unique_constraints
+ )
metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes)
conn_unique_constraints = set(
- _uq_constraint_sig(uq) for uq in conn_uniques)
+ _uq_constraint_sig(uq) for uq in conn_uniques
+ )
conn_indexes = set(_ix_constraint_sig(ix) for ix in conn_indexes)
# 5. index things by name, for those objects that have names
metadata_names = dict(
- (c.md_name_to_sql_name(autogen_context), c) for c in
- metadata_unique_constraints.union(metadata_indexes)
- if c.name is not None)
+ (c.md_name_to_sql_name(autogen_context), c)
+ for c in metadata_unique_constraints.union(metadata_indexes)
+ if c.name is not None
+ )
conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
- conn_names = dict((c.name, c) for c in
- conn_unique_constraints.union(conn_indexes)
- if c.name is not None)
+ conn_names = dict(
+ (c.name, c)
+ for c in conn_unique_constraints.union(conn_indexes)
+ if c.name is not None
+ )
doubled_constraints = dict(
(name, (conn_uniques_by_name[name], conn_indexes_by_name[name]))
- for name in set(
- conn_uniques_by_name).intersection(conn_indexes_by_name)
+ for name in set(conn_uniques_by_name).intersection(
+ conn_indexes_by_name
+ )
)
# 6. index things by "column signature", to help with unnamed unique
# constraints.
conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
metadata_uniques_by_sig = dict(
- (uq.sig, uq) for uq in metadata_unique_constraints)
- metadata_indexes_by_sig = dict(
- (ix.sig, ix) for ix in metadata_indexes)
+ (uq.sig, uq) for uq in metadata_unique_constraints
+ )
+ metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes)
unnamed_metadata_uniques = dict(
- (uq.sig, uq) for uq in
- metadata_unique_constraints if uq.name is None)
+ (uq.sig, uq) for uq in metadata_unique_constraints if uq.name is None
+ )
# assumptions:
# 1. a unique constraint or an index from the connection *always*
@@ -501,14 +564,14 @@ def _compare_indexes_and_uniques(
def obj_added(obj):
if obj.is_index:
if autogen_context.run_filters(
- obj.const, obj.name, "index", False, None):
- modify_ops.ops.append(
- ops.CreateIndexOp.from_index(obj.const)
+ obj.const, obj.name, "index", False, None
+ ):
+ modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const))
+ log.info(
+ "Detected added index '%s' on %s",
+ obj.name,
+ ", ".join(["'%s'" % obj.column_names]),
)
- log.info("Detected added index '%s' on %s",
- obj.name, ', '.join([
- "'%s'" % obj.column_names
- ]))
else:
if not supports_unique_constraints:
# can't report unique indexes as added if we don't
@@ -518,15 +581,16 @@ def _compare_indexes_and_uniques(
# unique constraints are created inline with table defs
return
if autogen_context.run_filters(
- obj.const, obj.name,
- "unique_constraint", False, None):
+ obj.const, obj.name, "unique_constraint", False, None
+ ):
modify_ops.ops.append(
ops.AddConstraintOp.from_constraint(obj.const)
)
- log.info("Detected added unique constraint '%s' on %s",
- obj.name, ', '.join([
- "'%s'" % obj.column_names
- ]))
+ log.info(
+ "Detected added unique constraint '%s' on %s",
+ obj.name,
+ ", ".join(["'%s'" % obj.column_names]),
+ )
def obj_removed(obj):
if obj.is_index:
@@ -537,48 +601,52 @@ def _compare_indexes_and_uniques(
return
if autogen_context.run_filters(
- obj.const, obj.name, "index", True, None):
- modify_ops.ops.append(
- ops.DropIndexOp.from_index(obj.const)
- )
+ obj.const, obj.name, "index", True, None
+ ):
+ modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const))
log.info(
- "Detected removed index '%s' on '%s'", obj.name, tname)
+ "Detected removed index '%s' on '%s'", obj.name, tname
+ )
else:
if is_create_table or is_drop_table:
# if the whole table is being dropped, we don't need to
# consider unique constraint separately
return
if autogen_context.run_filters(
- obj.const, obj.name,
- "unique_constraint", True, None):
+ obj.const, obj.name, "unique_constraint", True, None
+ ):
modify_ops.ops.append(
ops.DropConstraintOp.from_constraint(obj.const)
)
- log.info("Detected removed unique constraint '%s' on '%s'",
- obj.name, tname
- )
+ log.info(
+ "Detected removed unique constraint '%s' on '%s'",
+ obj.name,
+ tname,
+ )
def obj_changed(old, new, msg):
if old.is_index:
if autogen_context.run_filters(
- new.const, new.name, "index",
- False, old.const):
- log.info("Detected changed index '%s' on '%s':%s",
- old.name, tname, ', '.join(msg)
- )
- modify_ops.ops.append(
- ops.DropIndexOp.from_index(old.const)
- )
- modify_ops.ops.append(
- ops.CreateIndexOp.from_index(new.const)
+ new.const, new.name, "index", False, old.const
+ ):
+ log.info(
+ "Detected changed index '%s' on '%s':%s",
+ old.name,
+ tname,
+ ", ".join(msg),
)
+ modify_ops.ops.append(ops.DropIndexOp.from_index(old.const))
+ modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const))
else:
if autogen_context.run_filters(
- new.const, new.name,
- "unique_constraint", False, old.const):
- log.info("Detected changed unique constraint '%s' on '%s':%s",
- old.name, tname, ', '.join(msg)
- )
+ new.const, new.name, "unique_constraint", False, old.const
+ ):
+ log.info(
+ "Detected changed unique constraint '%s' on '%s':%s",
+ old.name,
+ tname,
+ ", ".join(msg),
+ )
modify_ops.ops.append(
ops.DropConstraintOp.from_constraint(old.const)
)
@@ -608,13 +676,14 @@ def _compare_indexes_and_uniques(
else:
msg = []
if conn_obj.is_unique != metadata_obj.is_unique:
- msg.append(' unique=%r to unique=%r' % (
- conn_obj.is_unique, metadata_obj.is_unique
- ))
+ msg.append(
+ " unique=%r to unique=%r"
+ % (conn_obj.is_unique, metadata_obj.is_unique)
+ )
if conn_obj.sig != metadata_obj.sig:
- msg.append(' columns %r to %r' % (
- conn_obj.sig, metadata_obj.sig
- ))
+ msg.append(
+ " columns %r to %r" % (conn_obj.sig, metadata_obj.sig)
+ )
if msg:
obj_changed(conn_obj, metadata_obj, msg)
@@ -624,8 +693,10 @@ def _compare_indexes_and_uniques(
if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
continue
elif removed_name in doubled_constraints:
- if conn_obj.sig not in metadata_indexes_by_sig and \
- conn_obj.sig not in metadata_uniques_by_sig:
+ if (
+ conn_obj.sig not in metadata_indexes_by_sig
+ and conn_obj.sig not in metadata_uniques_by_sig
+ ):
conn_uq, conn_idx = doubled_constraints[removed_name]
obj_removed(conn_uq)
obj_removed(conn_idx)
@@ -639,40 +710,51 @@ def _compare_indexes_and_uniques(
def _correct_for_uq_duplicates_uix(
conn_unique_constraints,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes):
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+):
# dedupe unique indexes vs. constraints, since MySQL / Oracle
# doesn't really have unique constraints as a separate construct.
# but look in the metadata and try to maintain constructs
# that already seem to be defined one way or the other
# on that side. This logic was formerly local to MySQL dialect,
# generalized to Oracle and others. See #276
- metadata_uq_names = set([
- cons.name for cons in metadata_unique_constraints
- if cons.name is not None])
-
- unnamed_metadata_uqs = set([
- _uq_constraint_sig(cons).sig
- for cons in metadata_unique_constraints
- if cons.name is None
- ])
-
- metadata_ix_names = set([
- cons.name for cons in metadata_indexes if cons.unique])
+ metadata_uq_names = set(
+ [
+ cons.name
+ for cons in metadata_unique_constraints
+ if cons.name is not None
+ ]
+ )
+
+ unnamed_metadata_uqs = set(
+ [
+ _uq_constraint_sig(cons).sig
+ for cons in metadata_unique_constraints
+ if cons.name is None
+ ]
+ )
+
+ metadata_ix_names = set(
+ [cons.name for cons in metadata_indexes if cons.unique]
+ )
conn_ix_names = dict(
(cons.name, cons) for cons in conn_indexes if cons.unique
)
uqs_dupe_indexes = dict(
- (cons.name, cons) for cons in conn_unique_constraints
- if cons.info['duplicates_index']
+ (cons.name, cons)
+ for cons in conn_unique_constraints
+ if cons.info["duplicates_index"]
)
for overlap in uqs_dupe_indexes:
if overlap not in metadata_uq_names:
- if _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig \
- not in unnamed_metadata_uqs:
+ if (
+ _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig
+ not in unnamed_metadata_uqs
+ ):
conn_unique_constraints.discard(uqs_dupe_indexes[overlap])
elif overlap not in metadata_ix_names:
@@ -681,8 +763,14 @@ def _correct_for_uq_duplicates_uix(
@comparators.dispatch_for("column")
def _compare_nullable(
- autogen_context, alter_column_op, schema, tname, cname, conn_col,
- metadata_col):
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+):
# work around SQLAlchemy issue #3023
if metadata_col.primary_key:
@@ -694,57 +782,83 @@ def _compare_nullable(
if conn_col_nullable is not metadata_col_nullable:
alter_column_op.modify_nullable = metadata_col_nullable
- log.info("Detected %s on column '%s.%s'",
- "NULL" if metadata_col_nullable else "NOT NULL",
- tname,
- cname
- )
+ log.info(
+ "Detected %s on column '%s.%s'",
+ "NULL" if metadata_col_nullable else "NOT NULL",
+ tname,
+ cname,
+ )
@comparators.dispatch_for("column")
def _setup_autoincrement(
- autogen_context, alter_column_op, schema, tname, cname, conn_col,
- metadata_col):
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+):
if metadata_col.table._autoincrement_column is metadata_col:
- alter_column_op.kw['autoincrement'] = True
+ alter_column_op.kw["autoincrement"] = True
elif util.sqla_110 and metadata_col.autoincrement is True:
- alter_column_op.kw['autoincrement'] = True
+ alter_column_op.kw["autoincrement"] = True
elif metadata_col.autoincrement is False:
- alter_column_op.kw['autoincrement'] = False
+ alter_column_op.kw["autoincrement"] = False
@comparators.dispatch_for("column")
def _compare_type(
- autogen_context, alter_column_op, schema, tname, cname, conn_col,
- metadata_col):
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+):
conn_type = conn_col.type
alter_column_op.existing_type = conn_type
metadata_type = metadata_col.type
if conn_type._type_affinity is sqltypes.NullType:
- log.info("Couldn't determine database type "
- "for column '%s.%s'", tname, cname)
+ log.info(
+ "Couldn't determine database type " "for column '%s.%s'",
+ tname,
+ cname,
+ )
return
if metadata_type._type_affinity is sqltypes.NullType:
- log.info("Column '%s.%s' has no type within "
- "the model; can't compare", tname, cname)
+ log.info(
+ "Column '%s.%s' has no type within " "the model; can't compare",
+ tname,
+ cname,
+ )
return
isdiff = autogen_context.migration_context._compare_type(
- conn_col, metadata_col)
+ conn_col, metadata_col
+ )
if isdiff:
alter_column_op.modify_type = metadata_type
- log.info("Detected type change from %r to %r on '%s.%s'",
- conn_type, metadata_type, tname, cname
- )
+ log.info(
+ "Detected type change from %r to %r on '%s.%s'",
+ conn_type,
+ metadata_type,
+ tname,
+ cname,
+ )
-def _render_server_default_for_compare(metadata_default,
- metadata_col, autogen_context):
+def _render_server_default_for_compare(
+ metadata_default, metadata_col, autogen_context
+):
rendered = _user_defined_render(
- "server_default", metadata_default, autogen_context)
+ "server_default", metadata_default, autogen_context
+ )
if rendered is not False:
return rendered
@@ -752,8 +866,9 @@ def _render_server_default_for_compare(metadata_default,
if isinstance(metadata_default.arg, compat.string_types):
metadata_default = metadata_default.arg
else:
- metadata_default = str(metadata_default.arg.compile(
- dialect=autogen_context.dialect))
+ metadata_default = str(
+ metadata_default.arg.compile(dialect=autogen_context.dialect)
+ )
if isinstance(metadata_default, compat.string_types):
if metadata_col.type._type_affinity is sqltypes.String:
metadata_default = re.sub(r"^'|'$", "", metadata_default)
@@ -766,37 +881,49 @@ def _render_server_default_for_compare(metadata_default,
@comparators.dispatch_for("column")
def _compare_server_default(
- autogen_context, alter_column_op, schema, tname, cname,
- conn_col, metadata_col):
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+):
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
if conn_col_default is None and metadata_default is None:
return False
rendered_metadata_default = _render_server_default_for_compare(
- metadata_default, metadata_col, autogen_context)
+ metadata_default, metadata_col, autogen_context
+ )
- rendered_conn_default = conn_col.server_default.arg.text \
- if conn_col.server_default else None
+ rendered_conn_default = (
+ conn_col.server_default.arg.text if conn_col.server_default else None
+ )
alter_column_op.existing_server_default = conn_col_default
isdiff = autogen_context.migration_context._compare_server_default(
- conn_col, metadata_col,
+ conn_col,
+ metadata_col,
rendered_metadata_default,
- rendered_conn_default
+ rendered_conn_default,
)
if isdiff:
alter_column_op.modify_server_default = metadata_default
- log.info(
- "Detected server default on column '%s.%s'",
- tname, cname)
+ log.info("Detected server default on column '%s.%s'", tname, cname)
@comparators.dispatch_for("table")
def _compare_foreign_keys(
- autogen_context, modify_table_ops, schema, tname, conn_table,
- metadata_table):
+ autogen_context,
+ modify_table_ops,
+ schema,
+ tname,
+ conn_table,
+ metadata_table,
+):
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
@@ -805,22 +932,22 @@ def _compare_foreign_keys(
inspector = autogen_context.inspector
metadata_fks = set(
- fk for fk in metadata_table.constraints
+ fk
+ for fk in metadata_table.constraints
if isinstance(fk, sa_schema.ForeignKeyConstraint)
)
conn_fks = inspector.get_foreign_keys(tname, schema=schema)
- backend_reflects_fk_options = conn_fks and 'options' in conn_fks[0]
+ backend_reflects_fk_options = conn_fks and "options" in conn_fks[0]
conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
# give the dialect a chance to correct the FKs to match more
# closely
- autogen_context.migration_context.impl.\
- correct_for_autogen_foreignkeys(
- conn_fks, metadata_fks,
- )
+ autogen_context.migration_context.impl.correct_for_autogen_foreignkeys(
+ conn_fks, metadata_fks
+ )
metadata_fks = set(
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
@@ -832,12 +959,8 @@ def _compare_foreign_keys(
for fk in conn_fks
)
- conn_fks_by_sig = dict(
- (c.sig, c) for c in conn_fks
- )
- metadata_fks_by_sig = dict(
- (c.sig, c) for c in metadata_fks
- )
+ conn_fks_by_sig = dict((c.sig, c) for c in conn_fks)
+ metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks)
metadata_fks_by_name = dict(
(c.name, c) for c in metadata_fks if c.name is not None
@@ -848,8 +971,8 @@ def _compare_foreign_keys(
def _add_fk(obj, compare_to):
if autogen_context.run_filters(
- obj.const, obj.name, "foreign_key_constraint", False,
- compare_to):
+ obj.const, obj.name, "foreign_key_constraint", False, compare_to
+ ):
modify_table_ops.ops.append(
ops.CreateForeignKeyOp.from_constraint(const.const)
)
@@ -859,12 +982,13 @@ def _compare_foreign_keys(
", ".join(obj.source_columns),
", ".join(obj.target_columns),
"%s." % obj.source_schema if obj.source_schema else "",
- obj.source_table)
+ obj.source_table,
+ )
def _remove_fk(obj, compare_to):
if autogen_context.run_filters(
- obj.const, obj.name, "foreign_key_constraint", True,
- compare_to):
+ obj.const, obj.name, "foreign_key_constraint", True, compare_to
+ ):
modify_table_ops.ops.append(
ops.DropConstraintOp.from_constraint(obj.const)
)
@@ -873,7 +997,8 @@ def _compare_foreign_keys(
", ".join(obj.source_columns),
", ".join(obj.target_columns),
"%s." % obj.source_schema if obj.source_schema else "",
- obj.source_table)
+ obj.source_table,
+ )
# so far it appears we don't need to do this by name at all.
# SQLite doesn't preserve constraint names anyway
@@ -881,13 +1006,19 @@ def _compare_foreign_keys(
for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig):
const = conn_fks_by_sig[removed_sig]
if removed_sig not in metadata_fks_by_sig:
- compare_to = metadata_fks_by_name[const.name].const \
- if const.name in metadata_fks_by_name else None
+ compare_to = (
+ metadata_fks_by_name[const.name].const
+ if const.name in metadata_fks_by_name
+ else None
+ )
_remove_fk(const, compare_to)
for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig):
const = metadata_fks_by_sig[added_sig]
if added_sig not in conn_fks_by_sig:
- compare_to = conn_fks_by_name[const.name].const \
- if const.name in conn_fks_by_name else None
+ compare_to = (
+ conn_fks_by_name[const.name].const
+ if const.name in conn_fks_by_name
+ else None
+ )
_add_fk(const, compare_to)
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index 4fbe91f..573ee02 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -19,29 +19,35 @@ try:
return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
else:
return name
+
+
except ImportError:
+
def _render_gen_name(autogen_context, name):
return name
def _indent(text):
- text = re.compile(r'^', re.M).sub(" ", text).strip()
- text = re.compile(r' +$', re.M).sub("", text)
+ text = re.compile(r"^", re.M).sub(" ", text).strip()
+ text = re.compile(r" +$", re.M).sub("", text)
return text
def _render_python_into_templatevars(
- autogen_context, migration_script, template_args):
+ autogen_context, migration_script, template_args
+):
imports = autogen_context.imports
for upgrade_ops, downgrade_ops in zip(
- migration_script.upgrade_ops_list,
- migration_script.downgrade_ops_list):
+ migration_script.upgrade_ops_list, migration_script.downgrade_ops_list
+ ):
template_args[upgrade_ops.upgrade_token] = _indent(
- _render_cmd_body(upgrade_ops, autogen_context))
+ _render_cmd_body(upgrade_ops, autogen_context)
+ )
template_args[downgrade_ops.downgrade_token] = _indent(
- _render_cmd_body(downgrade_ops, autogen_context))
- template_args['imports'] = "\n".join(sorted(imports))
+ _render_cmd_body(downgrade_ops, autogen_context)
+ )
+ template_args["imports"] = "\n".join(sorted(imports))
default_renderers = renderers = util.Dispatcher()
@@ -83,7 +89,7 @@ def render_op_text(autogen_context, op):
@renderers.dispatch_for(ops.ModifyTableOps)
def _render_modify_table(autogen_context, op):
opts = autogen_context.opts
- render_as_batch = opts.get('render_as_batch', False)
+ render_as_batch = opts.get("render_as_batch", False)
if op.ops:
lines = []
@@ -104,33 +110,39 @@ def _render_modify_table(autogen_context, op):
return lines
else:
- return [
- "pass"
- ]
+ return ["pass"]
@renderers.dispatch_for(ops.CreateTableOp)
def _add_table(autogen_context, op):
table = op.to_table()
- args = [col for col in
- [_render_column(col, autogen_context) for col in table.columns]
- if col] + \
- sorted([rcons for rcons in
- [_render_constraint(cons, autogen_context) for cons in
- table.constraints]
- if rcons is not None
- ])
+ args = [
+ col
+ for col in [
+ _render_column(col, autogen_context) for col in table.columns
+ ]
+ if col
+ ] + sorted(
+ [
+ rcons
+ for rcons in [
+ _render_constraint(cons, autogen_context)
+ for cons in table.constraints
+ ]
+ if rcons is not None
+ ]
+ )
if len(args) > MAX_PYTHON_ARGS:
- args = '*[' + ',\n'.join(args) + ']'
+ args = "*[" + ",\n".join(args) + "]"
else:
- args = ',\n'.join(args)
+ args = ",\n".join(args)
text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % {
- 'tablename': _ident(op.table_name),
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'args': args,
+ "tablename": _ident(op.table_name),
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "args": args,
}
if op.schema:
text += ",\nschema=%r" % _ident(op.schema)
@@ -144,7 +156,7 @@ def _add_table(autogen_context, op):
def _drop_table(autogen_context, op):
text = "%(prefix)sdrop_table(%(tname)r" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
- "tname": _ident(op.table_name)
+ "tname": _ident(op.table_name),
}
if op.schema:
text += ", schema=%r" % _ident(op.schema)
@@ -159,28 +171,39 @@ def _add_index(autogen_context, op):
has_batch = autogen_context._has_batch
if has_batch:
- tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\
+ tmpl = (
+ "%(prefix)screate_index(%(name)r, [%(columns)s], "
"unique=%(unique)r%(kwargs)s)"
+ )
else:
- tmpl = "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], "\
+ tmpl = (
+ "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], "
"unique=%(unique)r%(schema)s%(kwargs)s)"
+ )
text = tmpl % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'name': _render_gen_name(autogen_context, index.name),
- 'table': _ident(index.table.name),
- 'columns': ", ".join(
- _get_index_rendered_expressions(index, autogen_context)),
- 'unique': index.unique or False,
- 'schema': (", schema=%r" % _ident(index.table.schema))
- if index.table.schema else '',
- 'kwargs': (
- ', ' +
- ', '.join(
- ["%s=%s" %
- (key, _render_potential_expr(val, autogen_context))
- for key, val in index.kwargs.items()]))
- if len(index.kwargs) else ''
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "name": _render_gen_name(autogen_context, index.name),
+ "table": _ident(index.table.name),
+ "columns": ", ".join(
+ _get_index_rendered_expressions(index, autogen_context)
+ ),
+ "unique": index.unique or False,
+ "schema": (", schema=%r" % _ident(index.table.schema))
+ if index.table.schema
+ else "",
+ "kwargs": (
+ ", "
+ + ", ".join(
+ [
+ "%s=%s"
+ % (key, _render_potential_expr(val, autogen_context))
+ for key, val in index.kwargs.items()
+ ]
+ )
+ )
+ if len(index.kwargs)
+ else "",
}
return text
@@ -192,15 +215,16 @@ def _drop_index(autogen_context, op):
if has_batch:
tmpl = "%(prefix)sdrop_index(%(name)r)"
else:
- tmpl = "%(prefix)sdrop_index(%(name)r, "\
+ tmpl = (
+ "%(prefix)sdrop_index(%(name)r, "
"table_name=%(table_name)r%(schema)s)"
+ )
text = tmpl % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'name': _render_gen_name(autogen_context, op.index_name),
- 'table_name': _ident(op.table_name),
- 'schema': ((", schema=%r" % _ident(op.schema))
- if op.schema else '')
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "name": _render_gen_name(autogen_context, op.index_name),
+ "table_name": _ident(op.table_name),
+ "schema": ((", schema=%r" % _ident(op.schema)) if op.schema else ""),
}
return text
@@ -213,30 +237,28 @@ def _add_unique_constraint(autogen_context, op):
@renderers.dispatch_for(ops.CreateForeignKeyOp)
def _add_fk_constraint(autogen_context, op):
- args = [
- repr(
- _render_gen_name(autogen_context, op.constraint_name)),
- ]
+ args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
- args.append(
- repr(_ident(op.source_table))
- )
+ args.append(repr(_ident(op.source_table)))
args.extend(
[
repr(_ident(op.referent_table)),
repr([_ident(col) for col in op.local_cols]),
- repr([_ident(col) for col in op.remote_cols])
+ repr([_ident(col) for col in op.remote_cols]),
]
)
kwargs = [
- 'referent_schema',
- 'onupdate', 'ondelete', 'initially',
- 'deferrable', 'use_alter'
+ "referent_schema",
+ "onupdate",
+ "ondelete",
+ "initially",
+ "deferrable",
+ "use_alter",
]
if not autogen_context._has_batch:
- kwargs.insert(0, 'source_schema')
+ kwargs.insert(0, "source_schema")
for k in kwargs:
if k in op.kw:
@@ -245,8 +267,8 @@ def _add_fk_constraint(autogen_context, op):
args.append("%s=%r" % (k, value))
return "%(prefix)screate_foreign_key(%(args)s)" % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'args': ", ".join(args)
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "args": ", ".join(args),
}
@@ -264,20 +286,19 @@ def _add_check_constraint(constraint, autogen_context):
def _drop_constraint(autogen_context, op):
if autogen_context._has_batch:
- template = "%(prefix)sdrop_constraint"\
- "(%(name)r, type_=%(type)r)"
+ template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)"
else:
- template = "%(prefix)sdrop_constraint"\
+ template = (
+ "%(prefix)sdrop_constraint"
"(%(name)r, '%(table_name)s'%(schema)s, type_=%(type)r)"
+ )
text = template % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'name': _render_gen_name(
- autogen_context, op.constraint_name),
- 'table_name': _ident(op.table_name),
- 'type': op.constraint_type,
- 'schema': (", schema=%r" % _ident(op.schema))
- if op.schema else '',
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "name": _render_gen_name(autogen_context, op.constraint_name),
+ "table_name": _ident(op.table_name),
+ "type": op.constraint_type,
+ "schema": (", schema=%r" % _ident(op.schema)) if op.schema else "",
}
return text
@@ -297,7 +318,7 @@ def _add_column(autogen_context, op):
"prefix": _alembic_autogenerate_prefix(autogen_context),
"tname": tname,
"column": _render_column(column, autogen_context),
- "schema": schema
+ "schema": schema,
}
return text
@@ -319,7 +340,7 @@ def _drop_column(autogen_context, op):
"prefix": _alembic_autogenerate_prefix(autogen_context),
"tname": _ident(tname),
"cname": _ident(column_name),
- "schema": _ident(schema)
+ "schema": _ident(schema),
}
return text
@@ -332,7 +353,7 @@ def _alter_column(autogen_context, op):
server_default = op.modify_server_default
type_ = op.modify_type
nullable = op.modify_nullable
- autoincrement = op.kw.get('autoincrement', None)
+ autoincrement = op.kw.get("autoincrement", None)
existing_type = op.existing_type
existing_nullable = op.existing_nullable
existing_server_default = op.existing_server_default
@@ -346,37 +367,32 @@ def _alter_column(autogen_context, op):
template = "%(prefix)salter_column(%(tname)r, %(cname)r"
text = template % {
- 'prefix': _alembic_autogenerate_prefix(
- autogen_context),
- 'tname': tname,
- 'cname': cname}
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "tname": tname,
+ "cname": cname,
+ }
if existing_type is not None:
text += ",\n%sexisting_type=%s" % (
indent,
- _repr_type(existing_type, autogen_context))
+ _repr_type(existing_type, autogen_context),
+ )
if server_default is not False:
- rendered = _render_server_default(
- server_default, autogen_context)
+ rendered = _render_server_default(server_default, autogen_context)
text += ",\n%sserver_default=%s" % (indent, rendered)
if type_ is not None:
- text += ",\n%stype_=%s" % (indent,
- _repr_type(type_, autogen_context))
+ text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context))
if nullable is not None:
- text += ",\n%snullable=%r" % (
- indent, nullable,)
+ text += ",\n%snullable=%r" % (indent, nullable)
if nullable is None and existing_nullable is not None:
- text += ",\n%sexisting_nullable=%r" % (
- indent, existing_nullable)
+ text += ",\n%sexisting_nullable=%r" % (indent, existing_nullable)
if autoincrement is not None:
- text += ",\n%sautoincrement=%r" % (
- indent, autoincrement)
+ text += ",\n%sautoincrement=%r" % (indent, autoincrement)
if server_default is False and existing_server_default:
rendered = _render_server_default(
- existing_server_default,
- autogen_context)
- text += ",\n%sexisting_server_default=%s" % (
- indent, rendered)
+ existing_server_default, autogen_context
+ )
+ text += ",\n%sexisting_server_default=%s" % (indent, rendered)
if schema and not autogen_context._has_batch:
text += ",\n%sschema=%r" % (indent, schema)
text += ")"
@@ -384,7 +400,6 @@ def _alter_column(autogen_context, op):
class _f_name(object):
-
def __init__(self, prefix, name):
self.prefix = prefix
self.name = name
@@ -410,7 +425,7 @@ def _ident(name):
# u'' literals only when py2k + SQLA 0.9, in particular
# makes unit tests testing code generation very difficult
try:
- return name.encode('ascii')
+ return name.encode("ascii")
except UnicodeError:
return compat.text_type(name)
else:
@@ -421,8 +436,9 @@ def _ident(name):
def _render_potential_expr(value, autogen_context, wrap_in_text=True):
if isinstance(value, sql.ClauseElement):
- compile_kw = dict(compile_kwargs={
- 'literal_binds': True, "include_table": False})
+ compile_kw = dict(
+ compile_kwargs={"literal_binds": True, "include_table": False}
+ )
if wrap_in_text:
template = "%(prefix)stext(%(sql)r)"
@@ -432,9 +448,8 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
return template % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"sql": compat.text_type(
- value.compile(dialect=autogen_context.dialect,
- **compile_kw)
- )
+ value.compile(dialect=autogen_context.dialect, **compile_kw)
+ ),
}
else:
@@ -442,10 +457,12 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
def _get_index_rendered_expressions(idx, autogen_context):
- return [repr(_ident(getattr(exp, "name", None)))
- if isinstance(exp, sa_schema.Column)
- else _render_potential_expr(exp, autogen_context)
- for exp in idx.expressions]
+ return [
+ repr(_ident(getattr(exp, "name", None)))
+ if isinstance(exp, sa_schema.Column)
+ else _render_potential_expr(exp, autogen_context)
+ for exp in idx.expressions
+ ]
def _uq_constraint(constraint, autogen_context, alter):
@@ -461,32 +478,30 @@ def _uq_constraint(constraint, autogen_context, alter):
opts.append(("schema", _ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
- ("name",
- _render_gen_name(autogen_context, constraint.name)))
+ ("name", _render_gen_name(autogen_context, constraint.name))
+ )
if alter:
- args = [
- repr(_render_gen_name(
- autogen_context, constraint.name))]
+ args = [repr(_render_gen_name(autogen_context, constraint.name))]
if not has_batch:
args += [repr(_ident(constraint.table.name))]
args.append(repr([_ident(col.name) for col in constraint.columns]))
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)screate_unique_constraint(%(args)s)" % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'args': ", ".join(args)
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "args": ", ".join(args),
}
else:
args = [repr(_ident(col.name)) for col in constraint.columns]
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)sUniqueConstraint(%(args)s)" % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
- "args": ", ".join(args)
+ "args": ", ".join(args),
}
def _user_autogenerate_prefix(autogen_context, target):
- prefix = autogen_context.opts['user_module_prefix']
+ prefix = autogen_context.opts["user_module_prefix"]
if prefix is None:
return "%s." % target.__module__
else:
@@ -494,19 +509,19 @@ def _user_autogenerate_prefix(autogen_context, target):
def _sqlalchemy_autogenerate_prefix(autogen_context):
- return autogen_context.opts['sqlalchemy_module_prefix'] or ''
+ return autogen_context.opts["sqlalchemy_module_prefix"] or ""
def _alembic_autogenerate_prefix(autogen_context):
if autogen_context._has_batch:
- return 'batch_op.'
+ return "batch_op."
else:
- return autogen_context.opts['alembic_module_prefix'] or ''
+ return autogen_context.opts["alembic_module_prefix"] or ""
def _user_defined_render(type_, object_, autogen_context):
- if 'render_item' in autogen_context.opts:
- render = autogen_context.opts['render_item']
+ if "render_item" in autogen_context.opts:
+ render = autogen_context.opts["render_item"]
if render:
rendered = render(type_, object_, autogen_context)
if rendered is not False:
@@ -527,8 +542,10 @@ def _render_column(column, autogen_context):
if rendered:
opts.append(("server_default", rendered))
- if column.autoincrement is not None and \
- column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT:
+ if (
+ column.autoincrement is not None
+ and column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT
+ ):
opts.append(("autoincrement", column.autoincrement))
if column.nullable is not None:
@@ -539,10 +556,10 @@ def _render_column(column, autogen_context):
# TODO: for non-ascii colname, assign a "key"
return "%(prefix)sColumn(%(name)r, %(type)s, %(kw)s)" % {
- 'prefix': _sqlalchemy_autogenerate_prefix(autogen_context),
- 'name': _ident(column.name),
- 'type': _repr_type(column.type, autogen_context),
- 'kw': ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
+ "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
+ "name": _ident(column.name),
+ "type": _repr_type(column.type, autogen_context),
+ "kw": ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]),
}
@@ -568,9 +585,10 @@ def _repr_type(type_, autogen_context):
if rendered is not False:
return rendered
- if hasattr(autogen_context.migration_context, 'impl'):
+ if hasattr(autogen_context.migration_context, "impl"):
impl_rt = autogen_context.migration_context.impl.render_type(
- type_, autogen_context)
+ type_, autogen_context
+ )
else:
impl_rt = None
@@ -587,8 +605,8 @@ def _repr_type(type_, autogen_context):
elif impl_rt:
return impl_rt
elif mod.startswith("sqlalchemy."):
- if '_render_%s_type' % type_.__visit_name__ in globals():
- fn = globals()['_render_%s_type' % type_.__visit_name__]
+ if "_render_%s_type" % type_.__visit_name__ in globals():
+ fn = globals()["_render_%s_type" % type_.__visit_name__]
return fn(type_, autogen_context)
else:
prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
@@ -600,12 +618,13 @@ def _repr_type(type_, autogen_context):
def _render_ARRAY_type(type_, autogen_context):
return _render_type_w_subtype(
- type_, autogen_context, 'item_type', r'(.+?\()'
+ type_, autogen_context, "item_type", r"(.+?\()"
)
def _render_type_w_subtype(
- type_, autogen_context, attrname, regexp, prefix=None):
+ type_, autogen_context, attrname, regexp, prefix=None
+):
outer_repr = repr(type_)
inner_type = getattr(type_, attrname, None)
if inner_type is None:
@@ -613,11 +632,9 @@ def _render_type_w_subtype(
inner_repr = repr(inner_type)
- inner_repr = re.sub(r'([\(\)])', r'\\\1', inner_repr)
+ inner_repr = re.sub(r"([\(\)])", r"\\\1", inner_repr)
sub_type = _repr_type(getattr(type_, attrname), autogen_context)
- outer_type = re.sub(
- regexp + inner_repr,
- r"\1%s" % sub_type, outer_repr)
+ outer_type = re.sub(regexp + inner_repr, r"\1%s" % sub_type, outer_repr)
if prefix:
return "%s%s" % (prefix, outer_type)
@@ -632,6 +649,7 @@ def _render_type_w_subtype(
else:
return None
+
_constraint_renderers = util.Dispatcher()
@@ -656,13 +674,14 @@ def _render_primary_key(constraint, autogen_context):
opts = []
if constraint.name:
- opts.append(("name", repr(
- _render_gen_name(autogen_context, constraint.name))))
+ opts.append(
+ ("name", repr(_render_gen_name(autogen_context, constraint.name)))
+ )
return "%(prefix)sPrimaryKeyConstraint(%(args)s)" % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"args": ", ".join(
- [repr(c.name) for c in constraint.columns] +
- ["%s=%s" % (kwname, val) for kwname, val in opts]
+ [repr(c.name) for c in constraint.columns]
+ + ["%s=%s" % (kwname, val) for kwname, val in opts]
),
}
@@ -681,8 +700,11 @@ def _fk_colspec(fk, metadata_schema):
else:
table_fullname = ".".join(tokens[0:-1])
- if not fk.link_to_name and \
- fk.parent is not None and fk.parent.table is not None:
+ if (
+ not fk.link_to_name
+ and fk.parent is not None
+ and fk.parent.table is not None
+ ):
# try to resolve the remote table in order to adjust for column.key.
# the FK constraint needs to be rendered in terms of the column
# name.
@@ -719,23 +741,30 @@ def _render_foreign_key(constraint, autogen_context):
opts = []
if constraint.name:
- opts.append(("name", repr(
- _render_gen_name(autogen_context, constraint.name))))
+ opts.append(
+ ("name", repr(_render_gen_name(autogen_context, constraint.name)))
+ )
_populate_render_fk_opts(constraint, opts)
apply_metadata_schema = constraint.parent.metadata.schema
- return "%(prefix)sForeignKeyConstraint([%(cols)s], "\
- "[%(refcols)s], %(args)s)" % {
+ return (
+ "%(prefix)sForeignKeyConstraint([%(cols)s], "
+ "[%(refcols)s], %(args)s)"
+ % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"cols": ", ".join(
- "%r" % _ident(f.parent.name) for f in constraint.elements),
- "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema))
- for f in constraint.elements),
+ "%r" % _ident(f.parent.name) for f in constraint.elements
+ ),
+ "refcols": ", ".join(
+ repr(_fk_colspec(f, apply_metadata_schema))
+ for f in constraint.elements
+ ),
"args": ", ".join(
- ["%s=%s" % (kwname, val) for kwname, val in opts]
+ ["%s=%s" % (kwname, val) for kwname, val in opts]
),
}
+ )
@_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
@@ -757,27 +786,25 @@ def _render_check_constraint(constraint, autogen_context):
# a parent type which is probably in the Table already.
# ideally SQLAlchemy would give us more of a first class
# way to detect this.
- if constraint._create_rule and \
- hasattr(constraint._create_rule, 'target') and \
- isinstance(constraint._create_rule.target,
- sqltypes.TypeEngine):
+ if (
+ constraint._create_rule
+ and hasattr(constraint._create_rule, "target")
+ and isinstance(constraint._create_rule.target, sqltypes.TypeEngine)
+ ):
return None
opts = []
if constraint.name:
opts.append(
- (
- "name",
- repr(
- _render_gen_name(
- autogen_context, constraint.name))
- )
+ ("name", repr(_render_gen_name(autogen_context, constraint.name)))
)
return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
- "opts": ", " + (", ".join("%s=%s" % (k, v)
- for k, v in opts)) if opts else "",
+ "opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
+ if opts
+ else "",
"sqltext": _render_potential_expr(
- constraint.sqltext, autogen_context, wrap_in_text=False)
+ constraint.sqltext, autogen_context, wrap_in_text=False
+ ),
}
@@ -788,7 +815,7 @@ def _execute_sql(autogen_context, op):
"Autogenerate rendering of SQL Expression language constructs "
"not supported here; please use a plain SQL string"
)
- return 'op.execute(%r)' % op.sqltext
+ return "op.execute(%r)" % op.sqltext
renderers = default_renderers.branch()
diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py
index 941bd4b..1e9522b 100644
--- a/alembic/autogenerate/rewriter.py
+++ b/alembic/autogenerate/rewriter.py
@@ -95,7 +95,8 @@ class Rewriter(object):
yield directive
else:
for r_directive in util.to_list(
- _rewriter(context, revision, directive)):
+ _rewriter(context, revision, directive)
+ ):
yield r_directive
def __call__(self, context, revision, directives):
@@ -110,17 +111,20 @@ class Rewriter(object):
ret = self._traverse_for(context, revision, directive.upgrade_ops)
if len(ret) != 1:
raise ValueError(
- "Can only return single object for UpgradeOps traverse")
+ "Can only return single object for UpgradeOps traverse"
+ )
upgrade_ops_list.append(ret[0])
directive.upgrade_ops = upgrade_ops_list
downgrade_ops_list = []
for downgrade_ops in directive.downgrade_ops_list:
ret = self._traverse_for(
- context, revision, directive.downgrade_ops)
+ context, revision, directive.downgrade_ops
+ )
if len(ret) != 1:
raise ValueError(
- "Can only return single object for DowngradeOps traverse")
+ "Can only return single object for DowngradeOps traverse"
+ )
downgrade_ops_list.append(ret[0])
directive.downgrade_ops = downgrade_ops_list
diff --git a/alembic/command.py b/alembic/command.py
index cd61fd1..20027b4 100644
--- a/alembic/command.py
+++ b/alembic/command.py
@@ -15,10 +15,9 @@ def list_templates(config):
config.print_stdout("Available templates:\n")
for tempname in os.listdir(config.get_template_directory()):
- with open(os.path.join(
- config.get_template_directory(),
- tempname,
- 'README')) as readme:
+ with open(
+ os.path.join(config.get_template_directory(), tempname, "README")
+ ) as readme:
synopsis = next(readme)
config.print_stdout("%s - %s", tempname, synopsis)
@@ -26,7 +25,7 @@ def list_templates(config):
config.print_stdout("\n alembic init --template generic ./scripts")
-def init(config, directory, template='generic'):
+def init(config, directory, template="generic"):
"""Initialize a new scripts directory.
:param config: a :class:`.Config` object.
@@ -41,48 +40,58 @@ def init(config, directory, template='generic'):
if os.access(directory, os.F_OK):
raise util.CommandError("Directory %s already exists" % directory)
- template_dir = os.path.join(config.get_template_directory(),
- template)
+ template_dir = os.path.join(config.get_template_directory(), template)
if not os.access(template_dir, os.F_OK):
raise util.CommandError("No such template %r" % template)
- util.status("Creating directory %s" % os.path.abspath(directory),
- os.makedirs, directory)
+ util.status(
+ "Creating directory %s" % os.path.abspath(directory),
+ os.makedirs,
+ directory,
+ )
- versions = os.path.join(directory, 'versions')
- util.status("Creating directory %s" % os.path.abspath(versions),
- os.makedirs, versions)
+ versions = os.path.join(directory, "versions")
+ util.status(
+ "Creating directory %s" % os.path.abspath(versions),
+ os.makedirs,
+ versions,
+ )
script = ScriptDirectory(directory)
for file_ in os.listdir(template_dir):
file_path = os.path.join(template_dir, file_)
- if file_ == 'alembic.ini.mako':
+ if file_ == "alembic.ini.mako":
config_file = os.path.abspath(config.config_file_name)
if os.access(config_file, os.F_OK):
util.msg("File %s already exists, skipping" % config_file)
else:
script._generate_template(
- file_path,
- config_file,
- script_location=directory
+ file_path, config_file, script_location=directory
)
elif os.path.isfile(file_path):
output_file = os.path.join(directory, file_)
- script._copy_file(
- file_path,
- output_file
- )
+ script._copy_file(file_path, output_file)
- util.msg("Please edit configuration/connection/logging "
- "settings in %r before proceeding." % config_file)
+ util.msg(
+ "Please edit configuration/connection/logging "
+ "settings in %r before proceeding." % config_file
+ )
def revision(
- config, message=None, autogenerate=False, sql=False,
- head="head", splice=False, branch_label=None,
- version_path=None, rev_id=None, depends_on=None,
- process_revision_directives=None):
+ config,
+ message=None,
+ autogenerate=False,
+ sql=False,
+ head="head",
+ splice=False,
+ branch_label=None,
+ version_path=None,
+ rev_id=None,
+ depends_on=None,
+ process_revision_directives=None,
+):
"""Create a new revision file.
:param config: a :class:`.Config` object.
@@ -134,35 +143,46 @@ def revision(
command_args = dict(
message=message,
autogenerate=autogenerate,
- sql=sql, head=head, splice=splice, branch_label=branch_label,
- version_path=version_path, rev_id=rev_id, depends_on=depends_on
+ sql=sql,
+ head=head,
+ splice=splice,
+ branch_label=branch_label,
+ version_path=version_path,
+ rev_id=rev_id,
+ depends_on=depends_on,
)
revision_context = autogen.RevisionContext(
- config, script_directory, command_args,
- process_revision_directives=process_revision_directives)
-
- environment = util.asbool(
- config.get_main_option("revision_environment")
+ config,
+ script_directory,
+ command_args,
+ process_revision_directives=process_revision_directives,
)
+ environment = util.asbool(config.get_main_option("revision_environment"))
+
if autogenerate:
environment = True
if sql:
raise util.CommandError(
- "Using --sql with --autogenerate does not make any sense")
+ "Using --sql with --autogenerate does not make any sense"
+ )
def retrieve_migrations(rev, context):
revision_context.run_autogenerate(rev, context)
return []
+
elif environment:
+
def retrieve_migrations(rev, context):
revision_context.run_no_autogenerate(rev, context)
return []
+
elif sql:
raise util.CommandError(
"Using --sql with the revision command when "
- "revision_environment is not configured does not make any sense")
+ "revision_environment is not configured does not make any sense"
+ )
if environment:
with EnvironmentContext(
@@ -171,14 +191,11 @@ def revision(
fn=retrieve_migrations,
as_sql=sql,
template_args=revision_context.template_args,
- revision_context=revision_context
+ revision_context=revision_context,
):
script_directory.run_env()
- scripts = [
- script for script in
- revision_context.generate_scripts()
- ]
+ scripts = [script for script in revision_context.generate_scripts()]
if len(scripts) == 1:
return scripts[0]
else:
@@ -207,13 +224,17 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None):
script = ScriptDirectory.from_config(config)
template_args = {
- 'config': config # Let templates use config for
- # e.g. multiple databases
+ "config": config # Let templates use config for
+ # e.g. multiple databases
}
return script.generate_revision(
- rev_id or util.rev_id(), message, refresh=True,
- head=revisions, branch_labels=branch_label,
- **template_args)
+ rev_id or util.rev_id(),
+ message,
+ refresh=True,
+ head=revisions,
+ branch_labels=branch_label,
+ **template_args
+ )
def upgrade(config, revision, sql=False, tag=None):
@@ -237,7 +258,7 @@ def upgrade(config, revision, sql=False, tag=None):
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
- starting_rev, revision = revision.split(':', 2)
+ starting_rev, revision = revision.split(":", 2)
def upgrade(rev, context):
return script._upgrade_revs(revision, rev)
@@ -249,7 +270,7 @@ def upgrade(config, revision, sql=False, tag=None):
as_sql=sql,
starting_rev=starting_rev,
destination_rev=revision,
- tag=tag
+ tag=tag,
):
script.run_env()
@@ -274,10 +295,11 @@ def downgrade(config, revision, sql=False, tag=None):
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
- starting_rev, revision = revision.split(':', 2)
+ starting_rev, revision = revision.split(":", 2)
elif sql:
raise util.CommandError(
- "downgrade with --sql requires <fromrev>:<torev>")
+ "downgrade with --sql requires <fromrev>:<torev>"
+ )
def downgrade(rev, context):
return script._downgrade_revs(revision, rev)
@@ -289,7 +311,7 @@ def downgrade(config, revision, sql=False, tag=None):
as_sql=sql,
starting_rev=starting_rev,
destination_rev=revision,
- tag=tag
+ tag=tag,
):
script.run_env()
@@ -306,15 +328,13 @@ def show(config, rev):
script = ScriptDirectory.from_config(config)
if rev == "current":
+
def show_current(rev, context):
for sc in script.get_revisions(rev):
config.print_stdout(sc.log_entry)
return []
- with EnvironmentContext(
- config,
- script,
- fn=show_current
- ):
+
+ with EnvironmentContext(config, script, fn=show_current):
script.run_env()
else:
for sc in script.get_revisions(rev):
@@ -340,44 +360,45 @@ def history(config, rev_range=None, verbose=False, indicate_current=False):
if rev_range is not None:
if ":" not in rev_range:
raise util.CommandError(
- "History range requires [start]:[end], "
- "[start]:, or :[end]")
+ "History range requires [start]:[end], " "[start]:, or :[end]"
+ )
base, head = rev_range.strip().split(":")
else:
base = head = None
- environment = util.asbool(
- config.get_main_option("revision_environment")
- ) or indicate_current
+ environment = (
+ util.asbool(config.get_main_option("revision_environment"))
+ or indicate_current
+ )
def _display_history(config, script, base, head, currents=()):
for sc in script.walk_revisions(
- base=base or "base",
- head=head or "heads"):
+ base=base or "base", head=head or "heads"
+ ):
if indicate_current:
sc._db_current_indicator = sc.revision in currents
config.print_stdout(
sc.cmd_format(
- verbose=verbose, include_branches=True,
- include_doc=True, include_parents=True))
+ verbose=verbose,
+ include_branches=True,
+ include_doc=True,
+ include_parents=True,
+ )
+ )
def _display_history_w_current(config, script, base, head):
def _display_current_history(rev, context):
- if head == 'current':
+ if head == "current":
_display_history(config, script, base, rev, rev)
- elif base == 'current':
+ elif base == "current":
_display_history(config, script, rev, head, rev)
else:
_display_history(config, script, base, head, rev)
return []
- with EnvironmentContext(
- config,
- script,
- fn=_display_current_history
- ):
+ with EnvironmentContext(config, script, fn=_display_current_history):
script.run_env()
if base == "current" or head == "current" or environment:
@@ -406,7 +427,9 @@ def heads(config, verbose=False, resolve_dependencies=False):
for rev in heads:
config.print_stdout(
rev.cmd_format(
- verbose, include_branches=True, tree_indicators=False))
+ verbose, include_branches=True, tree_indicators=False
+ )
+ )
def branches(config, verbose=False):
@@ -424,13 +447,17 @@ def branches(config, verbose=False):
"%s\n%s\n",
sc.cmd_format(verbose, include_branches=True),
"\n".join(
- "%s -> %s" % (
+ "%s -> %s"
+ % (
" " * len(str(sc.revision)),
rev_obj.cmd_format(
- False, include_branches=True, include_doc=verbose)
- ) for rev_obj in
- (script.get_revision(rev) for rev in sc.nextrev)
- )
+ False, include_branches=True, include_doc=verbose
+ ),
+ )
+ for rev_obj in (
+ script.get_revision(rev) for rev in sc.nextrev
+ )
+ ),
)
@@ -454,18 +481,14 @@ def current(config, verbose=False, head_only=False):
if verbose:
config.print_stdout(
"Current revision(s) for %s:",
- util.obfuscate_url_pw(context.connection.engine.url)
+ util.obfuscate_url_pw(context.connection.engine.url),
)
for rev in script.get_all_current(rev):
config.print_stdout(rev.cmd_format(verbose))
return []
- with EnvironmentContext(
- config,
- script,
- fn=display_version
- ):
+ with EnvironmentContext(config, script, fn=display_version):
script.run_env()
@@ -491,7 +514,7 @@ def stamp(config, revision, sql=False, tag=None):
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
- starting_rev, revision = revision.split(':', 2)
+ starting_rev, revision = revision.split(":", 2)
def do_stamp(rev, context):
return script._stamp_revs(revision, rev)
@@ -503,7 +526,7 @@ def stamp(config, revision, sql=False, tag=None):
as_sql=sql,
destination_rev=revision,
starting_rev=starting_rev,
- tag=tag
+ tag=tag,
):
script.run_env()
@@ -520,23 +543,21 @@ def edit(config, rev):
script = ScriptDirectory.from_config(config)
if rev == "current":
+
def edit_current(rev, context):
if not rev:
raise util.CommandError("No current revisions")
for sc in script.get_revisions(rev):
util.edit(sc.path)
return []
- with EnvironmentContext(
- config,
- script,
- fn=edit_current
- ):
+
+ with EnvironmentContext(config, script, fn=edit_current):
script.run_env()
else:
revs = script.get_revisions(rev)
if not revs:
raise util.CommandError(
- "No revision files indicated by symbol '%s'" % rev)
+ "No revision files indicated by symbol '%s'" % rev
+ )
for sc in revs:
util.edit(sc.path)
-
diff --git a/alembic/config.py b/alembic/config.py
index 5856099..915091c 100644
--- a/alembic/config.py
+++ b/alembic/config.py
@@ -90,9 +90,16 @@ class Config(object):
"""
- def __init__(self, file_=None, ini_section='alembic', output_buffer=None,
- stdout=sys.stdout, cmd_opts=None,
- config_args=util.immutabledict(), attributes=None):
+ def __init__(
+ self,
+ file_=None,
+ ini_section="alembic",
+ output_buffer=None,
+ stdout=sys.stdout,
+ cmd_opts=None,
+ config_args=util.immutabledict(),
+ attributes=None,
+ ):
"""Construct a new :class:`.Config`
"""
@@ -167,15 +174,11 @@ class Config(object):
"""
if arg:
- output = (compat.text_type(text) % arg)
+ output = compat.text_type(text) % arg
else:
output = compat.text_type(text)
- util.write_outstream(
- self.stdout,
- output,
- "\n"
- )
+ util.write_outstream(self.stdout, output, "\n")
@util.memoized_property
def file_config(self):
@@ -192,7 +195,7 @@ class Config(object):
here = os.path.abspath(os.path.dirname(self.config_file_name))
else:
here = ""
- self.config_args['here'] = here
+ self.config_args["here"] = here
file_config = SafeConfigParser(self.config_args)
if self.config_file_name:
file_config.read([self.config_file_name])
@@ -207,7 +210,7 @@ class Config(object):
commands.
"""
- return os.path.join(package_dir, 'templates')
+ return os.path.join(package_dir, "templates")
def get_section(self, name):
"""Return all the configuration options from a given .ini file section
@@ -265,9 +268,10 @@ class Config(object):
"""
if not self.file_config.has_section(section):
- raise util.CommandError("No config file %r found, or file has no "
- "'[%s]' section" %
- (self.config_file_name, section))
+ raise util.CommandError(
+ "No config file %r found, or file has no "
+ "'[%s]' section" % (self.config_file_name, section)
+ )
if self.file_config.has_option(section, name):
return self.file_config.get(section, name)
else:
@@ -285,140 +289,144 @@ class Config(object):
class CommandLine(object):
-
def __init__(self, prog=None):
self._generate_args(prog)
def _generate_args(self, prog):
def add_options(parser, positional, kwargs):
kwargs_opts = {
- 'template': (
- "-t", "--template",
+ "template": (
+ "-t",
+ "--template",
dict(
- default='generic',
+ default="generic",
type=str,
- help="Setup template for use with 'init'"
- )
+ help="Setup template for use with 'init'",
+ ),
),
- 'message': (
- "-m", "--message",
+ "message": (
+ "-m",
+ "--message",
dict(
- type=str,
- help="Message string to use with 'revision'")
+ type=str, help="Message string to use with 'revision'"
+ ),
),
- 'sql': (
+ "sql": (
"--sql",
dict(
action="store_true",
help="Don't emit SQL to database - dump to "
"standard output/file instead. See docs on "
- "offline mode."
- )
+ "offline mode.",
+ ),
),
- 'tag': (
+ "tag": (
"--tag",
dict(
type=str,
help="Arbitrary 'tag' name - can be used by "
- "custom env.py scripts.")
+ "custom env.py scripts.",
+ ),
),
- 'head': (
+ "head": (
"--head",
dict(
type=str,
help="Specify head revision or <branchname>@head "
- "to base new revision on."
- )
+ "to base new revision on.",
+ ),
),
- 'splice': (
+ "splice": (
"--splice",
dict(
action="store_true",
help="Allow a non-head revision as the "
- "'head' to splice onto"
- )
+ "'head' to splice onto",
+ ),
),
- 'depends_on': (
+ "depends_on": (
"--depends-on",
dict(
action="append",
help="Specify one or more revision identifiers "
- "which this revision should depend on."
- )
+ "which this revision should depend on.",
+ ),
),
- 'rev_id': (
+ "rev_id": (
"--rev-id",
dict(
type=str,
help="Specify a hardcoded revision id instead of "
- "generating one"
- )
+ "generating one",
+ ),
),
- 'version_path': (
+ "version_path": (
"--version-path",
dict(
type=str,
help="Specify specific path from config for "
- "version file"
- )
+ "version file",
+ ),
),
- 'branch_label': (
+ "branch_label": (
"--branch-label",
dict(
type=str,
help="Specify a branch label to apply to the "
- "new revision"
- )
+ "new revision",
+ ),
),
- 'verbose': (
- "-v", "--verbose",
- dict(
- action="store_true",
- help="Use more verbose output"
- )
+ "verbose": (
+ "-v",
+ "--verbose",
+ dict(action="store_true", help="Use more verbose output"),
),
- 'resolve_dependencies': (
- '--resolve-dependencies',
+ "resolve_dependencies": (
+ "--resolve-dependencies",
dict(
action="store_true",
- help="Treat dependency versions as down revisions"
- )
+ help="Treat dependency versions as down revisions",
+ ),
),
- 'autogenerate': (
+ "autogenerate": (
"--autogenerate",
dict(
action="store_true",
help="Populate revision script with candidate "
"migration operations, based on comparison "
- "of database to model.")
+ "of database to model.",
+ ),
),
- 'head_only': (
+ "head_only": (
"--head-only",
dict(
action="store_true",
help="Deprecated. Use --verbose for "
- "additional output")
+ "additional output",
+ ),
),
- 'rev_range': (
- "-r", "--rev-range",
+ "rev_range": (
+ "-r",
+ "--rev-range",
dict(
action="store",
help="Specify a revision range; "
- "format is [start]:[end]")
+ "format is [start]:[end]",
+ ),
),
- 'indicate_current': (
- "-i", "--indicate-current",
+ "indicate_current": (
+ "-i",
+ "--indicate-current",
dict(
action="store_true",
- help="Indicate the current revision"
- )
- )
+ help="Indicate the current revision",
+ ),
+ ),
}
positional_help = {
- 'directory': "location of scripts directory",
- 'revision': "revision identifier",
- 'revisions': "one or more revisions, or 'heads' for all heads"
-
+ "directory": "location of scripts directory",
+ "revision": "revision identifier",
+ "revisions": "one or more revisions, or 'heads' for all heads",
}
for arg in kwargs:
if arg in kwargs_opts:
@@ -429,44 +437,56 @@ class CommandLine(object):
for arg in positional:
if arg == "revisions":
subparser.add_argument(
- arg, nargs='+', help=positional_help.get(arg))
+ arg, nargs="+", help=positional_help.get(arg)
+ )
else:
subparser.add_argument(arg, help=positional_help.get(arg))
parser = ArgumentParser(prog=prog)
- parser.add_argument("-c", "--config",
- type=str,
- default="alembic.ini",
- help="Alternate config file")
- parser.add_argument("-n", "--name",
- type=str,
- default="alembic",
- help="Name of section in .ini file to "
- "use for Alembic config")
- parser.add_argument("-x", action="append",
- help="Additional arguments consumed by "
- "custom env.py scripts, e.g. -x "
- "setting1=somesetting -x setting2=somesetting")
- parser.add_argument("--raiseerr", action="store_true",
- help="Raise a full stack trace on error")
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=str,
+ default="alembic.ini",
+ help="Alternate config file",
+ )
+ parser.add_argument(
+ "-n",
+ "--name",
+ type=str,
+ default="alembic",
+ help="Name of section in .ini file to " "use for Alembic config",
+ )
+ parser.add_argument(
+ "-x",
+ action="append",
+ help="Additional arguments consumed by "
+ "custom env.py scripts, e.g. -x "
+ "setting1=somesetting -x setting2=somesetting",
+ )
+ parser.add_argument(
+ "--raiseerr",
+ action="store_true",
+ help="Raise a full stack trace on error",
+ )
subparsers = parser.add_subparsers()
for fn in [getattr(command, n) for n in dir(command)]:
- if inspect.isfunction(fn) and \
- fn.__name__[0] != '_' and \
- fn.__module__ == 'alembic.command':
+ if (
+ inspect.isfunction(fn)
+ and fn.__name__[0] != "_"
+ and fn.__module__ == "alembic.command"
+ ):
spec = compat.inspect_getargspec(fn)
if spec[3]:
- positional = spec[0][1:-len(spec[3])]
- kwarg = spec[0][-len(spec[3]):]
+ positional = spec[0][1 : -len(spec[3])]
+ kwarg = spec[0][-len(spec[3]) :]
else:
positional = spec[0][1:]
kwarg = []
- subparser = subparsers.add_parser(
- fn.__name__,
- help=fn.__doc__)
+ subparser = subparsers.add_parser(fn.__name__, help=fn.__doc__)
add_options(subparser, positional, kwarg)
subparser.set_defaults(cmd=(fn, positional, kwarg))
self.parser = parser
@@ -475,10 +495,11 @@ class CommandLine(object):
fn, positional, kwarg = options.cmd
try:
- fn(config,
- *[getattr(options, k, None) for k in positional],
- **dict((k, getattr(options, k, None)) for k in kwarg)
- )
+ fn(
+ config,
+ *[getattr(options, k, None) for k in positional],
+ **dict((k, getattr(options, k, None)) for k in kwarg)
+ )
except util.CommandError as e:
if options.raiseerr:
raise
@@ -492,8 +513,11 @@ class CommandLine(object):
# behavior changed incompatibly in py3.3
self.parser.error("too few arguments")
else:
- cfg = Config(file_=options.config,
- ini_section=options.name, cmd_opts=options)
+ cfg = Config(
+ file_=options.config,
+ ini_section=options.name,
+ cmd_opts=options,
+ )
self.run_cmd(cfg, options)
@@ -502,5 +526,6 @@ def main(argv=None, prog=None, **kwargs):
CommandLine(prog=prog).main(argv=argv)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
main()
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py
index f4a525f..f177a07 100644
--- a/alembic/ddl/base.py
+++ b/alembic/ddl/base.py
@@ -9,7 +9,11 @@ from .. import util
# backwards compat
from ..util.sqla_compat import ( # noqa
_table_for_constraint,
- _columns_for_constraint, _fk_spec, _is_type_bound, _find_columns)
+ _columns_for_constraint,
+ _fk_spec,
+ _is_type_bound,
+ _find_columns,
+)
if util.sqla_09:
from sqlalchemy.sql.elements import quoted_name
@@ -30,65 +34,63 @@ class AlterTable(DDLElement):
class RenameTable(AlterTable):
-
def __init__(self, old_table_name, new_table_name, schema=None):
super(RenameTable, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
class AlterColumn(AlterTable):
-
- def __init__(self, name, column_name, schema=None,
- existing_type=None,
- existing_nullable=None,
- existing_server_default=None):
+ def __init__(
+ self,
+ name,
+ column_name,
+ schema=None,
+ existing_type=None,
+ existing_nullable=None,
+ existing_server_default=None,
+ ):
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
- self.existing_type = sqltypes.to_instance(existing_type) \
- if existing_type is not None else None
+ self.existing_type = (
+ sqltypes.to_instance(existing_type)
+ if existing_type is not None
+ else None
+ )
self.existing_nullable = existing_nullable
self.existing_server_default = existing_server_default
class ColumnNullable(AlterColumn):
-
def __init__(self, name, column_name, nullable, **kw):
- super(ColumnNullable, self).__init__(name, column_name,
- **kw)
+ super(ColumnNullable, self).__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
-
def __init__(self, name, column_name, type_, **kw):
- super(ColumnType, self).__init__(name, column_name,
- **kw)
+ super(ColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
-
def __init__(self, name, column_name, newname, **kw):
super(ColumnName, self).__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
-
def __init__(self, name, column_name, default, **kw):
super(ColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class AddColumn(AlterTable):
-
def __init__(self, name, column, schema=None):
super(AddColumn, self).__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
-
def __init__(self, name, column, schema=None):
super(DropColumn, self).__init__(name, schema=schema)
self.column = column
@@ -98,7 +100,7 @@ class DropColumn(AlterTable):
def visit_rename_table(element, compiler, **kw):
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
- format_table_name(compiler, element.new_table_name, element.schema)
+ format_table_name(compiler, element.new_table_name, element.schema),
)
@@ -106,7 +108,7 @@ def visit_rename_table(element, compiler, **kw):
def visit_add_column(element, compiler, **kw):
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
- add_column(compiler, element.column, **kw)
+ add_column(compiler, element.column, **kw),
)
@@ -114,7 +116,7 @@ def visit_add_column(element, compiler, **kw):
def visit_drop_column(element, compiler, **kw):
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
- drop_column(compiler, element.column.name, **kw)
+ drop_column(compiler, element.column.name, **kw),
)
@@ -123,7 +125,7 @@ def visit_column_nullable(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- "DROP NOT NULL" if element.nullable else "SET NOT NULL"
+ "DROP NOT NULL" if element.nullable else "SET NOT NULL",
)
@@ -132,7 +134,7 @@ def visit_column_type(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- "TYPE %s" % format_type(compiler, element.type_)
+ "TYPE %s" % format_type(compiler, element.type_),
)
@@ -141,7 +143,7 @@ def visit_column_name(element, compiler, **kw):
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
- format_column_name(compiler, element.newname)
+ format_column_name(compiler, element.newname),
)
@@ -150,10 +152,9 @@ def visit_column_default(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- "SET DEFAULT %s" %
- format_server_default(compiler, element.default)
+ "SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
- else "DROP DEFAULT"
+ else "DROP DEFAULT",
)
@@ -162,7 +163,7 @@ def quote_dotted(name, quote):
if util.sqla_09 and isinstance(name, quoted_name):
return quote(name)
- result = '.'.join([quote(x) for x in name.split('.')])
+ result = ".".join([quote(x) for x in name.split(".")])
return result
@@ -193,11 +194,11 @@ def alter_table(compiler, name, schema):
def drop_column(compiler, name):
- return 'DROP COLUMN %s' % format_column_name(compiler, name)
+ return "DROP COLUMN %s" % format_column_name(compiler, name)
def alter_column(compiler, name):
- return 'ALTER COLUMN %s' % format_column_name(compiler, name)
+ return "ALTER COLUMN %s" % format_column_name(compiler, name)
def add_column(compiler, column, **kw):
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 98be164..4e3ff04 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -1,22 +1,20 @@
from sqlalchemy import schema, text
from sqlalchemy import types as sqltypes
-from ..util.compat import (
- string_types, text_type, with_metaclass
-)
+from ..util.compat import string_types, text_type, with_metaclass
from ..util import sqla_compat
from .. import util
from . import base
class ImplMeta(type):
-
def __init__(cls, classname, bases, dict_):
newtype = type.__init__(cls, classname, bases, dict_)
- if '__dialect__' in dict_:
- _impls[dict_['__dialect__']] = cls
+ if "__dialect__" in dict_:
+ _impls[dict_["__dialect__"]] = cls
return newtype
+
_impls = {}
@@ -33,18 +31,25 @@ class DefaultImpl(with_metaclass(ImplMeta)):
bulk inserts.
"""
- __dialect__ = 'default'
+
+ __dialect__ = "default"
transactional_ddl = False
command_terminator = ";"
- def __init__(self, dialect, connection, as_sql,
- transactional_ddl, output_buffer,
- context_opts):
+ def __init__(
+ self,
+ dialect,
+ connection,
+ as_sql,
+ transactional_ddl,
+ output_buffer,
+ context_opts,
+ ):
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
- self.literal_binds = context_opts.get('literal_binds', False)
+ self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
self.memo = {}
@@ -55,7 +60,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
if self.literal_binds:
if not self.as_sql:
raise util.CommandError(
- "Can't use literal_binds setting without as_sql mode")
+ "Can't use literal_binds setting without as_sql mode"
+ )
@classmethod
def get_by_dialect(cls, dialect):
@@ -89,9 +95,13 @@ class DefaultImpl(with_metaclass(ImplMeta)):
def bind(self):
return self.connection
- def _exec(self, construct, execution_options=None,
- multiparams=(),
- params=util.immutabledict()):
+ def _exec(
+ self,
+ construct,
+ execution_options=None,
+ multiparams=(),
+ params=util.immutabledict(),
+ ):
if isinstance(construct, string_types):
construct = text(construct)
if self.as_sql:
@@ -100,14 +110,20 @@ class DefaultImpl(with_metaclass(ImplMeta)):
raise Exception("Execution arguments not allowed with as_sql")
if self.literal_binds and not isinstance(
- construct, schema.DDLElement):
+ construct, schema.DDLElement
+ ):
compile_kw = dict(compile_kwargs={"literal_binds": True})
else:
compile_kw = {}
- self.static_output(text_type(
- construct.compile(dialect=self.dialect, **compile_kw)
- ).replace("\t", " ").strip() + self.command_terminator)
+ self.static_output(
+ text_type(
+ construct.compile(dialect=self.dialect, **compile_kw)
+ )
+ .replace("\t", " ")
+ .strip()
+ + self.command_terminator
+ )
else:
conn = self.connection
if execution_options:
@@ -117,53 +133,75 @@ class DefaultImpl(with_metaclass(ImplMeta)):
def execute(self, sql, execution_options=None):
self._exec(sql, execution_options)
- def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None
- ):
+ def alter_column(
+ self,
+ table_name,
+ column_name,
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ autoincrement=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ existing_autoincrement=None,
+ ):
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
- "only make sense for MySQL")
+ "only make sense for MySQL"
+ )
if nullable is not None:
- self._exec(base.ColumnNullable(
- table_name, column_name,
- nullable, schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ self._exec(
+ base.ColumnNullable(
+ table_name,
+ column_name,
+ nullable,
+ schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ )
+ )
if server_default is not False:
- self._exec(base.ColumnDefault(
- table_name, column_name, server_default,
- schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ self._exec(
+ base.ColumnDefault(
+ table_name,
+ column_name,
+ server_default,
+ schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ )
+ )
if type_ is not None:
- self._exec(base.ColumnType(
- table_name, column_name, type_, schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ self._exec(
+ base.ColumnType(
+ table_name,
+ column_name,
+ type_,
+ schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ )
+ )
# do the new name last ;)
if name is not None:
- self._exec(base.ColumnName(
- table_name, column_name, name, schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ self._exec(
+ base.ColumnName(
+ table_name,
+ column_name,
+ name,
+ schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ )
+ )
def add_column(self, table_name, column, schema=None):
self._exec(base.AddColumn(table_name, column, schema=schema))
@@ -172,25 +210,25 @@ class DefaultImpl(with_metaclass(ImplMeta)):
self._exec(base.DropColumn(table_name, column, schema=schema))
def add_constraint(self, const):
- if const._create_rule is None or \
- const._create_rule(self):
+ if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const):
self._exec(schema.DropConstraint(const))
def rename_table(self, old_table_name, new_table_name, schema=None):
- self._exec(base.RenameTable(old_table_name,
- new_table_name, schema=schema))
+ self._exec(
+ base.RenameTable(old_table_name, new_table_name, schema=schema)
+ )
def create_table(self, table):
- table.dispatch.before_create(table, self.connection,
- checkfirst=False,
- _ddl_runner=self)
+ table.dispatch.before_create(
+ table, self.connection, checkfirst=False, _ddl_runner=self
+ )
self._exec(schema.CreateTable(table))
- table.dispatch.after_create(table, self.connection,
- checkfirst=False,
- _ddl_runner=self)
+ table.dispatch.after_create(
+ table, self.connection, checkfirst=False, _ddl_runner=self
+ )
for index in table.indexes:
self._exec(schema.CreateIndex(index))
@@ -210,17 +248,26 @@ class DefaultImpl(with_metaclass(ImplMeta)):
raise TypeError("List of dictionaries expected")
if self.as_sql:
for row in rows:
- self._exec(table.insert(inline=True).values(**dict(
- (k,
- sqla_compat._literal_bindparam(
- k, v, type_=table.c[k].type)
- if not isinstance(
- v, sqla_compat._literal_bindparam) else v)
- for k, v in row.items()
- )))
+ self._exec(
+ table.insert(inline=True).values(
+ **dict(
+ (
+ k,
+ sqla_compat._literal_bindparam(
+ k, v, type_=table.c[k].type
+ )
+ if not isinstance(
+ v, sqla_compat._literal_bindparam
+ )
+ else v,
+ )
+ for k, v in row.items()
+ )
+ )
+ )
else:
# work around http://www.sqlalchemy.org/trac/ticket/2461
- if not hasattr(table, '_autoincrement_column'):
+ if not hasattr(table, "_autoincrement_column"):
table._autoincrement_column = None
if rows:
if multiinsert:
@@ -240,32 +287,38 @@ class DefaultImpl(with_metaclass(ImplMeta)):
# work around SQLAlchemy bug "stale value for type affinity"
# fixed in 0.7.4
- metadata_impl.__dict__.pop('_type_affinity', None)
+ metadata_impl.__dict__.pop("_type_affinity", None)
if hasattr(metadata_impl, "compare_against_backend"):
comparison = metadata_impl.compare_against_backend(
- self.dialect, conn_type)
+ self.dialect, conn_type
+ )
if comparison is not None:
return not comparison
- if conn_type._compare_type_affinity(
- metadata_impl
- ):
+ if conn_type._compare_type_affinity(metadata_impl):
comparator = _type_comparators.get(conn_type._type_affinity, None)
return comparator and comparator(metadata_impl, conn_type)
else:
return True
- def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ def compare_server_default(
+ self,
+ inspector_column,
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default,
+ ):
return rendered_inspector_default != rendered_metadata_default
- def correct_for_autogen_constraints(self, conn_uniques, conn_indexes,
- metadata_unique_constraints,
- metadata_indexes):
+ def correct_for_autogen_constraints(
+ self,
+ conn_uniques,
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+ ):
pass
def _compat_autogen_column_reflect(self, inspector):
@@ -316,38 +369,37 @@ class DefaultImpl(with_metaclass(ImplMeta)):
def _string_compare(t1, t2):
- return \
- t1.length is not None and \
- t1.length != t2.length
+ return t1.length is not None and t1.length != t2.length
def _numeric_compare(t1, t2):
- return (
- t1.precision is not None and
- t1.precision != t2.precision
- ) or (
- t1.precision is not None and
- t1.scale is not None and
- t1.scale != t2.scale
+ return (t1.precision is not None and t1.precision != t2.precision) or (
+ t1.precision is not None
+ and t1.scale is not None
+ and t1.scale != t2.scale
)
def _integer_compare(t1, t2):
t1_small_or_big = (
- 'S' if isinstance(t1, sqltypes.SmallInteger)
- else 'B' if isinstance(t1, sqltypes.BigInteger) else 'I'
+ "S"
+ if isinstance(t1, sqltypes.SmallInteger)
+ else "B"
+ if isinstance(t1, sqltypes.BigInteger)
+ else "I"
)
t2_small_or_big = (
- 'S' if isinstance(t2, sqltypes.SmallInteger)
- else 'B' if isinstance(t2, sqltypes.BigInteger) else 'I'
+ "S"
+ if isinstance(t2, sqltypes.SmallInteger)
+ else "B"
+ if isinstance(t2, sqltypes.BigInteger)
+ else "I"
)
return t1_small_or_big != t2_small_or_big
def _datetime_compare(t1, t2):
- return (
- t1.timezone != t2.timezone
- )
+ return t1.timezone != t2.timezone
_type_comparators = {
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
index f303be4..7f43a89 100644
--- a/alembic/ddl/mssql.py
+++ b/alembic/ddl/mssql.py
@@ -2,24 +2,35 @@ from sqlalchemy.ext.compiler import compiles
from .. import util
from .impl import DefaultImpl
-from .base import alter_table, AddColumn, ColumnName, RenameTable,\
- format_table_name, format_column_name, ColumnNullable, alter_column,\
- format_server_default, ColumnDefault, format_type, ColumnType
+from .base import (
+ alter_table,
+ AddColumn,
+ ColumnName,
+ RenameTable,
+ format_table_name,
+ format_column_name,
+ ColumnNullable,
+ alter_column,
+ format_server_default,
+ ColumnDefault,
+ format_type,
+ ColumnType,
+)
from sqlalchemy.sql.expression import ClauseElement, Executable
from sqlalchemy.schema import CreateIndex, Column
from sqlalchemy import types as sqltypes
class MSSQLImpl(DefaultImpl):
- __dialect__ = 'mssql'
+ __dialect__ = "mssql"
transactional_ddl = True
batch_separator = "GO"
def __init__(self, *arg, **kw):
super(MSSQLImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
- "mssql_batch_separator",
- self.batch_separator)
+ "mssql_batch_separator", self.batch_separator
+ )
def _exec(self, construct, *args, **kw):
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
@@ -35,17 +46,20 @@ class MSSQLImpl(DefaultImpl):
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
- def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- **kw
- ):
+ def alter_column(
+ self,
+ table_name,
+ column_name,
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ **kw
+ ):
if nullable is not None and existing_type is None:
if type_ is not None:
@@ -57,10 +71,12 @@ class MSSQLImpl(DefaultImpl):
raise util.CommandError(
"MS-SQL ALTER COLUMN operations "
"with NULL or NOT NULL require the "
- "existing_type or a new type_ be passed.")
+ "existing_type or a new type_ be passed."
+ )
super(MSSQLImpl, self).alter_column(
- table_name, column_name,
+ table_name,
+ column_name,
nullable=nullable,
type_=type_,
schema=schema,
@@ -70,30 +86,30 @@ class MSSQLImpl(DefaultImpl):
)
if server_default is not False:
- if existing_server_default is not False or \
- server_default is None:
+ if existing_server_default is not False or server_default is None:
self._exec(
_ExecDropConstraint(
- table_name, column_name,
- 'sys.default_constraints')
+ table_name, column_name, "sys.default_constraints"
+ )
)
if server_default is not None:
super(MSSQLImpl, self).alter_column(
- table_name, column_name,
+ table_name,
+ column_name,
schema=schema,
- server_default=server_default)
+ server_default=server_default,
+ )
if name is not None:
super(MSSQLImpl, self).alter_column(
- table_name, column_name,
- schema=schema,
- name=name)
+ table_name, column_name, schema=schema, name=name
+ )
def create_index(self, index):
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
- mssql_include = index.kwargs.get('mssql_include', None) or ()
+ mssql_include = index.kwargs.get("mssql_include", None) or ()
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
@@ -102,42 +118,39 @@ class MSSQLImpl(DefaultImpl):
def bulk_insert(self, table, rows, **kw):
if self.as_sql:
self._exec(
- "SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(table)
+ "SET IDENTITY_INSERT %s ON"
+ % self.dialect.identifier_preparer.format_table(table)
)
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
self._exec(
- "SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer.format_table(table)
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(table)
)
else:
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
def drop_column(self, table_name, column, **kw):
- drop_default = kw.pop('mssql_drop_default', False)
+ drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
_ExecDropConstraint(
- table_name, column,
- 'sys.default_constraints')
+ table_name, column, "sys.default_constraints"
+ )
)
- drop_check = kw.pop('mssql_drop_check', False)
+ drop_check = kw.pop("mssql_drop_check", False)
if drop_check:
self._exec(
_ExecDropConstraint(
- table_name, column,
- 'sys.check_constraints')
+ table_name, column, "sys.check_constraints"
+ )
)
- drop_fks = kw.pop('mssql_drop_foreign_key', False)
+ drop_fks = kw.pop("mssql_drop_foreign_key", False)
if drop_fks:
- self._exec(
- _ExecDropFKConstraint(table_name, column)
- )
+ self._exec(_ExecDropFKConstraint(table_name, column))
super(MSSQLImpl, self).drop_column(table_name, column, **kw)
class _ExecDropConstraint(Executable, ClauseElement):
-
def __init__(self, tname, colname, type_):
self.tname = tname
self.colname = colname
@@ -145,13 +158,12 @@ class _ExecDropConstraint(Executable, ClauseElement):
class _ExecDropFKConstraint(Executable, ClauseElement):
-
def __init__(self, tname, colname):
self.tname = tname
self.colname = colname
-@compiles(_ExecDropConstraint, 'mssql')
+@compiles(_ExecDropConstraint, "mssql")
def _exec_drop_col_constraint(element, compiler, **kw):
tname, colname, type_ = element.tname, element.colname, element.type_
# from http://www.mssqltips.com/sqlservertip/1425/\
@@ -162,14 +174,14 @@ select @const_name = [name] from %(type)s
where parent_object_id = object_id('%(tname)s')
and col_name(parent_object_id, parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
- 'type': type_,
- 'tname': tname,
- 'colname': colname,
- 'tname_quoted': format_table_name(compiler, tname, None),
+ "type": type_,
+ "tname": tname,
+ "colname": colname,
+ "tname_quoted": format_table_name(compiler, tname, None),
}
-@compiles(_ExecDropFKConstraint, 'mssql')
+@compiles(_ExecDropFKConstraint, "mssql")
def _exec_drop_col_fk_constraint(element, compiler, **kw):
tname, colname = element.tname, element.colname
@@ -180,17 +192,17 @@ select @const_name = [name] from
where fkc.parent_object_id = object_id('%(tname)s')
and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
- 'tname': tname,
- 'colname': colname,
- 'tname_quoted': format_table_name(compiler, tname, None),
+ "tname": tname,
+ "colname": colname,
+ "tname_quoted": format_table_name(compiler, tname, None),
}
-@compiles(AddColumn, 'mssql')
+@compiles(AddColumn, "mssql")
def visit_add_column(element, compiler, **kw):
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
- mssql_add_column(compiler, element.column, **kw)
+ mssql_add_column(compiler, element.column, **kw),
)
@@ -198,49 +210,48 @@ def mssql_add_column(compiler, column, **kw):
return "ADD %s" % compiler.get_column_specification(column, **kw)
-@compiles(ColumnNullable, 'mssql')
+@compiles(ColumnNullable, "mssql")
def visit_column_nullable(element, compiler, **kw):
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.existing_type),
- "NULL" if element.nullable else "NOT NULL"
+ "NULL" if element.nullable else "NOT NULL",
)
-@compiles(ColumnDefault, 'mssql')
+@compiles(ColumnDefault, "mssql")
def visit_column_default(element, compiler, **kw):
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
alter_table(compiler, element.table_name, element.schema),
format_server_default(compiler, element.default),
- format_column_name(compiler, element.column_name)
+ format_column_name(compiler, element.column_name),
)
-@compiles(ColumnName, 'mssql')
+@compiles(ColumnName, "mssql")
def visit_rename_column(element, compiler, **kw):
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
- format_column_name(compiler, element.newname)
+ format_column_name(compiler, element.newname),
)
-@compiles(ColumnType, 'mssql')
+@compiles(ColumnType, "mssql")
def visit_column_type(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- format_type(compiler, element.type_)
+ format_type(compiler, element.type_),
)
-@compiles(RenameTable, 'mssql')
+@compiles(RenameTable, "mssql")
def visit_rename_table(element, compiler, **kw):
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
- format_table_name(compiler, element.new_table_name, None)
+ format_table_name(compiler, element.new_table_name, None),
)
-
diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py
index 1f4b345..bc98005 100644
--- a/alembic/ddl/mysql.py
+++ b/alembic/ddl/mysql.py
@@ -5,9 +5,15 @@ from sqlalchemy import schema
from ..util.compat import string_types
from .. import util
from .impl import DefaultImpl
-from .base import ColumnNullable, ColumnName, ColumnDefault, \
- ColumnType, AlterColumn, format_column_name, \
- format_server_default
+from .base import (
+ ColumnNullable,
+ ColumnName,
+ ColumnDefault,
+ ColumnType,
+ AlterColumn,
+ format_column_name,
+ format_server_default,
+)
from .base import alter_table
from ..autogenerate import compare
from ..util.sqla_compat import _is_type_bound, sqla_100
@@ -15,64 +21,76 @@ import re
class MySQLImpl(DefaultImpl):
- __dialect__ = 'mysql'
+ __dialect__ = "mysql"
transactional_ddl = False
- def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- autoincrement=None,
- existing_autoincrement=None,
- **kw
- ):
+ def alter_column(
+ self,
+ table_name,
+ column_name,
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ autoincrement=None,
+ existing_autoincrement=None,
+ **kw
+ ):
if name is not None:
self._exec(
MySQLChangeColumn(
- table_name, column_name,
+ table_name,
+ column_name,
schema=schema,
newname=name,
- nullable=nullable if nullable is not None else
- existing_nullable
+ nullable=nullable
+ if nullable is not None
+ else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
- default=server_default if server_default is not False
+ default=server_default
+ if server_default is not False
else existing_server_default,
- autoincrement=autoincrement if autoincrement is not None
- else existing_autoincrement
+ autoincrement=autoincrement
+ if autoincrement is not None
+ else existing_autoincrement,
)
)
- elif nullable is not None or \
- type_ is not None or \
- autoincrement is not None:
+ elif (
+ nullable is not None
+ or type_ is not None
+ or autoincrement is not None
+ ):
self._exec(
MySQLModifyColumn(
- table_name, column_name,
+ table_name,
+ column_name,
schema=schema,
newname=name if name is not None else column_name,
- nullable=nullable if nullable is not None else
- existing_nullable
+ nullable=nullable
+ if nullable is not None
+ else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
- default=server_default if server_default is not False
+ default=server_default
+ if server_default is not False
else existing_server_default,
- autoincrement=autoincrement if autoincrement is not None
- else existing_autoincrement
+ autoincrement=autoincrement
+ if autoincrement is not None
+ else existing_autoincrement,
)
)
elif server_default is not False:
self._exec(
MySQLAlterDefault(
- table_name, column_name, server_default,
- schema=schema,
+ table_name, column_name, server_default, schema=schema
)
)
@@ -82,41 +100,47 @@ class MySQLImpl(DefaultImpl):
super(MySQLImpl, self).drop_constraint(const)
- def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ def compare_server_default(
+ self,
+ inspector_column,
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default,
+ ):
# partially a workaround for SQLAlchemy issue #3023; if the
# column were created without "NOT NULL", MySQL may have added
# an implicit default of '0' which we need to skip
# TODO: this is not really covered anymore ?
- if metadata_column.type._type_affinity is sqltypes.Integer and \
- inspector_column.primary_key and \
- not inspector_column.autoincrement and \
- not rendered_metadata_default and \
- rendered_inspector_default == "'0'":
+ if (
+ metadata_column.type._type_affinity is sqltypes.Integer
+ and inspector_column.primary_key
+ and not inspector_column.autoincrement
+ and not rendered_metadata_default
+ and rendered_inspector_default == "'0'"
+ ):
return False
elif inspector_column.type._type_affinity is sqltypes.Integer:
rendered_inspector_default = re.sub(
- r"^'|'$", '', rendered_inspector_default)
+ r"^'|'$", "", rendered_inspector_default
+ )
return rendered_inspector_default != rendered_metadata_default
elif rendered_inspector_default and rendered_metadata_default:
# adjust for "function()" vs. "FUNCTION"
- return (
- re.sub(
- r'(.*?)(?:\(\))?$', r'\1',
- rendered_inspector_default.lower()) !=
- re.sub(
- r'(.*?)(?:\(\))?$', r'\1',
- rendered_metadata_default.lower())
+ return re.sub(
+ r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
+ ) != re.sub(
+ r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
)
else:
return rendered_inspector_default != rendered_metadata_default
- def correct_for_autogen_constraints(self, conn_unique_constraints,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes):
+ def correct_for_autogen_constraints(
+ self,
+ conn_unique_constraints,
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+ ):
# TODO: if SQLA 1.0, make use of "duplicates_index"
# metadata
@@ -153,31 +177,41 @@ class MySQLImpl(DefaultImpl):
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
- metadata_indexes
+ metadata_indexes,
)
- def _legacy_correct_for_dupe_uq_uix(self, conn_unique_constraints,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes):
+ def _legacy_correct_for_dupe_uq_uix(
+ self,
+ conn_unique_constraints,
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+ ):
# then dedupe unique indexes vs. constraints, since MySQL
# doesn't really have unique constraints as a separate construct.
# but look in the metadata and try to maintain constructs
# that already seem to be defined one way or the other
# on that side. See #276
- metadata_uq_names = set([
- cons.name for cons in metadata_unique_constraints
- if cons.name is not None])
-
- unnamed_metadata_uqs = set([
- compare._uq_constraint_sig(cons).sig
- for cons in metadata_unique_constraints
- if cons.name is None
- ])
-
- metadata_ix_names = set([
- cons.name for cons in metadata_indexes if cons.unique])
+ metadata_uq_names = set(
+ [
+ cons.name
+ for cons in metadata_unique_constraints
+ if cons.name is not None
+ ]
+ )
+
+ unnamed_metadata_uqs = set(
+ [
+ compare._uq_constraint_sig(cons).sig
+ for cons in metadata_unique_constraints
+ if cons.name is None
+ ]
+ )
+
+ metadata_ix_names = set(
+ [cons.name for cons in metadata_indexes if cons.unique]
+ )
conn_uq_names = dict(
(cons.name, cons) for cons in conn_unique_constraints
)
@@ -187,8 +221,10 @@ class MySQLImpl(DefaultImpl):
for overlap in set(conn_uq_names).intersection(conn_ix_names):
if overlap not in metadata_uq_names:
- if compare._uq_constraint_sig(conn_uq_names[overlap]).sig \
- not in unnamed_metadata_uqs:
+ if (
+ compare._uq_constraint_sig(conn_uq_names[overlap]).sig
+ not in unnamed_metadata_uqs
+ ):
conn_unique_constraints.discard(conn_uq_names[overlap])
elif overlap not in metadata_ix_names:
@@ -208,18 +244,21 @@ class MySQLImpl(DefaultImpl):
# MySQL considers RESTRICT to be the default and doesn't
# report on it. if the model has explicit RESTRICT and
# the conn FK has None, set it to RESTRICT
- if mdfk.ondelete is not None and \
- mdfk.ondelete.lower() == 'restrict' and \
- cnfk.ondelete is None:
- cnfk.ondelete = 'RESTRICT'
- if mdfk.onupdate is not None and \
- mdfk.onupdate.lower() == 'restrict' and \
- cnfk.onupdate is None:
- cnfk.onupdate = 'RESTRICT'
+ if (
+ mdfk.ondelete is not None
+ and mdfk.ondelete.lower() == "restrict"
+ and cnfk.ondelete is None
+ ):
+ cnfk.ondelete = "RESTRICT"
+ if (
+ mdfk.onupdate is not None
+ and mdfk.onupdate.lower() == "restrict"
+ and cnfk.onupdate is None
+ ):
+ cnfk.onupdate = "RESTRICT"
class MySQLAlterDefault(AlterColumn):
-
def __init__(self, name, column_name, default, schema=None):
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
@@ -227,13 +266,17 @@ class MySQLAlterDefault(AlterColumn):
class MySQLChangeColumn(AlterColumn):
-
- def __init__(self, name, column_name, schema=None,
- newname=None,
- type_=None,
- nullable=None,
- default=False,
- autoincrement=None):
+ def __init__(
+ self,
+ name,
+ column_name,
+ schema=None,
+ newname=None,
+ type_=None,
+ nullable=None,
+ default=False,
+ autoincrement=None,
+ ):
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
@@ -253,10 +296,10 @@ class MySQLModifyColumn(MySQLChangeColumn):
pass
-@compiles(ColumnNullable, 'mysql')
-@compiles(ColumnName, 'mysql')
-@compiles(ColumnDefault, 'mysql')
-@compiles(ColumnType, 'mysql')
+@compiles(ColumnNullable, "mysql")
+@compiles(ColumnName, "mysql")
+@compiles(ColumnDefault, "mysql")
+@compiles(ColumnType, "mysql")
def _mysql_doesnt_support_individual(element, compiler, **kw):
raise NotImplementedError(
"Individual alter column constructs not supported by MySQL"
@@ -270,7 +313,7 @@ def _mysql_alter_default(element, compiler, **kw):
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
- else "DROP DEFAULT"
+ else "DROP DEFAULT",
)
@@ -284,7 +327,7 @@ def _mysql_modify_column(element, compiler, **kw):
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
- autoincrement=element.autoincrement
+ autoincrement=element.autoincrement,
),
)
@@ -300,7 +343,7 @@ def _mysql_change_column(element, compiler, **kw):
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
- autoincrement=element.autoincrement
+ autoincrement=element.autoincrement,
),
)
@@ -312,11 +355,10 @@ def _render_value(compiler, expr):
return compiler.sql_compiler.process(expr)
-def _mysql_colspec(compiler, nullable, server_default, type_,
- autoincrement):
+def _mysql_colspec(compiler, nullable, server_default, type_, autoincrement):
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
- "NULL" if nullable else "NOT NULL"
+ "NULL" if nullable else "NOT NULL",
)
if autoincrement:
spec += " AUTO_INCREMENT"
@@ -332,21 +374,25 @@ def _mysql_drop_constraint(element, compiler, **kw):
raise errors for invalid constraint type."""
constraint = element.element
- if isinstance(constraint, (schema.ForeignKeyConstraint,
- schema.PrimaryKeyConstraint,
- schema.UniqueConstraint)
- ):
+ if isinstance(
+ constraint,
+ (
+ schema.ForeignKeyConstraint,
+ schema.PrimaryKeyConstraint,
+ schema.UniqueConstraint,
+ ),
+ ):
return compiler.visit_drop_constraint(element, **kw)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
- return "ALTER TABLE %s DROP CONSTRAINT %s" % \
- (compiler.preparer.format_table(constraint.table),
- compiler.preparer.format_constraint(constraint))
+ return "ALTER TABLE %s DROP CONSTRAINT %s" % (
+ compiler.preparer.format_table(constraint.table),
+ compiler.preparer.format_constraint(constraint),
+ )
else:
raise NotImplementedError(
"No generic 'DROP CONSTRAINT' in MySQL - "
- "please specify constraint type")
-
-
+ "please specify constraint type"
+ )
diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py
index e528744..3376155 100644
--- a/alembic/ddl/oracle.py
+++ b/alembic/ddl/oracle.py
@@ -1,13 +1,21 @@
from sqlalchemy.ext.compiler import compiles
from .impl import DefaultImpl
-from .base import alter_table, AddColumn, ColumnName, \
- format_column_name, ColumnNullable, \
- format_server_default, ColumnDefault, format_type, ColumnType
+from .base import (
+ alter_table,
+ AddColumn,
+ ColumnName,
+ format_column_name,
+ ColumnNullable,
+ format_server_default,
+ ColumnDefault,
+ format_type,
+ ColumnType,
+)
class OracleImpl(DefaultImpl):
- __dialect__ = 'oracle'
+ __dialect__ = "oracle"
transactional_ddl = False
batch_separator = "/"
command_terminator = ""
@@ -15,8 +23,8 @@ class OracleImpl(DefaultImpl):
def __init__(self, *arg, **kw):
super(OracleImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
- "oracle_batch_separator",
- self.batch_separator)
+ "oracle_batch_separator", self.batch_separator
+ )
def _exec(self, construct, *args, **kw):
result = super(OracleImpl, self)._exec(construct, *args, **kw)
@@ -31,7 +39,7 @@ class OracleImpl(DefaultImpl):
self._exec("COMMIT")
-@compiles(AddColumn, 'oracle')
+@compiles(AddColumn, "oracle")
def visit_add_column(element, compiler, **kw):
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
@@ -39,47 +47,46 @@ def visit_add_column(element, compiler, **kw):
)
-@compiles(ColumnNullable, 'oracle')
+@compiles(ColumnNullable, "oracle")
def visit_column_nullable(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- "NULL" if element.nullable else "NOT NULL"
+ "NULL" if element.nullable else "NOT NULL",
)
-@compiles(ColumnType, 'oracle')
+@compiles(ColumnType, "oracle")
def visit_column_type(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- "%s" % format_type(compiler, element.type_)
+ "%s" % format_type(compiler, element.type_),
)
-@compiles(ColumnName, 'oracle')
+@compiles(ColumnName, "oracle")
def visit_column_name(element, compiler, **kw):
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
- format_column_name(compiler, element.newname)
+ format_column_name(compiler, element.newname),
)
-@compiles(ColumnDefault, 'oracle')
+@compiles(ColumnDefault, "oracle")
def visit_column_default(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- "DEFAULT %s" %
- format_server_default(compiler, element.default)
+ "DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
- else "DEFAULT NULL"
+ else "DEFAULT NULL",
)
def alter_column(compiler, name):
- return 'MODIFY %s' % format_column_name(compiler, name)
+ return "MODIFY %s" % format_column_name(compiler, name)
def add_column(compiler, column, **kw):
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index d399833..f133a05 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -2,8 +2,15 @@ import re
from ..util import compat
from .. import util
-from .base import compiles, alter_column, alter_table, format_table_name, \
- format_type, AlterColumn, RenameTable
+from .base import (
+ compiles,
+ alter_column,
+ alter_table,
+ format_table_name,
+ format_type,
+ AlterColumn,
+ RenameTable,
+)
from .impl import DefaultImpl
from sqlalchemy.dialects.postgresql import INTEGER, BIGINT
from ..autogenerate import render
@@ -30,7 +37,7 @@ log = logging.getLogger(__name__)
class PostgresqlImpl(DefaultImpl):
- __dialect__ = 'postgresql'
+ __dialect__ = "postgresql"
transactional_ddl = True
def prep_table_for_batch(self, table):
@@ -38,13 +45,18 @@ class PostgresqlImpl(DefaultImpl):
if constraint.name is not None:
self.drop_constraint(constraint)
- def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ def compare_server_default(
+ self,
+ inspector_column,
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default,
+ ):
# don't do defaults for SERIAL columns
- if metadata_column.primary_key and \
- metadata_column is metadata_column.table._autoincrement_column:
+ if (
+ metadata_column.primary_key
+ and metadata_column is metadata_column.table._autoincrement_column
+ ):
return False
conn_col_default = rendered_inspector_default
@@ -56,53 +68,65 @@ class PostgresqlImpl(DefaultImpl):
if None in (conn_col_default, rendered_metadata_default):
return not defaults_equal
- if metadata_column.server_default is not None and \
- isinstance(metadata_column.server_default.arg,
- compat.string_types) and \
- not re.match(r"^'.+'$", rendered_metadata_default) and \
- not isinstance(inspector_column.type, Numeric):
- # don't single quote if the column type is float/numeric,
- # otherwise a comparison such as SELECT 5 = '5.0' will fail
+ if (
+ metadata_column.server_default is not None
+ and isinstance(
+ metadata_column.server_default.arg, compat.string_types
+ )
+ and not re.match(r"^'.+'$", rendered_metadata_default)
+ and not isinstance(inspector_column.type, Numeric)
+ ):
+ # don't single quote if the column type is float/numeric,
+ # otherwise a comparison such as SELECT 5 = '5.0' will fail
rendered_metadata_default = re.sub(
- r"^u?'?|'?$", "'", rendered_metadata_default)
+ r"^u?'?|'?$", "'", rendered_metadata_default
+ )
return not self.connection.scalar(
- "SELECT %s = %s" % (
- conn_col_default,
- rendered_metadata_default
- )
+ "SELECT %s = %s" % (conn_col_default, rendered_metadata_default)
)
- def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- **kw
- ):
-
- using = kw.pop('postgresql_using', None)
+ def alter_column(
+ self,
+ table_name,
+ column_name,
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ autoincrement=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ existing_autoincrement=None,
+ **kw
+ ):
+
+ using = kw.pop("postgresql_using", None)
if using is not None and type_ is None:
raise util.CommandError(
- "postgresql_using must be used with the type_ parameter")
+ "postgresql_using must be used with the type_ parameter"
+ )
if type_ is not None:
- self._exec(PostgresqlColumnType(
- table_name, column_name, type_, schema=schema,
- using=using, existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ self._exec(
+ PostgresqlColumnType(
+ table_name,
+ column_name,
+ type_,
+ schema=schema,
+ using=using,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ )
+ )
super(PostgresqlImpl, self).alter_column(
- table_name, column_name,
+ table_name,
+ column_name,
nullable=nullable,
server_default=server_default,
name=name,
@@ -112,57 +136,70 @@ class PostgresqlImpl(DefaultImpl):
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_autoincrement=existing_autoincrement,
- **kw)
+ **kw
+ )
def autogen_column_reflect(self, inspector, table, column_info):
- if column_info.get('default') and \
- isinstance(column_info['type'], (INTEGER, BIGINT)):
+ if column_info.get("default") and isinstance(
+ column_info["type"], (INTEGER, BIGINT)
+ ):
seq_match = re.match(
- r"nextval\('(.+?)'::regclass\)",
- column_info['default'])
+ r"nextval\('(.+?)'::regclass\)", column_info["default"]
+ )
if seq_match:
- info = inspector.bind.execute(text(
- "select c.relname, a.attname "
- "from pg_class as c join pg_depend d on d.objid=c.oid and "
- "d.classid='pg_class'::regclass and "
- "d.refclassid='pg_class'::regclass "
- "join pg_class t on t.oid=d.refobjid "
- "join pg_attribute a on a.attrelid=t.oid and "
- "a.attnum=d.refobjsubid "
- "where c.relkind='S' and c.relname=:seqname"
- ), seqname=seq_match.group(1)).first()
+ info = inspector.bind.execute(
+ text(
+ "select c.relname, a.attname "
+ "from pg_class as c join pg_depend d on d.objid=c.oid and "
+ "d.classid='pg_class'::regclass and "
+ "d.refclassid='pg_class'::regclass "
+ "join pg_class t on t.oid=d.refobjid "
+ "join pg_attribute a on a.attrelid=t.oid and "
+ "a.attnum=d.refobjsubid "
+ "where c.relkind='S' and c.relname=:seqname"
+ ),
+ seqname=seq_match.group(1),
+ ).first()
if info:
seqname, colname = info
- if colname == column_info['name']:
+ if colname == column_info["name"]:
log.info(
"Detected sequence named '%s' as "
"owned by integer column '%s(%s)', "
"assuming SERIAL and omitting",
- seqname, table.name, colname)
+ seqname,
+ table.name,
+ colname,
+ )
# sequence, and the owner is this column,
# its a SERIAL - whack it!
- del column_info['default']
+ del column_info["default"]
- def correct_for_autogen_constraints(self, conn_unique_constraints,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes):
+ def correct_for_autogen_constraints(
+ self,
+ conn_unique_constraints,
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+ ):
conn_uniques_by_name = dict(
- (c.name, c) for c in conn_unique_constraints)
- conn_indexes_by_name = dict(
- (c.name, c) for c in conn_indexes)
+ (c.name, c) for c in conn_unique_constraints
+ )
+ conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
if not util.sqla_100:
doubled_constraints = set(
conn_indexes_by_name[name]
for name in set(conn_uniques_by_name).intersection(
- conn_indexes_by_name)
+ conn_indexes_by_name
+ )
)
else:
doubled_constraints = set(
- index for index in
- conn_indexes if index.info.get('duplicates_constraint')
+ index
+ for index in conn_indexes
+ if index.info.get("duplicates_constraint")
)
for ix in doubled_constraints:
@@ -187,37 +224,36 @@ class PostgresqlImpl(DefaultImpl):
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
- if hasattr(self, '_render_%s_type' % type_.__visit_name__):
- meth = getattr(self, '_render_%s_type' % type_.__visit_name__)
+ if hasattr(self, "_render_%s_type" % type_.__visit_name__):
+ meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
return meth(type_, autogen_context)
return False
def _render_HSTORE_type(self, type_, autogen_context):
return render._render_type_w_subtype(
- type_, autogen_context, 'text_type', r'(.+?\(.*text_type=)'
+ type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
)
def _render_ARRAY_type(self, type_, autogen_context):
return render._render_type_w_subtype(
- type_, autogen_context, 'item_type', r'(.+?\()'
+ type_, autogen_context, "item_type", r"(.+?\()"
)
def _render_JSON_type(self, type_, autogen_context):
return render._render_type_w_subtype(
- type_, autogen_context, 'astext_type', r'(.+?\(.*astext_type=)'
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
)
def _render_JSONB_type(self, type_, autogen_context):
return render._render_type_w_subtype(
- type_, autogen_context, 'astext_type', r'(.+?\(.*astext_type=)'
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
)
class PostgresqlColumnType(AlterColumn):
-
def __init__(self, name, column_name, type_, **kw):
- using = kw.pop('using', None)
+ using = kw.pop("using", None)
super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
self.using = using
@@ -227,7 +263,7 @@ class PostgresqlColumnType(AlterColumn):
def visit_rename_table(element, compiler, **kw):
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
- format_table_name(compiler, element.new_table_name, None)
+ format_table_name(compiler, element.new_table_name, None),
)
@@ -237,13 +273,14 @@ def visit_column_type(element, compiler, **kw):
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
- "USING %s" % element.using if element.using else ""
+ "USING %s" % element.using if element.using else "",
)
@Operations.register_operation("create_exclude_constraint")
@BatchOperations.register_operation(
- "create_exclude_constraint", "batch_create_exclude_constraint")
+ "create_exclude_constraint", "batch_create_exclude_constraint"
+)
@ops.AddConstraintOp.register_add_constraint("exclude_constraint")
class CreateExcludeConstraintOp(ops.AddConstraintOp):
"""Represent a create exclude constraint operation."""
@@ -251,9 +288,15 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
constraint_type = "exclude"
def __init__(
- self, constraint_name, table_name,
- elements, where=None, schema=None,
- _orig_constraint=None, **kw):
+ self,
+ constraint_name,
+ table_name,
+ elements,
+ where=None,
+ schema=None,
+ _orig_constraint=None,
+ **kw
+ ):
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
@@ -275,13 +318,14 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
_orig_constraint=constraint,
deferrable=constraint.deferrable,
initially=constraint.initially,
- using=constraint.using
+ using=constraint.using,
)
def to_constraint(self, migration_context=None):
if not util.sqla_100:
raise NotImplementedError(
- "ExcludeConstraint not supported until SQLAlchemy 1.0")
+ "ExcludeConstraint not supported until SQLAlchemy 1.0"
+ )
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -299,8 +343,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
@classmethod
def create_exclude_constraint(
- cls, operations,
- constraint_name, table_name, *elements, **kw):
+ cls, operations, constraint_name, table_name, *elements, **kw
+ ):
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
@@ -344,7 +388,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
@classmethod
def batch_create_exclude_constraint(
- cls, operations, constraint_name, *elements, **kw):
+ cls, operations, constraint_name, *elements, **kw
+ ):
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
@@ -358,24 +403,23 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
:meth:`.Operations.create_exclude_constraint`
"""
- kw['schema'] = operations.impl.schema
+ kw["schema"] = operations.impl.schema
op = cls(constraint_name, operations.impl.table_name, elements, **kw)
return operations.invoke(op)
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
def _add_exclude_constraint(autogen_context, op):
- return _exclude_constraint(
- op.to_constraint(),
- autogen_context,
- alter=True
- )
+ return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
+
if util.sqla_100:
+
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(constraint, autogen_context):
rendered = render._user_defined_render(
- "exclude", constraint, autogen_context)
+ "exclude", constraint, autogen_context
+ )
if rendered is not False:
return rendered
@@ -405,48 +449,54 @@ def _exclude_constraint(constraint, autogen_context, alter):
opts.append(("schema", render._ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
- ("name",
- render._render_gen_name(autogen_context, constraint.name)))
+ ("name", render._render_gen_name(autogen_context, constraint.name))
+ )
if alter:
args = [
- repr(render._render_gen_name(
- autogen_context, constraint.name))]
+ repr(render._render_gen_name(autogen_context, constraint.name))
+ ]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
- args.extend([
- "(%s, %r)" % (
- _render_potential_column(sqltext, autogen_context),
- opstring
- )
- for sqltext, name, opstring in constraint._render_exprs
- ])
+ args.extend(
+ [
+ "(%s, %r)"
+ % (
+ _render_potential_column(sqltext, autogen_context),
+ opstring,
+ )
+ for sqltext, name, opstring in constraint._render_exprs
+ ]
+ )
if constraint.where is not None:
args.append(
- "where=%s" % render._render_potential_expr(
- constraint.where, autogen_context)
+ "where=%s"
+ % render._render_potential_expr(
+ constraint.where, autogen_context
+ )
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)screate_exclude_constraint(%(args)s)" % {
- 'prefix': render._alembic_autogenerate_prefix(autogen_context),
- 'args': ", ".join(args)
+ "prefix": render._alembic_autogenerate_prefix(autogen_context),
+ "args": ", ".join(args),
}
else:
args = [
- "(%s, %r)" % (
- _render_potential_column(sqltext, autogen_context),
- opstring
- ) for sqltext, name, opstring in constraint._render_exprs
+ "(%s, %r)"
+ % (_render_potential_column(sqltext, autogen_context), opstring)
+ for sqltext, name, opstring in constraint._render_exprs
]
if constraint.where is not None:
args.append(
- "where=%s" % render._render_potential_expr(
- constraint.where, autogen_context)
+ "where=%s"
+ % render._render_potential_expr(
+ constraint.where, autogen_context
+ )
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)sExcludeConstraint(%(args)s)" % {
"prefix": _postgresql_autogenerate_prefix(autogen_context),
- "args": ", ".join(args)
+ "args": ", ".join(args),
}
@@ -456,8 +506,10 @@ def _render_potential_column(value, autogen_context):
return template % {
"prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
- "name": value.name
+ "name": value.name,
}
else:
- return render._render_potential_expr(value, autogen_context, wrap_in_text=False)
+ return render._render_potential_expr(
+ value, autogen_context, wrap_in_text=False
+ )
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
index 5d231b5..f7699e6 100644
--- a/alembic/ddl/sqlite.py
+++ b/alembic/ddl/sqlite.py
@@ -4,7 +4,7 @@ import re
class SQLiteImpl(DefaultImpl):
- __dialect__ = 'sqlite'
+ __dialect__ = "sqlite"
transactional_ddl = False
"""SQLite supports transactional DDL, but pysqlite does not:
@@ -21,7 +21,7 @@ class SQLiteImpl(DefaultImpl):
"""
for op in batch_op.batch:
- if op[0] not in ('add_column', 'create_index', 'drop_index'):
+ if op[0] not in ("add_column", "create_index", "drop_index"):
return True
else:
return False
@@ -31,34 +31,46 @@ class SQLiteImpl(DefaultImpl):
# auto-gen constraint and an explicit one
if const._create_rule is None:
raise NotImplementedError(
- "No support for ALTER of constraints in SQLite dialect")
+ "No support for ALTER of constraints in SQLite dialect"
+ )
elif const._create_rule(self):
- util.warn("Skipping unsupported ALTER for "
- "creation of implicit constraint")
+ util.warn(
+ "Skipping unsupported ALTER for "
+ "creation of implicit constraint"
+ )
def drop_constraint(self, const):
if const._create_rule is None:
raise NotImplementedError(
- "No support for ALTER of constraints in SQLite dialect")
+ "No support for ALTER of constraints in SQLite dialect"
+ )
- def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ def compare_server_default(
+ self,
+ inspector_column,
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default,
+ ):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
- r"^\"'|\"'$", "", rendered_metadata_default)
+ r"^\"'|\"'$", "", rendered_metadata_default
+ )
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
- r"^\"'|\"'$", "", rendered_inspector_default)
+ r"^\"'|\"'$", "", rendered_inspector_default
+ )
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
- self, conn_unique_constraints, conn_indexes,
+ self,
+ conn_unique_constraints,
+ conn_indexes,
metadata_unique_constraints,
- metadata_indexes):
+ metadata_indexes,
+ ):
if util.sqla_100:
return
@@ -70,10 +82,7 @@ class SQLiteImpl(DefaultImpl):
def uq_sig(uq):
return tuple(sorted(uq.columns.keys()))
- conn_unique_sigs = set(
- uq_sig(uq)
- for uq in conn_unique_constraints
- )
+ conn_unique_sigs = set(uq_sig(uq) for uq in conn_unique_constraints)
for idx in list(metadata_unique_constraints):
# SQLite backend can't report on unnamed UNIQUE constraints,
diff --git a/alembic/op.py b/alembic/op.py
index 1f367a1..f3f5fac 100644
--- a/alembic/op.py
+++ b/alembic/op.py
@@ -3,4 +3,3 @@ from .operations.base import Operations
# create proxy functions for
# each method on the Operations class.
Operations.create_module_class_proxy(globals(), locals())
-
diff --git a/alembic/operations/__init__.py b/alembic/operations/__init__.py
index 1f6ee5d..e1ff01c 100644
--- a/alembic/operations/__init__.py
+++ b/alembic/operations/__init__.py
@@ -3,4 +3,4 @@ from .ops import MigrateOperation
from . import toimpl
-__all__ = ['Operations', 'BatchOperations', 'MigrateOperation'] \ No newline at end of file
+__all__ = ["Operations", "BatchOperations", "MigrateOperation"]
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index 1ae9524..2c3408a 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -9,7 +9,7 @@ from ..util.compat import inspect_formatargspec
from ..util.compat import inspect_getargspec
import textwrap
-__all__ = ('Operations', 'BatchOperations')
+__all__ = ("Operations", "BatchOperations")
try:
from sqlalchemy.sql.naming import conv
@@ -84,6 +84,7 @@ class Operations(util.ModuleClsProxy):
"""
+
def register(op_cls):
if sourcename is None:
fn = getattr(op_cls, name)
@@ -95,45 +96,53 @@ class Operations(util.ModuleClsProxy):
spec = inspect_getargspec(fn)
name_args = spec[0]
- assert name_args[0:2] == ['cls', 'operations']
+ assert name_args[0:2] == ["cls", "operations"]
- name_args[0:2] = ['self']
+ name_args[0:2] = ["self"]
args = inspect_formatargspec(*spec)
num_defaults = len(spec[3]) if spec[3] else 0
if num_defaults:
- defaulted_vals = name_args[0 - num_defaults:]
+ defaulted_vals = name_args[0 - num_defaults :]
else:
defaulted_vals = ()
apply_kw = inspect_formatargspec(
- name_args, spec[1], spec[2],
+ name_args,
+ spec[1],
+ spec[2],
defaulted_vals,
- formatvalue=lambda x: '=' + x)
+ formatvalue=lambda x: "=" + x,
+ )
- func_text = textwrap.dedent("""\
+ func_text = textwrap.dedent(
+ """\
def %(name)s%(args)s:
%(doc)r
return op_cls.%(source_name)s%(apply_kw)s
- """ % {
- 'name': name,
- 'source_name': source_name,
- 'args': args,
- 'apply_kw': apply_kw,
- 'doc': fn.__doc__,
- 'meth': fn.__name__
- })
- globals_ = {'op_cls': op_cls}
+ """
+ % {
+ "name": name,
+ "source_name": source_name,
+ "args": args,
+ "apply_kw": apply_kw,
+ "doc": fn.__doc__,
+ "meth": fn.__name__,
+ }
+ )
+ globals_ = {"op_cls": op_cls}
lcl = {}
exec_(func_text, globals_, lcl)
setattr(cls, name, lcl[name])
- fn.__func__.__doc__ = "This method is proxied on "\
- "the :class:`.%s` class, via the :meth:`.%s.%s` method." % (
- cls.__name__, cls.__name__, name
- )
- if hasattr(fn, '_legacy_translations'):
+ fn.__func__.__doc__ = (
+ "This method is proxied on "
+ "the :class:`.%s` class, via the :meth:`.%s.%s` method."
+ % (cls.__name__, cls.__name__, name)
+ )
+ if hasattr(fn, "_legacy_translations"):
lcl[name]._legacy_translations = fn._legacy_translations
return op_cls
+
return register
@classmethod
@@ -151,6 +160,7 @@ class Operations(util.ModuleClsProxy):
def decorate(fn):
cls._to_impl.dispatch_for(op_cls)(fn)
return fn
+
return decorate
@classmethod
@@ -163,10 +173,17 @@ class Operations(util.ModuleClsProxy):
@contextmanager
def batch_alter_table(
- self, table_name, schema=None, recreate="auto", copy_from=None,
- table_args=(), table_kwargs=util.immutabledict(),
- reflect_args=(), reflect_kwargs=util.immutabledict(),
- naming_convention=None):
+ self,
+ table_name,
+ schema=None,
+ recreate="auto",
+ copy_from=None,
+ table_args=(),
+ table_kwargs=util.immutabledict(),
+ reflect_args=(),
+ reflect_kwargs=util.immutabledict(),
+ naming_convention=None,
+ ):
"""Invoke a series of per-table migrations in batch.
Batch mode allows a series of operations specific to a table
@@ -292,9 +309,17 @@ class Operations(util.ModuleClsProxy):
"""
impl = batch.BatchOperationsImpl(
- self, table_name, schema, recreate,
- copy_from, table_args, table_kwargs, reflect_args,
- reflect_kwargs, naming_convention)
+ self,
+ table_name,
+ schema,
+ recreate,
+ copy_from,
+ table_args,
+ table_kwargs,
+ reflect_args,
+ reflect_kwargs,
+ naming_convention,
+ )
batch_op = BatchOperations(self.migration_context, impl=impl)
yield batch_op
impl.flush()
@@ -315,7 +340,8 @@ class Operations(util.ModuleClsProxy):
"""
fn = self._to_impl.dispatch(
- operation, self.migration_context.impl.__dialect__)
+ operation, self.migration_context.impl.__dialect__
+ )
return fn(self, operation)
def f(self, name):
@@ -363,7 +389,8 @@ class Operations(util.ModuleClsProxy):
return conv(name)
else:
raise NotImplementedError(
- "op.f() feature requires SQLAlchemy 0.9.4 or greater.")
+ "op.f() feature requires SQLAlchemy 0.9.4 or greater."
+ )
def inline_literal(self, value, type_=None):
"""Produce an 'inline literal' expression, suitable for
@@ -442,4 +469,5 @@ class BatchOperations(Operations):
def _noop(self, operation):
raise NotImplementedError(
"The %s method does not apply to a batch table alter operation."
- % operation)
+ % operation
+ )
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
index 79ad533..9362876 100644
--- a/alembic/operations/batch.py
+++ b/alembic/operations/batch.py
@@ -1,24 +1,48 @@
-from sqlalchemy import Table, MetaData, Index, select, Column, \
- ForeignKeyConstraint, PrimaryKeyConstraint, cast, CheckConstraint
+from sqlalchemy import (
+ Table,
+ MetaData,
+ Index,
+ select,
+ Column,
+ ForeignKeyConstraint,
+ PrimaryKeyConstraint,
+ cast,
+ CheckConstraint,
+)
from sqlalchemy import types as sqltypes
from sqlalchemy import schema as sql_schema
from sqlalchemy.util import OrderedDict
from .. import util
from sqlalchemy.events import SchemaEventTarget
-from ..util.sqla_compat import _columns_for_constraint, \
- _is_type_bound, _fk_is_self_referential, _remove_column_from_collection
+from ..util.sqla_compat import (
+ _columns_for_constraint,
+ _is_type_bound,
+ _fk_is_self_referential,
+ _remove_column_from_collection,
+)
class BatchOperationsImpl(object):
- def __init__(self, operations, table_name, schema, recreate,
- copy_from, table_args, table_kwargs,
- reflect_args, reflect_kwargs, naming_convention):
+ def __init__(
+ self,
+ operations,
+ table_name,
+ schema,
+ recreate,
+ copy_from,
+ table_args,
+ table_kwargs,
+ reflect_args,
+ reflect_kwargs,
+ naming_convention,
+ ):
self.operations = operations
self.table_name = table_name
self.schema = schema
- if recreate not in ('auto', 'always', 'never'):
+ if recreate not in ("auto", "always", "never"):
raise ValueError(
- "recreate may be one of 'auto', 'always', or 'never'.")
+ "recreate may be one of 'auto', 'always', or 'never'."
+ )
self.recreate = recreate
self.copy_from = copy_from
self.table_args = table_args
@@ -37,9 +61,9 @@ class BatchOperationsImpl(object):
return self.operations.impl
def _should_recreate(self):
- if self.recreate == 'auto':
+ if self.recreate == "auto":
return self.operations.impl.requires_recreate_in_batch(self)
- elif self.recreate == 'always':
+ elif self.recreate == "always":
return True
else:
return False
@@ -62,15 +86,19 @@ class BatchOperationsImpl(object):
reflected = False
else:
existing_table = Table(
- self.table_name, m1,
+ self.table_name,
+ m1,
schema=self.schema,
autoload=True,
autoload_with=self.operations.get_bind(),
- *self.reflect_args, **self.reflect_kwargs)
+ *self.reflect_args,
+ **self.reflect_kwargs
+ )
reflected = True
batch_impl = ApplyBatchImpl(
- existing_table, self.table_args, self.table_kwargs, reflected)
+ existing_table, self.table_args, self.table_kwargs, reflected
+ )
for opname, arg, kw in self.batch:
fn = getattr(batch_impl, opname)
fn(*arg, **kw)
@@ -90,7 +118,7 @@ class BatchOperationsImpl(object):
self.batch.append(("add_constraint", (const,), {}))
def drop_constraint(self, const):
- self.batch.append(("drop_constraint", (const, ), {}))
+ self.batch.append(("drop_constraint", (const,), {}))
def rename_table(self, *arg, **kw):
self.batch.append(("rename_table", arg, kw))
@@ -116,7 +144,7 @@ class ApplyBatchImpl(object):
self.temp_table_name = self._calc_temp_name(table.name)
self.new_table = None
self.column_transfers = OrderedDict(
- (c.name, {'expr': c}) for c in self.table.c
+ (c.name, {"expr": c}) for c in self.table.c
)
self.reflected = reflected
self._grab_table_elements()
@@ -165,16 +193,20 @@ class ApplyBatchImpl(object):
schema = self.table.schema
self.new_table = new_table = Table(
- self.temp_table_name, m,
+ self.temp_table_name,
+ m,
*(list(self.columns.values()) + list(self.table_args)),
schema=schema,
- **self.table_kwargs)
+ **self.table_kwargs
+ )
- for const in list(self.named_constraints.values()) + \
- self.unnamed_constraints:
+ for const in (
+ list(self.named_constraints.values()) + self.unnamed_constraints
+ ):
- const_columns = set([
- c.key for c in _columns_for_constraint(const)])
+ const_columns = set(
+ [c.key for c in _columns_for_constraint(const)]
+ )
if not const_columns.issubset(self.column_transfers):
continue
@@ -188,7 +220,8 @@ class ApplyBatchImpl(object):
# no foreign keys just keeps the names unchanged, so
# when we rename back, they match again.
const_copy = const.copy(
- schema=schema, target_table=self.table)
+ schema=schema, target_table=self.table
+ )
else:
# "target_table" for ForeignKeyConstraint.copy() is
# only used if the FK is detected as being
@@ -209,7 +242,8 @@ class ApplyBatchImpl(object):
index.name,
unique=index.unique,
*[self.new_table.c[col] for col in index.columns.keys()],
- **index.kwargs)
+ **index.kwargs
+ )
)
return idx
@@ -229,16 +263,20 @@ class ApplyBatchImpl(object):
for elem in constraint.elements:
colname = elem._get_colspec().split(".")[-1]
if not t.c.contains_column(colname):
- t.append_column(
- Column(colname, sqltypes.NULLTYPE)
- )
+ t.append_column(Column(colname, sqltypes.NULLTYPE))
else:
Table(
- tname, metadata,
- *[Column(n, sqltypes.NULLTYPE) for n in
- [elem._get_colspec().split(".")[-1]
- for elem in constraint.elements]],
- schema=referent_schema)
+ tname,
+ metadata,
+ *[
+ Column(n, sqltypes.NULLTYPE)
+ for n in [
+ elem._get_colspec().split(".")[-1]
+ for elem in constraint.elements
+ ]
+ ],
+ schema=referent_schema
+ )
def _create(self, op_impl):
self._transfer_elements_to_new_table()
@@ -249,13 +287,18 @@ class ApplyBatchImpl(object):
try:
op_impl._exec(
self.new_table.insert(inline=True).from_select(
- list(k for k, transfer in
- self.column_transfers.items() if 'expr' in transfer),
- select([
- transfer['expr']
- for transfer in self.column_transfers.values()
- if 'expr' in transfer
- ])
+ list(
+ k
+ for k, transfer in self.column_transfers.items()
+ if "expr" in transfer
+ ),
+ select(
+ [
+ transfer["expr"]
+ for transfer in self.column_transfers.values()
+ if "expr" in transfer
+ ]
+ ),
)
)
op_impl.drop_table(self.table)
@@ -264,9 +307,7 @@ class ApplyBatchImpl(object):
raise
else:
op_impl.rename_table(
- self.temp_table_name,
- self.table.name,
- schema=self.table.schema
+ self.temp_table_name, self.table.name, schema=self.table.schema
)
self.new_table.name = self.table.name
try:
@@ -275,14 +316,17 @@ class ApplyBatchImpl(object):
finally:
self.new_table.name = self.temp_table_name
- def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- autoincrement=None,
- **kw
- ):
+ def alter_column(
+ self,
+ table_name,
+ column_name,
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ autoincrement=None,
+ **kw
+ ):
existing = self.columns[column_name]
existing_transfer = self.column_transfers[column_name]
if name is not None and name != column_name:
@@ -299,12 +343,14 @@ class ApplyBatchImpl(object):
# we also ignore the drop_constraint that will come here from
# Operations.implementation_for(alter_column)
if isinstance(existing.type, SchemaEventTarget):
- existing.type._create_events = \
- existing.type.create_constraint = False
+ existing.type._create_events = (
+ existing.type.create_constraint
+ ) = False
if existing.type._type_affinity is not type_._type_affinity:
existing_transfer["expr"] = cast(
- existing_transfer["expr"], type_)
+ existing_transfer["expr"], type_
+ )
existing.type = type_
@@ -332,8 +378,7 @@ class ApplyBatchImpl(object):
def drop_column(self, table_name, column, **kw):
if column.name in self.table.primary_key.columns:
_remove_column_from_collection(
- self.table.primary_key.columns,
- column
+ self.table.primary_key.columns, column
)
del self.columns[column.name]
del self.column_transfers[column.name]
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index ade1cb3..5824469 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -46,12 +46,14 @@ class AddConstraintOp(MigrateOperation):
def go(klass):
cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
return klass
+
return go
@classmethod
def from_constraint(cls, constraint):
- return cls.add_constraint_ops.dispatch(
- constraint.__visit_name__)(constraint)
+ return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
+ constraint
+ )
def reverse(self):
return DropConstraintOp.from_constraint(self.to_constraint())
@@ -66,9 +68,13 @@ class DropConstraintOp(MigrateOperation):
"""Represent a drop constraint operation."""
def __init__(
- self,
- constraint_name, table_name, type_=None, schema=None,
- _orig_constraint=None):
+ self,
+ constraint_name,
+ table_name,
+ type_=None,
+ schema=None,
+ _orig_constraint=None,
+ ):
self.constraint_name = constraint_name
self.table_name = table_name
self.constraint_type = type_
@@ -79,7 +85,8 @@ class DropConstraintOp(MigrateOperation):
if self._orig_constraint is None:
raise ValueError(
"operation is not reversible; "
- "original constraint is not present")
+ "original constraint is not present"
+ )
return AddConstraintOp.from_constraint(self._orig_constraint)
def to_diff_tuple(self):
@@ -104,7 +111,7 @@ class DropConstraintOp(MigrateOperation):
constraint_table.name,
schema=constraint_table.schema,
type_=types[constraint.__visit_name__],
- _orig_constraint=constraint
+ _orig_constraint=constraint,
)
def to_constraint(self):
@@ -113,16 +120,14 @@ class DropConstraintOp(MigrateOperation):
else:
raise ValueError(
"constraint cannot be produced; "
- "original constraint is not present")
+ "original constraint is not present"
+ )
@classmethod
- @util._with_legacy_names([
- ("type", "type_"),
- ("name", "constraint_name"),
- ])
+ @util._with_legacy_names([("type", "type_"), ("name", "constraint_name")])
def drop_constraint(
- cls, operations, constraint_name, table_name,
- type_=None, schema=None):
+ cls, operations, constraint_name, table_name, type_=None, schema=None
+ ):
"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
:param constraint_name: name of the constraint.
@@ -166,15 +171,18 @@ class DropConstraintOp(MigrateOperation):
"""
op = cls(
- constraint_name, operations.impl.table_name,
- type_=type_, schema=operations.impl.schema
+ constraint_name,
+ operations.impl.table_name,
+ type_=type_,
+ schema=operations.impl.schema,
)
return operations.invoke(op)
@Operations.register_operation("create_primary_key")
@BatchOperations.register_operation(
- "create_primary_key", "batch_create_primary_key")
+ "create_primary_key", "batch_create_primary_key"
+)
@AddConstraintOp.register_add_constraint("primary_key_constraint")
class CreatePrimaryKeyOp(AddConstraintOp):
"""Represent a create primary key operation."""
@@ -182,8 +190,14 @@ class CreatePrimaryKeyOp(AddConstraintOp):
constraint_type = "primarykey"
def __init__(
- self, constraint_name, table_name, columns,
- schema=None, _orig_constraint=None, **kw):
+ self,
+ constraint_name,
+ table_name,
+ columns,
+ schema=None,
+ _orig_constraint=None,
+ **kw
+ ):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -200,7 +214,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
constraint_table.name,
constraint.columns,
schema=constraint_table.schema,
- _orig_constraint=constraint
+ _orig_constraint=constraint,
)
def to_constraint(self, migration_context=None):
@@ -209,17 +223,19 @@ class CreatePrimaryKeyOp(AddConstraintOp):
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.primary_key_constraint(
- self.constraint_name, self.table_name,
- self.columns, schema=self.schema)
+ self.constraint_name,
+ self.table_name,
+ self.columns,
+ schema=self.schema,
+ )
@classmethod
- @util._with_legacy_names([
- ('name', 'constraint_name'),
- ('cols', 'columns')
- ])
+ @util._with_legacy_names(
+ [("name", "constraint_name"), ("cols", "columns")]
+ )
def create_primary_key(
- cls, operations,
- constraint_name, table_name, columns, schema=None):
+ cls, operations, constraint_name, table_name, columns, schema=None
+ ):
"""Issue a "create primary key" instruction using the current
migration context.
@@ -282,15 +298,18 @@ class CreatePrimaryKeyOp(AddConstraintOp):
"""
op = cls(
- constraint_name, operations.impl.table_name, columns,
- schema=operations.impl.schema
+ constraint_name,
+ operations.impl.table_name,
+ columns,
+ schema=operations.impl.schema,
)
return operations.invoke(op)
@Operations.register_operation("create_unique_constraint")
@BatchOperations.register_operation(
- "create_unique_constraint", "batch_create_unique_constraint")
+ "create_unique_constraint", "batch_create_unique_constraint"
+)
@AddConstraintOp.register_add_constraint("unique_constraint")
class CreateUniqueConstraintOp(AddConstraintOp):
"""Represent a create unique constraint operation."""
@@ -298,8 +317,14 @@ class CreateUniqueConstraintOp(AddConstraintOp):
constraint_type = "unique"
def __init__(
- self, constraint_name, table_name,
- columns, schema=None, _orig_constraint=None, **kw):
+ self,
+ constraint_name,
+ table_name,
+ columns,
+ schema=None,
+ _orig_constraint=None,
+ **kw
+ ):
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -313,9 +338,9 @@ class CreateUniqueConstraintOp(AddConstraintOp):
kw = {}
if constraint.deferrable:
- kw['deferrable'] = constraint.deferrable
+ kw["deferrable"] = constraint.deferrable
if constraint.initially:
- kw['initially'] = constraint.initially
+ kw["initially"] = constraint.initially
return cls(
constraint.name,
@@ -332,18 +357,30 @@ class CreateUniqueConstraintOp(AddConstraintOp):
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.unique_constraint(
- self.constraint_name, self.table_name, self.columns,
- schema=self.schema, **self.kw)
+ self.constraint_name,
+ self.table_name,
+ self.columns,
+ schema=self.schema,
+ **self.kw
+ )
@classmethod
- @util._with_legacy_names([
- ('name', 'constraint_name'),
- ('source', 'table_name'),
- ('local_cols', 'columns'),
- ])
+ @util._with_legacy_names(
+ [
+ ("name", "constraint_name"),
+ ("source", "table_name"),
+ ("local_cols", "columns"),
+ ]
+ )
def create_unique_constraint(
- cls, operations, constraint_name, table_name, columns,
- schema=None, **kw):
+ cls,
+ operations,
+ constraint_name,
+ table_name,
+ columns,
+ schema=None,
+ **kw
+ ):
"""Issue a "create unique constraint" instruction using the
current migration context.
@@ -392,16 +429,14 @@ class CreateUniqueConstraintOp(AddConstraintOp):
"""
- op = cls(
- constraint_name, table_name, columns,
- schema=schema, **kw
- )
+ op = cls(constraint_name, table_name, columns, schema=schema, **kw)
return operations.invoke(op)
@classmethod
- @util._with_legacy_names([('name', 'constraint_name')])
+ @util._with_legacy_names([("name", "constraint_name")])
def batch_create_unique_constraint(
- cls, operations, constraint_name, columns, **kw):
+ cls, operations, constraint_name, columns, **kw
+ ):
"""Issue a "create unique constraint" instruction using the
current batch migration context.
@@ -418,17 +453,15 @@ class CreateUniqueConstraintOp(AddConstraintOp):
* name -> constraint_name
"""
- kw['schema'] = operations.impl.schema
- op = cls(
- constraint_name, operations.impl.table_name, columns,
- **kw
- )
+ kw["schema"] = operations.impl.schema
+ op = cls(constraint_name, operations.impl.table_name, columns, **kw)
return operations.invoke(op)
@Operations.register_operation("create_foreign_key")
@BatchOperations.register_operation(
- "create_foreign_key", "batch_create_foreign_key")
+ "create_foreign_key", "batch_create_foreign_key"
+)
@AddConstraintOp.register_add_constraint("foreign_key_constraint")
class CreateForeignKeyOp(AddConstraintOp):
"""Represent a create foreign key constraint operation."""
@@ -436,8 +469,15 @@ class CreateForeignKeyOp(AddConstraintOp):
constraint_type = "foreignkey"
def __init__(
- self, constraint_name, source_table, referent_table, local_cols,
- remote_cols, _orig_constraint=None, **kw):
+ self,
+ constraint_name,
+ source_table,
+ referent_table,
+ local_cols,
+ remote_cols,
+ _orig_constraint=None,
+ **kw
+ ):
self.constraint_name = constraint_name
self.source_table = source_table
self.referent_table = referent_table
@@ -453,24 +493,22 @@ class CreateForeignKeyOp(AddConstraintOp):
def from_constraint(cls, constraint):
kw = {}
if constraint.onupdate:
- kw['onupdate'] = constraint.onupdate
+ kw["onupdate"] = constraint.onupdate
if constraint.ondelete:
- kw['ondelete'] = constraint.ondelete
+ kw["ondelete"] = constraint.ondelete
if constraint.initially:
- kw['initially'] = constraint.initially
+ kw["initially"] = constraint.initially
if constraint.deferrable:
- kw['deferrable'] = constraint.deferrable
+ kw["deferrable"] = constraint.deferrable
if constraint.use_alter:
- kw['use_alter'] = constraint.use_alter
+ kw["use_alter"] = constraint.use_alter
- source_schema, source_table, \
- source_columns, target_schema, \
- target_table, target_columns,\
- onupdate, ondelete, deferrable, initially \
- = sqla_compat._fk_spec(constraint)
+ source_schema, source_table, source_columns, target_schema, target_table, target_columns, onupdate, ondelete, deferrable, initially = sqla_compat._fk_spec(
+ constraint
+ )
- kw['source_schema'] = source_schema
- kw['referent_schema'] = target_schema
+ kw["source_schema"] = source_schema
+ kw["referent_schema"] = target_schema
return cls(
constraint.name,
@@ -488,22 +526,38 @@ class CreateForeignKeyOp(AddConstraintOp):
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.foreign_key_constraint(
self.constraint_name,
- self.source_table, self.referent_table,
- self.local_cols, self.remote_cols,
- **self.kw)
+ self.source_table,
+ self.referent_table,
+ self.local_cols,
+ self.remote_cols,
+ **self.kw
+ )
@classmethod
- @util._with_legacy_names([
- ('name', 'constraint_name'),
- ('source', 'source_table'),
- ('referent', 'referent_table'),
- ])
- def create_foreign_key(cls, operations, constraint_name,
- source_table, referent_table, local_cols,
- remote_cols, onupdate=None, ondelete=None,
- deferrable=None, initially=None, match=None,
- source_schema=None, referent_schema=None,
- **dialect_kw):
+ @util._with_legacy_names(
+ [
+ ("name", "constraint_name"),
+ ("source", "source_table"),
+ ("referent", "referent_table"),
+ ]
+ )
+ def create_foreign_key(
+ cls,
+ operations,
+ constraint_name,
+ source_table,
+ referent_table,
+ local_cols,
+ remote_cols,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ match=None,
+ source_schema=None,
+ referent_schema=None,
+ **dialect_kw
+ ):
"""Issue a "create foreign key" instruction using the
current migration context.
@@ -558,29 +612,40 @@ class CreateForeignKeyOp(AddConstraintOp):
op = cls(
constraint_name,
- source_table, referent_table,
- local_cols, remote_cols,
- onupdate=onupdate, ondelete=ondelete,
+ source_table,
+ referent_table,
+ local_cols,
+ remote_cols,
+ onupdate=onupdate,
+ ondelete=ondelete,
deferrable=deferrable,
source_schema=source_schema,
referent_schema=referent_schema,
- initially=initially, match=match,
+ initially=initially,
+ match=match,
**dialect_kw
)
return operations.invoke(op)
@classmethod
- @util._with_legacy_names([
- ('name', 'constraint_name'),
- ('referent', 'referent_table')
- ])
+ @util._with_legacy_names(
+ [("name", "constraint_name"), ("referent", "referent_table")]
+ )
def batch_create_foreign_key(
- cls, operations, constraint_name, referent_table,
- local_cols, remote_cols,
- referent_schema=None,
- onupdate=None, ondelete=None,
- deferrable=None, initially=None, match=None,
- **dialect_kw):
+ cls,
+ operations,
+ constraint_name,
+ referent_table,
+ local_cols,
+ remote_cols,
+ referent_schema=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ match=None,
+ **dialect_kw
+ ):
"""Issue a "create foreign key" instruction using the
current batch migration context.
@@ -607,13 +672,17 @@ class CreateForeignKeyOp(AddConstraintOp):
"""
op = cls(
constraint_name,
- operations.impl.table_name, referent_table,
- local_cols, remote_cols,
- onupdate=onupdate, ondelete=ondelete,
+ operations.impl.table_name,
+ referent_table,
+ local_cols,
+ remote_cols,
+ onupdate=onupdate,
+ ondelete=ondelete,
deferrable=deferrable,
source_schema=operations.impl.schema,
referent_schema=referent_schema,
- initially=initially, match=match,
+ initially=initially,
+ match=match,
**dialect_kw
)
return operations.invoke(op)
@@ -621,7 +690,8 @@ class CreateForeignKeyOp(AddConstraintOp):
@Operations.register_operation("create_check_constraint")
@BatchOperations.register_operation(
- "create_check_constraint", "batch_create_check_constraint")
+ "create_check_constraint", "batch_create_check_constraint"
+)
@AddConstraintOp.register_add_constraint("check_constraint")
@AddConstraintOp.register_add_constraint("column_check_constraint")
class CreateCheckConstraintOp(AddConstraintOp):
@@ -630,8 +700,14 @@ class CreateCheckConstraintOp(AddConstraintOp):
constraint_type = "check"
def __init__(
- self, constraint_name, table_name,
- condition, schema=None, _orig_constraint=None, **kw):
+ self,
+ constraint_name,
+ table_name,
+ condition,
+ schema=None,
+ _orig_constraint=None,
+ **kw
+ ):
self.constraint_name = constraint_name
self.table_name = table_name
self.condition = condition
@@ -648,7 +724,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
constraint_table.name,
constraint.sqltext,
schema=constraint_table.schema,
- _orig_constraint=constraint
+ _orig_constraint=constraint,
)
def to_constraint(self, migration_context=None):
@@ -656,18 +732,26 @@ class CreateCheckConstraintOp(AddConstraintOp):
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.check_constraint(
- self.constraint_name, self.table_name,
- self.condition, schema=self.schema, **self.kw)
+ self.constraint_name,
+ self.table_name,
+ self.condition,
+ schema=self.schema,
+ **self.kw
+ )
@classmethod
- @util._with_legacy_names([
- ('name', 'constraint_name'),
- ('source', 'table_name')
- ])
+ @util._with_legacy_names(
+ [("name", "constraint_name"), ("source", "table_name")]
+ )
def create_check_constraint(
- cls, operations,
- constraint_name, table_name, condition,
- schema=None, **kw):
+ cls,
+ operations,
+ constraint_name,
+ table_name,
+ condition,
+ schema=None,
+ **kw
+ ):
"""Issue a "create check constraint" instruction using the
current migration context.
@@ -721,9 +805,10 @@ class CreateCheckConstraintOp(AddConstraintOp):
return operations.invoke(op)
@classmethod
- @util._with_legacy_names([('name', 'constraint_name')])
+ @util._with_legacy_names([("name", "constraint_name")])
def batch_create_check_constraint(
- cls, operations, constraint_name, condition, **kw):
+ cls, operations, constraint_name, condition, **kw
+ ):
"""Issue a "create check constraint" instruction using the
current batch migration context.
@@ -741,8 +826,12 @@ class CreateCheckConstraintOp(AddConstraintOp):
"""
op = cls(
- constraint_name, operations.impl.table_name,
- condition, schema=operations.impl.schema, **kw)
+ constraint_name,
+ operations.impl.table_name,
+ condition,
+ schema=operations.impl.schema,
+ **kw
+ )
return operations.invoke(op)
@@ -752,8 +841,15 @@ class CreateIndexOp(MigrateOperation):
"""Represent a create index operation."""
def __init__(
- self, index_name, table_name, columns, schema=None,
- unique=False, _orig_index=None, **kw):
+ self,
+ index_name,
+ table_name,
+ columns,
+ schema=None,
+ unique=False,
+ _orig_index=None,
+ **kw
+ ):
self.index_name = index_name
self.table_name = table_name
self.columns = columns
@@ -785,15 +881,26 @@ class CreateIndexOp(MigrateOperation):
return self._orig_index
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.index(
- self.index_name, self.table_name, self.columns, schema=self.schema,
- unique=self.unique, **self.kw)
+ self.index_name,
+ self.table_name,
+ self.columns,
+ schema=self.schema,
+ unique=self.unique,
+ **self.kw
+ )
@classmethod
- @util._with_legacy_names([('name', 'index_name')])
+ @util._with_legacy_names([("name", "index_name")])
def create_index(
- cls, operations,
- index_name, table_name, columns, schema=None,
- unique=False, **kw):
+ cls,
+ operations,
+ index_name,
+ table_name,
+ columns,
+ schema=None,
+ unique=False,
+ **kw
+ ):
r"""Issue a "create index" instruction using the current
migration context.
@@ -851,8 +958,7 @@ class CreateIndexOp(MigrateOperation):
"""
op = cls(
- index_name, table_name, columns, schema=schema,
- unique=unique, **kw
+ index_name, table_name, columns, schema=schema, unique=unique, **kw
)
return operations.invoke(op)
@@ -868,8 +974,11 @@ class CreateIndexOp(MigrateOperation):
"""
op = cls(
- index_name, operations.impl.table_name, columns,
- schema=operations.impl.schema, **kw
+ index_name,
+ operations.impl.table_name,
+ columns,
+ schema=operations.impl.schema,
+ **kw
)
return operations.invoke(op)
@@ -880,8 +989,8 @@ class DropIndexOp(MigrateOperation):
"""Represent a drop index operation."""
def __init__(
- self, index_name, table_name=None,
- schema=None, _orig_index=None, **kw):
+ self, index_name, table_name=None, schema=None, _orig_index=None, **kw
+ ):
self.index_name = index_name
self.table_name = table_name
self.schema = schema
@@ -894,8 +1003,8 @@ class DropIndexOp(MigrateOperation):
def reverse(self):
if self._orig_index is None:
raise ValueError(
- "operation is not reversible; "
- "original index is not present")
+ "operation is not reversible; " "original index is not present"
+ )
return CreateIndexOp.from_index(self._orig_index)
@classmethod
@@ -917,16 +1026,20 @@ class DropIndexOp(MigrateOperation):
# need a dummy column name here since SQLAlchemy
# 0.7.6 and further raises on Index with no columns
return schema_obj.index(
- self.index_name, self.table_name, ['x'],
- schema=self.schema, **self.kw)
+ self.index_name,
+ self.table_name,
+ ["x"],
+ schema=self.schema,
+ **self.kw
+ )
@classmethod
- @util._with_legacy_names([
- ('name', 'index_name'),
- ('tablename', 'table_name')
- ])
- def drop_index(cls, operations, index_name,
- table_name=None, schema=None, **kw):
+ @util._with_legacy_names(
+ [("name", "index_name"), ("tablename", "table_name")]
+ )
+ def drop_index(
+ cls, operations, index_name, table_name=None, schema=None, **kw
+ ):
r"""Issue a "drop index" instruction using the current
migration context.
@@ -964,7 +1077,7 @@ class DropIndexOp(MigrateOperation):
return operations.invoke(op)
@classmethod
- @util._with_legacy_names([('name', 'index_name')])
+ @util._with_legacy_names([("name", "index_name")])
def batch_drop_index(cls, operations, index_name, **kw):
"""Issue a "drop index" instruction using the
current batch migration context.
@@ -981,8 +1094,10 @@ class DropIndexOp(MigrateOperation):
"""
op = cls(
- index_name, table_name=operations.impl.table_name,
- schema=operations.impl.schema, **kw
+ index_name,
+ table_name=operations.impl.table_name,
+ schema=operations.impl.schema,
+ **kw
)
return operations.invoke(op)
@@ -992,7 +1107,8 @@ class CreateTableOp(MigrateOperation):
"""Represent a create table operation."""
def __init__(
- self, table_name, columns, schema=None, _orig_table=None, **kw):
+ self, table_name, columns, schema=None, _orig_table=None, **kw
+ ):
self.table_name = table_name
self.columns = columns
self.schema = schema
@@ -1025,7 +1141,7 @@ class CreateTableOp(MigrateOperation):
)
@classmethod
- @util._with_legacy_names([('name', 'table_name')])
+ @util._with_legacy_names([("name", "table_name")])
def create_table(cls, operations, table_name, *columns, **kw):
r"""Issue a "create table" instruction using the current migration
context.
@@ -1125,7 +1241,8 @@ class DropTableOp(MigrateOperation):
"""Represent a drop table operation."""
def __init__(
- self, table_name, schema=None, table_kw=None, _orig_table=None):
+ self, table_name, schema=None, table_kw=None, _orig_table=None
+ ):
self.table_name = table_name
self.schema = schema
self.table_kw = table_kw or {}
@@ -1137,8 +1254,8 @@ class DropTableOp(MigrateOperation):
def reverse(self):
if self._orig_table is None:
raise ValueError(
- "operation is not reversible; "
- "original table is not present")
+ "operation is not reversible; " "original table is not present"
+ )
return CreateTableOp.from_table(self._orig_table)
@classmethod
@@ -1150,12 +1267,11 @@ class DropTableOp(MigrateOperation):
return self._orig_table
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
- self.table_name,
- schema=self.schema,
- **self.table_kw)
+ self.table_name, schema=self.schema, **self.table_kw
+ )
@classmethod
- @util._with_legacy_names([('name', 'table_name')])
+ @util._with_legacy_names([("name", "table_name")])
def drop_table(cls, operations, table_name, schema=None, **kw):
r"""Issue a "drop table" instruction using the current
migration context.
@@ -1205,7 +1321,8 @@ class RenameTableOp(AlterTableOp):
@classmethod
def rename_table(
- cls, operations, old_table_name, new_table_name, schema=None):
+ cls, operations, old_table_name, new_table_name, schema=None
+ ):
"""Emit an ALTER TABLE to rename a table.
:param old_table_name: old name.
@@ -1229,16 +1346,18 @@ class AlterColumnOp(AlterTableOp):
"""Represent an alter column operation."""
def __init__(
- self, table_name, column_name, schema=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- modify_nullable=None,
- modify_server_default=False,
- modify_name=None,
- modify_type=None,
- **kw
-
+ self,
+ table_name,
+ column_name,
+ schema=None,
+ existing_type=None,
+ existing_server_default=False,
+ existing_nullable=None,
+ modify_nullable=None,
+ modify_server_default=False,
+ modify_name=None,
+ modify_type=None,
+ **kw
):
super(AlterColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
@@ -1257,47 +1376,64 @@ class AlterColumnOp(AlterTableOp):
if self.modify_type is not None:
col_diff.append(
- ("modify_type", schema, tname, cname,
- {
- "existing_nullable": self.existing_nullable,
- "existing_server_default": self.existing_server_default,
- },
- self.existing_type,
- self.modify_type)
+ (
+ "modify_type",
+ schema,
+ tname,
+ cname,
+ {
+ "existing_nullable": self.existing_nullable,
+ "existing_server_default": self.existing_server_default,
+ },
+ self.existing_type,
+ self.modify_type,
+ )
)
if self.modify_nullable is not None:
col_diff.append(
- ("modify_nullable", schema, tname, cname,
+ (
+ "modify_nullable",
+ schema,
+ tname,
+ cname,
{
"existing_type": self.existing_type,
- "existing_server_default": self.existing_server_default
+ "existing_server_default": self.existing_server_default,
},
self.existing_nullable,
- self.modify_nullable)
+ self.modify_nullable,
+ )
)
if self.modify_server_default is not False:
col_diff.append(
- ("modify_default", schema, tname, cname,
- {
- "existing_nullable": self.existing_nullable,
- "existing_type": self.existing_type
- },
- self.existing_server_default,
- self.modify_server_default)
+ (
+ "modify_default",
+ schema,
+ tname,
+ cname,
+ {
+ "existing_nullable": self.existing_nullable,
+ "existing_type": self.existing_type,
+ },
+ self.existing_server_default,
+ self.modify_server_default,
+ )
)
return col_diff
def has_changes(self):
- hc1 = self.modify_nullable is not None or \
- self.modify_server_default is not False or \
- self.modify_type is not None
+ hc1 = (
+ self.modify_nullable is not None
+ or self.modify_server_default is not False
+ or self.modify_type is not None
+ )
if hc1:
return True
for kw in self.kw:
- if kw.startswith('modify_'):
+ if kw.startswith("modify_"):
return True
else:
return False
@@ -1305,37 +1441,40 @@ class AlterColumnOp(AlterTableOp):
def reverse(self):
kw = self.kw.copy()
- kw['existing_type'] = self.existing_type
- kw['existing_nullable'] = self.existing_nullable
- kw['existing_server_default'] = self.existing_server_default
+ kw["existing_type"] = self.existing_type
+ kw["existing_nullable"] = self.existing_nullable
+ kw["existing_server_default"] = self.existing_server_default
if self.modify_type is not None:
- kw['modify_type'] = self.modify_type
+ kw["modify_type"] = self.modify_type
if self.modify_nullable is not None:
- kw['modify_nullable'] = self.modify_nullable
+ kw["modify_nullable"] = self.modify_nullable
if self.modify_server_default is not False:
- kw['modify_server_default'] = self.modify_server_default
+ kw["modify_server_default"] = self.modify_server_default
# TODO: make this a little simpler
- all_keys = set(m.group(1) for m in [
- re.match(r'^(?:existing_|modify_)(.+)$', k)
- for k in kw
- ] if m)
+ all_keys = set(
+ m.group(1)
+ for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw]
+ if m
+ )
for k in all_keys:
- if 'modify_%s' % k in kw:
- swap = kw['existing_%s' % k]
- kw['existing_%s' % k] = kw['modify_%s' % k]
- kw['modify_%s' % k] = swap
+ if "modify_%s" % k in kw:
+ swap = kw["existing_%s" % k]
+ kw["existing_%s" % k] = kw["modify_%s" % k]
+ kw["modify_%s" % k] = swap
return self.__class__(
- self.table_name, self.column_name, schema=self.schema,
- **kw
+ self.table_name, self.column_name, schema=self.schema, **kw
)
@classmethod
- @util._with_legacy_names([('name', 'new_column_name')])
+ @util._with_legacy_names([("name", "new_column_name")])
def alter_column(
- cls, operations, table_name, column_name,
+ cls,
+ operations,
+ table_name,
+ column_name,
nullable=None,
server_default=False,
new_column_name=None,
@@ -1343,7 +1482,8 @@ class AlterColumnOp(AlterTableOp):
existing_type=None,
existing_server_default=False,
existing_nullable=None,
- schema=None, **kw
+ schema=None,
+ **kw
):
"""Issue an "alter column" instruction using the
current migration context.
@@ -1430,7 +1570,9 @@ class AlterColumnOp(AlterTableOp):
"""
alt = cls(
- table_name, column_name, schema=schema,
+ table_name,
+ column_name,
+ schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
@@ -1445,7 +1587,9 @@ class AlterColumnOp(AlterTableOp):
@classmethod
def batch_alter_column(
- cls, operations, column_name,
+ cls,
+ operations,
+ column_name,
nullable=None,
server_default=False,
new_column_name=None,
@@ -1464,7 +1608,8 @@ class AlterColumnOp(AlterTableOp):
"""
alt = cls(
- operations.impl.table_name, column_name,
+ operations.impl.table_name,
+ column_name,
schema=operations.impl.schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
@@ -1490,7 +1635,8 @@ class AddColumnOp(AlterTableOp):
def reverse(self):
return DropColumnOp.from_column_and_tablename(
- self.schema, self.table_name, self.column)
+ self.schema, self.table_name, self.column
+ )
def to_diff_tuple(self):
return ("add_column", self.schema, self.table_name, self.column)
@@ -1575,8 +1721,7 @@ class AddColumnOp(AlterTableOp):
"""
op = cls(
- operations.impl.table_name, column,
- schema=operations.impl.schema
+ operations.impl.table_name, column, schema=operations.impl.schema
)
return operations.invoke(op)
@@ -1587,8 +1732,8 @@ class DropColumnOp(AlterTableOp):
"""Represent a drop column operation."""
def __init__(
- self, table_name, column_name, schema=None,
- _orig_column=None, **kw):
+ self, table_name, column_name, schema=None, _orig_column=None, **kw
+ ):
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.kw = kw
@@ -1596,16 +1741,22 @@ class DropColumnOp(AlterTableOp):
def to_diff_tuple(self):
return (
- "remove_column", self.schema, self.table_name, self.to_column())
+ "remove_column",
+ self.schema,
+ self.table_name,
+ self.to_column(),
+ )
def reverse(self):
if self._orig_column is None:
raise ValueError(
"operation is not reversible; "
- "original column is not present")
+ "original column is not present"
+ )
return AddColumnOp.from_column_and_tablename(
- self.schema, self.table_name, self._orig_column)
+ self.schema, self.table_name, self._orig_column
+ )
@classmethod
def from_column_and_tablename(cls, schema, tname, col):
@@ -1619,7 +1770,8 @@ class DropColumnOp(AlterTableOp):
@classmethod
def drop_column(
- cls, operations, table_name, column_name, schema=None, **kw):
+ cls, operations, table_name, column_name, schema=None, **kw
+ ):
"""Issue a "drop column" instruction using the current
migration context.
@@ -1677,8 +1829,11 @@ class DropColumnOp(AlterTableOp):
"""
op = cls(
- operations.impl.table_name, column_name,
- schema=operations.impl.schema, **kw)
+ operations.impl.table_name,
+ column_name,
+ schema=operations.impl.schema,
+ **kw
+ )
return operations.invoke(op)
@@ -1877,6 +2032,7 @@ class ExecuteSQLOp(MigrateOperation):
class OpContainer(MigrateOperation):
"""Represent a sequence of operations operation."""
+
def __init__(self, ops=()):
self.ops = ops
@@ -1889,7 +2045,7 @@ class OpContainer(MigrateOperation):
@classmethod
def _ops_as_diffs(cls, migrations):
for op in migrations.ops:
- if hasattr(op, 'ops'):
+ if hasattr(op, "ops"):
for sub_op in cls._ops_as_diffs(op):
yield sub_op
else:
@@ -1907,10 +2063,8 @@ class ModifyTableOps(OpContainer):
def reverse(self):
return ModifyTableOps(
self.table_name,
- ops=list(reversed(
- [op.reverse() for op in self.ops]
- )),
- schema=self.schema
+ ops=list(reversed([op.reverse() for op in self.ops])),
+ schema=self.schema,
)
@@ -1929,9 +2083,9 @@ class UpgradeOps(OpContainer):
self.upgrade_token = upgrade_token
def reverse_into(self, downgrade_ops):
- downgrade_ops.ops[:] = list(reversed(
- [op.reverse() for op in self.ops]
- ))
+ downgrade_ops.ops[:] = list(
+ reversed([op.reverse() for op in self.ops])
+ )
return downgrade_ops
def reverse(self):
@@ -1954,9 +2108,7 @@ class DowngradeOps(OpContainer):
def reverse(self):
return UpgradeOps(
- ops=list(reversed(
- [op.reverse() for op in self.ops]
- ))
+ ops=list(reversed([op.reverse() for op in self.ops]))
)
@@ -1990,10 +2142,18 @@ class MigrationScript(MigrateOperation):
"""
def __init__(
- self, rev_id, upgrade_ops, downgrade_ops,
- message=None,
- imports=set(), head=None, splice=None,
- branch_label=None, version_path=None, depends_on=None):
+ self,
+ rev_id,
+ upgrade_ops,
+ downgrade_ops,
+ message=None,
+ imports=set(),
+ head=None,
+ splice=None,
+ branch_label=None,
+ version_path=None,
+ depends_on=None,
+ ):
self.rev_id = rev_id
self.message = message
self.imports = imports
@@ -2017,7 +2177,8 @@ class MigrationScript(MigrateOperation):
raise ValueError(
"This MigrationScript instance has a multiple-entry "
"list for UpgradeOps; please use the "
- "upgrade_ops_list attribute.")
+ "upgrade_ops_list attribute."
+ )
elif not self._upgrade_ops:
return None
else:
@@ -2041,7 +2202,8 @@ class MigrationScript(MigrateOperation):
raise ValueError(
"This MigrationScript instance has a multiple-entry "
"list for DowngradeOps; please use the "
- "downgrade_ops_list attribute.")
+ "downgrade_ops_list attribute."
+ )
elif not self._downgrade_ops:
return None
else:
@@ -2078,4 +2240,3 @@ class MigrationScript(MigrateOperation):
"""
return self._downgrade_ops
-
diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py
index 1014ace..548b6c5 100644
--- a/alembic/operations/schemaobj.py
+++ b/alembic/operations/schemaobj.py
@@ -5,69 +5,82 @@ from .. import util
class SchemaObjects(object):
-
def __init__(self, migration_context=None):
self.migration_context = migration_context
def primary_key_constraint(self, name, table_name, cols, schema=None):
m = self.metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
- t = sa_schema.Table(
- table_name, m,
- *columns,
- schema=schema)
- p = sa_schema.PrimaryKeyConstraint(
- *[t.c[n] for n in cols], name=name)
+ t = sa_schema.Table(table_name, m, *columns, schema=schema)
+ p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name)
t.append_constraint(p)
return p
def foreign_key_constraint(
- self, name, source, referent,
- local_cols, remote_cols,
- onupdate=None, ondelete=None,
- deferrable=None, source_schema=None,
- referent_schema=None, initially=None,
- match=None, **dialect_kw):
+ self,
+ name,
+ source,
+ referent,
+ local_cols,
+ remote_cols,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ source_schema=None,
+ referent_schema=None,
+ initially=None,
+ match=None,
+ **dialect_kw
+ ):
m = self.metadata()
if source == referent and source_schema == referent_schema:
t1_cols = local_cols + remote_cols
else:
t1_cols = local_cols
sa_schema.Table(
- referent, m,
+ referent,
+ m,
*[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
- schema=referent_schema)
+ schema=referent_schema
+ )
t1 = sa_schema.Table(
- source, m,
+ source,
+ m,
*[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
- schema=source_schema)
-
- tname = "%s.%s" % (referent_schema, referent) if referent_schema \
- else referent
-
- dialect_kw['match'] = match
-
- f = sa_schema.ForeignKeyConstraint(local_cols,
- ["%s.%s" % (tname, n)
- for n in remote_cols],
- name=name,
- onupdate=onupdate,
- ondelete=ondelete,
- deferrable=deferrable,
- initially=initially,
- **dialect_kw
- )
+ schema=source_schema
+ )
+
+ tname = (
+ "%s.%s" % (referent_schema, referent)
+ if referent_schema
+ else referent
+ )
+
+ dialect_kw["match"] = match
+
+ f = sa_schema.ForeignKeyConstraint(
+ local_cols,
+ ["%s.%s" % (tname, n) for n in remote_cols],
+ name=name,
+ onupdate=onupdate,
+ ondelete=ondelete,
+ deferrable=deferrable,
+ initially=initially,
+ **dialect_kw
+ )
t1.append_constraint(f)
return f
def unique_constraint(self, name, source, local_cols, schema=None, **kw):
t = sa_schema.Table(
- source, self.metadata(),
+ source,
+ self.metadata(),
*[sa_schema.Column(n, NULLTYPE) for n in local_cols],
- schema=schema)
- kw['name'] = name
+ schema=schema
+ )
+ kw["name"] = name
uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
# TODO: need event tests to ensure the event
# is fired off here
@@ -75,8 +88,12 @@ class SchemaObjects(object):
return uq
def check_constraint(self, name, source, condition, schema=None, **kw):
- t = sa_schema.Table(source, self.metadata(),
- sa_schema.Column('x', Integer), schema=schema)
+ t = sa_schema.Table(
+ source,
+ self.metadata(),
+ sa_schema.Column("x", Integer),
+ schema=schema,
+ )
ck = sa_schema.CheckConstraint(condition, name=name, **kw)
t.append_constraint(ck)
return ck
@@ -84,18 +101,21 @@ class SchemaObjects(object):
def generic_constraint(self, name, table_name, type_, schema=None, **kw):
t = self.table(table_name, schema=schema)
types = {
- 'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
- [], [], name=name),
- 'primary': sa_schema.PrimaryKeyConstraint,
- 'unique': sa_schema.UniqueConstraint,
- 'check': lambda name: sa_schema.CheckConstraint("", name=name),
- None: sa_schema.Constraint
+ "foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
+ [], [], name=name
+ ),
+ "primary": sa_schema.PrimaryKeyConstraint,
+ "unique": sa_schema.UniqueConstraint,
+ "check": lambda name: sa_schema.CheckConstraint("", name=name),
+ None: sa_schema.Constraint,
}
try:
const = types[type_]
except KeyError:
- raise TypeError("'type' can be one of %s" %
- ", ".join(sorted(repr(x) for x in types)))
+ raise TypeError(
+ "'type' can be one of %s"
+ % ", ".join(sorted(repr(x) for x in types))
+ )
else:
const = const(name=name)
t.append_constraint(const)
@@ -103,11 +123,13 @@ class SchemaObjects(object):
def metadata(self):
kw = {}
- if self.migration_context is not None and \
- 'target_metadata' in self.migration_context.opts:
- mt = self.migration_context.opts['target_metadata']
- if hasattr(mt, 'naming_convention'):
- kw['naming_convention'] = mt.naming_convention
+ if (
+ self.migration_context is not None
+ and "target_metadata" in self.migration_context.opts
+ ):
+ mt = self.migration_context.opts["target_metadata"]
+ if hasattr(mt, "naming_convention"):
+ kw["naming_convention"] = mt.naming_convention
return sa_schema.MetaData(**kw)
def table(self, name, *columns, **kw):
@@ -122,18 +144,18 @@ class SchemaObjects(object):
def index(self, name, tablename, columns, schema=None, **kw):
t = sa_schema.Table(
- tablename or 'no_table', self.metadata(),
- schema=schema
+ tablename or "no_table", self.metadata(), schema=schema
)
idx = sa_schema.Index(
name,
*[util.sqla_compat._textual_index_column(t, n) for n in columns],
- **kw)
+ **kw
+ )
return idx
def _parse_table_key(self, table_key):
- if '.' in table_key:
- tokens = table_key.split('.')
+ if "." in table_key:
+ tokens = table_key.split(".")
sname = ".".join(tokens[0:-1])
tname = tokens[-1]
else:
@@ -147,7 +169,7 @@ class SchemaObjects(object):
"""
if isinstance(fk._colspec, string_types):
- table_key, cname = fk._colspec.rsplit('.', 1)
+ table_key, cname = fk._colspec.rsplit(".", 1)
sname, tname = self._parse_table_key(table_key)
if table_key not in metadata.tables:
rel_t = sa_schema.Table(tname, metadata, schema=sname)
diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py
index 1327367..1635a42 100644
--- a/alembic/operations/toimpl.py
+++ b/alembic/operations/toimpl.py
@@ -8,8 +8,7 @@ from sqlalchemy import schema as sa_schema
def alter_column(operations, operation):
compiler = operations.impl.dialect.statement_compiler(
- operations.impl.dialect,
- None
+ operations.impl.dialect, None
)
existing_type = operation.existing_type
@@ -24,24 +23,23 @@ def alter_column(operations, operation):
nullable = operation.modify_nullable
def _count_constraint(constraint):
- return not isinstance(
- constraint,
- sa_schema.PrimaryKeyConstraint) and \
- (not constraint._create_rule or
- constraint._create_rule(compiler))
+ return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and (
+ not constraint._create_rule or constraint._create_rule(compiler)
+ )
if existing_type and type_:
t = operations.schema_obj.table(
table_name,
sa_schema.Column(column_name, existing_type),
- schema=schema
+ schema=schema,
)
for constraint in t.constraints:
if _count_constraint(constraint):
operations.impl.drop_constraint(constraint)
operations.impl.alter_column(
- table_name, column_name,
+ table_name,
+ column_name,
nullable=nullable,
server_default=server_default,
name=new_column_name,
@@ -57,7 +55,7 @@ def alter_column(operations, operation):
t = operations.schema_obj.table(
table_name,
operations.schema_obj.column(column_name, type_),
- schema=schema
+ schema=schema,
)
for constraint in t.constraints:
if _count_constraint(constraint):
@@ -75,10 +73,7 @@ def drop_table(operations, operation):
def drop_column(operations, operation):
column = operation.to_column(operations.migration_context)
operations.impl.drop_column(
- operation.table_name,
- column,
- schema=operation.schema,
- **operation.kw
+ operation.table_name, column, schema=operation.schema, **operation.kw
)
@@ -105,9 +100,8 @@ def create_table(operations, operation):
@Operations.implementation_for(ops.RenameTableOp)
def rename_table(operations, operation):
operations.impl.rename_table(
- operation.table_name,
- operation.new_table_name,
- schema=operation.schema)
+ operation.table_name, operation.new_table_name, schema=operation.schema
+ )
@Operations.implementation_for(ops.AddColumnOp)
@@ -117,11 +111,7 @@ def add_column(operations, operation):
schema = operation.schema
t = operations.schema_obj.table(table_name, column, schema=schema)
- operations.impl.add_column(
- table_name,
- column,
- schema=schema
- )
+ operations.impl.add_column(table_name, column, schema=schema)
for constraint in t.constraints:
if not isinstance(constraint, sa_schema.PrimaryKeyConstraint):
operations.impl.add_constraint(constraint)
@@ -151,12 +141,12 @@ def drop_constraint(operations, operation):
@Operations.implementation_for(ops.BulkInsertOp)
def bulk_insert(operations, operation):
operations.impl.bulk_insert(
- operation.table, operation.rows, multiinsert=operation.multiinsert)
+ operation.table, operation.rows, multiinsert=operation.multiinsert
+ )
@Operations.implementation_for(ops.ExecuteSQLOp)
def execute_sql(operations, operation):
operations.migration_context.impl.execute(
- operation.sqltext,
- execution_options=operation.execution_options
+ operation.sqltext, execution_options=operation.execution_options
)
diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py
index ce9be63..32db3ae 100644
--- a/alembic/runtime/environment.py
+++ b/alembic/runtime/environment.py
@@ -120,7 +120,7 @@ class EnvironmentContext(util.ModuleClsProxy):
has been configured.
"""
- return self.context_opts.get('as_sql', False)
+ return self.context_opts.get("as_sql", False)
def is_transactional_ddl(self):
"""Return True if the context is configured to expect a
@@ -182,17 +182,20 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
if self._migration_context is not None:
return self.script.as_revision_number(
- self.get_context()._start_from_rev)
- elif 'starting_rev' in self.context_opts:
+ self.get_context()._start_from_rev
+ )
+ elif "starting_rev" in self.context_opts:
return self.script.as_revision_number(
- self.context_opts['starting_rev'])
+ self.context_opts["starting_rev"]
+ )
else:
# this should raise only in the case that a command
# is being run where the "starting rev" is never applicable;
# this is to catch scripts which rely upon this in
# non-sql mode or similar
raise util.CommandError(
- "No starting revision argument is available.")
+ "No starting revision argument is available."
+ )
def get_revision_argument(self):
"""Get the 'destination' revision argument.
@@ -209,7 +212,8 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
return self.script.as_revision_number(
- self.context_opts['destination_rev'])
+ self.context_opts["destination_rev"]
+ )
def get_tag_argument(self):
"""Return the value passed for the ``--tag`` argument, if any.
@@ -229,7 +233,7 @@ class EnvironmentContext(util.ModuleClsProxy):
line.
"""
- return self.context_opts.get('tag', None)
+ return self.context_opts.get("tag", None)
def get_x_argument(self, as_dictionary=False):
"""Return the value(s) passed for the ``-x`` argument, if any.
@@ -277,39 +281,38 @@ class EnvironmentContext(util.ModuleClsProxy):
else:
value = []
if as_dictionary:
- value = dict(
- arg.split('=', 1) for arg in value
- )
+ value = dict(arg.split("=", 1) for arg in value)
return value
- def configure(self,
- connection=None,
- url=None,
- dialect_name=None,
- transactional_ddl=None,
- transaction_per_migration=False,
- output_buffer=None,
- starting_rev=None,
- tag=None,
- template_args=None,
- render_as_batch=False,
- target_metadata=None,
- include_symbol=None,
- include_object=None,
- include_schemas=False,
- process_revision_directives=None,
- compare_type=False,
- compare_server_default=False,
- render_item=None,
- literal_binds=False,
- upgrade_token="upgrades",
- downgrade_token="downgrades",
- alembic_module_prefix="op.",
- sqlalchemy_module_prefix="sa.",
- user_module_prefix=None,
- on_version_apply=None,
- **kw
- ):
+ def configure(
+ self,
+ connection=None,
+ url=None,
+ dialect_name=None,
+ transactional_ddl=None,
+ transaction_per_migration=False,
+ output_buffer=None,
+ starting_rev=None,
+ tag=None,
+ template_args=None,
+ render_as_batch=False,
+ target_metadata=None,
+ include_symbol=None,
+ include_object=None,
+ include_schemas=False,
+ process_revision_directives=None,
+ compare_type=False,
+ compare_server_default=False,
+ render_item=None,
+ literal_binds=False,
+ upgrade_token="upgrades",
+ downgrade_token="downgrades",
+ alembic_module_prefix="op.",
+ sqlalchemy_module_prefix="sa.",
+ user_module_prefix=None,
+ on_version_apply=None,
+ **kw
+ ):
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
connectivity and other configuration to a series of
@@ -774,33 +777,33 @@ class EnvironmentContext(util.ModuleClsProxy):
elif self.config.output_buffer is not None:
opts["output_buffer"] = self.config.output_buffer
if starting_rev:
- opts['starting_rev'] = starting_rev
+ opts["starting_rev"] = starting_rev
if tag:
- opts['tag'] = tag
- if template_args and 'template_args' in opts:
- opts['template_args'].update(template_args)
+ opts["tag"] = tag
+ if template_args and "template_args" in opts:
+ opts["template_args"].update(template_args)
opts["transaction_per_migration"] = transaction_per_migration
- opts['target_metadata'] = target_metadata
- opts['include_symbol'] = include_symbol
- opts['include_object'] = include_object
- opts['include_schemas'] = include_schemas
- opts['render_as_batch'] = render_as_batch
- opts['upgrade_token'] = upgrade_token
- opts['downgrade_token'] = downgrade_token
- opts['sqlalchemy_module_prefix'] = sqlalchemy_module_prefix
- opts['alembic_module_prefix'] = alembic_module_prefix
- opts['user_module_prefix'] = user_module_prefix
- opts['literal_binds'] = literal_binds
- opts['process_revision_directives'] = process_revision_directives
- opts['on_version_apply'] = util.to_tuple(on_version_apply, default=())
+ opts["target_metadata"] = target_metadata
+ opts["include_symbol"] = include_symbol
+ opts["include_object"] = include_object
+ opts["include_schemas"] = include_schemas
+ opts["render_as_batch"] = render_as_batch
+ opts["upgrade_token"] = upgrade_token
+ opts["downgrade_token"] = downgrade_token
+ opts["sqlalchemy_module_prefix"] = sqlalchemy_module_prefix
+ opts["alembic_module_prefix"] = alembic_module_prefix
+ opts["user_module_prefix"] = user_module_prefix
+ opts["literal_binds"] = literal_binds
+ opts["process_revision_directives"] = process_revision_directives
+ opts["on_version_apply"] = util.to_tuple(on_version_apply, default=())
if render_item is not None:
- opts['render_item'] = render_item
+ opts["render_item"] = render_item
if compare_type is not None:
- opts['compare_type'] = compare_type
+ opts["compare_type"] = compare_type
if compare_server_default is not None:
- opts['compare_server_default'] = compare_server_default
- opts['script'] = self.script
+ opts["compare_server_default"] = compare_server_default
+ opts["script"] = self.script
opts.update(kw)
@@ -809,7 +812,7 @@ class EnvironmentContext(util.ModuleClsProxy):
url=url,
dialect_name=dialect_name,
environment_context=self,
- opts=opts
+ opts=opts,
)
def run_migrations(self, **kw):
@@ -847,8 +850,7 @@ class EnvironmentContext(util.ModuleClsProxy):
first been made available via :meth:`.configure`.
"""
- self.get_context().execute(sql,
- execution_options=execution_options)
+ self.get_context().execute(sql, execution_options=execution_options)
def static_output(self, text):
"""Emit text directly to the "offline" SQL stream.
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index 17cc226..80dc8ff 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -2,8 +2,14 @@ import logging
import sys
from contextlib import contextmanager
-from sqlalchemy import MetaData, Table, Column, String, literal_column,\
- PrimaryKeyConstraint
+from sqlalchemy import (
+ MetaData,
+ Table,
+ Column,
+ String,
+ literal_column,
+ PrimaryKeyConstraint,
+)
from sqlalchemy.engine.strategies import MockEngineStrategy
from sqlalchemy.engine import url as sqla_url
from sqlalchemy.engine import Connection
@@ -65,71 +71,82 @@ class MigrationContext(object):
self.environment_context = environment_context
self.opts = opts
self.dialect = dialect
- self.script = opts.get('script')
- as_sql = opts.get('as_sql', False)
+ self.script = opts.get("script")
+ as_sql = opts.get("as_sql", False)
transactional_ddl = opts.get("transactional_ddl")
self._transaction_per_migration = opts.get(
- "transaction_per_migration", False)
- self.on_version_apply_callbacks = opts.get('on_version_apply', ())
+ "transaction_per_migration", False
+ )
+ self.on_version_apply_callbacks = opts.get("on_version_apply", ())
if as_sql:
self.connection = self._stdout_connection(connection)
assert self.connection is not None
else:
self.connection = connection
- self._migrations_fn = opts.get('fn')
+ self._migrations_fn = opts.get("fn")
self.as_sql = as_sql
if "output_encoding" in opts:
self.output_buffer = EncodedIO(
opts.get("output_buffer") or sys.stdout,
- opts['output_encoding']
+ opts["output_encoding"],
)
else:
self.output_buffer = opts.get("output_buffer", sys.stdout)
- self._user_compare_type = opts.get('compare_type', False)
+ self._user_compare_type = opts.get("compare_type", False)
self._user_compare_server_default = opts.get(
- 'compare_server_default',
- False)
+ "compare_server_default", False
+ )
self.version_table = version_table = opts.get(
- 'version_table', 'alembic_version')
- self.version_table_schema = version_table_schema = \
- opts.get('version_table_schema', None)
+ "version_table", "alembic_version"
+ )
+ self.version_table_schema = version_table_schema = opts.get(
+ "version_table_schema", None
+ )
self._version = Table(
- version_table, MetaData(),
- Column('version_num', String(32), nullable=False),
- schema=version_table_schema)
+ version_table,
+ MetaData(),
+ Column("version_num", String(32), nullable=False),
+ schema=version_table_schema,
+ )
if opts.get("version_table_pk", True):
self._version.append_constraint(
PrimaryKeyConstraint(
- 'version_num', name="%s_pkc" % version_table
+ "version_num", name="%s_pkc" % version_table
)
)
self._start_from_rev = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
- dialect, self.connection, self.as_sql,
+ dialect,
+ self.connection,
+ self.as_sql,
transactional_ddl,
self.output_buffer,
- opts
+ opts,
)
log.info("Context impl %s.", self.impl.__class__.__name__)
if self.as_sql:
log.info("Generating static SQL")
- log.info("Will assume %s DDL.",
- "transactional" if self.impl.transactional_ddl
- else "non-transactional")
+ log.info(
+ "Will assume %s DDL.",
+ "transactional"
+ if self.impl.transactional_ddl
+ else "non-transactional",
+ )
@classmethod
- def configure(cls,
- connection=None,
- url=None,
- dialect_name=None,
- dialect=None,
- environment_context=None,
- opts=None,
- ):
+ def configure(
+ cls,
+ connection=None,
+ url=None,
+ dialect_name=None,
+ dialect=None,
+ environment_context=None,
+ opts=None,
+ ):
"""Create a new :class:`.MigrationContext`.
This is a factory method usually called
@@ -158,7 +175,8 @@ class MigrationContext(object):
util.warn(
"'connection' argument to configure() is expected "
"to be a sqlalchemy.engine.Connection instance, "
- "got %r" % connection)
+ "got %r" % connection
+ )
dialect = connection.dialect
elif url:
url = sqla_url.make_url(url)
@@ -175,22 +193,28 @@ class MigrationContext(object):
transaction_now = _per_migration == self._transaction_per_migration
if not transaction_now:
+
@contextmanager
def do_nothing():
yield
+
return do_nothing()
elif not self.impl.transactional_ddl:
+
@contextmanager
def do_nothing():
yield
+
return do_nothing()
elif self.as_sql:
+
@contextmanager
def begin_commit():
self.impl.emit_begin()
yield
self.impl.emit_commit()
+
return begin_commit()
else:
return self.bind.begin()
@@ -217,7 +241,8 @@ class MigrationContext(object):
elif len(heads) > 1:
raise util.CommandError(
"Version table '%s' has more than one head present; "
- "please use get_current_heads()" % self.version_table)
+ "please use get_current_heads()" % self.version_table
+ )
else:
return heads[0]
@@ -243,18 +268,20 @@ class MigrationContext(object):
"""
if self.as_sql:
start_from_rev = self._start_from_rev
- if start_from_rev == 'base':
+ if start_from_rev == "base":
start_from_rev = None
elif start_from_rev is not None and self.script:
- start_from_rev = \
- self.script.get_revision(start_from_rev).revision
+ start_from_rev = self.script.get_revision(
+ start_from_rev
+ ).revision
return util.to_tuple(start_from_rev, default=())
else:
if self._start_from_rev:
raise util.CommandError(
"Can't specify current_rev to context "
- "when using a database connection")
+ "when using a database connection"
+ )
if not self._has_version_table():
return ()
return tuple(
@@ -266,7 +293,8 @@ class MigrationContext(object):
def _has_version_table(self):
return self.connection.dialect.has_table(
- self.connection, self.version_table, self.version_table_schema)
+ self.connection, self.version_table, self.version_table_schema
+ )
def stamp(self, script_directory, revision):
"""Stamp the version table with a specific revision.
@@ -315,8 +343,9 @@ class MigrationContext(object):
head_maintainer = HeadMaintainer(self, heads)
- starting_in_transaction = not self.as_sql and \
- self._in_connection_transaction()
+ starting_in_transaction = (
+ not self.as_sql and self._in_connection_transaction()
+ )
for step in self._migrations_fn(heads, self):
with self.begin_transaction(_per_migration=True):
@@ -326,7 +355,9 @@ class MigrationContext(object):
self._version.create(self.connection)
log.info("Running %s", step)
if self.as_sql:
- self.impl.static_output("-- Running %s" % (step.short_log,))
+ self.impl.static_output(
+ "-- Running %s" % (step.short_log,)
+ )
step.migration_fn(**kw)
# previously, we wouldn't stamp per migration
@@ -336,19 +367,24 @@ class MigrationContext(object):
# just to run the operations on every version
head_maintainer.update_to_step(step)
for callback in self.on_version_apply_callbacks:
- callback(ctx=self,
- step=step.info,
- heads=set(head_maintainer.heads),
- run_args=kw)
-
- if not starting_in_transaction and not self.as_sql and \
- not self.impl.transactional_ddl and \
- self._in_connection_transaction():
+ callback(
+ ctx=self,
+ step=step.info,
+ heads=set(head_maintainer.heads),
+ run_args=kw,
+ )
+
+ if (
+ not starting_in_transaction
+ and not self.as_sql
+ and not self.impl.transactional_ddl
+ and self._in_connection_transaction()
+ ):
raise util.CommandError(
- "Migration \"%s\" has left an uncommitted "
+ 'Migration "%s" has left an uncommitted '
"transaction opened; transactional_ddl is False so "
- "Alembic is not committing transactions"
- % step)
+ "Alembic is not committing transactions" % step
+ )
if self.as_sql and not head_maintainer.heads:
self._version.drop(self.connection)
@@ -421,19 +457,20 @@ class MigrationContext(object):
inspector_column,
metadata_column,
inspector_column.type,
- metadata_column.type
+ metadata_column.type,
)
if user_value is not None:
return user_value
- return self.impl.compare_type(
- inspector_column,
- metadata_column)
+ return self.impl.compare_type(inspector_column, metadata_column)
- def _compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default):
+ def _compare_server_default(
+ self,
+ inspector_column,
+ metadata_column,
+ rendered_metadata_default,
+ rendered_column_default,
+ ):
if self._user_compare_server_default is False:
return False
@@ -445,7 +482,7 @@ class MigrationContext(object):
metadata_column,
rendered_column_default,
metadata_column.server_default,
- rendered_metadata_default
+ rendered_metadata_default,
)
if user_value is not None:
return user_value
@@ -454,7 +491,8 @@ class MigrationContext(object):
inspector_column,
metadata_column,
rendered_metadata_default,
- rendered_column_default)
+ rendered_column_default,
+ )
class HeadMaintainer(object):
@@ -467,8 +505,7 @@ class HeadMaintainer(object):
self.heads.add(version)
self.context.impl._exec(
- self.context._version.insert().
- values(
+ self.context._version.insert().values(
version_num=literal_column("'%s'" % version)
)
)
@@ -478,15 +515,17 @@ class HeadMaintainer(object):
ret = self.context.impl._exec(
self.context._version.delete().where(
- self.context._version.c.version_num ==
- literal_column("'%s'" % version)))
+ self.context._version.c.version_num
+ == literal_column("'%s'" % version)
+ )
+ )
if not self.context.as_sql and ret.rowcount != 1:
raise util.CommandError(
"Online migration expected to match one "
"row when deleting '%s' in '%s'; "
"%d found"
- % (version,
- self.context.version_table, ret.rowcount))
+ % (version, self.context.version_table, ret.rowcount)
+ )
def _update_version(self, from_, to_):
assert to_ not in self.heads
@@ -494,17 +533,20 @@ class HeadMaintainer(object):
self.heads.add(to_)
ret = self.context.impl._exec(
- self.context._version.update().
- values(version_num=literal_column("'%s'" % to_)).where(
+ self.context._version.update()
+ .values(version_num=literal_column("'%s'" % to_))
+ .where(
self.context._version.c.version_num
- == literal_column("'%s'" % from_))
+ == literal_column("'%s'" % from_)
+ )
)
if not self.context.as_sql and ret.rowcount != 1:
raise util.CommandError(
"Online migration expected to match one "
"row when updating '%s' to '%s' in '%s'; "
"%d found"
- % (from_, to_, self.context.version_table, ret.rowcount))
+ % (from_, to_, self.context.version_table, ret.rowcount)
+ )
def update_to_step(self, step):
if step.should_delete_branch(self.heads):
@@ -517,20 +559,32 @@ class HeadMaintainer(object):
self._insert_version(vers)
elif step.should_merge_branches(self.heads):
# delete revs, update from rev, update to rev
- (delete_revs, update_from_rev,
- update_to_rev) = step.merge_branch_idents(self.heads)
+ (
+ delete_revs,
+ update_from_rev,
+ update_to_rev,
+ ) = step.merge_branch_idents(self.heads)
log.debug(
"merge, delete %s, update %s to %s",
- delete_revs, update_from_rev, update_to_rev)
+ delete_revs,
+ update_from_rev,
+ update_to_rev,
+ )
for delrev in delete_revs:
self._delete_version(delrev)
self._update_version(update_from_rev, update_to_rev)
elif step.should_unmerge_branches(self.heads):
- (update_from_rev, update_to_rev,
- insert_revs) = step.unmerge_branch_idents(self.heads)
+ (
+ update_from_rev,
+ update_to_rev,
+ insert_revs,
+ ) = step.unmerge_branch_idents(self.heads)
log.debug(
"unmerge, insert %s, update %s to %s",
- insert_revs, update_from_rev, update_to_rev)
+ insert_revs,
+ update_from_rev,
+ update_to_rev,
+ )
for insrev in insert_revs:
self._insert_version(insrev)
self._update_version(update_from_rev, update_to_rev)
@@ -597,8 +651,9 @@ class MigrationInfo(object):
revision_map = None
"""The revision map inside of which this operation occurs."""
- def __init__(self, revision_map, is_upgrade, is_stamp, up_revisions,
- down_revisions):
+ def __init__(
+ self, revision_map, is_upgrade, is_stamp, up_revisions, down_revisions
+ ):
self.revision_map = revision_map
self.is_upgrade = is_upgrade
self.is_stamp = is_stamp
@@ -625,14 +680,16 @@ class MigrationInfo(object):
@property
def source_revision_ids(self):
"""Active revisions before this migration step is applied."""
- return self.down_revision_ids if self.is_upgrade \
- else self.up_revision_ids
+ return (
+ self.down_revision_ids if self.is_upgrade else self.up_revision_ids
+ )
@property
def destination_revision_ids(self):
"""Active revisions after this migration step is applied."""
- return self.up_revision_ids if self.is_upgrade \
- else self.down_revision_ids
+ return (
+ self.up_revision_ids if self.is_upgrade else self.down_revision_ids
+ )
@property
def up_revision(self):
@@ -689,7 +746,7 @@ class MigrationStep(object):
return "%s %s -> %s" % (
self.name,
util.format_as_comma(self.from_revisions_no_deps),
- util.format_as_comma(self.to_revisions_no_deps)
+ util.format_as_comma(self.to_revisions_no_deps),
)
def __str__(self):
@@ -698,7 +755,7 @@ class MigrationStep(object):
self.name,
util.format_as_comma(self.from_revisions_no_deps),
util.format_as_comma(self.to_revisions_no_deps),
- self.doc
+ self.doc,
)
else:
return self.short_log
@@ -716,13 +773,16 @@ class RevisionStep(MigrationStep):
def __repr__(self):
return "RevisionStep(%r, is_upgrade=%r)" % (
- self.revision.revision, self.is_upgrade
+ self.revision.revision,
+ self.is_upgrade,
)
def __eq__(self, other):
- return isinstance(other, RevisionStep) and \
- other.revision == self.revision and \
- self.is_upgrade == other.is_upgrade
+ return (
+ isinstance(other, RevisionStep)
+ and other.revision == self.revision
+ and self.is_upgrade == other.is_upgrade
+ )
@property
def doc(self):
@@ -733,26 +793,26 @@ class RevisionStep(MigrationStep):
if self.is_upgrade:
return self.revision._all_down_revisions
else:
- return (self.revision.revision, )
+ return (self.revision.revision,)
@property
def from_revisions_no_deps(self):
if self.is_upgrade:
return self.revision._versioned_down_revisions
else:
- return (self.revision.revision, )
+ return (self.revision.revision,)
@property
def to_revisions(self):
if self.is_upgrade:
- return (self.revision.revision, )
+ return (self.revision.revision,)
else:
return self.revision._all_down_revisions
@property
def to_revisions_no_deps(self):
if self.is_upgrade:
- return (self.revision.revision, )
+ return (self.revision.revision,)
else:
return self.revision._versioned_down_revisions
@@ -788,31 +848,31 @@ class RevisionStep(MigrationStep):
if other_heads:
ancestors = set(
- r.revision for r in
- self.revision_map._get_ancestor_nodes(
- self.revision_map.get_revisions(other_heads),
- check=False
+ r.revision
+ for r in self.revision_map._get_ancestor_nodes(
+ self.revision_map.get_revisions(other_heads), check=False
)
)
from_revisions = list(
- set(self.from_revisions).difference(ancestors))
+ set(self.from_revisions).difference(ancestors)
+ )
else:
from_revisions = list(self.from_revisions)
return (
# delete revs, update from rev, update to rev
- list(from_revisions[0:-1]), from_revisions[-1],
- self.to_revisions[0]
+ list(from_revisions[0:-1]),
+ from_revisions[-1],
+ self.to_revisions[0],
)
def _unmerge_to_revisions(self, heads):
other_heads = set(heads).difference([self.revision.revision])
if other_heads:
ancestors = set(
- r.revision for r in
- self.revision_map._get_ancestor_nodes(
- self.revision_map.get_revisions(other_heads),
- check=False
+ r.revision
+ for r in self.revision_map._get_ancestor_nodes(
+ self.revision_map.get_revisions(other_heads), check=False
)
)
return list(set(self.to_revisions).difference(ancestors))
@@ -824,8 +884,9 @@ class RevisionStep(MigrationStep):
return (
# update from rev, update to rev, insert revs
- self.from_revisions[0], to_revisions[-1],
- to_revisions[0:-1]
+ self.from_revisions[0],
+ to_revisions[-1],
+ to_revisions[0:-1],
)
def should_create_branch(self, heads):
@@ -853,8 +914,7 @@ class RevisionStep(MigrationStep):
downrevs = self.revision._all_down_revisions
- if len(downrevs) > 1 and \
- len(heads.intersection(downrevs)) > 1:
+ if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1:
return True
return False
@@ -873,8 +933,9 @@ class RevisionStep(MigrationStep):
def update_version_num(self, heads):
if not self._has_scalar_down_revision:
downrev = heads.intersection(self.revision._all_down_revisions)
- assert len(downrev) == 1, \
- "Can't do an UPDATE because downrevision is ambiguous"
+ assert (
+ len(downrev) == 1
+ ), "Can't do an UPDATE because downrevision is ambiguous"
down_revision = list(downrev)[0]
else:
down_revision = self.revision._all_down_revisions[0]
@@ -894,10 +955,13 @@ class RevisionStep(MigrationStep):
@property
def info(self):
- return MigrationInfo(revision_map=self.revision_map,
- up_revisions=self.revision.revision,
- down_revisions=self.revision._all_down_revisions,
- is_upgrade=self.is_upgrade, is_stamp=False)
+ return MigrationInfo(
+ revision_map=self.revision_map,
+ up_revisions=self.revision.revision,
+ down_revisions=self.revision._all_down_revisions,
+ is_upgrade=self.is_upgrade,
+ is_stamp=False,
+ )
class StampStep(MigrationStep):
@@ -915,11 +979,13 @@ class StampStep(MigrationStep):
return None
def __eq__(self, other):
- return isinstance(other, StampStep) and \
- other.from_revisions == self.revisions and \
- other.to_revisions == self.to_revisions and \
- other.branch_move == self.branch_move and \
- self.is_upgrade == other.is_upgrade
+ return (
+ isinstance(other, StampStep)
+ and other.from_revisions == self.revisions
+ and other.to_revisions == self.to_revisions
+ and other.branch_move == self.branch_move
+ and self.is_upgrade == other.is_upgrade
+ )
@property
def from_revisions(self):
@@ -955,15 +1021,17 @@ class StampStep(MigrationStep):
def merge_branch_idents(self, heads):
return (
# delete revs, update from rev, update to rev
- list(self.from_[0:-1]), self.from_[-1],
- self.to_[0]
+ list(self.from_[0:-1]),
+ self.from_[-1],
+ self.to_[0],
)
def unmerge_branch_idents(self, heads):
return (
# update from rev, update to rev, insert revs
- self.from_[0], self.to_[-1],
- list(self.to_[0:-1])
+ self.from_[0],
+ self.to_[-1],
+ list(self.to_[0:-1]),
)
def should_delete_branch(self, heads):
@@ -980,10 +1048,15 @@ class StampStep(MigrationStep):
@property
def info(self):
- up, down = (self.to_, self.from_) if self.is_upgrade \
+ up, down = (
+ (self.to_, self.from_)
+ if self.is_upgrade
else (self.from_, self.to_)
- return MigrationInfo(revision_map=self.revision_map,
- up_revisions=up,
- down_revisions=down,
- is_upgrade=self.is_upgrade,
- is_stamp=True)
+ )
+ return MigrationInfo(
+ revision_map=self.revision_map,
+ up_revisions=up,
+ down_revisions=down,
+ is_upgrade=self.is_upgrade,
+ is_stamp=True,
+ )
diff --git a/alembic/script/__init__.py b/alembic/script/__init__.py
index cae294f..65562b4 100644
--- a/alembic/script/__init__.py
+++ b/alembic/script/__init__.py
@@ -1,3 +1,3 @@
from .base import ScriptDirectory, Script # noqa
-__all__ = ['ScriptDirectory', 'Script']
+__all__ = ["ScriptDirectory", "Script"]
diff --git a/alembic/script/base.py b/alembic/script/base.py
index 12e1510..1c63e08 100644
--- a/alembic/script/base.py
+++ b/alembic/script/base.py
@@ -10,13 +10,13 @@ from ..runtime import migration
from contextlib import contextmanager
-_sourceless_rev_file = re.compile(r'(?!\.\#|__init__)(.*\.py)(c|o)?$')
-_only_source_rev_file = re.compile(r'(?!\.\#|__init__)(.*\.py)$')
-_legacy_rev = re.compile(r'([a-f0-9]+)\.py$')
-_mod_def_re = re.compile(r'(upgrade|downgrade)_([a-z0-9]+)')
-_slug_re = re.compile(r'\w+')
+_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
+_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
+_legacy_rev = re.compile(r"([a-f0-9]+)\.py$")
+_mod_def_re = re.compile(r"(upgrade|downgrade)_([a-z0-9]+)")
+_slug_re = re.compile(r"\w+")
_default_file_template = "%(rev)s_%(slug)s"
-_split_on_space_comma = re.compile(r',|(?: +)')
+_split_on_space_comma = re.compile(r",|(?: +)")
class ScriptDirectory(object):
@@ -40,11 +40,16 @@ class ScriptDirectory(object):
"""
- def __init__(self, dir, file_template=_default_file_template,
- truncate_slug_length=40,
- version_locations=None,
- sourceless=False, output_encoding="utf-8",
- timezone=None):
+ def __init__(
+ self,
+ dir,
+ file_template=_default_file_template,
+ truncate_slug_length=40,
+ version_locations=None,
+ sourceless=False,
+ output_encoding="utf-8",
+ timezone=None,
+ ):
self.dir = dir
self.file_template = file_template
self.version_locations = version_locations
@@ -55,9 +60,11 @@ class ScriptDirectory(object):
self.timezone = timezone
if not os.access(dir, os.F_OK):
- raise util.CommandError("Path doesn't exist: %r. Please use "
- "the 'init' command to create a new "
- "scripts folder." % dir)
+ raise util.CommandError(
+ "Path doesn't exist: %r. Please use "
+ "the 'init' command to create a new "
+ "scripts folder." % dir
+ )
@property
def versions(self):
@@ -75,13 +82,15 @@ class ScriptDirectory(object):
for location in self.version_locations
]
else:
- return (os.path.abspath(os.path.join(self.dir, 'versions')),)
+ return (os.path.abspath(os.path.join(self.dir, "versions")),)
def _load_revisions(self):
if self.version_locations:
paths = [
- vers for vers in self._version_locations
- if os.path.exists(vers)]
+ vers
+ for vers in self._version_locations
+ if os.path.exists(vers)
+ ]
else:
paths = [self.versions]
@@ -110,10 +119,11 @@ class ScriptDirectory(object):
present.
"""
- script_location = config.get_main_option('script_location')
+ script_location = config.get_main_option("script_location")
if script_location is None:
- raise util.CommandError("No 'script_location' key "
- "found in configuration.")
+ raise util.CommandError(
+ "No 'script_location' key " "found in configuration."
+ )
truncate_slug_length = config.get_main_option("truncate_slug_length")
if truncate_slug_length is not None:
truncate_slug_length = int(truncate_slug_length)
@@ -125,20 +135,24 @@ class ScriptDirectory(object):
return ScriptDirectory(
util.coerce_resource_to_filename(script_location),
file_template=config.get_main_option(
- 'file_template',
- _default_file_template),
+ "file_template", _default_file_template
+ ),
truncate_slug_length=truncate_slug_length,
sourceless=config.get_main_option("sourceless") == "true",
output_encoding=config.get_main_option("output_encoding", "utf-8"),
version_locations=version_locations,
- timezone=config.get_main_option("timezone")
+ timezone=config.get_main_option("timezone"),
)
@contextmanager
def _catch_revision_errors(
- self,
- ancestor=None, multiple_heads=None, start=None, end=None,
- resolution=None):
+ self,
+ ancestor=None,
+ multiple_heads=None,
+ start=None,
+ end=None,
+ resolution=None,
+ ):
try:
yield
except revision.RangeNotAncestorError as rna:
@@ -160,10 +174,11 @@ class ScriptDirectory(object):
"argument '%(head_arg)s'; please "
"specify a specific target revision, "
"'<branchname>@%(head_arg)s' to "
- "narrow to a specific head, or 'heads' for all heads")
+ "narrow to a specific head, or 'heads' for all heads"
+ )
multiple_heads = multiple_heads % {
"head_arg": end or mh.argument,
- "heads": util.format_as_comma(mh.heads)
+ "heads": util.format_as_comma(mh.heads),
}
compat.raise_from_cause(util.CommandError(multiple_heads))
except revision.ResolutionError as re:
@@ -192,7 +207,8 @@ class ScriptDirectory(object):
"""
with self._catch_revision_errors(start=base, end=head):
for rev in self.revision_map.iterate_revisions(
- head, base, inclusive=True, assert_relative_length=False):
+ head, base, inclusive=True, assert_relative_length=False
+ ):
yield rev
def get_revisions(self, id_):
@@ -210,7 +226,8 @@ class ScriptDirectory(object):
top_revs = set(self.revision_map.get_revisions(id_))
top_revs.update(
self.revision_map._get_ancestor_nodes(
- list(top_revs), include_dependencies=True)
+ list(top_revs), include_dependencies=True
+ )
)
top_revs = self.revision_map._filter_into_branch_heads(top_revs)
return top_revs
@@ -275,11 +292,13 @@ class ScriptDirectory(object):
:meth:`.ScriptDirectory.get_heads`
"""
- with self._catch_revision_errors(multiple_heads=(
- 'The script directory has multiple heads (due to branching).'
- 'Please use get_heads(), or merge the branches using '
- 'alembic merge.'
- )):
+ with self._catch_revision_errors(
+ multiple_heads=(
+ "The script directory has multiple heads (due to branching)."
+ "Please use get_heads(), or merge the branches using "
+ "alembic merge."
+ )
+ ):
return self.revision_map.get_current_head()
def get_heads(self):
@@ -310,7 +329,8 @@ class ScriptDirectory(object):
if len(bases) > 1:
raise util.CommandError(
"The script directory has multiple bases. "
- "Please use get_bases().")
+ "Please use get_bases()."
+ )
elif bases:
return bases[0]
else:
@@ -329,40 +349,50 @@ class ScriptDirectory(object):
def _upgrade_revs(self, destination, current_rev):
with self._catch_revision_errors(
- ancestor="Destination %(end)s is not a valid upgrade "
- "target from current head(s)", end=destination):
+ ancestor="Destination %(end)s is not a valid upgrade "
+ "target from current head(s)",
+ end=destination,
+ ):
revs = self.revision_map.iterate_revisions(
- destination, current_rev, implicit_base=True)
+ destination, current_rev, implicit_base=True
+ )
revs = list(revs)
return [
migration.MigrationStep.upgrade_from_script(
- self.revision_map, script)
+ self.revision_map, script
+ )
for script in reversed(list(revs))
]
def _downgrade_revs(self, destination, current_rev):
with self._catch_revision_errors(
- ancestor="Destination %(end)s is not a valid downgrade "
- "target from current head(s)", end=destination):
+ ancestor="Destination %(end)s is not a valid downgrade "
+ "target from current head(s)",
+ end=destination,
+ ):
revs = self.revision_map.iterate_revisions(
- current_rev, destination, select_for_downgrade=True)
+ current_rev, destination, select_for_downgrade=True
+ )
return [
migration.MigrationStep.downgrade_from_script(
- self.revision_map, script)
+ self.revision_map, script
+ )
for script in revs
]
def _stamp_revs(self, revision, heads):
with self._catch_revision_errors(
- multiple_heads="Multiple heads are present; please specify a "
- "single target revision"):
+ multiple_heads="Multiple heads are present; please specify a "
+ "single target revision"
+ ):
heads = self.get_revisions(heads)
# filter for lineage will resolve things like
# branchname@base, version@base, etc.
filtered_heads = self.revision_map.filter_for_lineage(
- heads, revision, include_dependencies=True)
+ heads, revision, include_dependencies=True
+ )
steps = []
@@ -371,11 +401,18 @@ class ScriptDirectory(object):
if dest is None:
# dest is 'base'. Return a "delete branch" migration
# for all applicable heads.
- steps.extend([
- migration.StampStep(head.revision, None, False, True,
- self.revision_map)
- for head in filtered_heads
- ])
+ steps.extend(
+ [
+ migration.StampStep(
+ head.revision,
+ None,
+ False,
+ True,
+ self.revision_map,
+ )
+ for head in filtered_heads
+ ]
+ )
continue
elif dest in filtered_heads:
# the dest is already in the version table, do nothing.
@@ -384,7 +421,8 @@ class ScriptDirectory(object):
# figure out if the dest is a descendant or an
# ancestor of the selected nodes
descendants = set(
- self.revision_map._get_descendant_nodes([dest]))
+ self.revision_map._get_descendant_nodes([dest])
+ )
ancestors = set(self.revision_map._get_ancestor_nodes([dest]))
if descendants.intersection(filtered_heads):
@@ -393,8 +431,12 @@ class ScriptDirectory(object):
assert not ancestors.intersection(filtered_heads)
todo_heads = [head.revision for head in filtered_heads]
step = migration.StampStep(
- todo_heads, dest.revision, False, False,
- self.revision_map)
+ todo_heads,
+ dest.revision,
+ False,
+ False,
+ self.revision_map,
+ )
steps.append(step)
continue
elif ancestors.intersection(filtered_heads):
@@ -402,15 +444,20 @@ class ScriptDirectory(object):
# we can treat them as a "merge", single step.
todo_heads = [head.revision for head in filtered_heads]
step = migration.StampStep(
- todo_heads, dest.revision, True, False,
- self.revision_map)
+ todo_heads,
+ dest.revision,
+ True,
+ False,
+ self.revision_map,
+ )
steps.append(step)
continue
else:
# destination is in a branch not represented,
# treat it as new branch
- step = migration.StampStep((), dest.revision, True, True,
- self.revision_map)
+ step = migration.StampStep(
+ (), dest.revision, True, True, self.revision_map
+ )
steps.append(step)
continue
return steps
@@ -424,32 +471,31 @@ class ScriptDirectory(object):
"""
- util.load_python_file(self.dir, 'env.py')
+ util.load_python_file(self.dir, "env.py")
@property
def env_py_location(self):
return os.path.abspath(os.path.join(self.dir, "env.py"))
def _generate_template(self, src, dest, **kw):
- util.status("Generating %s" % os.path.abspath(dest),
- util.template_to_file,
- src,
- dest,
- self.output_encoding,
- **kw
- )
+ util.status(
+ "Generating %s" % os.path.abspath(dest),
+ util.template_to_file,
+ src,
+ dest,
+ self.output_encoding,
+ **kw
+ )
def _copy_file(self, src, dest):
- util.status("Generating %s" % os.path.abspath(dest),
- shutil.copy,
- src, dest)
+ util.status(
+ "Generating %s" % os.path.abspath(dest), shutil.copy, src, dest
+ )
def _ensure_directory(self, path):
path = os.path.abspath(path)
if not os.path.exists(path):
- util.status(
- "Creating directory %s" % path,
- os.makedirs, path)
+ util.status("Creating directory %s" % path, os.makedirs, path)
def _generate_create_date(self):
if self.timezone is not None:
@@ -460,17 +506,29 @@ class ScriptDirectory(object):
tzinfo = tz.gettz(self.timezone.upper())
if tzinfo is None:
raise util.CommandError(
- "Can't locate timezone: %s" % self.timezone)
- create_date = datetime.datetime.utcnow().replace(
- tzinfo=tz.tzutc()).astimezone(tzinfo)
+ "Can't locate timezone: %s" % self.timezone
+ )
+ create_date = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=tz.tzutc())
+ .astimezone(tzinfo)
+ )
else:
create_date = datetime.datetime.now()
return create_date
def generate_revision(
- self, revid, message, head=None,
- refresh=False, splice=False, branch_labels=None,
- version_path=None, depends_on=None, **kw):
+ self,
+ revid,
+ message,
+ head=None,
+ refresh=False,
+ splice=False,
+ branch_labels=None,
+ version_path=None,
+ depends_on=None,
+ **kw
+ ):
"""Generate a new revision file.
This runs the ``script.py.mako`` template, given
@@ -500,11 +558,13 @@ class ScriptDirectory(object):
except revision.RevisionError as err:
compat.raise_from_cause(util.CommandError(err.args[0]))
- with self._catch_revision_errors(multiple_heads=(
- "Multiple heads are present; please specify the head "
- "revision on which the new revision should be based, "
- "or perform a merge."
- )):
+ with self._catch_revision_errors(
+ multiple_heads=(
+ "Multiple heads are present; please specify the head "
+ "revision on which the new revision should be based, "
+ "or perform a merge."
+ )
+ ):
heads = self.revision_map.get_revisions(head)
if len(set(heads)) != len(heads):
@@ -521,7 +581,8 @@ class ScriptDirectory(object):
else:
raise util.CommandError(
"Multiple version locations present, "
- "please specify --version-path")
+ "please specify --version-path"
+ )
else:
version_path = self.versions
@@ -532,7 +593,8 @@ class ScriptDirectory(object):
else:
raise util.CommandError(
"Path %s is not represented in current "
- "version locations" % version_path)
+ "version locations" % version_path
+ )
if self.version_locations:
self._ensure_directory(version_path)
@@ -545,7 +607,8 @@ class ScriptDirectory(object):
raise util.CommandError(
"Revision %s is not a head revision; please specify "
"--splice to create a new branch from this revision"
- % head.revision)
+ % head.revision
+ )
if depends_on:
with self._catch_revision_errors():
@@ -557,7 +620,6 @@ class ScriptDirectory(object):
(self.revision_map.get_revision(dep), dep)
for dep in util.to_list(depends_on)
]
-
]
self._generate_template(
@@ -565,7 +627,8 @@ class ScriptDirectory(object):
path,
up_revision=str(revid),
down_revision=revision.tuple_rev_as_scalar(
- tuple(h.revision if h is not None else None for h in heads)),
+ tuple(h.revision if h is not None else None for h in heads)
+ ),
branch_labels=util.to_tuple(branch_labels),
depends_on=revision.tuple_rev_as_scalar(depends_on),
create_date=create_date,
@@ -582,9 +645,9 @@ class ScriptDirectory(object):
"Version %s specified branch_labels %s, however the "
"migration file %s does not have them; have you upgraded "
"your script.py.mako to include the "
- "'branch_labels' section?" % (
- script.revision, branch_labels, script.path
- ))
+ "'branch_labels' section?"
+ % (script.revision, branch_labels, script.path)
+ )
self.revision_map.add_revision(script)
return script
@@ -592,17 +655,18 @@ class ScriptDirectory(object):
def _rev_path(self, path, rev_id, message, create_date):
slug = "_".join(_slug_re.findall(message or "")).lower()
if len(slug) > self.truncate_slug_length:
- slug = slug[:self.truncate_slug_length].rsplit('_', 1)[0] + '_'
+ slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
filename = "%s.py" % (
- self.file_template % {
- 'rev': rev_id,
- 'slug': slug,
- 'year': create_date.year,
- 'month': create_date.month,
- 'day': create_date.day,
- 'hour': create_date.hour,
- 'minute': create_date.minute,
- 'second': create_date.second
+ self.file_template
+ % {
+ "rev": rev_id,
+ "slug": slug,
+ "year": create_date.year,
+ "month": create_date.month,
+ "day": create_date.day,
+ "hour": create_date.hour,
+ "minute": create_date.minute,
+ "second": create_date.second,
}
)
return os.path.join(path, filename)
@@ -624,9 +688,11 @@ class Script(revision.Revision):
rev_id,
module.down_revision,
branch_labels=util.to_tuple(
- getattr(module, 'branch_labels', None), default=()),
+ getattr(module, "branch_labels", None), default=()
+ ),
dependencies=util.to_tuple(
- getattr(module, 'depends_on', None), default=())
+ getattr(module, "depends_on", None), default=()
+ ),
)
module = None
@@ -664,32 +730,32 @@ class Script(revision.Revision):
" (head)" if self.is_head else "",
" (branchpoint)" if self.is_branch_point else "",
" (mergepoint)" if self.is_merge_point else "",
- " (current)" if self._db_current_indicator else ""
+ " (current)" if self._db_current_indicator else "",
)
if self.is_merge_point:
- entry += "Merges: %s\n" % (self._format_down_revision(), )
+ entry += "Merges: %s\n" % (self._format_down_revision(),)
else:
- entry += "Parent: %s\n" % (self._format_down_revision(), )
+ entry += "Parent: %s\n" % (self._format_down_revision(),)
if self.dependencies:
entry += "Also depends on: %s\n" % (
- util.format_as_comma(self.dependencies))
+ util.format_as_comma(self.dependencies)
+ )
if self.is_branch_point:
entry += "Branches into: %s\n" % (
- util.format_as_comma(self.nextrev))
+ util.format_as_comma(self.nextrev)
+ )
if self.branch_labels:
entry += "Branch names: %s\n" % (
- util.format_as_comma(self.branch_labels), )
+ util.format_as_comma(self.branch_labels),
+ )
entry += "Path: %s\n" % (self.path,)
entry += "\n%s\n" % (
- "\n".join(
- " %s" % para
- for para in self.longdoc.splitlines()
- )
+ "\n".join(" %s" % para for para in self.longdoc.splitlines())
)
return entry
@@ -700,36 +766,41 @@ class Script(revision.Revision):
" (head)" if self.is_head else "",
" (branchpoint)" if self.is_branch_point else "",
" (mergepoint)" if self.is_merge_point else "",
- self.doc)
+ self.doc,
+ )
def _head_only(
- self, include_branches=False, include_doc=False,
- include_parents=False, tree_indicators=True,
- head_indicators=True):
+ self,
+ include_branches=False,
+ include_doc=False,
+ include_parents=False,
+ tree_indicators=True,
+ head_indicators=True,
+ ):
text = self.revision
if include_parents:
if self.dependencies:
text = "%s (%s) -> %s" % (
self._format_down_revision(),
util.format_as_comma(self.dependencies),
- text
+ text,
)
else:
- text = "%s -> %s" % (
- self._format_down_revision(), text)
+ text = "%s -> %s" % (self._format_down_revision(), text)
if include_branches and self.branch_labels:
text += " (%s)" % util.format_as_comma(self.branch_labels)
if head_indicators or tree_indicators:
text += "%s%s%s" % (
" (head)" if self._is_real_head else "",
- " (effective head)" if self.is_head and
- not self._is_real_head else "",
- " (current)" if self._db_current_indicator else ""
+ " (effective head)"
+ if self.is_head and not self._is_real_head
+ else "",
+ " (current)" if self._db_current_indicator else "",
)
if tree_indicators:
text += "%s%s" % (
" (branchpoint)" if self.is_branch_point else "",
- " (mergepoint)" if self.is_merge_point else ""
+ " (mergepoint)" if self.is_merge_point else "",
)
if include_doc:
text += ", %s" % self.doc
@@ -737,15 +808,18 @@ class Script(revision.Revision):
def cmd_format(
self,
- verbose,
- include_branches=False, include_doc=False,
- include_parents=False, tree_indicators=True):
+ verbose,
+ include_branches=False,
+ include_doc=False,
+ include_parents=False,
+ tree_indicators=True,
+ ):
if verbose:
return self.log_entry
else:
return self._head_only(
- include_branches, include_doc,
- include_parents, tree_indicators)
+ include_branches, include_doc, include_parents, tree_indicators
+ )
def _format_down_revision(self):
if not self.down_revision:
@@ -768,13 +842,13 @@ class Script(revision.Revision):
names = set(fname.split(".")[0] for fname in paths)
# look for __pycache__
- if os.path.exists(os.path.join(path, '__pycache__')):
+ if os.path.exists(os.path.join(path, "__pycache__")):
# add all files from __pycache__ whose filename is not
# already in the names we got from the version directory.
# add as relative paths including __pycache__ token
paths.extend(
- os.path.join('__pycache__', pyc)
- for pyc in os.listdir(os.path.join(path, '__pycache__'))
+ os.path.join("__pycache__", pyc)
+ for pyc in os.listdir(os.path.join(path, "__pycache__"))
if pyc.split(".")[0] not in names
)
return paths
@@ -794,8 +868,8 @@ class Script(revision.Revision):
py_filename = py_match.group(1)
if scriptdir.sourceless:
- is_c = py_match.group(2) == 'c'
- is_o = py_match.group(2) == 'o'
+ is_c = py_match.group(2) == "c"
+ is_o = py_match.group(2) == "o"
else:
is_c = is_o = False
@@ -821,7 +895,8 @@ class Script(revision.Revision):
"Be sure the 'revision' variable is "
"declared inside the script (please see 'Upgrading "
"from Alembic 0.1 to 0.2' in the documentation)."
- % filename)
+ % filename
+ )
else:
revision = m.group(1)
else:
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
index 3d9a332..832cce1 100644
--- a/alembic/script/revision.py
+++ b/alembic/script/revision.py
@@ -5,8 +5,8 @@ from .. import util
from sqlalchemy import util as sqlautil
from ..util import compat
-_relative_destination = re.compile(r'(?:(.+?)@)?(\w+)?((?:\+|-)\d+)')
-_revision_illegal_chars = ['@', '-', '+']
+_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
+_revision_illegal_chars = ["@", "-", "+"]
class RevisionError(Exception):
@@ -18,8 +18,8 @@ class RangeNotAncestorError(RevisionError):
self.lower = lower
self.upper = upper
super(RangeNotAncestorError, self).__init__(
- "Revision %s is not an ancestor of revision %s" %
- (lower or "base", upper or "base")
+ "Revision %s is not an ancestor of revision %s"
+ % (lower or "base", upper or "base")
)
@@ -122,8 +122,9 @@ class RevisionMap(object):
for revision in self._generator():
if revision.revision in map_:
- util.warn("Revision %s is present more than once" %
- revision.revision)
+ util.warn(
+ "Revision %s is present more than once" % revision.revision
+ )
map_[revision.revision] = revision
if revision.branch_labels:
has_branch_labels.add(revision)
@@ -132,9 +133,9 @@ class RevisionMap(object):
heads.add(revision.revision)
_real_heads.add(revision.revision)
if revision.is_base:
- self.bases += (revision.revision, )
+ self.bases += (revision.revision,)
if revision._is_real_base:
- self._real_bases += (revision.revision, )
+ self._real_bases += (revision.revision,)
# add the branch_labels to the map_. We'll need these
# to resolve the dependencies.
@@ -147,8 +148,10 @@ class RevisionMap(object):
for rev in map_.values():
for downrev in rev._all_down_revisions:
if downrev not in map_:
- util.warn("Revision %s referenced from %s is not present"
- % (downrev, rev))
+ util.warn(
+ "Revision %s referenced from %s is not present"
+ % (downrev, rev)
+ )
down_revision = map_[downrev]
down_revision.add_nextrev(rev)
if downrev in rev._versioned_down_revisions:
@@ -169,9 +172,12 @@ class RevisionMap(object):
if branch_label in map_:
raise RevisionError(
"Branch name '%s' in revision %s already "
- "used by revision %s" %
- (branch_label, revision.revision,
- map_[branch_label].revision)
+ "used by revision %s"
+ % (
+ branch_label,
+ revision.revision,
+ map_[branch_label].revision,
+ )
)
map_[branch_label] = revision
@@ -182,13 +188,16 @@ class RevisionMap(object):
if revision.branch_labels:
revision.branch_labels.update(revision.branch_labels)
for node in self._get_descendant_nodes(
- [revision], map_, include_dependencies=False):
+ [revision], map_, include_dependencies=False
+ ):
node.branch_labels.update(revision.branch_labels)
parent = node
- while parent and \
- not parent._is_real_branch_point and \
- not parent.is_merge_point:
+ while (
+ parent
+ and not parent._is_real_branch_point
+ and not parent.is_merge_point
+ ):
parent.branch_labels.update(revision.branch_labels)
if parent.down_revision:
@@ -201,7 +210,6 @@ class RevisionMap(object):
deps = [map_[dep] for dep in util.to_tuple(revision.dependencies)]
revision._resolved_dependencies = tuple([d.revision for d in deps])
-
def add_revision(self, revision, _replace=False):
"""add a single revision to an existing map.
@@ -211,8 +219,9 @@ class RevisionMap(object):
"""
map_ = self._revision_map
if not _replace and revision.revision in map_:
- util.warn("Revision %s is present more than once" %
- revision.revision)
+ util.warn(
+ "Revision %s is present more than once" % revision.revision
+ )
elif _replace and revision.revision not in map_:
raise Exception("revision %s not in map" % revision.revision)
@@ -221,9 +230,9 @@ class RevisionMap(object):
self._add_depends_on(revision, map_)
if revision.is_base:
- self.bases += (revision.revision, )
+ self.bases += (revision.revision,)
if revision._is_real_base:
- self._real_bases += (revision.revision, )
+ self._real_bases += (revision.revision,)
for downrev in revision._all_down_revisions:
if downrev not in map_:
util.warn(
@@ -233,15 +242,21 @@ class RevisionMap(object):
map_[downrev].add_nextrev(revision)
if revision._is_real_head:
self._real_heads = tuple(
- head for head in self._real_heads
- if head not in
- set(revision._all_down_revisions).union([revision.revision])
+ head
+ for head in self._real_heads
+ if head
+ not in set(revision._all_down_revisions).union(
+ [revision.revision]
+ )
) + (revision.revision,)
if revision.is_head:
self.heads = tuple(
- head for head in self.heads
- if head not in
- set(revision._versioned_down_revisions).union([revision.revision])
+ head
+ for head in self.heads
+ if head
+ not in set(revision._versioned_down_revisions).union(
+ [revision.revision]
+ )
) + (revision.revision,)
def get_current_head(self, branch_label=None):
@@ -264,11 +279,14 @@ class RevisionMap(object):
"""
current_heads = self.heads
if branch_label:
- current_heads = self.filter_for_lineage(current_heads, branch_label)
+ current_heads = self.filter_for_lineage(
+ current_heads, branch_label
+ )
if len(current_heads) > 1:
raise MultipleHeads(
current_heads,
- "%s@head" % branch_label if branch_label else "head")
+ "%s@head" % branch_label if branch_label else "head",
+ )
if current_heads:
return current_heads[0]
@@ -301,7 +319,8 @@ class RevisionMap(object):
resolved_id, branch_label = self._resolve_revision_number(id_)
return tuple(
self._revision_for_ident(rev_id, branch_label)
- for rev_id in resolved_id)
+ for rev_id in resolved_id
+ )
def get_revision(self, id_):
"""Return the :class:`.Revision` instance with the given rev id.
@@ -333,7 +352,8 @@ class RevisionMap(object):
nonbranch_rev = self._revision_for_ident(branch_label)
except ResolutionError:
raise ResolutionError(
- "No such branch: '%s'" % branch_label, branch_label)
+ "No such branch: '%s'" % branch_label, branch_label
+ )
else:
return nonbranch_rev
else:
@@ -352,30 +372,37 @@ class RevisionMap(object):
revision = False
if revision is False:
# do a partial lookup
- revs = [x for x in self._revision_map
- if x and x.startswith(resolved_id)]
+ revs = [
+ x
+ for x in self._revision_map
+ if x and x.startswith(resolved_id)
+ ]
if branch_rev:
revs = self.filter_for_lineage(revs, check_branch)
if not revs:
raise ResolutionError(
"No such revision or branch '%s'" % resolved_id,
- resolved_id)
+ resolved_id,
+ )
elif len(revs) > 1:
raise ResolutionError(
"Multiple revisions start "
- "with '%s': %s..." % (
- resolved_id,
- ", ".join("'%s'" % r for r in revs[0:3])
- ), resolved_id)
+ "with '%s': %s..."
+ % (resolved_id, ", ".join("'%s'" % r for r in revs[0:3])),
+ resolved_id,
+ )
else:
revision = self._revision_map[revs[0]]
if check_branch and revision is not None:
if not self._shares_lineage(
- revision.revision, branch_rev.revision):
+ revision.revision, branch_rev.revision
+ ):
raise ResolutionError(
- "Revision %s is not a member of branch '%s'" %
- (revision.revision, check_branch), resolved_id)
+ "Revision %s is not a member of branch '%s'"
+ % (revision.revision, check_branch),
+ resolved_id,
+ )
return revision
def _filter_into_branch_heads(self, targets):
@@ -383,14 +410,14 @@ class RevisionMap(object):
for rev in list(targets):
if targets.intersection(
- self._get_descendant_nodes(
- [rev], include_dependencies=False)).\
- difference([rev]):
+ self._get_descendant_nodes([rev], include_dependencies=False)
+ ).difference([rev]):
targets.discard(rev)
return targets
def filter_for_lineage(
- self, targets, check_against, include_dependencies=False):
+ self, targets, check_against, include_dependencies=False
+ ):
id_, branch_label = self._resolve_revision_number(check_against)
shares = []
@@ -400,12 +427,16 @@ class RevisionMap(object):
shares.extend(id_)
return [
- tg for tg in targets
+ tg
+ for tg in targets
if self._shares_lineage(
- tg, shares, include_dependencies=include_dependencies)]
+ tg, shares, include_dependencies=include_dependencies
+ )
+ ]
def _shares_lineage(
- self, target, test_against_revs, include_dependencies=False):
+ self, target, test_against_revs, include_dependencies=False
+ ):
if not test_against_revs:
return True
if not isinstance(target, Revision):
@@ -415,46 +446,61 @@ class RevisionMap(object):
self._revision_for_ident(test_against_rev)
if not isinstance(test_against_rev, Revision)
else test_against_rev
- for test_against_rev
- in util.to_tuple(test_against_revs, default=())
+ for test_against_rev in util.to_tuple(
+ test_against_revs, default=()
+ )
]
return bool(
- set(self._get_descendant_nodes([target],
- include_dependencies=include_dependencies))
- .union(self._get_ancestor_nodes([target],
- include_dependencies=include_dependencies))
+ set(
+ self._get_descendant_nodes(
+ [target], include_dependencies=include_dependencies
+ )
+ )
+ .union(
+ self._get_ancestor_nodes(
+ [target], include_dependencies=include_dependencies
+ )
+ )
.intersection(test_against_revs)
)
def _resolve_revision_number(self, id_):
if isinstance(id_, compat.string_types) and "@" in id_:
- branch_label, id_ = id_.split('@', 1)
+ branch_label, id_ = id_.split("@", 1)
else:
branch_label = None
# ensure map is loaded
self._revision_map
- if id_ == 'heads':
+ if id_ == "heads":
if branch_label:
- return self.filter_for_lineage(
- self.heads, branch_label), branch_label
+ return (
+ self.filter_for_lineage(self.heads, branch_label),
+ branch_label,
+ )
else:
return self._real_heads, branch_label
- elif id_ == 'head':
+ elif id_ == "head":
current_head = self.get_current_head(branch_label)
if current_head:
- return (current_head, ), branch_label
+ return (current_head,), branch_label
else:
return (), branch_label
- elif id_ == 'base' or id_ is None:
+ elif id_ == "base" or id_ is None:
return (), branch_label
else:
return util.to_tuple(id_, default=None), branch_label
def _relative_iterate(
- self, destination, source, is_upwards,
- implicit_base, inclusive, assert_relative_length):
+ self,
+ destination,
+ source,
+ is_upwards,
+ implicit_base,
+ inclusive,
+ assert_relative_length,
+ ):
if isinstance(destination, compat.string_types):
match = _relative_destination.match(destination)
if not match:
@@ -490,13 +536,15 @@ class RevisionMap(object):
revs = list(
self._iterate_revisions(
- from_, to_,
- inclusive=inclusive, implicit_base=implicit_base))
+ from_, to_, inclusive=inclusive, implicit_base=implicit_base
+ )
+ )
if symbol:
if branch_label:
symbol_rev = self.get_revision(
- "%s@%s" % (branch_label, symbol))
+ "%s@%s" % (branch_label, symbol)
+ )
else:
symbol_rev = self.get_revision(symbol)
if symbol.startswith("head"):
@@ -513,25 +561,39 @@ class RevisionMap(object):
else:
index = 0
if is_upwards:
- revs = revs[index - relative - reldelta:]
- if not index and assert_relative_length and \
- len(revs) < abs(relative - reldelta):
+ revs = revs[index - relative - reldelta :]
+ if (
+ not index
+ and assert_relative_length
+ and len(revs) < abs(relative - reldelta)
+ ):
raise RevisionError(
"Relative revision %s didn't "
- "produce %d migrations" % (destination, abs(relative)))
+ "produce %d migrations" % (destination, abs(relative))
+ )
else:
- revs = revs[0:index - relative + reldelta]
- if not index and assert_relative_length and \
- len(revs) != abs(relative) + reldelta:
+ revs = revs[0 : index - relative + reldelta]
+ if (
+ not index
+ and assert_relative_length
+ and len(revs) != abs(relative) + reldelta
+ ):
raise RevisionError(
"Relative revision %s didn't "
- "produce %d migrations" % (destination, abs(relative)))
+ "produce %d migrations" % (destination, abs(relative))
+ )
return iter(revs)
def iterate_revisions(
- self, upper, lower, implicit_base=False, inclusive=False,
- assert_relative_length=True, select_for_downgrade=False):
+ self,
+ upper,
+ lower,
+ implicit_base=False,
+ inclusive=False,
+ assert_relative_length=True,
+ select_for_downgrade=False,
+ ):
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
@@ -545,37 +607,59 @@ class RevisionMap(object):
"""
relative_upper = self._relative_iterate(
- upper, lower, True, implicit_base,
- inclusive, assert_relative_length
+ upper,
+ lower,
+ True,
+ implicit_base,
+ inclusive,
+ assert_relative_length,
)
if relative_upper:
return relative_upper
relative_lower = self._relative_iterate(
- lower, upper, False, implicit_base,
- inclusive, assert_relative_length
+ lower,
+ upper,
+ False,
+ implicit_base,
+ inclusive,
+ assert_relative_length,
)
if relative_lower:
return relative_lower
return self._iterate_revisions(
- upper, lower, inclusive=inclusive, implicit_base=implicit_base,
- select_for_downgrade=select_for_downgrade)
+ upper,
+ lower,
+ inclusive=inclusive,
+ implicit_base=implicit_base,
+ select_for_downgrade=select_for_downgrade,
+ )
def _get_descendant_nodes(
- self, targets, map_=None, check=False,
- omit_immediate_dependencies=False, include_dependencies=True):
+ self,
+ targets,
+ map_=None,
+ check=False,
+ omit_immediate_dependencies=False,
+ include_dependencies=True,
+ ):
if omit_immediate_dependencies:
+
def fn(rev):
if rev not in targets:
return rev._all_nextrev
else:
return rev.nextrev
+
elif include_dependencies:
+
def fn(rev):
return rev._all_nextrev
+
else:
+
def fn(rev):
return rev.nextrev
@@ -584,12 +668,16 @@ class RevisionMap(object):
)
def _get_ancestor_nodes(
- self, targets, map_=None, check=False, include_dependencies=True):
+ self, targets, map_=None, check=False, include_dependencies=True
+ ):
if include_dependencies:
+
def fn(rev):
return rev._all_down_revisions
+
else:
+
def fn(rev):
return rev._versioned_down_revisions
@@ -617,24 +705,30 @@ class RevisionMap(object):
if rev in seen:
continue
seen.add(rev)
- todo.extend(
- map_[rev_id] for rev_id in fn(rev))
+ todo.extend(map_[rev_id] for rev_id in fn(rev))
yield rev
if check:
- overlaps = per_target.intersection(targets).\
- difference([target])
+ overlaps = per_target.intersection(targets).difference(
+ [target]
+ )
if overlaps:
raise RevisionError(
"Requested revision %s overlaps with "
- "other requested revisions %s" % (
+ "other requested revisions %s"
+ % (
target.revision,
- ", ".join(r.revision for r in overlaps)
+ ", ".join(r.revision for r in overlaps),
)
)
def _iterate_revisions(
- self, upper, lower, inclusive=True, implicit_base=False,
- select_for_downgrade=False):
+ self,
+ upper,
+ lower,
+ inclusive=True,
+ implicit_base=False,
+ select_for_downgrade=False,
+ ):
"""iterate revisions from upper to lower.
The traversal is depth-first within branches, and breadth-first
@@ -650,8 +744,9 @@ class RevisionMap(object):
# is specified using a branch identifier, then we limit operations
# to just that branch.
- limit_to_lower_branch = \
- isinstance(lower, compat.string_types) and lower.endswith('@base')
+ limit_to_lower_branch = isinstance(
+ lower, compat.string_types
+ ) and lower.endswith("@base")
uppers = util.dedupe_tuple(self.get_revisions(upper))
@@ -663,16 +758,14 @@ class RevisionMap(object):
if limit_to_lower_branch:
lowers = self.get_revisions(self._get_base_revisions(lower))
elif implicit_base and requested_lowers:
- lower_ancestors = set(
- self._get_ancestor_nodes(requested_lowers)
- )
+ lower_ancestors = set(self._get_ancestor_nodes(requested_lowers))
lower_descendants = set(
self._get_descendant_nodes(requested_lowers)
)
base_lowers = set()
- candidate_lowers = upper_ancestors.\
- difference(lower_ancestors).\
- difference(lower_descendants)
+ candidate_lowers = upper_ancestors.difference(
+ lower_ancestors
+ ).difference(lower_descendants)
for rev in candidate_lowers:
for downrev in rev._all_down_revisions:
if self._revision_map[downrev] in candidate_lowers:
@@ -690,13 +783,15 @@ class RevisionMap(object):
# represents all nodes we will produce
total_space = set(
- rev.revision for rev in upper_ancestors).intersection(
- rev.revision for rev
- in self._get_descendant_nodes(
- lowers, check=True,
+ rev.revision for rev in upper_ancestors
+ ).intersection(
+ rev.revision
+ for rev in self._get_descendant_nodes(
+ lowers,
+ check=True,
omit_immediate_dependencies=(
select_for_downgrade and requested_lowers
- )
+ ),
)
)
@@ -706,7 +801,8 @@ class RevisionMap(object):
start_from = set(requested_lowers)
start_from.update(
self._get_ancestor_nodes(
- list(start_from), include_dependencies=True)
+ list(start_from), include_dependencies=True
+ )
)
# determine all the current branch points represented
@@ -725,19 +821,18 @@ class RevisionMap(object):
# organize branch points to be consumed separately from
# member nodes
branch_todo = set(
- rev for rev in
- (self._revision_map[rev] for rev in total_space)
- if rev._is_real_branch_point and
- len(total_space.intersection(rev._all_nextrev)) > 1
+ rev
+ for rev in (self._revision_map[rev] for rev in total_space)
+ if rev._is_real_branch_point
+ and len(total_space.intersection(rev._all_nextrev)) > 1
)
# it's not possible for any "uppers" to be in branch_todo,
# because the ._all_nextrev of those nodes is not in total_space
- #assert not branch_todo.intersection(uppers)
+ # assert not branch_todo.intersection(uppers)
todo = collections.deque(
- r for r in uppers
- if r.revision in total_space
+ r for r in uppers if r.revision in total_space
)
# iterate for total_space being emptied out
@@ -746,7 +841,8 @@ class RevisionMap(object):
if not total_space_modified:
raise RevisionError(
- "Dependency resolution failed; iteration can't proceed")
+ "Dependency resolution failed; iteration can't proceed"
+ )
total_space_modified = False
# when everything non-branch pending is consumed,
# add to the todo any branch nodes that have no
@@ -755,12 +851,13 @@ class RevisionMap(object):
todo.extendleft(
sorted(
(
- rev for rev in branch_todo
+ rev
+ for rev in branch_todo
if not rev._all_nextrev.intersection(total_space)
),
# favor "revisioned" branch points before
# dependent ones
- key=lambda rev: 0 if rev.is_branch_point else 1
+ key=lambda rev: 0 if rev.is_branch_point else 1,
)
)
branch_todo.difference_update(todo)
@@ -772,11 +869,14 @@ class RevisionMap(object):
# do depth first for elements within branches,
# don't consume any actual branch nodes
- todo.extendleft([
- self._revision_map[downrev]
- for downrev in reversed(rev._all_down_revisions)
- if self._revision_map[downrev] not in branch_todo
- and downrev in total_space])
+ todo.extendleft(
+ [
+ self._revision_map[downrev]
+ for downrev in reversed(rev._all_down_revisions)
+ if self._revision_map[downrev] not in branch_todo
+ and downrev in total_space
+ ]
+ )
if not inclusive and rev in requested_lowers:
continue
@@ -795,6 +895,7 @@ class Revision(object):
to Python files in a version directory.
"""
+
nextrev = frozenset()
"""following revisions, based on down_revision only."""
@@ -830,15 +931,13 @@ class Revision(object):
illegal_chars = set(revision).intersection(_revision_illegal_chars)
if illegal_chars:
raise RevisionError(
- "Character(s) '%s' not allowed in revision identifier '%s'" % (
- ", ".join(sorted(illegal_chars)),
- revision
- )
+ "Character(s) '%s' not allowed in revision identifier '%s'"
+ % (", ".join(sorted(illegal_chars)), revision)
)
def __init__(
- self, revision, down_revision,
- dependencies=None, branch_labels=None):
+ self, revision, down_revision, dependencies=None, branch_labels=None
+ ):
self.verify_rev_id(revision)
self.revision = revision
self.down_revision = tuple_rev_as_scalar(down_revision)
@@ -848,18 +947,12 @@ class Revision(object):
self.branch_labels = set(self._orig_branch_labels)
def __repr__(self):
- args = [
- repr(self.revision),
- repr(self.down_revision)
- ]
+ args = [repr(self.revision), repr(self.down_revision)]
if self.dependencies:
args.append("dependencies=%r" % (self.dependencies,))
if self.branch_labels:
args.append("branch_labels=%r" % (self.branch_labels,))
- return "%s(%s)" % (
- self.__class__.__name__,
- ", ".join(args)
- )
+ return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
def add_nextrev(self, revision):
self._all_nextrev = self._all_nextrev.union([revision.revision])
@@ -868,8 +961,10 @@ class Revision(object):
@property
def _all_down_revisions(self):
- return util.to_tuple(self.down_revision, default=()) + \
- self._resolved_dependencies
+ return (
+ util.to_tuple(self.down_revision, default=())
+ + self._resolved_dependencies
+ )
@property
def _versioned_down_revisions(self):
diff --git a/alembic/templates/generic/env.py b/alembic/templates/generic/env.py
index 058378b..f3df952 100644
--- a/alembic/templates/generic/env.py
+++ b/alembic/templates/generic/env.py
@@ -37,7 +37,8 @@ def run_migrations_offline():
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
- url=url, target_metadata=target_metadata, literal_binds=True)
+ url=url, target_metadata=target_metadata, literal_binds=True
+ )
with context.begin_transaction():
context.run_migrations()
@@ -52,18 +53,19 @@ def run_migrations_online():
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
- prefix='sqlalchemy.',
- poolclass=pool.NullPool)
+ prefix="sqlalchemy.",
+ poolclass=pool.NullPool,
+ )
with connectable.connect() as connection:
context.configure(
- connection=connection,
- target_metadata=target_metadata
+ connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
+
if context.is_offline_mode():
run_migrations_offline()
else:
diff --git a/alembic/templates/multidb/env.py b/alembic/templates/multidb/env.py
index db24173..f5ad3d4 100644
--- a/alembic/templates/multidb/env.py
+++ b/alembic/templates/multidb/env.py
@@ -14,12 +14,12 @@ config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
-logger = logging.getLogger('alembic.env')
+logger = logging.getLogger("alembic.env")
# gather section names referring to different
# databases. These are named "engine1", "engine2"
# in the sample .ini file.
-db_names = config.get_main_option('databases')
+db_names = config.get_main_option("databases")
# add your model's MetaData objects here
# for 'autogenerate' support. These must be set
@@ -56,19 +56,21 @@ def run_migrations_offline():
# individual files.
engines = {}
- for name in re.split(r',\s*', db_names):
+ for name in re.split(r",\s*", db_names):
engines[name] = rec = {}
- rec['url'] = context.config.get_section_option(name,
- "sqlalchemy.url")
+ rec["url"] = context.config.get_section_option(name, "sqlalchemy.url")
for name, rec in engines.items():
logger.info("Migrating database %s" % name)
file_ = "%s.sql" % name
logger.info("Writing output to %s" % file_)
- with open(file_, 'w') as buffer:
- context.configure(url=rec['url'], output_buffer=buffer,
- target_metadata=target_metadata.get(name),
- literal_binds=True)
+ with open(file_, "w") as buffer:
+ context.configure(
+ url=rec["url"],
+ output_buffer=buffer,
+ target_metadata=target_metadata.get(name),
+ literal_binds=True,
+ )
with context.begin_transaction():
context.run_migrations(engine_name=name)
@@ -85,46 +87,47 @@ def run_migrations_online():
# engines, then run all migrations, then commit all transactions.
engines = {}
- for name in re.split(r',\s*', db_names):
+ for name in re.split(r",\s*", db_names):
engines[name] = rec = {}
- rec['engine'] = engine_from_config(
+ rec["engine"] = engine_from_config(
context.config.get_section(name),
- prefix='sqlalchemy.',
- poolclass=pool.NullPool)
+ prefix="sqlalchemy.",
+ poolclass=pool.NullPool,
+ )
for name, rec in engines.items():
- engine = rec['engine']
- rec['connection'] = conn = engine.connect()
+ engine = rec["engine"]
+ rec["connection"] = conn = engine.connect()
if USE_TWOPHASE:
- rec['transaction'] = conn.begin_twophase()
+ rec["transaction"] = conn.begin_twophase()
else:
- rec['transaction'] = conn.begin()
+ rec["transaction"] = conn.begin()
try:
for name, rec in engines.items():
logger.info("Migrating database %s" % name)
context.configure(
- connection=rec['connection'],
+ connection=rec["connection"],
upgrade_token="%s_upgrades" % name,
downgrade_token="%s_downgrades" % name,
- target_metadata=target_metadata.get(name)
+ target_metadata=target_metadata.get(name),
)
context.run_migrations(engine_name=name)
if USE_TWOPHASE:
for rec in engines.values():
- rec['transaction'].prepare()
+ rec["transaction"].prepare()
for rec in engines.values():
- rec['transaction'].commit()
+ rec["transaction"].commit()
except:
for rec in engines.values():
- rec['transaction'].rollback()
+ rec["transaction"].rollback()
raise
finally:
for rec in engines.values():
- rec['connection'].close()
+ rec["connection"].close()
if context.is_offline_mode():
diff --git a/alembic/templates/pylons/env.py b/alembic/templates/pylons/env.py
index 5ad9fd5..8c06cdc 100644
--- a/alembic/templates/pylons/env.py
+++ b/alembic/templates/pylons/env.py
@@ -13,18 +13,21 @@ from sqlalchemy.engine.base import Engine
try:
# if pylons app already in, don't create a new app
from pylons import config as pylons_config
- pylons_config['__file__']
+
+ pylons_config["__file__"]
except:
config = context.config
# can use config['__file__'] here, i.e. the Pylons
# ini file, instead of alembic.ini
- config_file = config.get_main_option('pylons_config_file')
+ config_file = config.get_main_option("pylons_config_file")
fileConfig(config_file)
- wsgi_app = loadapp('config:%s' % config_file, relative_to='.')
+ wsgi_app = loadapp("config:%s" % config_file, relative_to=".")
# customize this section for non-standard engine configurations.
-meta = __import__("%s.model.meta" % wsgi_app.config['pylons.package']).model.meta
+meta = __import__(
+ "%s.model.meta" % wsgi_app.config["pylons.package"]
+).model.meta
# add your model's MetaData object here
# for 'autogenerate' support
@@ -46,8 +49,10 @@ def run_migrations_offline():
"""
context.configure(
- url=meta.engine.url, target_metadata=target_metadata,
- literal_binds=True)
+ url=meta.engine.url,
+ target_metadata=target_metadata,
+ literal_binds=True,
+ )
with context.begin_transaction():
context.run_migrations()
@@ -65,13 +70,13 @@ def run_migrations_online():
with engine.connect() as connection:
context.configure(
- connection=connection,
- target_metadata=target_metadata
+ connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
+
if context.is_offline_mode():
run_migrations_offline()
else:
diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py
index 553f501..70c28e0 100644
--- a/alembic/testing/__init__.py
+++ b/alembic/testing/__init__.py
@@ -1,6 +1,13 @@
from .fixtures import TestBase
-from .assertions import eq_, ne_, is_, is_not_, assert_raises_message, \
- eq_ignore_whitespace, assert_raises
+from .assertions import (
+ eq_,
+ ne_,
+ is_,
+ is_not_,
+ assert_raises_message,
+ eq_ignore_whitespace,
+ assert_raises,
+)
from .util import provide_metadata
diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py
index 2c7382c..c25b444 100644
--- a/alembic/testing/assertions.py
+++ b/alembic/testing/assertions.py
@@ -15,6 +15,7 @@ from . import config
if not util.sqla_094:
+
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
@@ -46,27 +47,36 @@ if not util.sqla_094:
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
- assert re.search(
- msg, text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(text_type(e).encode('utf-8'))
+ assert re.search(msg, text_type(e), re.UNICODE), "%r !~ %s" % (
+ msg,
+ e,
+ )
+ print(text_type(e).encode("utf-8"))
+
else:
- from sqlalchemy.testing.assertions import eq_, ne_, is_, is_not_, \
- assert_raises_message, assert_raises
+ from sqlalchemy.testing.assertions import (
+ eq_,
+ ne_,
+ is_,
+ is_not_,
+ assert_raises_message,
+ assert_raises,
+ )
def eq_ignore_whitespace(a, b, msg=None):
- a = re.sub(r'^\s+?|\n', "", a)
- a = re.sub(r' {2,}', " ", a)
- b = re.sub(r'^\s+?|\n', "", b)
- b = re.sub(r' {2,}', " ", b)
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
# convert for unicode string rendering,
# using special escape character "!U"
if py3k:
- b = re.sub(r'!U', '', b)
+ b = re.sub(r"!U", "", b)
else:
- b = re.sub(r'!U', 'u', b)
+ b = re.sub(r"!U", "u", b)
assert a == b, msg or "%r != %r" % (a, b)
@@ -74,9 +84,10 @@ def eq_ignore_whitespace(a, b, msg=None):
def assert_compiled(element, assert_string, dialect=None):
dialect = _get_dialect(dialect)
eq_(
- text_type(element.compile(dialect=dialect)).
- replace("\n", "").replace("\t", ""),
- assert_string.replace("\n", "").replace("\t", "")
+ text_type(element.compile(dialect=dialect))
+ .replace("\n", "")
+ .replace("\t", ""),
+ assert_string.replace("\n", "").replace("\t", ""),
)
@@ -84,19 +95,20 @@ _dialect_mods = {}
def _get_dialect(name):
- if name is None or name == 'default':
+ if name is None or name == "default":
return default.DefaultDialect()
else:
try:
dialect_mod = _dialect_mods[name]
except KeyError:
dialect_mod = getattr(
- __import__('sqlalchemy.dialects.%s' % name).dialects, name)
+ __import__("sqlalchemy.dialects.%s" % name).dialects, name
+ )
_dialect_mods[name] = dialect_mod
d = dialect_mod.dialect()
- if name == 'postgresql':
+ if name == "postgresql":
d.implicit_returning = True
- elif name == 'mssql':
+ elif name == "mssql":
d.legacy_schema_aliasing = False
return d
@@ -161,6 +173,7 @@ def emits_warning_on(db, *messages):
were in fact seen.
"""
+
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, *messages):
@@ -189,8 +202,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or \
- (not regex and filter_ == msg):
+ if (regex and filter_.match(msg)) or (
+ not regex and filter_ == msg
+ ):
seen.discard(filter_)
break
else:
@@ -203,6 +217,6 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
yield
if assert_:
- assert not seen, "Warnings were not seen: %s" % \
- ", ".join("%r" % (s.pattern if regex else s) for s in seen)
-
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
diff --git a/alembic/testing/compat.py b/alembic/testing/compat.py
index e0af6a2..9fbd50f 100644
--- a/alembic/testing/compat.py
+++ b/alembic/testing/compat.py
@@ -1,13 +1,12 @@
def get_url_driver_name(url):
- if '+' not in url.drivername:
+ if "+" not in url.drivername:
return url.get_dialect().driver
else:
- return url.drivername.split('+')[1]
+ return url.drivername.split("+")[1]
def get_url_backend_name(url):
- if '+' not in url.drivername:
+ if "+" not in url.drivername:
return url.drivername
else:
- return url.drivername.split('+')[0]
-
+ return url.drivername.split("+")[0]
diff --git a/alembic/testing/config.py b/alembic/testing/config.py
index ca28c6b..7d7009e 100644
--- a/alembic/testing/config.py
+++ b/alembic/testing/config.py
@@ -66,7 +66,8 @@ class Config(object):
assert _current, "Can't push without a default Config set up"
cls.push(
Config(
- db, _current.db_opts, _current.options, _current.file_config)
+ db, _current.db_opts, _current.options, _current.file_config
+ )
)
@classmethod
@@ -88,4 +89,3 @@ class Config(object):
def all_dbs(cls):
for cfg in cls.all_configs():
yield cfg.db
-
diff --git a/alembic/testing/engines.py b/alembic/testing/engines.py
index dadabc8..68d0068 100644
--- a/alembic/testing/engines.py
+++ b/alembic/testing/engines.py
@@ -25,4 +25,3 @@ def testing_engine(url=None, options=None):
engine = create_engine(url, **options)
return engine
-
diff --git a/alembic/testing/env.py b/alembic/testing/env.py
index 0318703..51483d1 100644
--- a/alembic/testing/env.py
+++ b/alembic/testing/env.py
@@ -15,21 +15,22 @@ def _get_staging_directory():
if provision.FOLLOWER_IDENT:
return "scratch_%s" % provision.FOLLOWER_IDENT
else:
- return 'scratch'
+ return "scratch"
def staging_env(create=True, template="generic", sourceless=False):
from alembic import command, script
+
cfg = _testing_config()
if create:
- path = os.path.join(_get_staging_directory(), 'scripts')
+ path = os.path.join(_get_staging_directory(), "scripts")
if os.path.exists(path):
shutil.rmtree(path)
command.init(cfg, path, template=template)
if sourceless:
try:
# do an import so that a .pyc/.pyo is generated.
- util.load_python_file(path, 'env.py')
+ util.load_python_file(path, "env.py")
except AttributeError:
# we don't have the migration context set up yet
# so running the .env py throws this exception.
@@ -38,10 +39,13 @@ def staging_env(create=True, template="generic", sourceless=False):
# worth it.
pass
assert sourceless in (
- "pep3147_envonly", "simple", "pep3147_everything"), sourceless
+ "pep3147_envonly",
+ "simple",
+ "pep3147_everything",
+ ), sourceless
make_sourceless(
os.path.join(path, "env.py"),
- "pep3147" if "pep3147" in sourceless else "simple"
+ "pep3147" if "pep3147" in sourceless else "simple",
)
sc = script.ScriptDirectory.from_config(cfg)
@@ -53,40 +57,44 @@ def clear_staging_env():
def script_file_fixture(txt):
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
path = os.path.join(dir_, "script.py.mako")
- with open(path, 'w') as f:
+ with open(path, "w") as f:
f.write(txt)
def env_file_fixture(txt):
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
- txt = """
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
+ txt = (
+ """
from alembic import context
config = context.config
-""" + txt
+"""
+ + txt
+ )
path = os.path.join(dir_, "env.py")
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
os.unlink(pyc_path)
- with open(path, 'w') as f:
+ with open(path, "w") as f:
f.write(txt)
def _sqlite_file_db(tempname="foo.db"):
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/%s" % (dir_, tempname)
return engines.testing_engine(url=url)
def _sqlite_testing_config(sourceless=False):
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/foo.db" % dir_
- return _write_config_file("""
+ return _write_config_file(
+ """
[alembic]
script_location = %s
sqlalchemy.url = %s
@@ -115,14 +123,17 @@ keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
- """ % (dir_, url, "true" if sourceless else "false"))
+ """
+ % (dir_, url, "true" if sourceless else "false")
+ )
-def _multi_dir_testing_config(sourceless=False, extra_version_location=''):
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
+def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/foo.db" % dir_
- return _write_config_file("""
+ return _write_config_file(
+ """
[alembic]
script_location = %s
sqlalchemy.url = %s
@@ -152,15 +163,22 @@ keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
- """ % (dir_, url, "true" if sourceless else "false",
- extra_version_location))
+ """
+ % (
+ dir_,
+ url,
+ "true" if sourceless else "false",
+ extra_version_location,
+ )
+ )
def _no_sql_testing_config(dialect="postgresql", directives=""):
"""use a postgresql url with no host so that
connections guaranteed to fail"""
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
- return _write_config_file("""
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
+ return _write_config_file(
+ """
[alembic]
script_location = %s
sqlalchemy.url = %s://
@@ -190,32 +208,36 @@ keys = generic
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
-""" % (dir_, dialect, directives))
+"""
+ % (dir_, dialect, directives)
+ )
def _write_config_file(text):
cfg = _testing_config()
- with open(cfg.config_file_name, 'w') as f:
+ with open(cfg.config_file_name, "w") as f:
f.write(text)
return cfg
def _testing_config():
from alembic.config import Config
+
if not os.access(_get_staging_directory(), os.F_OK):
os.mkdir(_get_staging_directory())
- return Config(os.path.join(_get_staging_directory(), 'test_alembic.ini'))
+ return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
def write_script(
- scriptdir, rev_id, content, encoding='ascii', sourceless=False):
+ scriptdir, rev_id, content, encoding="ascii", sourceless=False
+):
old = scriptdir.revision_map.get_revision(rev_id)
path = old.path
content = textwrap.dedent(content)
if encoding:
content = content.encode(encoding)
- with open(path, 'wb') as fp:
+ with open(path, "wb") as fp:
fp.write(content)
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
@@ -223,20 +245,21 @@ def write_script(
script = Script._from_path(scriptdir, path)
old = scriptdir.revision_map.get_revision(script.revision)
if old.down_revision != script.down_revision:
- raise Exception("Can't change down_revision "
- "on a refresh operation.")
+ raise Exception(
+ "Can't change down_revision " "on a refresh operation."
+ )
scriptdir.revision_map.add_revision(script, _replace=True)
if sourceless:
make_sourceless(
- path,
- "pep3147" if sourceless == "pep3147_everything" else "simple"
+ path, "pep3147" if sourceless == "pep3147_everything" else "simple"
)
def make_sourceless(path, style):
import py_compile
+
py_compile.compile(path)
if style == "simple" and has_pep3147():
@@ -264,7 +287,10 @@ def three_rev_fixture(cfg):
script = ScriptDirectory.from_config(cfg)
script.generate_revision(a, "revision a", refresh=True)
- write_script(script, a, """\
+ write_script(
+ script,
+ a,
+ """\
"Rev A"
revision = '%s'
down_revision = None
@@ -279,10 +305,16 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 1")
-""" % a)
+"""
+ % a,
+ )
script.generate_revision(b, "revision b", refresh=True)
- write_script(script, b, u("""# coding: utf-8
+ write_script(
+ script,
+ b,
+ u(
+ """# coding: utf-8
"Rev B, méil, %3"
revision = '{}'
down_revision = '{}'
@@ -297,10 +329,16 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 2")
-""").format(b, a), encoding="utf-8")
+"""
+ ).format(b, a),
+ encoding="utf-8",
+ )
script.generate_revision(c, "revision c", refresh=True)
- write_script(script, c, """\
+ write_script(
+ script,
+ c,
+ """\
"Rev C"
revision = '%s'
down_revision = '%s'
@@ -315,7 +353,9 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 3")
-""" % (c, b))
+"""
+ % (c, b),
+ )
return a, b, c
@@ -328,8 +368,12 @@ def multi_heads_fixture(cfg, a, b, c):
script = ScriptDirectory.from_config(cfg)
script.generate_revision(
- d, "revision d from b", head=b, splice=True, refresh=True)
- write_script(script, d, """\
+ d, "revision d from b", head=b, splice=True, refresh=True
+ )
+ write_script(
+ script,
+ d,
+ """\
"Rev D"
revision = '%s'
down_revision = '%s'
@@ -344,11 +388,17 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 4")
-""" % (d, b))
+"""
+ % (d, b),
+ )
script.generate_revision(
- e, "revision e from d", head=d, splice=True, refresh=True)
- write_script(script, e, """\
+ e, "revision e from d", head=d, splice=True, refresh=True
+ )
+ write_script(
+ script,
+ e,
+ """\
"Rev E"
revision = '%s'
down_revision = '%s'
@@ -363,11 +413,17 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 5")
-""" % (e, d))
+"""
+ % (e, d),
+ )
script.generate_revision(
- f, "revision f from b", head=b, splice=True, refresh=True)
- write_script(script, f, """\
+ f, "revision f from b", head=b, splice=True, refresh=True
+ )
+ write_script(
+ script,
+ f,
+ """\
"Rev F"
revision = '%s'
down_revision = '%s'
@@ -382,7 +438,9 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 6")
-""" % (f, b))
+"""
+ % (f, b),
+ )
return d, e, f
@@ -390,18 +448,16 @@ def downgrade():
def _multidb_testing_config(engines):
"""alembic.ini fixture to work exactly with the 'multidb' template"""
- dir_ = os.path.join(_get_staging_directory(), 'scripts')
+ dir_ = os.path.join(_get_staging_directory(), "scripts")
- databases = ", ".join(
- engines.keys()
- )
+ databases = ", ".join(engines.keys())
engines = "\n\n".join(
- "[%s]\n"
- "sqlalchemy.url = %s" % (key, value.url)
+ "[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
for key, value in engines.items()
)
- return _write_config_file("""
+ return _write_config_file(
+ """
[alembic]
script_location = %s
sourceless = false
@@ -432,5 +488,6 @@ keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
- """ % (dir_, databases, engines)
+ """
+ % (dir_, databases, engines)
)
diff --git a/alembic/testing/exclusions.py b/alembic/testing/exclusions.py
index 7d33a5b..41ed547 100644
--- a/alembic/testing/exclusions.py
+++ b/alembic/testing/exclusions.py
@@ -74,15 +74,15 @@ class compound(object):
def matching_config_reasons(self, config):
return [
- predicate._as_string(config) for predicate
- in self.skips.union(self.fails)
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
- not self.tags.intersection(exclude_tags) and
- (not include_tags or self.tags.intersection(include_tags))
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
@@ -91,13 +91,14 @@ class compound(object):
self.tags.update(other.tags)
def __call__(self, fn):
- if hasattr(fn, '_sa_exclusion_extend'):
+ if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
+
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@@ -117,10 +118,7 @@ class compound(object):
def _do(self, config, fn, *args, **kw):
for skip in self.skips:
if skip(config):
- msg = "'%s' : %s" % (
- fn.__name__,
- skip._as_string(config)
- )
+ msg = "'%s' : %s" % (fn.__name__, skip._as_string(config))
raise SkipTest(msg)
try:
@@ -131,16 +129,20 @@ class compound(object):
self._expect_success(config, name=fn.__name__)
return return_value
- def _expect_failure(self, config, ex, name='block'):
+ def _expect_failure(self, config, ex, name="block"):
for fail in self.fails:
if fail(config):
- print(("%s failed as expected (%s): %s " % (
- name, fail._as_string(config), str(ex))))
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str(ex))
+ )
+ )
break
else:
compat.raise_from_cause(ex)
- def _expect_success(self, config, name='block'):
+ def _expect_success(self, config, name="block"):
if not self.fails:
return
for fail in self.fails:
@@ -148,13 +150,12 @@ class compound(object):
break
else:
raise AssertionError(
- "Unexpected success for '%s' (%s)" %
- (
+ "Unexpected success for '%s' (%s)"
+ % (
name,
" and ".join(
- fail._as_string(config)
- for fail in self.fails
- )
+ fail._as_string(config) for fail in self.fails
+ ),
)
)
@@ -191,8 +192,8 @@ class Predicate(object):
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate],
- description)
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, compat.string_types):
@@ -217,7 +218,7 @@ class Predicate(object):
"driver": get_url_driver_name(config.db.url),
"database": get_url_backend_name(config.db.url),
"doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support"
+ "does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
@@ -244,21 +245,21 @@ class SpecPredicate(Predicate):
self.description = description
_ops = {
- '<': operator.lt,
- '>': operator.gt,
- '==': operator.eq,
- '!=': operator.ne,
- '<=': operator.le,
- '>=': operator.ge,
- 'in': operator.contains,
- 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
engine = config.db
if "+" in self.db:
- dialect, driver = self.db.split('+')
+ dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
@@ -271,8 +272,9 @@ class SpecPredicate(Predicate):
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
- oper = hasattr(self.op, '__call__') and self.op \
- or self._ops[self.op]
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
return oper(version, self.spec)
else:
return True
@@ -287,17 +289,9 @@ class SpecPredicate(Predicate):
return "%s" % self.db
else:
if negate:
- return "not %s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "not %s %s %s" % (self.db, self.op, self.spec)
else:
- return "%s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
@@ -354,8 +348,9 @@ class OrPredicate(Predicate):
conjunction = " and "
else:
conjunction = " or "
- return conjunction.join(p._as_string(config, negate=negate)
- for p in self.predicates)
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
def _negation_str(self, config):
if self.description is not None:
@@ -385,15 +380,13 @@ def _server_version(engine):
# force metadata to be retrieved
conn = engine.connect()
- version = getattr(engine.dialect, 'server_version_info', ())
+ version = getattr(engine.dialect, "server_version_info", ())
conn.close()
return version
def db_spec(*dbs):
- return OrPredicate(
- [Predicate.as_predicate(db) for db in dbs]
- )
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open():
@@ -418,11 +411,7 @@ def fails_on(db, reason=None):
def fails_on_everything_except(*dbs):
- return succeeds_if(
- OrPredicate([
- Predicate.as_predicate(db) for db in dbs
- ])
- )
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
@@ -441,7 +430,6 @@ def exclude(db, op, spec, reason=None):
def against(config, *queries):
assert queries, "no queries sent!"
- return OrPredicate([
- Predicate.as_predicate(query)
- for query in queries
- ])(config)
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py
index 86d40a2..b812476 100644
--- a/alembic/testing/fixtures.py
+++ b/alembic/testing/fixtures.py
@@ -17,10 +17,11 @@ from .assertions import _get_dialect, eq_
from . import mock
testing_config = configparser.ConfigParser()
-testing_config.read(['test.cfg'])
+testing_config.read(["test.cfg"])
if not util.sqla_094:
+
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
@@ -51,6 +52,8 @@ if not util.sqla_094:
def teardown(self):
if hasattr(self, "tearDown"):
self.tearDown()
+
+
else:
from sqlalchemy.testing.fixtures import TestBase
@@ -60,23 +63,22 @@ def capture_db():
def dump(sql, *multiparams, **params):
buf.append(str(sql.compile(dialect=engine.dialect)))
+
engine = create_engine("postgresql://", strategy="mock", executor=dump)
return engine, buf
+
_engs = {}
@contextmanager
def capture_context_buffer(**kw):
- if kw.pop('bytes_io', False):
+ if kw.pop("bytes_io", False):
buf = io.BytesIO()
else:
buf = io.StringIO()
- kw.update({
- 'dialect_name': "sqlite",
- 'output_buffer': buf
- })
+ kw.update({"dialect_name": "sqlite", "output_buffer": buf})
conf = EnvironmentContext.configure
def configure(*arg, **opt):
@@ -88,17 +90,20 @@ def capture_context_buffer(**kw):
def op_fixture(
- dialect='default', as_sql=False,
- naming_convention=None, literal_binds=False,
- native_boolean=None):
+ dialect="default",
+ as_sql=False,
+ naming_convention=None,
+ literal_binds=False,
+ native_boolean=None,
+):
opts = {}
if naming_convention:
if not util.sqla_092:
raise SkipTest(
- "naming_convention feature requires "
- "sqla 0.9.2 or greater")
- opts['target_metadata'] = MetaData(naming_convention=naming_convention)
+ "naming_convention feature requires " "sqla 0.9.2 or greater"
+ )
+ opts["target_metadata"] = MetaData(naming_convention=naming_convention)
class buffer_(object):
def __init__(self):
@@ -106,12 +111,12 @@ def op_fixture(
def write(self, msg):
msg = msg.strip()
- msg = re.sub(r'[\n\t]', '', msg)
+ msg = re.sub(r"[\n\t]", "", msg)
if as_sql:
# the impl produces soft tabs,
# so search for blocks of 4 spaces
- msg = re.sub(r' ', '', msg)
- msg = re.sub(r'\;\n*$', '', msg)
+ msg = re.sub(r" ", "", msg)
+ msg = re.sub(r"\;\n*$", "", msg)
self.lines.append(msg)
@@ -136,13 +141,13 @@ def op_fixture(
else:
assert False, "Could not locate fragment %r in %r" % (
sql,
- buf.lines
+ buf.lines,
)
if as_sql:
- opts['as_sql'] = as_sql
+ opts["as_sql"] = as_sql
if literal_binds:
- opts['literal_binds'] = literal_binds
+ opts["literal_binds"] = literal_binds
ctx_dialect = _get_dialect(dialect)
if native_boolean is not None:
ctx_dialect.supports_native_boolean = native_boolean
@@ -150,6 +155,7 @@ def op_fixture(
# which breaks assumptions in the alembic test suite
ctx_dialect.non_native_boolean_check_constraint = True
if not as_sql:
+
def execute(stmt, *multiparam, **param):
if isinstance(stmt, string_types):
stmt = text(stmt)
@@ -160,12 +166,9 @@ def op_fixture(
connection = mock.Mock(dialect=ctx_dialect, execute=execute)
else:
- opts['output_buffer'] = buf
+ opts["output_buffer"] = buf
connection = None
- context = ctx(
- ctx_dialect,
- connection,
- opts)
+ context = ctx(ctx_dialect, connection, opts)
alembic.op._proxy = Operations(context)
return context
diff --git a/alembic/testing/mock.py b/alembic/testing/mock.py
index 08a756c..1d5256d 100644
--- a/alembic/testing/mock.py
+++ b/alembic/testing/mock.py
@@ -22,4 +22,5 @@ else:
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "
- "'mock' library as of 0.8.2.")
+ "'mock' library as of 0.8.2."
+ )
diff --git a/alembic/testing/plugin/bootstrap.py b/alembic/testing/plugin/bootstrap.py
index 9f42fd2..4bd415d 100644
--- a/alembic/testing/plugin/bootstrap.py
+++ b/alembic/testing/plugin/bootstrap.py
@@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0.
import os
import sys
-bootstrap_file = locals()['bootstrap_file']
-to_bootstrap = locals()['to_bootstrap']
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
if sys.version_info.major >= 3:
from importlib import machinery
+
mod = machinery.SourceFileLoader(name, path).load_module()
else:
import imp
+
mod = imp.load_source(name, path)
return mod
+
if to_bootstrap == "pytest":
sys.modules["alembic_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["alembic_pytestplugin"] = load_file_as_module("pytestplugin")
diff --git a/alembic/testing/plugin/noseplugin.py b/alembic/testing/plugin/noseplugin.py
index f8894d6..fafb9e1 100644
--- a/alembic/testing/plugin/noseplugin.py
+++ b/alembic/testing/plugin/noseplugin.py
@@ -25,6 +25,7 @@ import os
import sys
from nose.plugins import Plugin
+
fixtures = None
py3k = sys.version_info.major >= 3
@@ -33,7 +34,7 @@ py3k = sys.version_info.major >= 3
class NoseSQLAlchemy(Plugin):
enabled = True
- name = 'sqla_testing'
+ name = "sqla_testing"
score = 100
def options(self, parser, env=os.environ):
@@ -43,8 +44,10 @@ class NoseSQLAlchemy(Plugin):
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
+
def wrap_(option, opt_str, value, parser):
callback_(opt_str, value, parser)
+
kw["callback"] = wrap_
opt(name, **kw)
@@ -71,7 +74,7 @@ class NoseSQLAlchemy(Plugin):
def wantMethod(self, fn):
if py3k:
- if not hasattr(fn.__self__, 'cls'):
+ if not hasattr(fn.__self__, "cls"):
return False
cls = fn.__self__.cls
else:
@@ -85,19 +88,19 @@ class NoseSQLAlchemy(Plugin):
plugin_base.before_test(
test,
test.test.cls.__module__,
- test.test.cls, test.test.method.__name__)
+ test.test.cls,
+ test.test.method.__name__,
+ )
def afterTest(self, test):
plugin_base.after_test(test)
def startContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.start_test_class(ctx)
def stopContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.stop_test_class(ctx)
diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py
index 141e82f..9acffb5 100644
--- a/alembic/testing/plugin/plugin_base.py
+++ b/alembic/testing/plugin/plugin_base.py
@@ -17,12 +17,14 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0
"""
from __future__ import absolute_import
+
try:
# unitttest has a SkipTest also but pytest doesn't
# honor it unless nose is imported too...
from nose import SkipTest
except ImportError:
from pytest import skip
+
SkipTest = skip.Exception
import sys
@@ -55,54 +57,118 @@ options = None
def setup_options(make_option):
- make_option("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
- make_option("--log-debug", action="callback",
- type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
- make_option("--db", action="append", type="string", dest="db",
- help="Use prefab database uri. Multiple OK, "
- "first one is run by default.")
- make_option('--dbs', action='callback', zeroarg_callback=_list_dbs,
- help="List available prefab dbs")
- make_option("--dburi", action="append", type="string", dest="dburi",
- help="Database uri. Multiple OK, "
- "first one is run by default.")
- make_option("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first")
- make_option("--backend-only", action="store_true", dest="backend_only",
- help="Run only tests marked with __backend__")
- make_option("--postgresql-templatedb", type="string",
- help="name of template database to use for Postgresql "
- "CREATE DATABASE (defaults to current database)")
- make_option("--low-connections", action="store_true",
- dest="low_connections",
- help="Use a low number of distinct connections - "
- "i.e. for Oracle TNS")
- make_option("--write-idents", type="string", dest="write_idents",
- help="write out generated follower idents to <file>, "
- "when -n<num> is used")
- make_option("--reversetop", action="store_true",
- dest="reversetop", default=False,
- help="Use a random-ordering set implementation in the ORM "
- "(helps reveal dependency issues)")
- make_option("--requirements", action="callback", type="string",
- callback=_requirements_opt,
- help="requirements class for testing, overrides setup.cfg")
- make_option("--with-cdecimal", action="store_true",
- dest="cdecimal", default=False,
- help="Monkeypatch the cdecimal library into Python 'decimal' "
- "for all tests")
- make_option("--include-tag", action="callback", callback=_include_tag,
- type="string",
- help="Include tests with tag <tag>")
- make_option("--exclude-tag", action="callback", callback=_exclude_tag,
- type="string",
- help="Exclude tests with tag <tag>")
- make_option("--mysql-engine", action="store",
- dest="mysql_engine", default=None,
- help="Use the specified MySQL storage engine for all tables, "
- "default is a db-default/InnoDB combo.")
+ make_option(
+ "--log-info",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type="string",
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type="string",
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type="string",
+ help="name of template database to use for Postgresql "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type="string",
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type="string",
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type="string",
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type="string",
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--mysql-engine",
+ action="store",
+ dest="mysql_engine",
+ default=None,
+ help="Use the specified MySQL storage engine for all tables, "
+ "default is a db-default/InnoDB combo.",
+ )
def configure_follower(follower_ident):
@@ -113,6 +179,7 @@ def configure_follower(follower_ident):
"""
from alembic.testing import provision
+
provision.FOLLOWER_IDENT = follower_ident
@@ -126,9 +193,9 @@ def memoize_important_follower_config(dict_):
callables, so we have to just copy all of that over.
"""
- dict_['memoized_config'] = {
- 'include_tags': include_tags,
- 'exclude_tags': exclude_tags
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
}
@@ -138,14 +205,14 @@ def restore_important_follower_config(dict_):
This invokes in the follower process.
"""
- include_tags.update(dict_['memoized_config']['include_tags'])
- exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
def read_config():
global file_config
file_config = configparser.ConfigParser()
- file_config.read(['setup.cfg', 'test.cfg'])
+ file_config.read(["setup.cfg", "test.cfg"])
def pre_begin(opt):
@@ -169,12 +236,11 @@ def post_begin():
# late imports, has to happen after config as well
# as nose plugins like coverage
- global util, fixtures, engines, exclusions, \
- assertions, warnings, profiling,\
- config, testing
+ global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing
from alembic.testing import config, warnings, exclusions # noqa
from alembic.testing import engines, fixtures # noqa
from sqlalchemy import util # noqa
+
warnings.setup_filters()
@@ -182,18 +248,19 @@ def _log(opt_str, value, parser):
global logging
if not logging:
import logging
+
logging.basicConfig()
- if opt_str.endswith('-info'):
+ if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
+ elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print("Available --db options (use --dburi to override)")
- for macro in sorted(file_config.options('db')):
- print("%20s\t%s" % (macro, file_config.get('db', macro)))
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
@@ -202,11 +269,12 @@ def _requirements_opt(opt_str, value, parser):
def _exclude_tag(opt_str, value, parser):
- exclude_tags.add(value.replace('-', '_'))
+ exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
- include_tags.add(value.replace('-', '_'))
+ include_tags.add(value.replace("-", "_"))
+
pre_configure = []
post_configure = []
@@ -228,12 +296,12 @@ def _setup_options(opt, file_config):
options = opt
-
@pre
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
import cdecimal
- sys.modules['decimal'] = cdecimal
+
+ sys.modules["decimal"] = cdecimal
@post
@@ -248,26 +316,27 @@ def _engine_uri(options, file_config):
if options.db:
for db_token in options.db:
- for db in re.split(r'[,\s]+', db_token):
- if db not in file_config.options('db'):
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
- "Specify --dbs for known uris."
- % db)
+ "Specify --dbs for known uris." % db
+ )
else:
- db_urls.append(file_config.get('db', db))
+ db_urls.append(file_config.get("db", db))
if not db_urls:
- db_urls.append(file_config.get('db', 'default'))
+ db_urls.append(file_config.get("db", "default"))
for db_url in db_urls:
- if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
+ if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
with open(options.write_idents, "a") as file_:
file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
cfg = provision.setup_config(
- db_url, options, file_config, provision.FOLLOWER_IDENT)
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
if not config._current:
cfg.set_as_current(cfg)
@@ -276,7 +345,7 @@ def _engine_uri(options, file_config):
@post
def _requirements(options, file_config):
- requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
@@ -317,56 +386,75 @@ def _prep_testing_database(options, file_config):
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData())
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(vname, schema.MetaData())
+ )
+ )
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(
- schema="test_schema")
+ view_names = inspector.get_view_names(schema="test_schema")
except NotImplementedError:
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData(),
- schema="test_schema")
- ))
-
- for tname in reversed(inspector.get_table_names(
- order_by="foreign_key")):
- e.execute(schema.DropTable(
- schema.Table(tname, schema.MetaData())
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
+
+ for tname in reversed(
+ inspector.get_table_names(order_by="foreign_key")
+ ):
+ e.execute(
+ schema.DropTable(schema.Table(tname, schema.MetaData()))
+ )
if config.requirements.schemas.enabled_for_config(cfg):
- for tname in reversed(inspector.get_table_names(
- order_by="foreign_key", schema="test_schema")):
- e.execute(schema.DropTable(
- schema.Table(tname, schema.MetaData(),
- schema="test_schema")
- ))
+ for tname in reversed(
+ inspector.get_table_names(
+ order_by="foreign_key", schema="test_schema"
+ )
+ ):
+ e.execute(
+ schema.DropTable(
+ schema.Table(
+ tname, schema.MetaData(), schema="test_schema"
+ )
+ )
+ )
if against(cfg, "postgresql") and util.sqla_100:
from sqlalchemy.dialects import postgresql
+
for enum in inspector.get_enums("*"):
- e.execute(postgresql.DropEnumType(
- postgresql.ENUM(
- name=enum['name'],
- schema=enum['schema'])))
+ e.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(
+ name=enum["name"], schema=enum["schema"]
+ )
+ )
+ )
@post
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm.util import randomize_unitofwork
+
randomize_unitofwork()
@post
def _post_setup_options(opt, file_config):
from alembic.testing import config
+
config.options = options
config.file_config = file_config
@@ -374,10 +462,11 @@ def _post_setup_options(opt, file_config):
def want_class(cls):
if not issubclass(cls, fixtures.TestBase):
return False
- elif cls.__name__.startswith('_'):
+ elif cls.__name__.startswith("_"):
return False
- elif config.options.backend_only and not getattr(cls, '__backend__',
- False):
+ elif config.options.backend_only and not getattr(
+ cls, "__backend__", False
+ ):
return False
else:
return True
@@ -390,25 +479,28 @@ def want_method(cls, fn):
return False
elif include_tags:
return (
- hasattr(cls, '__tags__') and
- exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
) or (
- hasattr(fn, '_sa_exclusion_extend') and
- fn._sa_exclusion_extend.include_test(
- include_tags, exclude_tags)
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
)
- elif exclude_tags and hasattr(cls, '__tags__'):
+ elif exclude_tags and hasattr(cls, "__tags__"):
return exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
- elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
else:
return True
def generate_sub_tests(cls, module):
- if getattr(cls, '__backend__', False):
+ if getattr(cls, "__backend__", False):
for cfg in _possible_configs_for_cls(cls):
orig_name = cls.__name__
@@ -416,17 +508,14 @@ def generate_sub_tests(cls, module):
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
- alpha_name = re.sub(r'[_\[\]\.]+', '_', cfg.name)
- alpha_name = re.sub('_+$', '', alpha_name)
+ alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub("_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
- (cls, ),
- {
- "_sa_orig_cls_name": orig_name,
- "__only_on_config__": cfg
- }
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
@@ -440,8 +529,8 @@ def start_test_class(cls):
def stop_test_class(cls):
- #from sqlalchemy import inspect
- #assert not inspect(testing.db).get_table_names()
+ # from sqlalchemy import inspect
+ # assert not inspect(testing.db).get_table_names()
_restore_engine()
@@ -450,7 +539,7 @@ def _restore_engine():
def _setup_engine(cls):
- if getattr(cls, '__engine_options__', None):
+ if getattr(cls, "__engine_options__", None):
eng = engines.testing_engine(options=cls.__engine_options__)
config._current.push_engine(eng)
@@ -472,16 +561,16 @@ def _possible_configs_for_cls(cls, reasons=None):
if spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on__', None):
+ if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on_config__', None):
+ if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
- if hasattr(cls, '__requires__'):
+ if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
@@ -494,7 +583,7 @@ def _possible_configs_for_cls(cls, reasons=None):
reasons.extend(skip_reasons)
break
- if hasattr(cls, '__prefer_requires__'):
+ if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
@@ -513,30 +602,32 @@ def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
- if getattr(cls, '__skip_if__', False):
- for c in getattr(cls, '__skip_if__'):
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
if c():
- raise SkipTest("'%s' skipped by %s" % (
- cls.__name__, c.__name__)
+ raise SkipTest(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s' unsupported on any DB implementation %s%s" % (
cls.__name__,
", ".join(
- "'%s(%s)+%s'" % (
+ "'%s(%s)+%s'"
+ % (
config_obj.db.name,
".".join(
- str(dig) for dig in
- config_obj.db.dialect.server_version_info),
- config_obj.db.driver
+ str(dig)
+ for dig in config_obj.db.dialect.server_version_info
+ ),
+ config_obj.db.driver,
)
- for config_obj in config.Config.all_configs()
+ for config_obj in config.Config.all_configs()
),
- ", ".join(reasons)
+ ", ".join(reasons),
)
raise SkipTest(msg)
- elif hasattr(cls, '__prefer_backends__'):
+ elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py
index 4d0f340..cc5b69f 100644
--- a/alembic/testing/plugin/pytestplugin.py
+++ b/alembic/testing/plugin/pytestplugin.py
@@ -21,6 +21,7 @@ import os
try:
import xdist # noqa
+
has_xdist = True
except ImportError:
has_xdist = False
@@ -32,30 +33,42 @@ def pytest_addoption(parser):
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
+
class CallableAction(argparse.Action):
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
callback_(option_string, values, parser)
+
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
+
class CallableAction(argparse.Action):
- def __init__(self, option_strings,
- dest, default=False,
- required=False, help=None):
- super(CallableAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=True,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None,
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
zeroarg_callback(option_string, values, parser)
+
kw["action"] = CallableAction
group.addoption(name, **kw)
@@ -67,23 +80,24 @@ def pytest_addoption(parser):
def pytest_configure(config):
if hasattr(config, "slaveinput"):
plugin_base.restore_important_follower_config(config.slaveinput)
- plugin_base.configure_follower(
- config.slaveinput["follower_ident"]
- )
+ plugin_base.configure_follower(config.slaveinput["follower_ident"])
else:
- if config.option.write_idents and \
- os.path.exists(config.option.write_idents):
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
- plugin_base.set_coverage_flag(bool(getattr(config.option,
- "cov_source", False)))
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
def pytest_sessionstart(session):
plugin_base.post_begin()
+
if has_xdist:
import uuid
@@ -95,10 +109,12 @@ if has_xdist:
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
from alembic.testing import provision
+
provision.create_follower_db(node.slaveinput["follower_ident"])
def pytest_testnodedown(node, error):
from alembic.testing import provision
+
provision.drop_follower_db(node.slaveinput["follower_ident"])
@@ -115,18 +131,19 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(list)
items[:] = [
- item for item in
- items if isinstance(item.parent, pytest.Instance)]
+ item for item in items if isinstance(item.parent, pytest.Instance)
+ ]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
- test_class.cls, test_class.parent.module):
+ test_class.cls, test_class.parent.module
+ ):
if sub_cls is not test_class.cls:
list_ = rebuilt_items[test_class.cls]
for inst in pytest.Class(
- sub_cls.__name__,
- parent=test_class.parent.parent).collect():
+ sub_cls.__name__, parent=test_class.parent.parent
+ ).collect():
list_.extend(inst.collect())
newitems = []
@@ -139,23 +156,29 @@ def pytest_collection_modifyitems(session, config, items):
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
- items[:] = sorted(newitems, key=lambda item: (
- item.parent.parent.parent.name,
- item.parent.parent.name,
- item.name
- ))
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.parent.parent.parent.name,
+ item.parent.parent.name,
+ item.name,
+ ),
+ )
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(obj):
return pytest.Class(name, parent=collector)
- elif inspect.isfunction(obj) and \
- isinstance(collector, pytest.Instance) and \
- plugin_base.want_method(collector.cls, obj):
+ elif (
+ inspect.isfunction(obj)
+ and isinstance(collector, pytest.Instance)
+ and plugin_base.want_method(collector.cls, obj)
+ ):
return pytest.Function(name, parent=collector)
else:
return []
+
_current_class = None
@@ -180,6 +203,7 @@ def pytest_runtest_setup(item):
global _current_class
class_teardown(item.parent.parent)
_current_class = None
+
item.parent.parent.addfinalizer(finalize)
test_setup(item)
@@ -194,8 +218,9 @@ def pytest_runtest_teardown(item):
def test_setup(item):
- plugin_base.before_test(item, item.parent.module.__name__,
- item.parent.cls, item.name)
+ plugin_base.before_test(
+ item, item.parent.module.__name__, item.parent.cls, item.name
+ )
def test_teardown(item):
diff --git a/alembic/testing/provision.py b/alembic/testing/provision.py
index 05a21d3..a5ce53c 100644
--- a/alembic/testing/provision.py
+++ b/alembic/testing/provision.py
@@ -30,6 +30,7 @@ class register(object):
def decorate(fn):
self.fns[dbname] = fn
return self
+
return decorate
def __call__(self, cfg, *arg):
@@ -43,7 +44,7 @@ class register(object):
if backend in self.fns:
return self.fns[backend](cfg, *arg)
else:
- return self.fns['*'](cfg, *arg)
+ return self.fns["*"](cfg, *arg)
def create_follower_db(follower_ident):
@@ -86,9 +87,7 @@ def _configs_for_db_operation():
for cfg in config.Config.all_configs():
url = cfg.db.url
backend = get_url_backend_name(url)
- host_conf = (
- backend,
- url.username, url.host, url.database)
+ host_conf = (backend, url.username, url.host, url.database)
if host_conf not in hosts:
yield cfg
@@ -132,13 +131,13 @@ def _follower_url_from_main(url, ident):
@_update_db_opts.for_db("mssql")
def _mssql_update_db_opts(db_url, db_opts):
- db_opts['legacy_schema_aliasing'] = False
+ db_opts["legacy_schema_aliasing"] = False
@_follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
- if not url.database or url.database == ':memory:':
+ if not url.database or url.database == ":memory:":
return url
else:
return sa_url.make_url("sqlite:///%s.db" % ident)
@@ -154,19 +153,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
# as an attached
if not follower_ident:
dbapi_connection.execute(
- 'ATTACH DATABASE "test_schema.db" AS test_schema')
+ 'ATTACH DATABASE "test_schema.db" AS test_schema'
+ )
else:
dbapi_connection.execute(
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
- % follower_ident)
+ % follower_ident
+ )
@_create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
try:
_pg_drop_db(cfg, conn, ident)
except Exception:
@@ -222,14 +222,15 @@ def _sqlite_create_db(cfg, eng, ident):
@_drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
- ), dname=ident)
+ ),
+ dname=ident,
+ )
conn.execute("DROP DATABASE %s" % ident)
@@ -258,7 +259,7 @@ def _oracle_create_db(cfg, eng, ident):
conn.execute("create user %s identified by xe" % ident)
conn.execute("create user %s_ts1 identified by xe" % ident)
conn.execute("create user %s_ts2 identified by xe" % ident)
- conn.execute("grant dba to %s" % (ident, ))
+ conn.execute("grant dba to %s" % (ident,))
conn.execute("grant unlimited tablespace to %s" % ident)
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
@@ -316,8 +317,9 @@ def reap_oracle_dbs(idents_file):
to_reap = conn.execute(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
- "from v$session where username=u.username)")
- all_names = set(username.lower() for (username, ) in to_reap)
+ "from v$session where username=u.username)"
+ )
+ all_names = set(username.lower() for (username,) in to_reap)
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
@@ -334,15 +336,13 @@ def reap_oracle_dbs(idents_file):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
@_follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
url.username = ident
- url.password = 'xe'
+ url.password = "xe"
return url
-
-
diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py
index 400642f..f25f5d7 100644
--- a/alembic/testing/requirements.py
+++ b/alembic/testing/requirements.py
@@ -5,6 +5,7 @@ from . import exclusions
if util.sqla_094:
from sqlalchemy.testing.requirements import Requirements
else:
+
class Requirements(object):
pass
@@ -28,7 +29,7 @@ class SuiteRequirements(Requirements):
insp = inspect(config.db)
try:
- insp.get_unique_constraints('x')
+ insp.get_unique_constraints("x")
except NotImplementedError:
return True
except TypeError:
@@ -62,83 +63,80 @@ class SuiteRequirements(Requirements):
def fail_before_sqla_100(self):
return exclusions.fails_if(
lambda config: not util.sqla_100,
- "SQLAlchemy 1.0.0 or greater required"
+ "SQLAlchemy 1.0.0 or greater required",
)
@property
def fail_before_sqla_1010(self):
return exclusions.fails_if(
lambda config: not util.sqla_1010,
- "SQLAlchemy 1.0.10 or greater required"
+ "SQLAlchemy 1.0.10 or greater required",
)
@property
def fail_before_sqla_099(self):
return exclusions.fails_if(
lambda config: not util.sqla_099,
- "SQLAlchemy 0.9.9 or greater required"
+ "SQLAlchemy 0.9.9 or greater required",
)
@property
def fail_before_sqla_110(self):
return exclusions.fails_if(
lambda config: not util.sqla_110,
- "SQLAlchemy 1.1.0 or greater required"
+ "SQLAlchemy 1.1.0 or greater required",
)
@property
def sqlalchemy_092(self):
return exclusions.skip_if(
lambda config: not util.sqla_092,
- "SQLAlchemy 0.9.2 or greater required"
+ "SQLAlchemy 0.9.2 or greater required",
)
@property
def sqlalchemy_094(self):
return exclusions.skip_if(
lambda config: not util.sqla_094,
- "SQLAlchemy 0.9.4 or greater required"
+ "SQLAlchemy 0.9.4 or greater required",
)
@property
def sqlalchemy_099(self):
return exclusions.skip_if(
lambda config: not util.sqla_099,
- "SQLAlchemy 0.9.9 or greater required"
+ "SQLAlchemy 0.9.9 or greater required",
)
@property
def sqlalchemy_100(self):
return exclusions.skip_if(
lambda config: not util.sqla_100,
- "SQLAlchemy 1.0.0 or greater required"
+ "SQLAlchemy 1.0.0 or greater required",
)
@property
def sqlalchemy_1014(self):
return exclusions.skip_if(
lambda config: not util.sqla_1014,
- "SQLAlchemy 1.0.14 or greater required"
+ "SQLAlchemy 1.0.14 or greater required",
)
@property
def sqlalchemy_1115(self):
return exclusions.skip_if(
lambda config: not util.sqla_1115,
- "SQLAlchemy 1.1.15 or greater required"
+ "SQLAlchemy 1.1.15 or greater required",
)
@property
def sqlalchemy_110(self):
return exclusions.skip_if(
lambda config: not util.sqla_110,
- "SQLAlchemy 1.1.0 or greater required"
+ "SQLAlchemy 1.1.0 or greater required",
)
@property
def pep3147(self):
- return exclusions.only_if(
- lambda config: util.compat.has_pep3147()
- )
-
+ return exclusions.only_if(lambda config: util.compat.has_pep3147())
diff --git a/alembic/testing/runner.py b/alembic/testing/runner.py
index d4adbcf..46236a0 100644
--- a/alembic/testing/runner.py
+++ b/alembic/testing/runner.py
@@ -45,4 +45,4 @@ def setup_py_test():
to nose.
"""
- nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])
+ nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"])
diff --git a/alembic/testing/util.py b/alembic/testing/util.py
index 466dea3..b2b3476 100644
--- a/alembic/testing/util.py
+++ b/alembic/testing/util.py
@@ -10,7 +10,7 @@ def provide_metadata(fn, *args, **kw):
metadata = schema.MetaData(config.db)
self = args[0]
- prev_meta = getattr(self, 'metadata', None)
+ prev_meta = getattr(self, "metadata", None)
self.metadata = metadata
try:
return fn(*args, **kw)
diff --git a/alembic/testing/warnings.py b/alembic/testing/warnings.py
index de91778..cb59a64 100644
--- a/alembic/testing/warnings.py
+++ b/alembic/testing/warnings.py
@@ -17,11 +17,12 @@ import re
def setup_filters():
"""Set global warning behavior for the test suite."""
- warnings.filterwarnings('ignore',
- category=sa_exc.SAPendingDeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SAWarning)
- warnings.filterwarnings('error', category=DeprecationWarning)
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
+ warnings.filterwarnings("error", category=DeprecationWarning)
def assert_warnings(fn, warning_msgs, regex=False):
diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py
index 1e6c645..e28f715 100644
--- a/alembic/util/__init__.py
+++ b/alembic/util/__init__.py
@@ -1,17 +1,45 @@
from .langhelpers import ( # noqa
- asbool, rev_id, to_tuple, to_list, memoized_property, dedupe_tuple,
- immutabledict, _with_legacy_names, Dispatcher, ModuleClsProxy)
+ asbool,
+ rev_id,
+ to_tuple,
+ to_list,
+ memoized_property,
+ dedupe_tuple,
+ immutabledict,
+ _with_legacy_names,
+ Dispatcher,
+ ModuleClsProxy,
+)
from .messaging import ( # noqa
- write_outstream, status, err, obfuscate_url_pw, warn, msg, format_as_comma)
+ write_outstream,
+ status,
+ err,
+ obfuscate_url_pw,
+ warn,
+ msg,
+ format_as_comma,
+)
from .pyfiles import ( # noqa
- template_to_file, coerce_resource_to_filename,
- pyc_file_from_path, load_python_file, edit)
+ template_to_file,
+ coerce_resource_to_filename,
+ pyc_file_from_path,
+ load_python_file,
+ edit,
+)
from .sqla_compat import ( # noqa
- sqla_09, sqla_092, sqla_094, sqla_099, sqla_100, sqla_105, sqla_110, sqla_1010,
- sqla_1014, sqla_1115)
+ sqla_09,
+ sqla_092,
+ sqla_094,
+ sqla_099,
+ sqla_100,
+ sqla_105,
+ sqla_110,
+ sqla_1010,
+ sqla_1014,
+ sqla_1115,
+)
from .exc import CommandError
if not sqla_09:
- raise CommandError(
- "SQLAlchemy 0.9.0 or greater is required. ")
+ raise CommandError("SQLAlchemy 0.9.0 or greater is required. ")
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index dec2ca8..7e07ed4 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -19,12 +19,13 @@ else:
if py3k:
import builtins as compat_builtins
- string_types = str,
+
+ string_types = (str,)
binary_type = bytes
text_type = str
def callable(fn):
- return hasattr(fn, '__call__')
+ return hasattr(fn, "__call__")
def u(s):
return s
@@ -35,7 +36,8 @@ if py3k:
range = range
else:
import __builtin__ as compat_builtins
- string_types = basestring,
+
+ string_types = (basestring,)
binary_type = str
text_type = unicode
callable = callable
@@ -55,16 +57,17 @@ else:
if py3k:
import collections
+
ArgSpec = collections.namedtuple(
- "ArgSpec",
- ["args", "varargs", "keywords", "defaults"])
+ "ArgSpec", ["args", "varargs", "keywords", "defaults"]
+ )
from inspect import getfullargspec as inspect_getfullargspec
def inspect_getargspec(func):
- return ArgSpec(
- *inspect_getfullargspec(func)[0:4]
- )
+ return ArgSpec(*inspect_getfullargspec(func)[0:4])
+
+
else:
from inspect import getargspec as inspect_getargspec # noqa
@@ -72,14 +75,20 @@ if py35:
from inspect import formatannotation
def inspect_formatargspec(
- args, varargs=None, varkw=None, defaults=None,
- kwonlyargs=(), kwonlydefaults={}, annotations={},
- formatarg=str,
- formatvarargs=lambda name: '*' + name,
- formatvarkw=lambda name: '**' + name,
- formatvalue=lambda value: '=' + repr(value),
- formatreturns=lambda text: ' -> ' + text,
- formatannotation=formatannotation):
+ args,
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=(),
+ kwonlydefaults={},
+ annotations={},
+ formatarg=str,
+ formatvarargs=lambda name: "*" + name,
+ formatvarkw=lambda name: "**" + name,
+ formatvalue=lambda value: "=" + repr(value),
+ formatreturns=lambda text: " -> " + text,
+ formatannotation=formatannotation,
+ ):
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
@@ -93,8 +102,9 @@ if py35:
def formatargandannotation(arg):
result = formatarg(arg)
if arg in annotations:
- result += ': ' + formatannotation(annotations[arg])
+ result += ": " + formatannotation(annotations[arg])
return result
+
specs = []
if defaults:
firstdefault = len(args) - len(defaults)
@@ -107,7 +117,7 @@ if py35:
specs.append(formatvarargs(formatargandannotation(varargs)))
else:
if kwonlyargs:
- specs.append('*')
+ specs.append("*")
if kwonlyargs:
for kwonlyarg in kwonlyargs:
spec = formatargandannotation(kwonlyarg)
@@ -116,11 +126,12 @@ if py35:
specs.append(spec)
if varkw is not None:
specs.append(formatvarkw(formatargandannotation(varkw)))
- result = '(' + ', '.join(specs) + ')'
- if 'return' in annotations:
- result += formatreturns(formatannotation(annotations['return']))
+ result = "(" + ", ".join(specs) + ")"
+ if "return" in annotations:
+ result += formatreturns(formatannotation(annotations["return"]))
return result
+
else:
from inspect import formatargspec as inspect_formatargspec
@@ -151,22 +162,27 @@ if py35:
spec.loader.exec_module(module)
return module
+
elif py3k:
import importlib.machinery
def load_module_py(module_id, path):
module = importlib.machinery.SourceFileLoader(
- module_id, path).load_module(module_id)
+ module_id, path
+ ).load_module(module_id)
del sys.modules[module_id]
return module
def load_module_pyc(module_id, path):
module = importlib.machinery.SourcelessFileLoader(
- module_id, path).load_module(module_id)
+ module_id, path
+ ).load_module(module_id)
del sys.modules[module_id]
return module
+
if py3k:
+
def get_bytecode_suffixes():
try:
return importlib.machinery.BYTECODE_SUFFIXES
@@ -188,13 +204,15 @@ if py3k:
# http://www.python.org/dev/peps/pep-3147/#detecting-pep-3147-availability
import imp
- return hasattr(imp, 'get_tag')
+
+ return hasattr(imp, "get_tag")
+
else:
import imp
def load_module_py(module_id, path): # noqa
- with open(path, 'rb') as fp:
+ with open(path, "rb") as fp:
mod = imp.load_source(module_id, path, fp)
if py2k:
source_encoding = parse_encoding(fp)
@@ -204,7 +222,7 @@ else:
return mod
def load_module_pyc(module_id, path): # noqa
- with open(path, 'rb') as fp:
+ with open(path, "rb") as fp:
mod = imp.load_compiled(module_id, path, fp)
# no source encoding here
del sys.modules[module_id]
@@ -219,12 +237,14 @@ else:
def has_pep3147():
return False
+
try:
- exec_ = getattr(compat_builtins, 'exec')
+ exec_ = getattr(compat_builtins, "exec")
except AttributeError:
# Python 2
def exec_(func_text, globals_, lcl):
- exec('exec func_text in globals_, lcl')
+ exec("exec func_text in globals_, lcl")
+
################################################
# cross-compatible metaclass implementation
@@ -234,9 +254,12 @@ except AttributeError:
def with_metaclass(meta, base=object):
"""Create a base class with a metaclass."""
return meta("%sBase" % meta.__name__, (base,), {})
+
+
################################################
if py3k:
+
def reraise(tp, value, tb=None, cause=None):
if cause is not None:
value.__cause__ = cause
@@ -249,9 +272,13 @@ if py3k:
exc_info = sys.exc_info()
exc_type, exc_value, exc_tb = exc_info
reraise(type(exception), exception, tb=exc_tb, cause=exc_value)
+
+
else:
- exec("def reraise(tp, value, tb=None, cause=None):\n"
- " raise tp, value, tb\n")
+ exec(
+ "def reraise(tp, value, tb=None, cause=None):\n"
+ " raise tp, value, tb\n"
+ )
def raise_from_cause(exception, exc_info=None):
# not as nice as that of Py3K, but at least preserves
@@ -261,14 +288,15 @@ else:
exc_type, exc_value, exc_tb = exc_info
reraise(type(exception), exception, tb=exc_tb)
+
# produce a wrapper that allows encoded text to stream
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
-
def close(self):
pass
+
if py2k:
# in Py2K, the io.* package is awkward because it does not
# easily wrap the file type (e.g. sys.stdout) and I can't
@@ -303,7 +331,7 @@ if py2k:
return self.file_.flush()
class EncodedIO(EncodedIO):
-
def __init__(self, file_, encoding):
super(EncodedIO, self).__init__(
- ActLikePy3kIO(file_), encoding=encoding)
+ ActLikePy3kIO(file_), encoding=encoding
+ )
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index 832332c..a298cc0 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -37,23 +37,21 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
def _install_proxy(self):
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
- globals_['_proxy'] = self
+ globals_["_proxy"] = self
for attr_name in attr_names:
globals_[attr_name] = getattr(self, attr_name)
def _remove_proxy(self):
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
- globals_['_proxy'] = None
+ globals_["_proxy"] = None
for attr_name in attr_names:
del globals_[attr_name]
@classmethod
def create_module_class_proxy(cls, globals_, locals_):
attr_names, modules = cls._setups[cls]
- modules.append(
- (globals_, locals_)
- )
+ modules.append((globals_, locals_))
cls._setup_proxy(globals_, locals_, attr_names)
@classmethod
@@ -63,11 +61,12 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
@classmethod
def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
- if not methname.startswith('_'):
+ if not methname.startswith("_"):
meth = getattr(cls, methname)
if callable(meth):
locals_[methname] = cls._create_method_proxy(
- methname, globals_, locals_)
+ methname, globals_, locals_
+ )
else:
attr_names.add(methname)
@@ -75,7 +74,7 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
def _create_method_proxy(cls, name, globals_, locals_):
fn = getattr(cls, name)
spec = inspect_getargspec(fn)
- if spec[0] and spec[0][0] == 'self':
+ if spec[0] and spec[0][0] == "self":
spec[0].pop(0)
args = inspect_formatargspec(*spec)
num_defaults = 0
@@ -83,24 +82,28 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
num_defaults += len(spec[3])
name_args = spec[0]
if num_defaults:
- defaulted_vals = name_args[0 - num_defaults:]
+ defaulted_vals = name_args[0 - num_defaults :]
else:
defaulted_vals = ()
apply_kw = inspect_formatargspec(
- name_args, spec[1], spec[2],
+ name_args,
+ spec[1],
+ spec[2],
defaulted_vals,
- formatvalue=lambda x: '=' + x)
+ formatvalue=lambda x: "=" + x,
+ )
def _name_error(name):
raise NameError(
"Can't invoke function '%s', as the proxy object has "
"not yet been "
"established for the Alembic '%s' class. "
- "Try placing this code inside a callable." % (
- name, cls.__name__
- ))
- globals_['_name_error'] = _name_error
+ "Try placing this code inside a callable."
+ % (name, cls.__name__)
+ )
+
+ globals_["_name_error"] = _name_error
translations = getattr(fn, "_legacy_translations", [])
if translations:
@@ -108,7 +111,7 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % (
fn.__name__,
tuple(spec),
- translations
+ translations,
)
def translate(fn_name, spec, translations, args, kw):
@@ -119,15 +122,14 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
if oldname in kw:
warnings.warn(
"Argument %r is now named %r "
- "for method %s()." % (
- oldname, newname, fn_name
- ))
+ "for method %s()." % (oldname, newname, fn_name)
+ )
return_kw[newname] = kw.pop(oldname)
return_kw.update(kw)
args = list(args)
if spec[3]:
- pos_only = spec[0][:-len(spec[3])]
+ pos_only = spec[0][: -len(spec[3])]
else:
pos_only = spec[0]
for arg in pos_only:
@@ -137,17 +139,20 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
except IndexError:
raise TypeError(
"missing required positional argument: %s"
- % arg)
+ % arg
+ )
return_args.extend(args)
return return_args, return_kw
- globals_['_translate'] = translate
+
+ globals_["_translate"] = translate
else:
outer_args = args[1:-1]
inner_args = apply_kw[1:-1]
translate_str = ""
- func_text = textwrap.dedent("""\
+ func_text = textwrap.dedent(
+ """\
def %(name)s(%(args)s):
%(doc)r
%(translate)s
@@ -157,13 +162,15 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
_name_error('%(name)s')
return _proxy.%(name)s(%(apply_kw)s)
e
- """ % {
- 'name': name,
- 'translate': translate_str,
- 'args': outer_args,
- 'apply_kw': inner_args,
- 'doc': fn.__doc__,
- })
+ """
+ % {
+ "name": name,
+ "translate": translate_str,
+ "args": outer_args,
+ "apply_kw": inner_args,
+ "doc": fn.__doc__,
+ }
+ )
lcl = {}
exec_(func_text, globals_, lcl)
return lcl[name]
@@ -178,8 +185,7 @@ def _with_legacy_names(translations):
def asbool(value):
- return value is not None and \
- value.lower() == 'true'
+ return value is not None and value.lower() == "true"
def rev_id():
@@ -201,31 +207,30 @@ def to_tuple(x, default=None):
if x is None:
return default
elif isinstance(x, string_types):
- return (x, )
+ return (x,)
elif isinstance(x, collections_abc.Iterable):
return tuple(x)
else:
- return (x, )
+ return (x,)
def unique_list(seq, hashfunc=None):
seen = set()
seen_add = seen.add
if not hashfunc:
- return [x for x in seq
- if x not in seen
- and not seen_add(x)]
+ return [x for x in seq if x not in seen and not seen_add(x)]
else:
- return [x for x in seq
- if hashfunc(x) not in seen
- and not seen_add(hashfunc(x))]
+ return [
+ x
+ for x in seq
+ if hashfunc(x) not in seen and not seen_add(hashfunc(x))
+ ]
def dedupe_tuple(tup):
return tuple(unique_list(tup))
-
class memoized_property(object):
"""A read-only @property that is only evaluated once."""
@@ -243,13 +248,12 @@ class memoized_property(object):
class immutabledict(dict):
-
def _immutable(self, *arg, **kw):
raise TypeError("%s object is immutable" % self.__class__.__name__)
- __delitem__ = __setitem__ = __setattr__ = \
- clear = pop = popitem = setdefault = \
- update = _immutable
+ __delitem__ = (
+ __setitem__
+ ) = __setattr__ = clear = pop = popitem = setdefault = update = _immutable
def __new__(cls, *args):
new = dict.__new__(cls)
@@ -260,7 +264,7 @@ class immutabledict(dict):
pass
def __reduce__(self):
- return immutabledict, (dict(self), )
+ return immutabledict, (dict(self),)
def union(self, d):
if not self:
@@ -279,7 +283,7 @@ class Dispatcher(object):
self._registry = {}
self.uselist = uselist
- def dispatch_for(self, target, qualifier='default'):
+ def dispatch_for(self, target, qualifier="default"):
def decorate(fn):
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
@@ -287,9 +291,10 @@ class Dispatcher(object):
assert (target, qualifier) not in self._registry
self._registry[(target, qualifier)] = fn
return fn
+
return decorate
- def dispatch(self, obj, qualifier='default'):
+ def dispatch(self, obj, qualifier="default"):
if isinstance(obj, string_types):
targets = [obj]
@@ -299,20 +304,20 @@ class Dispatcher(object):
targets = type(obj).__mro__
for spcls in targets:
- if qualifier != 'default' and (spcls, qualifier) in self._registry:
- return self._fn_or_list(
- self._registry[(spcls, qualifier)])
- elif (spcls, 'default') in self._registry:
- return self._fn_or_list(
- self._registry[(spcls, 'default')])
+ if qualifier != "default" and (spcls, qualifier) in self._registry:
+ return self._fn_or_list(self._registry[(spcls, qualifier)])
+ elif (spcls, "default") in self._registry:
+ return self._fn_or_list(self._registry[(spcls, "default")])
else:
raise ValueError("no dispatch function for object: %s" % obj)
def _fn_or_list(self, fn_or_list):
if self.uselist:
+
def go(*arg, **kw):
for fn in fn_or_list:
fn(*arg, **kw)
+
return go
else:
return fn_or_list
@@ -324,8 +329,7 @@ class Dispatcher(object):
d = Dispatcher()
if self.uselist:
d._registry.update(
- (k, [fn for fn in self._registry[k]])
- for k in self._registry
+ (k, [fn for fn in self._registry[k]]) for k in self._registry
)
else:
d._registry.update(self._registry)
diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py
index 872345b..44eacbf 100644
--- a/alembic/util/messaging.py
+++ b/alembic/util/messaging.py
@@ -11,16 +11,16 @@ log = logging.getLogger(__name__)
if py27:
# disable "no handler found" errors
- logging.getLogger('alembic').addHandler(logging.NullHandler())
+ logging.getLogger("alembic").addHandler(logging.NullHandler())
try:
import fcntl
import termios
import struct
- ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ,
- struct.pack('HHHH', 0, 0, 0, 0))
- _h, TERMWIDTH, _hp, _wp = struct.unpack('HHHH', ioctl)
+
+ ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack("HHHH", 0, 0, 0, 0))
+ _h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl)
if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty
TERMWIDTH = None
except (ImportError, IOError):
@@ -28,10 +28,10 @@ except (ImportError, IOError):
def write_outstream(stream, *text):
- encoding = getattr(stream, 'encoding', 'ascii') or 'ascii'
+ encoding = getattr(stream, "encoding", "ascii") or "ascii"
for t in text:
if not isinstance(t, binary_type):
- t = t.encode(encoding, 'replace')
+ t = t.encode(encoding, "replace")
t = t.decode(encoding)
try:
stream.write(t)
@@ -62,7 +62,7 @@ def err(message):
def obfuscate_url_pw(u):
u = url.make_url(u)
if u.password:
- u.password = 'XXXXX'
+ u.password = "XXXXX"
return str(u)
diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py
index 0e52133..4093b89 100644
--- a/alembic/util/pyfiles.py
+++ b/alembic/util/pyfiles.py
@@ -1,8 +1,12 @@
import sys
import os
import re
-from .compat import load_module_py, load_module_pyc, \
- get_current_bytecode_suffixes, has_pep3147
+from .compat import (
+ load_module_py,
+ load_module_pyc,
+ get_current_bytecode_suffixes,
+ has_pep3147,
+)
from mako.template import Template
from mako import exceptions
import tempfile
@@ -14,16 +18,19 @@ def template_to_file(template_file, dest, output_encoding, **kw):
try:
output = template.render_unicode(**kw).encode(output_encoding)
except:
- with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as ntf:
+ with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as ntf:
ntf.write(
- exceptions.text_error_template().
- render_unicode().encode(output_encoding))
+ exceptions.text_error_template()
+ .render_unicode()
+ .encode(output_encoding)
+ )
fname = ntf.name
raise CommandError(
"Template rendering failed; see %s for a "
- "template-oriented traceback." % fname)
+ "template-oriented traceback." % fname
+ )
else:
- with open(dest, 'wb') as f:
+ with open(dest, "wb") as f:
f.write(output)
@@ -37,7 +44,8 @@ def coerce_resource_to_filename(fname):
"""
if not os.path.isabs(fname) and ":" in fname:
import pkg_resources
- fname = pkg_resources.resource_filename(*fname.split(':'))
+
+ fname = pkg_resources.resource_filename(*fname.split(":"))
return fname
@@ -48,6 +56,7 @@ def pyc_file_from_path(path):
if has_pep3147():
import imp
+
candidate = imp.cache_from_source(path)
if os.path.exists(candidate):
return candidate
@@ -64,16 +73,17 @@ def edit(path):
"""Given a source path, run the EDITOR for it"""
import editor
+
try:
editor.edit(path)
except Exception as exc:
- raise CommandError('Error executing editor (%s)' % (exc,))
+ raise CommandError("Error executing editor (%s)" % (exc,))
def load_python_file(dir_, filename):
"""Load a file from the given path as a Python module."""
- module_id = re.sub(r'\W', "_", filename)
+ module_id = re.sub(r"\W", "_", filename)
path = os.path.join(dir_, filename)
_, ext = os.path.splitext(filename)
if ext == ".py":
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index 0556124..63c9798 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -15,8 +15,11 @@ def _safe_int(value):
return int(value)
except:
return value
+
+
_vers = tuple(
- [_safe_int(x) for x in re.findall(r'(\d+|[abc]\d)', __version__)])
+ [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
+)
sqla_09 = _vers >= (0, 9, 0)
sqla_092 = _vers >= (0, 9, 2)
sqla_094 = _vers >= (0, 9, 4)
@@ -31,7 +34,7 @@ sqla_1115 = _vers >= (1, 1, 15)
if sqla_110:
- AUTOINCREMENT_DEFAULT = 'auto'
+ AUTOINCREMENT_DEFAULT = "auto"
else:
AUTOINCREMENT_DEFAULT = True
@@ -55,10 +58,12 @@ def _columns_for_constraint(constraint):
def _fk_spec(constraint):
if sqla_100:
source_columns = [
- constraint.columns[key].name for key in constraint.column_keys]
+ constraint.columns[key].name for key in constraint.column_keys
+ ]
else:
source_columns = [
- element.parent.name for element in constraint.elements]
+ element.parent.name for element in constraint.elements
+ ]
source_table = constraint.parent.name
source_schema = constraint.parent.schema
@@ -70,9 +75,17 @@ def _fk_spec(constraint):
deferrable = constraint.deferrable
initially = constraint.initially
return (
- source_schema, source_table,
- source_columns, target_schema, target_table, target_columns,
- onupdate, ondelete, deferrable, initially)
+ source_schema,
+ source_table,
+ source_columns,
+ target_schema,
+ target_table,
+ target_columns,
+ onupdate,
+ ondelete,
+ deferrable,
+ initially,
+ )
def _fk_is_self_referential(constraint):
@@ -91,11 +104,9 @@ def _is_type_bound(constraint):
return constraint._type_bound
else:
# old way, look at what we know Boolean/Enum to use
- return (
- constraint._create_rule is not None and
- isinstance(
- getattr(constraint._create_rule, "target", None),
- sqltypes.SchemaType)
+ return constraint._create_rule is not None and isinstance(
+ getattr(constraint._create_rule, "target", None),
+ sqltypes.SchemaType,
)
@@ -103,7 +114,7 @@ def _find_columns(clause):
"""locate Column objects within the given expression."""
cols = set()
- traverse(clause, {}, {'column': cols.add})
+ traverse(clause, {}, {"column": cols.add})
return cols
@@ -143,7 +154,8 @@ class _textual_index_element(sql.ColumnElement):
See SQLAlchemy issue 3174.
"""
- __visit_name__ = '_textual_idx_element'
+
+ __visit_name__ = "_textual_idx_element"
def __init__(self, table, text):
self.table = table
@@ -198,7 +210,7 @@ def _get_index_final_name(dialect, idx):
def _is_mariadb(mysql_dialect):
- return 'MariaDB' in mysql_dialect.server_version_info
+ return "MariaDB" in mysql_dialect.server_version_info
def _mariadb_normalized_version_info(mysql_dialect):
diff --git a/setup.cfg b/setup.cfg
index 945d94a..50d02c1 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -15,6 +15,21 @@ identity = C4DAFEE1
with-sqla_testing = true
where = tests
+[flake8]
+show-source = true
+enable-extensions = G
+# E203 is due to https://github.com/PyCQA/pycodestyle/issues/373
+ignore =
+ A003,
+ D,
+ E203,E305,E711,E712,E721,E722,E741,
+ N801,N802,N806,
+ RST304,RST303,RST299,RST399,
+ W503,W504
+exclude = .venv,.git,.tox,dist,doc,*egg,build
+import-order-style = google
+application-import-names = alembic,tests
+
[sqla_testing]
requirement_cls=tests.requirements:DefaultRequirements
diff --git a/setup.py b/setup.py
index e962f89..70da388 100644
--- a/setup.py
+++ b/setup.py
@@ -5,23 +5,23 @@ import re
import sys
-v = open(os.path.join(os.path.dirname(__file__), 'alembic', '__init__.py'))
-VERSION = re.compile(r".*__version__ = '(.*?)'", re.S).match(v.read()).group(1)
+v = open(os.path.join(os.path.dirname(__file__), "alembic", "__init__.py"))
+VERSION = re.compile(r""".*__version__ = ["'](.*?)["']""", re.S).match(v.read()).group(1)
v.close()
-readme = os.path.join(os.path.dirname(__file__), 'README.rst')
+readme = os.path.join(os.path.dirname(__file__), "README.rst")
requires = [
- 'SQLAlchemy>=0.9.0',
- 'Mako',
- 'python-editor>=0.3',
- 'python-dateutil'
+ "SQLAlchemy>=0.9.0",
+ "Mako",
+ "python-editor>=0.3",
+ "python-dateutil",
]
class PyTest(TestCommand):
- user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")]
+ user_options = [("pytest-args=", "a", "Arguments to pass to py.test")]
def initialize_options(self):
TestCommand.initialize_options(self)
@@ -35,42 +35,42 @@ class PyTest(TestCommand):
def run_tests(self):
# import here, cause outside the eggs aren't loaded
import pytest
+
errno = pytest.main(self.pytest_args)
sys.exit(errno)
-setup(name='alembic',
- version=VERSION,
- description="A database migration tool for SQLAlchemy.",
- long_description=open(readme).read(),
- python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*',
- classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Environment :: Console',
- 'Intended Audience :: Developers',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: Implementation :: CPython',
- 'Programming Language :: Python :: Implementation :: PyPy',
- 'Topic :: Database :: Front-Ends',
- ],
- keywords='SQLAlchemy migrations',
- author='Mike Bayer',
- author_email='mike@zzzcomputing.com',
- url='https://alembic.sqlalchemy.org',
- license='MIT',
- packages=find_packages('.', exclude=['examples*', 'test*']),
- include_package_data=True,
- tests_require=['pytest!=3.9.1,!=3.9.2', 'mock', 'Mako'],
- cmdclass={'test': PyTest},
- zip_safe=False,
- install_requires=requires,
- entry_points={
- 'console_scripts': ['alembic = alembic.config:main'],
- }
- )
+setup(
+ name="alembic",
+ version=VERSION,
+ description="A database migration tool for SQLAlchemy.",
+ long_description=open(readme).read(),
+ python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Environment :: Console",
+ "Intended Audience :: Developers",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.4",
+ "Programming Language :: Python :: 3.5",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Programming Language :: Python :: Implementation :: PyPy",
+ "Topic :: Database :: Front-Ends",
+ ],
+ keywords="SQLAlchemy migrations",
+ author="Mike Bayer",
+ author_email="mike@zzzcomputing.com",
+ url="https://alembic.sqlalchemy.org",
+ license="MIT",
+ packages=find_packages(".", exclude=["examples*", "test*"]),
+ include_package_data=True,
+ tests_require=["pytest!=3.9.1,!=3.9.2", "mock", "Mako"],
+ cmdclass={"test": PyTest},
+ zip_safe=False,
+ install_requires=requires,
+ entry_points={"console_scripts": ["alembic = alembic.config:main"]},
+)
diff --git a/tests/_autogen_fixtures.py b/tests/_autogen_fixtures.py
index 94c6866..4bda756 100644
--- a/tests/_autogen_fixtures.py
+++ b/tests/_autogen_fixtures.py
@@ -1,5 +1,18 @@
-from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
- Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Integer,
+ String,
+ Text,
+ Numeric,
+ CHAR,
+ ForeignKey,
+ Index,
+ UniqueConstraint,
+ CheckConstraint,
+ text,
+)
from sqlalchemy.engine.reflection import Inspector
from alembic.operations import ops
@@ -27,11 +40,12 @@ def _default_include_object(obj, name, type_, reflected, compare_to):
else:
return True
+
_default_object_filters = _default_include_object
class ModelOne(object):
- __requires__ = ('unique_constraint_reflection', )
+ __requires__ = ("unique_constraint_reflection",)
schema = None
@@ -41,30 +55,42 @@ class ModelOne(object):
m = MetaData(schema=schema)
- Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50)),
- Column('a1', Text),
- Column("pw", String(50)),
- Index('pw_idx', 'pw')
- )
-
- Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- )
-
- Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column("amount", Numeric(8, 2), nullable=False,
- server_default=text("0")),
- CheckConstraint('amount >= 0', name='ck_order_amount')
- )
-
- Table('extra', m,
- Column("x", CHAR),
- Column('uid', Integer, ForeignKey('user.id'))
- )
+ Table(
+ "user",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50)),
+ Column("a1", Text),
+ Column("pw", String(50)),
+ Index("pw_idx", "pw"),
+ )
+
+ Table(
+ "address",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email_address", String(100), nullable=False),
+ )
+
+ Table(
+ "order",
+ m,
+ Column("order_id", Integer, primary_key=True),
+ Column(
+ "amount",
+ Numeric(8, 2),
+ nullable=False,
+ server_default=text("0"),
+ ),
+ CheckConstraint("amount >= 0", name="ck_order_amount"),
+ )
+
+ Table(
+ "extra",
+ m,
+ Column("x", CHAR),
+ Column("uid", Integer, ForeignKey("user.id")),
+ )
return m
@@ -74,50 +100,80 @@ class ModelOne(object):
m = MetaData(schema=schema)
- Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', Text, server_default="x")
- )
-
- Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('street', String(50)),
- UniqueConstraint('email_address', name="uq_email")
- )
-
- Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True,
- server_default=text("0")),
- Column('user_id', Integer, ForeignKey('user.id')),
- CheckConstraint('amount > -1', name='ck_order_amount'),
- )
-
- Table('item', m,
- Column('id', Integer, primary_key=True),
- Column('description', String(100)),
- Column('order_id', Integer, ForeignKey('order.order_id')),
- CheckConstraint('len(description) > 5')
- )
+ Table(
+ "user",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", Text, server_default="x"),
+ )
+
+ Table(
+ "address",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email_address", String(100), nullable=False),
+ Column("street", String(50)),
+ UniqueConstraint("email_address", name="uq_email"),
+ )
+
+ Table(
+ "order",
+ m,
+ Column("order_id", Integer, primary_key=True),
+ Column(
+ "amount",
+ Numeric(10, 2),
+ nullable=True,
+ server_default=text("0"),
+ ),
+ Column("user_id", Integer, ForeignKey("user.id")),
+ CheckConstraint("amount > -1", name="ck_order_amount"),
+ )
+
+ Table(
+ "item",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("description", String(100)),
+ Column("order_id", Integer, ForeignKey("order.order_id")),
+ CheckConstraint("len(description) > 5"),
+ )
return m
class _ComparesFKs(object):
def _assert_fk_diff(
- self, diff, type_, source_table, source_columns,
- target_table, target_columns, name=None, conditional_name=None,
- source_schema=None, onupdate=None, ondelete=None,
- initially=None, deferrable=None):
+ self,
+ diff,
+ type_,
+ source_table,
+ source_columns,
+ target_table,
+ target_columns,
+ name=None,
+ conditional_name=None,
+ source_schema=None,
+ onupdate=None,
+ ondelete=None,
+ initially=None,
+ deferrable=None,
+ ):
# the public API for ForeignKeyConstraint was not very rich
# in 0.7, 0.8, so here we use the well-known but slightly
# private API to get at its elements
- (fk_source_schema, fk_source_table,
- fk_source_columns, fk_target_schema, fk_target_table,
- fk_target_columns,
- fk_onupdate, fk_ondelete, fk_deferrable, fk_initially
- ) = _fk_spec(diff[1])
+ (
+ fk_source_schema,
+ fk_source_table,
+ fk_source_columns,
+ fk_target_schema,
+ fk_target_table,
+ fk_target_columns,
+ fk_onupdate,
+ fk_ondelete,
+ fk_deferrable,
+ fk_initially,
+ ) = _fk_spec(diff[1])
eq_(diff[0], type_)
eq_(fk_source_table, source_table)
@@ -129,15 +185,15 @@ class _ComparesFKs(object):
eq_(fk_initially, initially)
eq_(fk_deferrable, deferrable)
- eq_([elem.column.name for elem in diff[1].elements],
- target_columns)
+ eq_([elem.column.name for elem in diff[1].elements], target_columns)
if conditional_name is not None:
if config.requirements.no_fk_names.enabled:
eq_(diff[1].name, None)
- elif conditional_name == 'servergenerated':
- fks = Inspector.from_engine(self.bind).\
- get_foreign_keys(source_table)
- server_fk_name = fks[0]['name']
+ elif conditional_name == "servergenerated":
+ fks = Inspector.from_engine(self.bind).get_foreign_keys(
+ source_table
+ )
+ server_fk_name = fks[0]["name"]
eq_(diff[1].name, server_fk_name)
else:
eq_(diff[1].name, conditional_name)
@@ -146,7 +202,6 @@ class _ComparesFKs(object):
class AutogenTest(_ComparesFKs):
-
def _flatten_diffs(self, diffs):
for d in diffs:
if isinstance(d, list):
@@ -177,20 +232,19 @@ class AutogenTest(_ComparesFKs):
def setUp(self):
self.conn = conn = self.bind.connect()
ctx_opts = {
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m2,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- 'include_object': _default_object_filters
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": self.m2,
+ "upgrade_token": "upgrades",
+ "downgrade_token": "downgrades",
+ "alembic_module_prefix": "op.",
+ "sqlalchemy_module_prefix": "sa.",
+ "include_object": _default_object_filters,
}
if self.configure_opts:
ctx_opts.update(self.configure_opts)
self.context = context = MigrationContext.configure(
- connection=conn,
- opts=ctx_opts
+ connection=conn, opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context, self.m2)
@@ -200,46 +254,47 @@ class AutogenTest(_ComparesFKs):
def _update_context(self, object_filters=None, include_schemas=None):
if include_schemas is not None:
- self.autogen_context.opts['include_schemas'] = include_schemas
+ self.autogen_context.opts["include_schemas"] = include_schemas
if object_filters is not None:
self.autogen_context._object_filters = [object_filters]
return self.autogen_context
class AutogenFixtureTest(_ComparesFKs):
-
def _fixture(
- self, m1, m2, include_schemas=False,
- opts=None, object_filters=_default_object_filters,
- return_ops=False):
+ self,
+ m1,
+ m2,
+ include_schemas=False,
+ opts=None,
+ object_filters=_default_object_filters,
+ return_ops=False,
+ ):
self.metadata, model_metadata = m1, m2
for m in util.to_list(self.metadata):
m.create_all(self.bind)
with self.bind.connect() as conn:
ctx_opts = {
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': model_metadata,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- 'include_object': object_filters,
- 'include_schemas': include_schemas
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": model_metadata,
+ "upgrade_token": "upgrades",
+ "downgrade_token": "downgrades",
+ "alembic_module_prefix": "op.",
+ "sqlalchemy_module_prefix": "sa.",
+ "include_object": object_filters,
+ "include_schemas": include_schemas,
}
if opts:
ctx_opts.update(opts)
self.context = context = MigrationContext.configure(
- connection=conn,
- opts=ctx_opts
+ connection=conn, opts=ctx_opts
)
autogen_context = api.AutogenContext(context, model_metadata)
uo = ops.UpgradeOps(ops=[])
- autogenerate._produce_net_changes(
- autogen_context, uo
- )
+ autogenerate._produce_net_changes(autogen_context, uo)
if return_ops:
return uo
@@ -253,8 +308,7 @@ class AutogenFixtureTest(_ComparesFKs):
self.bind = config.db
def tearDown(self):
- if hasattr(self, 'metadata'):
+ if hasattr(self, "metadata"):
for m in util.to_list(self.metadata):
m.drop_all(self.bind)
clear_staging_env()
-
diff --git a/tests/_large_map.py b/tests/_large_map.py
index bc7133f..13ac41f 100644
--- a/tests/_large_map.py
+++ b/tests/_large_map.py
@@ -1,152 +1,148 @@
from alembic.script.revision import RevisionMap, Revision
data = [
- Revision('3fc8a578bc0a', ('4878cb1cb7f6', '454a0529f84e'), ),
- Revision('69285b0faaa', ('36c31e4e1c37', '3a3b24a31b57'), ),
- Revision('3b0452c64639', '2f1a0f3667f3', ),
- Revision('2d9d787a496', '135b5fd31062', ),
- Revision('184f65ed83af', '3b0452c64639', ),
- Revision('430074f99c29', '54f871bfe0b0', ),
- Revision('3ffb59981d9a', '519c9f3ce294', ),
- Revision('454a0529f84e', ('40f6508e4373', '38a936c6ab11'), ),
- Revision('24c2620b2e3f', ('430074f99c29', '1f5ceb1ec255'), ),
- Revision('169a948471a9', '247ad6880f93', ),
- Revision('2f1a0f3667f3', '17dd0f165262', ),
- Revision('27227dc4fda8', '2a66d7c4d8a1', ),
- Revision('4b2ad1ffe2e7', ('3b409f268da4', '4f8a9b79a063'), ),
- Revision('124ef6a17781', '2529684536da', ),
- Revision('4789d9c82ca7', '593b8076fb2c', ),
- Revision('64ed798bcc3', ('44ed1bf512a0', '169a948471a9'), ),
- Revision('2588a3c36a0f', '50c7b21c9089', ),
- Revision('359329c2ebb', ('5810e9eff996', '339faa12616'), ),
- Revision('540bc5634bd', '3a5db5f31209', ),
- Revision('20fe477817d2', '53d5ff905573', ),
- Revision('4f8a9b79a063', ('3cf34fcd6473', '300209d8594'), ),
- Revision('6918589deaf', '3314c17f6e35', ),
- Revision('1755e3b1481c', ('17b66754be21', '31b1d4b7fc95'), ),
- Revision('58c988e1aa4e', ('219240032b88', 'f067f0b825c'), ),
- Revision('593b8076fb2c', '1d94175d221b', ),
- Revision('38d069994064', ('46b70a57edc0', '3ed56beabfb7'), ),
- Revision('3e2f6c6d1182', '7f96a01461b', ),
- Revision('1f6969597fe7', '1811bdae9e63', ),
- Revision('17dd0f165262', '3cf02a593a68', ),
- Revision('3cf02a593a68', '25a7ef58d293', ),
- Revision('34dfac7edb2d', '28f4dd53ad3a', ),
- Revision('4009c533e05d', '42ded7355da2', ),
- Revision('5a0003c3b09c', ('3ed56beabfb7', '2028d94d3863'), ),
- Revision('38a936c6ab11', '2588a3c36a0f', ),
- Revision('59223c5b7b36', '2f93dd880bae', ),
- Revision('4121bd6e99e9', '540bc5634bd', ),
- Revision('260714a3f2de', '6918589deaf', ),
- Revision('ae77a2ed69b', '274fd2642933', ),
- Revision('18ff1ab3b4c4', '430133b6d46c', ),
- Revision('2b9a327527a9', ('359329c2ebb', '593b8076fb2c'), ),
- Revision('4e6167c75ed0', '325b273d61bd', ),
- Revision('21ab11a7c5c4', ('3da31f3323ec', '22f26011d635'), ),
- Revision('3b93e98481b1', '4e28e2f4fe2f', ),
- Revision('145d8f1e334d', 'b4143d129e', ),
- Revision('135b5fd31062', '1d94175d221b', ),
- Revision('300209d8594', ('52804033910e', '593b8076fb2c'), ),
- Revision('8dca95cce28', 'f034666cd80', ),
- Revision('46b70a57edc0', ('145d8f1e334d', '4cc2960cbe19'), ),
- Revision('4d45e479fbb9', '2d9d787a496', ),
- Revision('22f085bf8bbd', '540bc5634bd', ),
- Revision('263e91fd17d8', '2b9a327527a9', ),
- Revision('219240032b88', ('300209d8594', '2b9a327527a9'), ),
- Revision('325b273d61bd', '4b2ad1ffe2e7', ),
- Revision('199943ccc774', '1aa674ccfa4e', ),
- Revision('247ad6880f93', '1f6969597fe7', ),
- Revision('4878cb1cb7f6', '28f4dd53ad3a', ),
- Revision('2a66d7c4d8a1', '23f1ccb18d6d', ),
- Revision('42b079245b55', '593b8076fb2c', ),
- Revision('1cccf82219cb', ('20fe477817d2', '915c67915c2'), ),
- Revision('b4143d129e', ('159331d6f484', '504d5168afe1'), ),
- Revision('53d5ff905573', '3013877bf5bd', ),
- Revision('1f5ceb1ec255', '3ffb59981d9a', ),
- Revision('ef1c1c1531f', '4738812e6ece', ),
- Revision('1f6963d1ae02', '247ad6880f93', ),
- Revision('44d58f1d31f0', '18ff1ab3b4c4', ),
- Revision('c3ebe64dfb5', ('3409c57b0da', '31f352e77045'), ),
- Revision('f067f0b825c', '359329c2ebb', ),
- Revision('52ab2d3b57ce', '96d590bd82e', ),
- Revision('3b409f268da4', ('20e90eb3eeb6', '263e91fd17d8'), ),
- Revision('5a4ca8889674', '4e6167c75ed0', ),
- Revision('5810e9eff996', ('2d30d79c4093', '52804033910e'), ),
- Revision('40f6508e4373', '4ed16fad67a7', ),
- Revision('1811bdae9e63', '260714a3f2de', ),
- Revision('3013877bf5bd', ('8dca95cce28', '3fc8a578bc0a'), ),
- Revision('16426dbea880', '28f4dd53ad3a', ),
- Revision('22f26011d635', ('4c93d063d2ba', '3b93e98481b1'), ),
- Revision('3409c57b0da', '17b66754be21', ),
- Revision('44373001000f', ('42b079245b55', '219240032b88'), ),
- Revision('28f4dd53ad3a', '2e71fd90eb9d', ),
- Revision('4cc2960cbe19', '504d5168afe1', ),
- Revision('31f352e77045', ('17b66754be21', '22f085bf8bbd'), ),
- Revision('4ed16fad67a7', 'f034666cd80', ),
- Revision('3da31f3323ec', '4c93d063d2ba', ),
- Revision('31b1d4b7fc95', '1cc4459fd115', ),
- Revision('11bc0ff42f87', '28f4dd53ad3a', ),
- Revision('3a5db5f31209', '59742a546b84', ),
- Revision('20e90eb3eeb6', ('58c988e1aa4e', '44373001000f'), ),
- Revision('23f1ccb18d6d', '52ab2d3b57ce', ),
- Revision('1d94175d221b', '21ab11a7c5c4', ),
- Revision('36f1a410ed', '54f871bfe0b0', ),
- Revision('181a149173e', '2ee35cac4c62', ),
- Revision('171ad2f0c672', '4a4e0838e206', ),
- Revision('2f93dd880bae', '540bc5634bd', ),
- Revision('25a7ef58d293', None, ),
- Revision('7f96a01461b', '184f65ed83af', ),
- Revision('b21f22233f', '3e2f6c6d1182', ),
- Revision('52804033910e', '1d94175d221b', ),
- Revision('1e6240aba5b3', ('4121bd6e99e9', '2c50d8bab6ee'), ),
- Revision('1cc4459fd115', '1e6240aba5b3', ),
- Revision('274fd2642933', '4009c533e05d', ),
- Revision('1aa674ccfa4e', ('59223c5b7b36', '42050bf030fd'), ),
- Revision('4e28e2f4fe2f', '596d7b9e11', ),
- Revision('49ddec8c7a5e', ('124ef6a17781', '47578179e766'), ),
- Revision('3e9bb349cc46', 'ef1c1c1531f', ),
- Revision('2028d94d3863', '504d5168afe1', ),
- Revision('159331d6f484', '34dfac7edb2d', ),
- Revision('596d7b9e11', '171ad2f0c672', ),
- Revision('3b96bcc8da76', 'f034666cd80', ),
- Revision('4738812e6ece', '78982bf5499', ),
- Revision('3314c17f6e35', '27227dc4fda8', ),
- Revision('30931c545bf', '2e71fd90eb9d', ),
- Revision('2e71fd90eb9d', ('c3ebe64dfb5', '1755e3b1481c'), ),
- Revision('3ed56beabfb7', ('11bc0ff42f87', '69285b0faaa'), ),
- Revision('96d590bd82e', '3e9bb349cc46', ),
- Revision('339faa12616', '4d45e479fbb9', ),
- Revision('47578179e766', '2529684536da', ),
- Revision('2ee35cac4c62', 'b21f22233f', ),
- Revision('50c7b21c9089', ('4ed16fad67a7', '3b96bcc8da76'), ),
- Revision('78982bf5499', 'ae77a2ed69b', ),
- Revision('519c9f3ce294', '2c50d8bab6ee', ),
- Revision('2720fc75e5fd', '1cccf82219cb', ),
- Revision('21638ec787ba', '44d58f1d31f0', ),
- Revision('59742a546b84', '49ddec8c7a5e', ),
- Revision('2d30d79c4093', '135b5fd31062', ),
- Revision('f034666cd80', ('5a0003c3b09c', '38d069994064'), ),
- Revision('430133b6d46c', '181a149173e', ),
- Revision('3a3b24a31b57', ('16426dbea880', '4cc2960cbe19'), ),
- Revision('2529684536da', ('64ed798bcc3', '1f6963d1ae02'), ),
- Revision('17b66754be21', ('19e0db9d806a', '24c2620b2e3f'), ),
- Revision('3cf34fcd6473', ('52804033910e', '4789d9c82ca7'), ),
- Revision('36c31e4e1c37', '504d5168afe1', ),
- Revision('54f871bfe0b0', '519c9f3ce294', ),
- Revision('4a4e0838e206', '2a7f37cf7770', ),
- Revision('19e0db9d806a', ('430074f99c29', '36f1a410ed'), ),
- Revision('44ed1bf512a0', '247ad6880f93', ),
- Revision('42050bf030fd', '2f93dd880bae', ),
- Revision('2c50d8bab6ee', '199943ccc774', ),
- Revision('504d5168afe1', ('28f4dd53ad3a', '30931c545bf'), ),
- Revision('915c67915c2', '3fc8a578bc0a', ),
- Revision('2a7f37cf7770', '2720fc75e5fd', ),
- Revision('4c93d063d2ba', '4e28e2f4fe2f', ),
- Revision('42ded7355da2', '21638ec787ba', ),
+ Revision("3fc8a578bc0a", ("4878cb1cb7f6", "454a0529f84e")),
+ Revision("69285b0faaa", ("36c31e4e1c37", "3a3b24a31b57")),
+ Revision("3b0452c64639", "2f1a0f3667f3"),
+ Revision("2d9d787a496", "135b5fd31062"),
+ Revision("184f65ed83af", "3b0452c64639"),
+ Revision("430074f99c29", "54f871bfe0b0"),
+ Revision("3ffb59981d9a", "519c9f3ce294"),
+ Revision("454a0529f84e", ("40f6508e4373", "38a936c6ab11")),
+ Revision("24c2620b2e3f", ("430074f99c29", "1f5ceb1ec255")),
+ Revision("169a948471a9", "247ad6880f93"),
+ Revision("2f1a0f3667f3", "17dd0f165262"),
+ Revision("27227dc4fda8", "2a66d7c4d8a1"),
+ Revision("4b2ad1ffe2e7", ("3b409f268da4", "4f8a9b79a063")),
+ Revision("124ef6a17781", "2529684536da"),
+ Revision("4789d9c82ca7", "593b8076fb2c"),
+ Revision("64ed798bcc3", ("44ed1bf512a0", "169a948471a9")),
+ Revision("2588a3c36a0f", "50c7b21c9089"),
+ Revision("359329c2ebb", ("5810e9eff996", "339faa12616")),
+ Revision("540bc5634bd", "3a5db5f31209"),
+ Revision("20fe477817d2", "53d5ff905573"),
+ Revision("4f8a9b79a063", ("3cf34fcd6473", "300209d8594")),
+ Revision("6918589deaf", "3314c17f6e35"),
+ Revision("1755e3b1481c", ("17b66754be21", "31b1d4b7fc95")),
+ Revision("58c988e1aa4e", ("219240032b88", "f067f0b825c")),
+ Revision("593b8076fb2c", "1d94175d221b"),
+ Revision("38d069994064", ("46b70a57edc0", "3ed56beabfb7")),
+ Revision("3e2f6c6d1182", "7f96a01461b"),
+ Revision("1f6969597fe7", "1811bdae9e63"),
+ Revision("17dd0f165262", "3cf02a593a68"),
+ Revision("3cf02a593a68", "25a7ef58d293"),
+ Revision("34dfac7edb2d", "28f4dd53ad3a"),
+ Revision("4009c533e05d", "42ded7355da2"),
+ Revision("5a0003c3b09c", ("3ed56beabfb7", "2028d94d3863")),
+ Revision("38a936c6ab11", "2588a3c36a0f"),
+ Revision("59223c5b7b36", "2f93dd880bae"),
+ Revision("4121bd6e99e9", "540bc5634bd"),
+ Revision("260714a3f2de", "6918589deaf"),
+ Revision("ae77a2ed69b", "274fd2642933"),
+ Revision("18ff1ab3b4c4", "430133b6d46c"),
+ Revision("2b9a327527a9", ("359329c2ebb", "593b8076fb2c")),
+ Revision("4e6167c75ed0", "325b273d61bd"),
+ Revision("21ab11a7c5c4", ("3da31f3323ec", "22f26011d635")),
+ Revision("3b93e98481b1", "4e28e2f4fe2f"),
+ Revision("145d8f1e334d", "b4143d129e"),
+ Revision("135b5fd31062", "1d94175d221b"),
+ Revision("300209d8594", ("52804033910e", "593b8076fb2c")),
+ Revision("8dca95cce28", "f034666cd80"),
+ Revision("46b70a57edc0", ("145d8f1e334d", "4cc2960cbe19")),
+ Revision("4d45e479fbb9", "2d9d787a496"),
+ Revision("22f085bf8bbd", "540bc5634bd"),
+ Revision("263e91fd17d8", "2b9a327527a9"),
+ Revision("219240032b88", ("300209d8594", "2b9a327527a9")),
+ Revision("325b273d61bd", "4b2ad1ffe2e7"),
+ Revision("199943ccc774", "1aa674ccfa4e"),
+ Revision("247ad6880f93", "1f6969597fe7"),
+ Revision("4878cb1cb7f6", "28f4dd53ad3a"),
+ Revision("2a66d7c4d8a1", "23f1ccb18d6d"),
+ Revision("42b079245b55", "593b8076fb2c"),
+ Revision("1cccf82219cb", ("20fe477817d2", "915c67915c2")),
+ Revision("b4143d129e", ("159331d6f484", "504d5168afe1")),
+ Revision("53d5ff905573", "3013877bf5bd"),
+ Revision("1f5ceb1ec255", "3ffb59981d9a"),
+ Revision("ef1c1c1531f", "4738812e6ece"),
+ Revision("1f6963d1ae02", "247ad6880f93"),
+ Revision("44d58f1d31f0", "18ff1ab3b4c4"),
+ Revision("c3ebe64dfb5", ("3409c57b0da", "31f352e77045")),
+ Revision("f067f0b825c", "359329c2ebb"),
+ Revision("52ab2d3b57ce", "96d590bd82e"),
+ Revision("3b409f268da4", ("20e90eb3eeb6", "263e91fd17d8")),
+ Revision("5a4ca8889674", "4e6167c75ed0"),
+ Revision("5810e9eff996", ("2d30d79c4093", "52804033910e")),
+ Revision("40f6508e4373", "4ed16fad67a7"),
+ Revision("1811bdae9e63", "260714a3f2de"),
+ Revision("3013877bf5bd", ("8dca95cce28", "3fc8a578bc0a")),
+ Revision("16426dbea880", "28f4dd53ad3a"),
+ Revision("22f26011d635", ("4c93d063d2ba", "3b93e98481b1")),
+ Revision("3409c57b0da", "17b66754be21"),
+ Revision("44373001000f", ("42b079245b55", "219240032b88")),
+ Revision("28f4dd53ad3a", "2e71fd90eb9d"),
+ Revision("4cc2960cbe19", "504d5168afe1"),
+ Revision("31f352e77045", ("17b66754be21", "22f085bf8bbd")),
+ Revision("4ed16fad67a7", "f034666cd80"),
+ Revision("3da31f3323ec", "4c93d063d2ba"),
+ Revision("31b1d4b7fc95", "1cc4459fd115"),
+ Revision("11bc0ff42f87", "28f4dd53ad3a"),
+ Revision("3a5db5f31209", "59742a546b84"),
+ Revision("20e90eb3eeb6", ("58c988e1aa4e", "44373001000f")),
+ Revision("23f1ccb18d6d", "52ab2d3b57ce"),
+ Revision("1d94175d221b", "21ab11a7c5c4"),
+ Revision("36f1a410ed", "54f871bfe0b0"),
+ Revision("181a149173e", "2ee35cac4c62"),
+ Revision("171ad2f0c672", "4a4e0838e206"),
+ Revision("2f93dd880bae", "540bc5634bd"),
+ Revision("25a7ef58d293", None),
+ Revision("7f96a01461b", "184f65ed83af"),
+ Revision("b21f22233f", "3e2f6c6d1182"),
+ Revision("52804033910e", "1d94175d221b"),
+ Revision("1e6240aba5b3", ("4121bd6e99e9", "2c50d8bab6ee")),
+ Revision("1cc4459fd115", "1e6240aba5b3"),
+ Revision("274fd2642933", "4009c533e05d"),
+ Revision("1aa674ccfa4e", ("59223c5b7b36", "42050bf030fd")),
+ Revision("4e28e2f4fe2f", "596d7b9e11"),
+ Revision("49ddec8c7a5e", ("124ef6a17781", "47578179e766")),
+ Revision("3e9bb349cc46", "ef1c1c1531f"),
+ Revision("2028d94d3863", "504d5168afe1"),
+ Revision("159331d6f484", "34dfac7edb2d"),
+ Revision("596d7b9e11", "171ad2f0c672"),
+ Revision("3b96bcc8da76", "f034666cd80"),
+ Revision("4738812e6ece", "78982bf5499"),
+ Revision("3314c17f6e35", "27227dc4fda8"),
+ Revision("30931c545bf", "2e71fd90eb9d"),
+ Revision("2e71fd90eb9d", ("c3ebe64dfb5", "1755e3b1481c")),
+ Revision("3ed56beabfb7", ("11bc0ff42f87", "69285b0faaa")),
+ Revision("96d590bd82e", "3e9bb349cc46"),
+ Revision("339faa12616", "4d45e479fbb9"),
+ Revision("47578179e766", "2529684536da"),
+ Revision("2ee35cac4c62", "b21f22233f"),
+ Revision("50c7b21c9089", ("4ed16fad67a7", "3b96bcc8da76")),
+ Revision("78982bf5499", "ae77a2ed69b"),
+ Revision("519c9f3ce294", "2c50d8bab6ee"),
+ Revision("2720fc75e5fd", "1cccf82219cb"),
+ Revision("21638ec787ba", "44d58f1d31f0"),
+ Revision("59742a546b84", "49ddec8c7a5e"),
+ Revision("2d30d79c4093", "135b5fd31062"),
+ Revision("f034666cd80", ("5a0003c3b09c", "38d069994064")),
+ Revision("430133b6d46c", "181a149173e"),
+ Revision("3a3b24a31b57", ("16426dbea880", "4cc2960cbe19")),
+ Revision("2529684536da", ("64ed798bcc3", "1f6963d1ae02")),
+ Revision("17b66754be21", ("19e0db9d806a", "24c2620b2e3f")),
+ Revision("3cf34fcd6473", ("52804033910e", "4789d9c82ca7")),
+ Revision("36c31e4e1c37", "504d5168afe1"),
+ Revision("54f871bfe0b0", "519c9f3ce294"),
+ Revision("4a4e0838e206", "2a7f37cf7770"),
+ Revision("19e0db9d806a", ("430074f99c29", "36f1a410ed")),
+ Revision("44ed1bf512a0", "247ad6880f93"),
+ Revision("42050bf030fd", "2f93dd880bae"),
+ Revision("2c50d8bab6ee", "199943ccc774"),
+ Revision("504d5168afe1", ("28f4dd53ad3a", "30931c545bf")),
+ Revision("915c67915c2", "3fc8a578bc0a"),
+ Revision("2a7f37cf7770", "2720fc75e5fd"),
+ Revision("4c93d063d2ba", "4e28e2f4fe2f"),
+ Revision("42ded7355da2", "21638ec787ba"),
]
-map_ = RevisionMap(
- lambda: data
-)
-
-
+map_ = RevisionMap(lambda: data)
diff --git a/tests/conftest.py b/tests/conftest.py
index 608d903..6cf770d 100755
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,12 +11,16 @@ import os
# use bootstrapping so that test plugins are loaded
# without touching the main library before coverage starts
bootstrap_file = os.path.join(
- os.path.dirname(__file__), "..", "alembic",
- "testing", "plugin", "bootstrap.py"
+ os.path.dirname(__file__),
+ "..",
+ "alembic",
+ "testing",
+ "plugin",
+ "bootstrap.py",
)
with open(bootstrap_file) as f:
- code = compile(f.read(), "bootstrap.py", 'exec')
+ code = compile(f.read(), "bootstrap.py", "exec")
to_bootstrap = "pytest"
exec(code, globals(), locals())
from pytestplugin import * # noqa
diff --git a/tests/requirements.py b/tests/requirements.py
index 0be95eb..74a7d83 100644
--- a/tests/requirements.py
+++ b/tests/requirements.py
@@ -5,16 +5,12 @@ from alembic.util import sqla_compat
class DefaultRequirements(SuiteRequirements):
-
@property
def schemas(self):
"""Target database must support external schemas, and have one
named 'test_schema'."""
- return exclusions.skip_if([
- "sqlite",
- "firebird"
- ], "no schema support")
+ return exclusions.skip_if(["sqlite", "firebird"], "no schema support")
@property
def no_referential_integrity(self):
@@ -48,12 +44,12 @@ class DefaultRequirements(SuiteRequirements):
@property
def unnamed_constraints(self):
"""constraints without names are supported."""
- return exclusions.only_on(['sqlite'])
+ return exclusions.only_on(["sqlite"])
@property
def fk_names(self):
"""foreign key constraints always have names in the DB"""
- return exclusions.fails_on('sqlite')
+ return exclusions.fails_on("sqlite")
@property
def no_name_normalize(self):
@@ -63,20 +59,24 @@ class DefaultRequirements(SuiteRequirements):
@property
def reflects_fk_options(self):
- return exclusions.only_on([
- 'postgresql', 'mysql',
- lambda config: util.sqla_110 and
- exclusions.against(config, 'sqlite')])
+ return exclusions.only_on(
+ [
+ "postgresql",
+ "mysql",
+ lambda config: util.sqla_110
+ and exclusions.against(config, "sqlite"),
+ ]
+ )
@property
def fk_initially(self):
"""backend supports INITIALLY option in foreign keys"""
- return exclusions.only_on(['postgresql'])
+ return exclusions.only_on(["postgresql"])
@property
def fk_deferrable(self):
"""backend supports DEFERRABLE option in foreign keys"""
- return exclusions.only_on(['postgresql'])
+ return exclusions.only_on(["postgresql"])
@property
def flexible_fk_cascades(self):
@@ -84,8 +84,7 @@ class DefaultRequirements(SuiteRequirements):
full range of keywords (e.g. NO ACTION, etc.)"""
return exclusions.skip_if(
- ['oracle'],
- 'target backend has poor FK cascade syntax'
+ ["oracle"], "target backend has poor FK cascade syntax"
)
@property
@@ -97,10 +96,13 @@ class DefaultRequirements(SuiteRequirements):
"""Target driver reflects the name of primary key constraints."""
return exclusions.fails_on_everything_except(
- 'postgresql', 'oracle', 'mssql', 'sybase',
+ "postgresql",
+ "oracle",
+ "mssql",
+ "sybase",
lambda config: (
util.sqla_110 and exclusions.against(config, "sqlite")
- )
+ ),
)
@property
@@ -122,8 +124,10 @@ class DefaultRequirements(SuiteRequirements):
return False
count = config.db.scalar(
"SELECT count(*) FROM pg_extension "
- "WHERE extname='%s'" % name)
+ "WHERE extname='%s'" % name
+ )
return bool(count)
+
return exclusions.only_if(check, "needs %s extension" % name)
@property
@@ -134,7 +138,6 @@ class DefaultRequirements(SuiteRequirements):
def btree_gist(self):
return self._has_pg_extension("btree_gist")
-
@property
def autoincrement_on_composite_pk(self):
return exclusions.skip_if(["sqlite"], "not supported by database")
@@ -153,34 +156,42 @@ class DefaultRequirements(SuiteRequirements):
@property
def mysql_check_reflection_or_none(self):
def go(config):
- return not self._mariadb_102(config) \
- or self.sqlalchemy_1115.enabled
+ return (
+ not self._mariadb_102(config) or self.sqlalchemy_1115.enabled
+ )
+
return exclusions.succeeds_if(go)
@property
def mysql_timestamp_reflection(self):
def go(config):
- return not self._mariadb_102(config) \
- or self.sqlalchemy_1115.enabled
+ return (
+ not self._mariadb_102(config) or self.sqlalchemy_1115.enabled
+ )
+
return exclusions.only_if(go)
def _mariadb_102(self, config):
- return exclusions.against(config, "mysql") and \
- sqla_compat._is_mariadb(config.db.dialect) and \
- sqla_compat._mariadb_normalized_version_info(
- config.db.dialect) > (10, 2)
+ return (
+ exclusions.against(config, "mysql")
+ and sqla_compat._is_mariadb(config.db.dialect)
+ and sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+ > (10, 2)
+ )
def _mariadb_only_102(self, config):
- return exclusions.against(config, "mysql") and \
- sqla_compat._is_mariadb(config.db.dialect) and \
- sqla_compat._mariadb_normalized_version_info(
- config.db.dialect) >= (10, 2) and \
- sqla_compat._mariadb_normalized_version_info(
- config.db.dialect) < (10, 3)
+ return (
+ exclusions.against(config, "mysql")
+ and sqla_compat._is_mariadb(config.db.dialect)
+ and sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+ >= (10, 2)
+ and sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+ < (10, 3)
+ )
def _mysql_not_mariadb_102(self, config):
return exclusions.against(config, "mysql") and (
- not sqla_compat._is_mariadb(config.db.dialect) or
- sqla_compat._mariadb_normalized_version_info(
- config.db.dialect) < (10, 2)
+ not sqla_compat._is_mariadb(config.db.dialect)
+ or sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+ < (10, 2)
)
diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py
index c536317..7694cbd 100644
--- a/tests/test_autogen_composition.py
+++ b/tests/test_autogen_composition.py
@@ -9,64 +9,73 @@ from ._autogen_fixtures import AutogenTest, ModelOne, _default_include_object
class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
def test_render_nothing(self):
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m1,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": self.m1,
+ "upgrade_token": "upgrades",
+ "downgrade_token": "downgrades",
+ },
)
template_args = {}
autogenerate._render_migration_diffs(context, template_args)
- eq_(re.sub(r"u'", "'", template_args['upgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["upgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
pass
- # ### end Alembic commands ###""")
- eq_(re.sub(r"u'", "'", template_args['downgrades']),
+ # ### end Alembic commands ###""",
+ )
+ eq_(
+ re.sub(r"u'", "'", template_args["downgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
pass
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
def test_render_nothing_batch(self):
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m1,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- 'render_as_batch': True,
- 'include_symbol': lambda name, schema: False
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": self.m1,
+ "upgrade_token": "upgrades",
+ "downgrade_token": "downgrades",
+ "alembic_module_prefix": "op.",
+ "sqlalchemy_module_prefix": "sa.",
+ "render_as_batch": True,
+ "include_symbol": lambda name, schema: False,
+ },
)
template_args = {}
autogenerate._render_migration_diffs(context, template_args)
- eq_(re.sub(r"u'", "'", template_args['upgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["upgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
pass
- # ### end Alembic commands ###""")
- eq_(re.sub(r"u'", "'", template_args['downgrades']),
+ # ### end Alembic commands ###""",
+ )
+ eq_(
+ re.sub(r"u'", "'", template_args["downgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
pass
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
def test_render_diffs_standard(self):
"""test a full render including indentation"""
template_args = {}
autogenerate._render_migration_diffs(self.context, template_args)
- eq_(re.sub(r"u'", "'", template_args['upgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["upgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
sa.Column('id', sa.Integer(), nullable=False),
@@ -96,9 +105,11 @@ nullable=True))
nullable=False)
op.drop_index('pw_idx', table_name='user')
op.drop_column('user', 'pw')
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
- eq_(re.sub(r"u'", "'", template_args['downgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["downgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
op.add_column('user', sa.Column('pw', sa.VARCHAR(length=50), \
nullable=True))
@@ -125,16 +136,18 @@ nullable=True))
sa.ForeignKeyConstraint(['uid'], ['user.id'], )
)
op.drop_table('item')
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
def test_render_diffs_batch(self):
"""test a full render in batch mode including indentation"""
template_args = {}
- self.context.opts['render_as_batch'] = True
+ self.context.opts["render_as_batch"] = True
autogenerate._render_migration_diffs(self.context, template_args)
- eq_(re.sub(r"u'", "'", template_args['upgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["upgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
sa.Column('id', sa.Integer(), nullable=False),
@@ -169,9 +182,11 @@ nullable=True))
batch_op.drop_index('pw_idx')
batch_op.drop_column('pw')
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
- eq_(re.sub(r"u'", "'", template_args['downgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["downgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.add_column(sa.Column('pw', sa.VARCHAR(length=50), nullable=True))
@@ -203,74 +218,80 @@ nullable=True))
sa.ForeignKeyConstraint(['uid'], ['user.id'], )
)
op.drop_table('item')
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
def test_imports_maintined(self):
template_args = {}
- self.context.opts['render_as_batch'] = True
+ self.context.opts["render_as_batch"] = True
def render_item(type_, col, autogen_context):
autogen_context.imports.add(
"from mypackage import my_special_import"
)
- autogen_context.imports.add(
- "from foobar import bat"
- )
+ autogen_context.imports.add("from foobar import bat")
self.context.opts["render_item"] = render_item
autogenerate._render_migration_diffs(self.context, template_args)
eq_(
+ set(template_args["imports"].split("\n")),
set(
- template_args['imports'].split("\n")
+ [
+ "from foobar import bat",
+ "from mypackage import my_special_import",
+ ]
),
- set([
- "from foobar import bat",
- "from mypackage import my_special_import"
- ])
)
class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
schema = "test_schema"
def test_render_nothing(self):
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m1,
- 'upgrade_token': "upgrades",
- 'downgrade_token': "downgrades",
- 'alembic_module_prefix': 'op.',
- 'sqlalchemy_module_prefix': 'sa.',
- 'include_symbol': lambda name, schema: False
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": self.m1,
+ "upgrade_token": "upgrades",
+ "downgrade_token": "downgrades",
+ "alembic_module_prefix": "op.",
+ "sqlalchemy_module_prefix": "sa.",
+ "include_symbol": lambda name, schema: False,
+ },
)
template_args = {}
autogenerate._render_migration_diffs(context, template_args)
- eq_(re.sub(r"u'", "'", template_args['upgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["upgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
pass
- # ### end Alembic commands ###""")
- eq_(re.sub(r"u'", "'", template_args['downgrades']),
+ # ### end Alembic commands ###""",
+ )
+ eq_(
+ re.sub(r"u'", "'", template_args["downgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
pass
- # ### end Alembic commands ###""")
+ # ### end Alembic commands ###""",
+ )
def test_render_diffs_extras(self):
"""test a full render including indentation (include and schema)"""
template_args = {}
- self.context.opts.update({
- 'include_object': _default_include_object,
- 'include_schemas': True
- })
+ self.context.opts.update(
+ {
+ "include_object": _default_include_object,
+ "include_schemas": True,
+ }
+ )
autogenerate._render_migration_diffs(self.context, template_args)
- eq_(re.sub(r"u'", "'", template_args['upgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["upgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
sa.Column('id', sa.Integer(), nullable=False),
@@ -307,9 +328,12 @@ source_schema='%(schema)s', referent_schema='%(schema)s')
schema='%(schema)s')
op.drop_index('pw_idx', table_name='user', schema='test_schema')
op.drop_column('user', 'pw', schema='%(schema)s')
- # ### end Alembic commands ###""" % {"schema": self.schema})
+ # ### end Alembic commands ###"""
+ % {"schema": self.schema},
+ )
- eq_(re.sub(r"u'", "'", template_args['downgrades']),
+ eq_(
+ re.sub(r"u'", "'", template_args["downgrades"]),
"""# ### commands auto generated by Alembic - please adjust! ###
op.add_column('user', sa.Column('pw', sa.VARCHAR(length=50), \
autoincrement=False, nullable=True), schema='%(schema)s')
@@ -341,5 +365,6 @@ name='extra_uid_fkey'),
schema='%(schema)s'
)
op.drop_table('item', schema='%(schema)s')
- # ### end Alembic commands ###""" % {"schema": self.schema})
-
+ # ### end Alembic commands ###"""
+ % {"schema": self.schema},
+ )
diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py
index 38e06b6..af2e2c2 100644
--- a/tests/test_autogen_diffs.py
+++ b/tests/test_autogen_diffs.py
@@ -1,10 +1,30 @@
import sys
-from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
- Numeric, CHAR, ForeignKey, INTEGER, Index, UniqueConstraint, \
- TypeDecorator, CheckConstraint, text, PrimaryKeyConstraint, \
- ForeignKeyConstraint, VARCHAR, DECIMAL, DateTime, BigInteger, BIGINT, \
- SmallInteger
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Integer,
+ String,
+ Text,
+ Numeric,
+ CHAR,
+ ForeignKey,
+ INTEGER,
+ Index,
+ UniqueConstraint,
+ TypeDecorator,
+ CheckConstraint,
+ text,
+ PrimaryKeyConstraint,
+ ForeignKeyConstraint,
+ VARCHAR,
+ DECIMAL,
+ DateTime,
+ BigInteger,
+ BIGINT,
+ SmallInteger,
+)
from sqlalchemy.dialects import sqlite
from sqlalchemy.types import NULLTYPE, VARBINARY
from sqlalchemy.engine.reflection import Inspector
@@ -20,62 +40,41 @@ from alembic.testing import eq_, is_, is_not_
from alembic.util import CommandError
from ._autogen_fixtures import AutogenTest, AutogenFixtureTest
-py3k = sys.version_info >= (3, )
+py3k = sys.version_info >= (3,)
class AutogenCrossSchemaTest(AutogenTest, TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
@classmethod
def _get_db_schema(cls):
m = MetaData()
- Table('t1', m,
- Column('x', Integer)
- )
- Table('t2', m,
- Column('y', Integer),
- schema=config.test_schema
- )
- Table('t6', m,
- Column('u', Integer)
- )
- Table('t7', m,
- Column('v', Integer),
- schema=config.test_schema
- )
+ Table("t1", m, Column("x", Integer))
+ Table("t2", m, Column("y", Integer), schema=config.test_schema)
+ Table("t6", m, Column("u", Integer))
+ Table("t7", m, Column("v", Integer), schema=config.test_schema)
return m
@classmethod
def _get_model_schema(cls):
m = MetaData()
- Table('t3', m,
- Column('q', Integer)
- )
- Table('t4', m,
- Column('z', Integer),
- schema=config.test_schema
- )
- Table('t6', m,
- Column('u', Integer)
- )
- Table('t7', m,
- Column('v', Integer),
- schema=config.test_schema
- )
+ Table("t3", m, Column("q", Integer))
+ Table("t4", m, Column("z", Integer), schema=config.test_schema)
+ Table("t6", m, Column("u", Integer))
+ Table("t7", m, Column("v", Integer), schema=config.test_schema)
return m
def test_default_schema_omitted_upgrade(self):
-
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t3"
else:
return True
+
self._update_context(
- object_filters=include_object,
- include_schemas=True,
+ object_filters=include_object, include_schemas=True
)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -85,7 +84,6 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
eq_(diffs[0][1].schema, None)
def test_alt_schema_included_upgrade(self):
-
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t4"
@@ -93,8 +91,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return True
self._update_context(
- object_filters=include_object,
- include_schemas=True,
+ object_filters=include_object, include_schemas=True
)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -109,9 +106,9 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
return name == "t1"
else:
return True
+
self._update_context(
- object_filters=include_object,
- include_schemas=True,
+ object_filters=include_object, include_schemas=True
)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -121,15 +118,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
eq_(diffs[0][1].schema, None)
def test_alt_schema_included_downgrade(self):
-
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t2"
else:
return True
+
self._update_context(
- object_filters=include_object,
- include_schemas=True,
+ object_filters=include_object, include_schemas=True
)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -139,7 +135,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
def test_uses_explcit_schema_in_default_one(self):
@@ -149,8 +145,8 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('a', m1, Column('x', String(50)))
- Table('a', m2, Column('x', String(50)), schema=default_schema)
+ Table("a", m1, Column("x", String(50)))
+ Table("a", m2, Column("x", String(50)), schema=default_schema)
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs, [])
@@ -162,15 +158,15 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('a', m1, Column('x', String(50)))
- Table('a', m2, Column('x', String(50)), schema=default_schema)
- Table('a', m2, Column('y', String(50)), schema="test_schema")
+ Table("a", m1, Column("x", String(50)))
+ Table("a", m2, Column("x", String(50)), schema=default_schema)
+ Table("a", m2, Column("y", String(50)), schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(len(diffs), 1)
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].schema, "test_schema")
- eq_(diffs[0][1].c.keys(), ['y'])
+ eq_(diffs[0][1].c.keys(), ["y"])
def test_uses_explcit_schema_in_default_three(self):
@@ -179,20 +175,20 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('a', m1, Column('y', String(50)), schema="test_schema")
+ Table("a", m1, Column("y", String(50)), schema="test_schema")
- Table('a', m2, Column('x', String(50)), schema=default_schema)
- Table('a', m2, Column('y', String(50)), schema="test_schema")
+ Table("a", m2, Column("x", String(50)), schema=default_schema)
+ Table("a", m2, Column("y", String(50)), schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(len(diffs), 1)
eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].schema, default_schema)
- eq_(diffs[0][1].c.keys(), ['x'])
+ eq_(diffs[0][1].c.keys(), ["x"])
class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
def setUp(self):
super(AutogenDefaultSchemaIsNoneTest, self).setUp()
@@ -205,23 +201,23 @@ class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('a', m1, Column('x', String(50)))
- Table('a', m2, Column('x', String(50)))
+ Table("a", m1, Column("x", String(50)))
+ Table("a", m2, Column("x", String(50)))
def _include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
- return name in 'a' and obj.schema != 'main'
+ return name in "a" and obj.schema != "main"
else:
return True
diffs = self._fixture(
- m1, m2, include_schemas=True,
- object_filters=_include_object)
+ m1, m2, include_schemas=True, object_filters=_include_object
+ )
eq_(len(diffs), 0)
class ModelOne(object):
- __requires__ = ('unique_constraint_reflection', )
+ __requires__ = ("unique_constraint_reflection",)
schema = None
@@ -231,30 +227,42 @@ class ModelOne(object):
m = MetaData(schema=schema)
- Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50)),
- Column('a1', Text),
- Column("pw", String(50)),
- Index('pw_idx', 'pw')
- )
-
- Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- )
-
- Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column("amount", Numeric(8, 2), nullable=False,
- server_default=text("0")),
- CheckConstraint('amount >= 0', name='ck_order_amount')
- )
-
- Table('extra', m,
- Column("x", CHAR),
- Column('uid', Integer, ForeignKey('user.id'))
- )
+ Table(
+ "user",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50)),
+ Column("a1", Text),
+ Column("pw", String(50)),
+ Index("pw_idx", "pw"),
+ )
+
+ Table(
+ "address",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email_address", String(100), nullable=False),
+ )
+
+ Table(
+ "order",
+ m,
+ Column("order_id", Integer, primary_key=True),
+ Column(
+ "amount",
+ Numeric(8, 2),
+ nullable=False,
+ server_default=text("0"),
+ ),
+ CheckConstraint("amount >= 0", name="ck_order_amount"),
+ )
+
+ Table(
+ "extra",
+ m,
+ Column("x", CHAR),
+ Column("uid", Integer, ForeignKey("user.id")),
+ )
return m
@@ -264,38 +272,50 @@ class ModelOne(object):
m = MetaData(schema=schema)
- Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', Text, server_default="x")
- )
-
- Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('street', String(50)),
- UniqueConstraint('email_address', name="uq_email")
- )
-
- Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True,
- server_default=text("0")),
- Column('user_id', Integer, ForeignKey('user.id')),
- CheckConstraint('amount > -1', name='ck_order_amount'),
- )
-
- Table('item', m,
- Column('id', Integer, primary_key=True),
- Column('description', String(100)),
- Column('order_id', Integer, ForeignKey('order.order_id')),
- CheckConstraint('len(description) > 5')
- )
+ Table(
+ "user",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", Text, server_default="x"),
+ )
+
+ Table(
+ "address",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email_address", String(100), nullable=False),
+ Column("street", String(50)),
+ UniqueConstraint("email_address", name="uq_email"),
+ )
+
+ Table(
+ "order",
+ m,
+ Column("order_id", Integer, primary_key=True),
+ Column(
+ "amount",
+ Numeric(10, 2),
+ nullable=True,
+ server_default=text("0"),
+ ),
+ Column("user_id", Integer, ForeignKey("user.id")),
+ CheckConstraint("amount > -1", name="ck_order_amount"),
+ )
+
+ Table(
+ "item",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("description", String(100)),
+ Column("order_id", Integer, ForeignKey("order.order_id")),
+ CheckConstraint("len(description) > 5"),
+ )
return m
class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
def test_diffs(self):
"""test generation of diff rules"""
@@ -304,23 +324,18 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
uo = ops.UpgradeOps(ops=[])
ctx = self.autogen_context
- autogenerate._produce_net_changes(
- ctx, uo
- )
+ autogenerate._produce_net_changes(ctx, uo)
diffs = uo.as_diffs()
- eq_(
- diffs[0],
- ('add_table', metadata.tables['item'])
- )
+ eq_(diffs[0], ("add_table", metadata.tables["item"]))
- eq_(diffs[1][0], 'remove_table')
+ eq_(diffs[1][0], "remove_table")
eq_(diffs[1][1].name, "extra")
eq_(diffs[2][0], "add_column")
eq_(diffs[2][1], None)
eq_(diffs[2][2], "address")
- eq_(diffs[2][3], metadata.tables['address'].c.street)
+ eq_(diffs[2][3], metadata.tables["address"].c.street)
eq_(diffs[3][0], "add_constraint")
eq_(diffs[3][1].name, "uq_email")
@@ -328,7 +343,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
eq_(diffs[4][0], "add_column")
eq_(diffs[4][1], None)
eq_(diffs[4][2], "order")
- eq_(diffs[4][3], metadata.tables['order'].c.user_id)
+ eq_(diffs[4][3], metadata.tables["order"].c.user_id)
eq_(diffs[5][0][0], "modify_type")
eq_(diffs[5][0][1], None)
@@ -338,9 +353,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
self._assert_fk_diff(
- diffs[6], "add_fk",
- "order", ["user_id"],
- "user", ["id"]
+ diffs[6], "add_fk", "order", ["user_id"], "user", ["id"]
)
eq_(diffs[7][0][0], "modify_default")
@@ -349,45 +362,47 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
eq_(diffs[7][0][3], "a1")
eq_(diffs[7][0][6].arg, "x")
- eq_(diffs[8][0][0], 'modify_nullable')
+ eq_(diffs[8][0][0], "modify_nullable")
eq_(diffs[8][0][5], True)
eq_(diffs[8][0][6], False)
- eq_(diffs[9][0], 'remove_index')
- eq_(diffs[9][1].name, 'pw_idx')
+ eq_(diffs[9][0], "remove_index")
+ eq_(diffs[9][1].name, "pw_idx")
- eq_(diffs[10][0], 'remove_column')
- eq_(diffs[10][3].name, 'pw')
- eq_(diffs[10][3].table.name, 'user')
- assert isinstance(
- diffs[10][3].type, String
- )
+ eq_(diffs[10][0], "remove_column")
+ eq_(diffs[10][3].name, "pw")
+ eq_(diffs[10][3].table.name, "user")
+ assert isinstance(diffs[10][3].type, String)
def test_include_symbol(self):
diffs = []
def include_symbol(name, schema=None):
- return name in ('address', 'order')
+ return name in ("address", "order")
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m2,
- 'include_symbol': include_symbol,
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": self.m2,
+ "include_symbol": include_symbol,
+ },
)
diffs = autogenerate.compare_metadata(
- context, context.opts['target_metadata'])
+ context, context.opts["target_metadata"]
+ )
- alter_cols = set([
- d[2] for d in self._flatten_diffs(diffs)
- if d[0].startswith('modify')
- ])
- eq_(alter_cols, set(['order']))
+ alter_cols = set(
+ [
+ d[2]
+ for d in self._flatten_diffs(diffs)
+ if d[0].startswith("modify")
+ ]
+ )
+ eq_(alter_cols, set(["order"]))
def test_include_object(self):
def include_object(obj, name, type_, reflected, compare_to):
@@ -410,33 +425,46 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'target_metadata': self.m2,
- 'include_object': include_object,
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": self.m2,
+ "include_object": include_object,
+ },
)
diffs = autogenerate.compare_metadata(
- context, context.opts['target_metadata'])
+ context, context.opts["target_metadata"]
+ )
- alter_cols = set([
- d[2] for d in self._flatten_diffs(diffs)
- if d[0].startswith('modify')
- ]).union(
- d[3].name for d in self._flatten_diffs(diffs)
- if d[0] == 'add_column'
- ).union(
- d[1].name for d in self._flatten_diffs(diffs)
- if d[0] == 'add_table'
+ alter_cols = (
+ set(
+ [
+ d[2]
+ for d in self._flatten_diffs(diffs)
+ if d[0].startswith("modify")
+ ]
+ )
+ .union(
+ d[3].name
+ for d in self._flatten_diffs(diffs)
+ if d[0] == "add_column"
+ )
+ .union(
+ d[1].name
+ for d in self._flatten_diffs(diffs)
+ if d[0] == "add_table"
+ )
)
- eq_(alter_cols, set(['user_id', 'order', 'user']))
+ eq_(alter_cols, set(["user_id", "order", "user"]))
def test_skip_null_type_comparison_reflected(self):
ac = ops.AlterColumnOp("sometable", "somecol")
autogenerate.compare._compare_type(
- self.autogen_context, ac,
- None, "sometable", "somecol",
+ self.autogen_context,
+ ac,
+ None,
+ "sometable",
+ "somecol",
Column("somecol", NULLTYPE),
Column("somecol", Integer()),
)
@@ -446,8 +474,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
def test_skip_null_type_comparison_local(self):
ac = ops.AlterColumnOp("sometable", "somecol")
autogenerate.compare._compare_type(
- self.autogen_context, ac,
- None, "sometable", "somecol",
+ self.autogen_context,
+ ac,
+ None,
+ "sometable",
+ "somecol",
Column("somecol", Integer()),
Column("somecol", NULLTYPE),
)
@@ -463,8 +494,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
ac = ops.AlterColumnOp("sometable", "somecol")
autogenerate.compare._compare_type(
- self.autogen_context, ac,
- None, "sometable", "somecol",
+ self.autogen_context,
+ ac,
+ None,
+ "sometable",
+ "somecol",
Column("somecol", INTEGER()),
Column("somecol", MyType()),
)
@@ -473,56 +507,63 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
ac = ops.AlterColumnOp("sometable", "somecol")
autogenerate.compare._compare_type(
- self.autogen_context, ac,
- None, "sometable", "somecol",
+ self.autogen_context,
+ ac,
+ None,
+ "sometable",
+ "somecol",
Column("somecol", String()),
Column("somecol", MyType()),
)
diff = ac.to_diff_tuple()
- eq_(
- diff[0][0:4],
- ('modify_type', None, 'sometable', 'somecol')
- )
+ eq_(diff[0][0:4], ("modify_type", None, "sometable", "somecol"))
def test_affinity_typedec(self):
class MyType(TypeDecorator):
impl = CHAR
def load_dialect_impl(self, dialect):
- if dialect.name == 'sqlite':
+ if dialect.name == "sqlite":
return dialect.type_descriptor(Integer())
else:
return dialect.type_descriptor(CHAR(32))
- uo = ops.AlterColumnOp('sometable', 'somecol')
+ uo = ops.AlterColumnOp("sometable", "somecol")
autogenerate.compare._compare_type(
- self.autogen_context, uo,
- None, "sometable", "somecol",
+ self.autogen_context,
+ uo,
+ None,
+ "sometable",
+ "somecol",
Column("somecol", Integer, nullable=True),
- Column("somecol", MyType())
+ Column("somecol", MyType()),
)
assert not uo.has_changes()
def test_dont_barf_on_already_reflected(self):
from sqlalchemy.util import OrderedSet
+
inspector = Inspector.from_engine(self.bind)
uo = ops.UpgradeOps(ops=[])
autogenerate.compare._compare_tables(
- OrderedSet([(None, 'extra'), (None, 'user')]),
- OrderedSet(), inspector,
- uo, self.autogen_context
+ OrderedSet([(None, "extra"), (None, "user")]),
+ OrderedSet(),
+ inspector,
+ uo,
+ self.autogen_context,
)
eq_(
[(rec[0], rec[1].name) for rec in uo.as_diffs()],
[
- ('remove_table', 'extra'),
- ('remove_index', 'pw_idx'),
- ('remove_table', 'user'), ]
+ ("remove_table", "extra"),
+ ("remove_index", "pw_idx"),
+ ("remove_table", "user"),
+ ],
)
class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
schema = "test_schema"
@@ -531,26 +572,21 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
metadata = self.m2
- self._update_context(
- include_schemas=True,
- )
+ self._update_context(include_schemas=True)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(self.autogen_context, uo)
diffs = uo.as_diffs()
- eq_(
- diffs[0],
- ('add_table', metadata.tables['%s.item' % self.schema])
- )
+ eq_(diffs[0], ("add_table", metadata.tables["%s.item" % self.schema]))
- eq_(diffs[1][0], 'remove_table')
+ eq_(diffs[1][0], "remove_table")
eq_(diffs[1][1].name, "extra")
eq_(diffs[2][0], "add_column")
eq_(diffs[2][1], self.schema)
eq_(diffs[2][2], "address")
- eq_(diffs[2][3], metadata.tables['%s.address' % self.schema].c.street)
+ eq_(diffs[2][3], metadata.tables["%s.address" % self.schema].c.street)
eq_(diffs[3][0], "add_constraint")
eq_(diffs[3][1].name, "uq_email")
@@ -558,7 +594,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
eq_(diffs[4][0], "add_column")
eq_(diffs[4][1], self.schema)
eq_(diffs[4][2], "order")
- eq_(diffs[4][3], metadata.tables['%s.order' % self.schema].c.user_id)
+ eq_(diffs[4][3], metadata.tables["%s.order" % self.schema].c.user_id)
eq_(diffs[5][0][0], "modify_type")
eq_(diffs[5][0][1], self.schema)
@@ -568,10 +604,13 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
self._assert_fk_diff(
- diffs[6], "add_fk",
- "order", ["user_id"],
- "user", ["id"],
- source_schema=config.test_schema
+ diffs[6],
+ "add_fk",
+ "order",
+ ["user_id"],
+ "user",
+ ["id"],
+ source_schema=config.test_schema,
)
eq_(diffs[7][0][0], "modify_default")
@@ -580,15 +619,15 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
eq_(diffs[7][0][3], "a1")
eq_(diffs[7][0][6].arg, "x")
- eq_(diffs[8][0][0], 'modify_nullable')
+ eq_(diffs[8][0][0], "modify_nullable")
eq_(diffs[8][0][5], True)
eq_(diffs[8][0][6], False)
- eq_(diffs[9][0], 'remove_index')
- eq_(diffs[9][1].name, 'pw_idx')
+ eq_(diffs[9][0], "remove_index")
+ eq_(diffs[9][1].name, "pw_idx")
- eq_(diffs[10][0], 'remove_column')
- eq_(diffs[10][3].name, 'pw')
+ eq_(diffs[10][0], "remove_column")
+ eq_(diffs[10][3].name, "pw")
class CompareTypeSpecificityTest(TestBase):
@@ -597,10 +636,10 @@ class CompareTypeSpecificityTest(TestBase):
from sqlalchemy.engine import default
return impl.DefaultImpl(
- default.DefaultDialect(), None, False, True, None, {})
+ default.DefaultDialect(), None, False, True, None, {}
+ )
def test_typedec_to_nonstandard(self):
-
class PasswordType(TypeDecorator):
impl = VARBINARY
@@ -608,7 +647,7 @@ class CompareTypeSpecificityTest(TestBase):
return PasswordType(self.impl.length)
def load_dialect_impl(self, dialect):
- if dialect.name == 'default':
+ if dialect.name == "default":
impl = sqlite.NUMERIC(self.length)
else:
impl = VARBINARY(self.length)
@@ -616,8 +655,8 @@ class CompareTypeSpecificityTest(TestBase):
impl = self._fixture()
impl.compare_type(
- Column('x', sqlite.NUMERIC(50)),
- Column('x', PasswordType(50)))
+ Column("x", sqlite.NUMERIC(50)), Column("x", PasswordType(50))
+ )
def test_string(self):
t1 = String(30)
@@ -626,9 +665,9 @@ class CompareTypeSpecificityTest(TestBase):
t4 = Integer
impl = self._fixture()
- is_(impl.compare_type(Column('x', t3), Column('x', t1)), False)
- is_(impl.compare_type(Column('x', t3), Column('x', t2)), True)
- is_(impl.compare_type(Column('x', t3), Column('x', t4)), True)
+ is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
+ is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
+ is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
def test_numeric(self):
t1 = Numeric(10, 5)
@@ -637,16 +676,16 @@ class CompareTypeSpecificityTest(TestBase):
t4 = DateTime
impl = self._fixture()
- is_(impl.compare_type(Column('x', t3), Column('x', t1)), False)
- is_(impl.compare_type(Column('x', t3), Column('x', t2)), True)
- is_(impl.compare_type(Column('x', t3), Column('x', t4)), True)
+ is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
+ is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
+ is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
def test_numeric_noprecision(self):
t1 = Numeric()
t2 = Numeric(scale=5)
impl = self._fixture()
- is_(impl.compare_type(Column('x', t1), Column('x', t2)), False)
+ is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
def test_integer(self):
t1 = Integer()
@@ -657,12 +696,12 @@ class CompareTypeSpecificityTest(TestBase):
t6 = BigInteger()
impl = self._fixture()
- is_(impl.compare_type(Column('x', t5), Column('x', t1)), False)
- is_(impl.compare_type(Column('x', t3), Column('x', t1)), True)
- is_(impl.compare_type(Column('x', t3), Column('x', t6)), False)
- is_(impl.compare_type(Column('x', t3), Column('x', t2)), True)
- is_(impl.compare_type(Column('x', t5), Column('x', t2)), True)
- is_(impl.compare_type(Column('x', t1), Column('x', t4)), True)
+ is_(impl.compare_type(Column("x", t5), Column("x", t1)), False)
+ is_(impl.compare_type(Column("x", t3), Column("x", t1)), True)
+ is_(impl.compare_type(Column("x", t3), Column("x", t6)), False)
+ is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
+ is_(impl.compare_type(Column("x", t5), Column("x", t2)), True)
+ is_(impl.compare_type(Column("x", t1), Column("x", t4)), True)
def test_datetime(self):
t1 = DateTime()
@@ -670,22 +709,19 @@ class CompareTypeSpecificityTest(TestBase):
t3 = DateTime(timezone=True)
impl = self._fixture()
- is_(impl.compare_type(Column('x', t1), Column('x', t2)), False)
- is_(impl.compare_type(Column('x', t1), Column('x', t3)), True)
- is_(impl.compare_type(Column('x', t2), Column('x', t3)), True)
+ is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
+ is_(impl.compare_type(Column("x", t1), Column("x", t3)), True)
+ is_(impl.compare_type(Column("x", t2), Column("x", t3)), True)
class AutogenSystemColTest(AutogenTest, TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
@classmethod
def _get_db_schema(cls):
m = MetaData()
- Table(
- 'sometable', m,
- Column('id', Integer, primary_key=True),
- )
+ Table("sometable", m, Column("id", Integer, primary_key=True))
return m
@classmethod
@@ -695,9 +731,10 @@ class AutogenSystemColTest(AutogenTest, TestBase):
# 'xmin' is implicitly present, when added to a model should produce
# no change
Table(
- 'sometable', m,
- Column('id', Integer, primary_key=True),
- Column('xmin', Integer, system=True)
+ "sometable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("xmin", Integer, system=True),
)
return m
@@ -715,30 +752,38 @@ class AutogenerateVariantCompareTest(AutogenTest, TestBase):
# 1.0.13 and lower fail on Postgresql due to variant / bigserial issue
# #3739
- __requires__ = ('sqlalchemy_1014', )
+ __requires__ = ("sqlalchemy_1014",)
@classmethod
def _get_db_schema(cls):
m = MetaData()
- Table('sometable', m,
- Column(
- 'id',
- BigInteger().with_variant(Integer, "sqlite"),
- primary_key=True),
- Column('value', String(50)))
+ Table(
+ "sometable",
+ m,
+ Column(
+ "id",
+ BigInteger().with_variant(Integer, "sqlite"),
+ primary_key=True,
+ ),
+ Column("value", String(50)),
+ )
return m
@classmethod
def _get_model_schema(cls):
m = MetaData()
- Table('sometable', m,
- Column(
- 'id',
- BigInteger().with_variant(Integer, "sqlite"),
- primary_key=True),
- Column('value', String(50)))
+ Table(
+ "sometable",
+ m,
+ Column(
+ "id",
+ BigInteger().with_variant(Integer, "sqlite"),
+ primary_key=True,
+ ),
+ Column("value", String(50)),
+ )
return m
def test_variant_no_issue(self):
@@ -750,24 +795,30 @@ class AutogenerateVariantCompareTest(AutogenTest, TestBase):
class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
@classmethod
def _get_db_schema(cls):
m = MetaData()
- Table('sometable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer))
+ Table(
+ "sometable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("value", Integer),
+ )
return m
@classmethod
def _get_model_schema(cls):
m = MetaData()
- Table('sometable', m,
- Column('id', Integer, primary_key=True),
- Column('value', String))
+ Table(
+ "sometable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("value", String),
+ )
return m
def test_uses_custom_compare_type_function(self):
@@ -779,15 +830,20 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
ctx = self.autogen_context
autogenerate._produce_net_changes(ctx, uo)
- first_table = self.m2.tables['sometable']
- first_column = first_table.columns['id']
+ first_table = self.m2.tables["sometable"]
+ first_column = first_table.columns["id"]
eq_(len(my_compare_type.mock_calls), 2)
# We'll just test the first call
_, args, _ = my_compare_type.mock_calls[0]
- (context, inspected_column, metadata_column,
- inspected_type, metadata_type) = args
+ (
+ context,
+ inspected_column,
+ metadata_column,
+ inspected_type,
+ metadata_type,
+ ) = args
eq_(context, self.context)
eq_(metadata_column, first_column)
eq_(metadata_type, first_column.type)
@@ -816,8 +872,8 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
autogenerate._produce_net_changes(ctx, uo)
diffs = uo.as_diffs()
- eq_(diffs[0][0][0], 'modify_type')
- eq_(diffs[1][0][0], 'modify_type')
+ eq_(diffs[0][0][0], "modify_type")
+ eq_(diffs[1][0][0], "modify_type")
class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
@@ -829,10 +885,11 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
m = MetaData()
Table(
- 'person_to_role', m,
- Column('person_id', Integer, autoincrement=False),
- Column('role_id', Integer, autoincrement=False),
- PrimaryKeyConstraint('person_id', 'role_id')
+ "person_to_role",
+ m,
+ Column("person_id", Integer, autoincrement=False),
+ Column("role_id", Integer, autoincrement=False),
+ PrimaryKeyConstraint("person_id", "role_id"),
)
return m
@@ -849,34 +906,40 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
class AutogenKeyTest(AutogenTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
@classmethod
def _get_db_schema(cls):
m = MetaData()
- Table('someothertable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer, key="somekey"),
- )
+ Table(
+ "someothertable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("value", Integer, key="somekey"),
+ )
return m
@classmethod
def _get_model_schema(cls):
m = MetaData()
- Table('sometable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer, key="someotherkey"),
- )
- Table('someothertable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer, key="somekey"),
- Column("othervalue", Integer, key="otherkey")
- )
+ Table(
+ "sometable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("value", Integer, key="someotherkey"),
+ )
+ Table(
+ "someothertable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("value", Integer, key="somekey"),
+ Column("othervalue", Integer, key="otherkey"),
+ )
return m
- symbols = ['someothertable', 'sometable']
+ symbols = ["someothertable", "sometable"]
def test_autogen(self):
@@ -892,16 +955,19 @@ class AutogenKeyTest(AutogenTest, TestBase):
class AutogenVersionTableTest(AutogenTest, TestBase):
- __only_on__ = 'sqlite'
- version_table_name = 'alembic_version'
+ __only_on__ = "sqlite"
+ version_table_name = "alembic_version"
version_table_schema = None
@classmethod
def _get_db_schema(cls):
m = MetaData()
Table(
- cls.version_table_name, m,
- Column('x', Integer), schema=cls.version_table_schema)
+ cls.version_table_name,
+ m,
+ Column("x", Integer),
+ schema=cls.version_table_schema,
+ )
return m
@classmethod
@@ -919,7 +985,10 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
def test_version_table_in_target(self):
Table(
self.version_table_name,
- self.m2, Column('x', Integer), schema=self.version_table_schema)
+ self.m2,
+ Column("x", Integer),
+ schema=self.version_table_schema,
+ )
ctx = self.autogen_context
uo = ops.UpgradeOps(ops=[])
@@ -928,29 +997,30 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
class AutogenCustomVersionTableSchemaTest(AutogenVersionTableTest):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
- version_table_schema = 'test_schema'
- configure_opts = {'version_table_schema': 'test_schema'}
+ version_table_schema = "test_schema"
+ configure_opts = {"version_table_schema": "test_schema"}
class AutogenCustomVersionTableTest(AutogenVersionTableTest):
- version_table_name = 'my_version_table'
- configure_opts = {'version_table': 'my_version_table'}
+ version_table_name = "my_version_table"
+ configure_opts = {"version_table": "my_version_table"}
class AutogenCustomVersionTableAndSchemaTest(AutogenVersionTableTest):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
- version_table_name = 'my_version_table'
- version_table_schema = 'test_schema'
+ version_table_name = "my_version_table"
+ version_table_schema = "test_schema"
configure_opts = {
- 'version_table': 'my_version_table',
- 'version_table_schema': 'test_schema'}
+ "version_table": "my_version_table",
+ "version_table_schema": "test_schema",
+ }
class AutogenerateDiffOrderTest(AutogenTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
@classmethod
def _get_db_schema(cls):
@@ -959,13 +1029,11 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase):
@classmethod
def _get_model_schema(cls):
m = MetaData()
- Table('parent', m,
- Column('id', Integer, primary_key=True)
- )
+ Table("parent", m, Column("id", Integer, primary_key=True))
- Table('child', m,
- Column('parent_id', Integer, ForeignKey('parent.id')),
- )
+ Table(
+ "child", m, Column("parent_id", Integer, ForeignKey("parent.id"))
+ )
return m
@@ -980,32 +1048,29 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase):
autogenerate._produce_net_changes(ctx, uo)
diffs = uo.as_diffs()
- eq_(diffs[0][0], 'add_table')
+ eq_(diffs[0][0], "add_table")
eq_(diffs[0][1].name, "parent")
- eq_(diffs[1][0], 'add_table')
+ eq_(diffs[1][0], "add_table")
eq_(diffs[1][1].name, "child")
class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
def test_compare_metadata(self):
metadata = self.m2
diffs = autogenerate.compare_metadata(self.context, metadata)
- eq_(
- diffs[0],
- ('add_table', metadata.tables['item'])
- )
+ eq_(diffs[0], ("add_table", metadata.tables["item"]))
- eq_(diffs[1][0], 'remove_table')
+ eq_(diffs[1][0], "remove_table")
eq_(diffs[1][1].name, "extra")
eq_(diffs[2][0], "add_column")
eq_(diffs[2][1], None)
eq_(diffs[2][2], "address")
- eq_(diffs[2][3], metadata.tables['address'].c.street)
+ eq_(diffs[2][3], metadata.tables["address"].c.street)
eq_(diffs[3][0], "add_constraint")
eq_(diffs[3][1].name, "uq_email")
@@ -1013,7 +1078,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
eq_(diffs[4][0], "add_column")
eq_(diffs[4][1], None)
eq_(diffs[4][2], "order")
- eq_(diffs[4][3], metadata.tables['order'].c.user_id)
+ eq_(diffs[4][3], metadata.tables["order"].c.user_id)
eq_(diffs[5][0][0], "modify_type")
eq_(diffs[5][0][1], None)
@@ -1023,9 +1088,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
self._assert_fk_diff(
- diffs[6], "add_fk",
- "order", ["user_id"],
- "user", ["id"]
+ diffs[6], "add_fk", "order", ["user_id"], "user", ["id"]
)
eq_(diffs[7][0][0], "modify_default")
@@ -1034,15 +1097,15 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
eq_(diffs[7][0][3], "a1")
eq_(diffs[7][0][6].arg, "x")
- eq_(diffs[8][0][0], 'modify_nullable')
+ eq_(diffs[8][0][0], "modify_nullable")
eq_(diffs[8][0][5], True)
eq_(diffs[8][0][6], False)
- eq_(diffs[9][0], 'remove_index')
- eq_(diffs[9][1].name, 'pw_idx')
+ eq_(diffs[9][0], "remove_index")
+ eq_(diffs[9][1].name, "pw_idx")
- eq_(diffs[10][0], 'remove_column')
- eq_(diffs[10][3].name, 'pw')
+ eq_(diffs[10][0], "remove_column")
+ eq_(diffs[10][3].name, "pw")
def test_compare_metadata_include_object(self):
metadata = self.m2
@@ -1058,46 +1121,46 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'include_object': include_object,
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "include_object": include_object,
+ },
)
diffs = autogenerate.compare_metadata(context, metadata)
- eq_(diffs[0][0], 'remove_table')
+ eq_(diffs[0][0], "remove_table")
eq_(diffs[0][1].name, "extra")
eq_(diffs[1][0], "add_column")
eq_(diffs[1][1], None)
eq_(diffs[1][2], "order")
- eq_(diffs[1][3], metadata.tables['order'].c.user_id)
+ eq_(diffs[1][3], metadata.tables["order"].c.user_id)
def test_compare_metadata_include_symbol(self):
metadata = self.m2
def include_symbol(table_name, schema_name):
- return table_name in ('extra', 'order')
+ return table_name in ("extra", "order")
context = MigrationContext.configure(
connection=self.bind.connect(),
opts={
- 'compare_type': True,
- 'compare_server_default': True,
- 'include_symbol': include_symbol,
- }
+ "compare_type": True,
+ "compare_server_default": True,
+ "include_symbol": include_symbol,
+ },
)
diffs = autogenerate.compare_metadata(context, metadata)
- eq_(diffs[0][0], 'remove_table')
+ eq_(diffs[0][0], "remove_table")
eq_(diffs[0][1].name, "extra")
eq_(diffs[1][0], "add_column")
eq_(diffs[1][1], None)
eq_(diffs[1][2], "order")
- eq_(diffs[1][3], metadata.tables['order'].c.user_id)
+ eq_(diffs[1][3], metadata.tables["order"].c.user_id)
eq_(diffs[2][0][0], "modify_type")
eq_(diffs[2][0][1], None)
@@ -1106,15 +1169,14 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
eq_(repr(diffs[2][0][5]), "NUMERIC(precision=8, scale=2)")
eq_(repr(diffs[2][0][6]), "Numeric(precision=10, scale=2)")
- eq_(diffs[2][1][0], 'modify_nullable')
- eq_(diffs[2][1][2], 'order')
+ eq_(diffs[2][1][0], "modify_nullable")
+ eq_(diffs[2][1][2], "order")
eq_(diffs[2][1][5], False)
eq_(diffs[2][1][6], True)
def test_compare_metadata_as_sql(self):
context = MigrationContext.configure(
- connection=self.bind.connect(),
- opts={'as_sql': True}
+ connection=self.bind.connect(), opts={"as_sql": True}
)
metadata = self.m2
@@ -1122,12 +1184,14 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
CommandError,
"autogenerate can't use as_sql=True as it prevents "
"querying the database for schema information",
- autogenerate.compare_metadata, context, metadata
+ autogenerate.compare_metadata,
+ context,
+ metadata,
)
class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
schema = "test_schema"
@@ -1135,26 +1199,20 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
metadata = self.m2
context = MigrationContext.configure(
- connection=self.bind.connect(),
- opts={
- "include_schemas": True
- }
+ connection=self.bind.connect(), opts={"include_schemas": True}
)
diffs = autogenerate.compare_metadata(context, metadata)
- eq_(
- diffs[0],
- ('add_table', metadata.tables['test_schema.item'])
- )
+ eq_(diffs[0], ("add_table", metadata.tables["test_schema.item"]))
- eq_(diffs[1][0], 'remove_table')
+ eq_(diffs[1][0], "remove_table")
eq_(diffs[1][1].name, "extra")
eq_(diffs[2][0], "add_column")
eq_(diffs[2][1], "test_schema")
eq_(diffs[2][2], "address")
- eq_(diffs[2][3], metadata.tables['test_schema.address'].c.street)
+ eq_(diffs[2][3], metadata.tables["test_schema.address"].c.street)
eq_(diffs[3][0], "add_constraint")
eq_(diffs[3][1].name, "uq_email")
@@ -1162,27 +1220,25 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
eq_(diffs[4][0], "add_column")
eq_(diffs[4][1], "test_schema")
eq_(diffs[4][2], "order")
- eq_(diffs[4][3], metadata.tables['test_schema.order'].c.user_id)
+ eq_(diffs[4][3], metadata.tables["test_schema.order"].c.user_id)
- eq_(diffs[5][0][0], 'modify_nullable')
+ eq_(diffs[5][0][0], "modify_nullable")
eq_(diffs[5][0][5], False)
eq_(diffs[5][0][6], True)
+
class OrigObjectTest(TestBase):
def setUp(self):
self.metadata = m = MetaData()
t = Table(
- 't', m,
- Column('id', Integer(), primary_key=True),
- Column('x', Integer())
- )
- self.ix = Index('ix1', t.c.id)
- fk = ForeignKeyConstraint(['t_id'], ['t.id'])
- q = Table(
- 'q', m,
- Column('t_id', Integer()),
- fk
+ "t",
+ m,
+ Column("id", Integer(), primary_key=True),
+ Column("x", Integer()),
)
+ self.ix = Index("ix1", t.c.id)
+ fk = ForeignKeyConstraint(["t_id"], ["t.id"])
+ q = Table("q", m, Column("t_id", Integer()), fk)
self.table = t
self.fk = fk
self.ck = CheckConstraint(t.c.x > 5)
@@ -1232,10 +1288,10 @@ class OrigObjectTest(TestBase):
is_not_(None, op.to_constraint().table)
def test_add_pk_no_orig(self):
- op = ops.CreatePrimaryKeyOp('pk1', 't', ['x', 'y'])
+ op = ops.CreatePrimaryKeyOp("pk1", "t", ["x", "y"])
pk = op.to_constraint()
- eq_(pk.name, 'pk1')
- eq_(pk.table.name, 't')
+ eq_(pk.name, "pk1")
+ eq_(pk.table.name, "t")
def test_add_pk(self):
pk = self.pk
@@ -1254,7 +1310,7 @@ class OrigObjectTest(TestBase):
def test_drop_column(self):
t = self.table
- op = ops.DropColumnOp.from_column_and_tablename(None, 't', t.c.x)
+ op = ops.DropColumnOp.from_column_and_tablename(None, "t", t.c.x)
is_(op.to_column(), t.c.x)
is_(op.reverse().to_column(), t.c.x)
is_not_(None, op.to_column().table)
@@ -1262,7 +1318,7 @@ class OrigObjectTest(TestBase):
def test_add_column(self):
t = self.table
- op = ops.AddColumnOp.from_column_and_tablename(None, 't', t.c.x)
+ op = ops.AddColumnOp.from_column_and_tablename(None, "t", t.c.x)
is_(op.to_column(), t.c.x)
is_(op.reverse().to_column(), t.c.x)
is_not_(None, op.to_column().table)
@@ -1304,25 +1360,33 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
m2b = MetaData()
m2c = MetaData()
- Table('a', m1a, Column('id', Integer, primary_key=True))
- Table('b1', m1b, Column('id', Integer, primary_key=True))
- Table('b2', m1b, Column('id', Integer, primary_key=True))
- Table('c1', m1c, Column('id', Integer, primary_key=True),
- Column('x', Integer))
+ Table("a", m1a, Column("id", Integer, primary_key=True))
+ Table("b1", m1b, Column("id", Integer, primary_key=True))
+ Table("b2", m1b, Column("id", Integer, primary_key=True))
+ Table(
+ "c1",
+ m1c,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ )
- a = Table('a', m2a, Column('id', Integer, primary_key=True),
- Column('q', Integer))
- Table('b1', m2b, Column('id', Integer, primary_key=True))
- Table('c1', m2c, Column('id', Integer, primary_key=True))
- c2 = Table('c2', m2c, Column('id', Integer, primary_key=True))
+ a = Table(
+ "a",
+ m2a,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer),
+ )
+ Table("b1", m2b, Column("id", Integer, primary_key=True))
+ Table("c1", m2c, Column("id", Integer, primary_key=True))
+ c2 = Table("c2", m2c, Column("id", Integer, primary_key=True))
diffs = self._fixture([m1a, m1b, m1c], [m2a, m2b, m2c])
- eq_(diffs[0], ('add_table', c2))
- eq_(diffs[1][0], 'remove_table')
- eq_(diffs[1][1].name, 'b2')
- eq_(diffs[2], ('add_column', None, 'a', a.c.q))
- eq_(diffs[3][0:3], ('remove_column', None, 'c1'))
- eq_(diffs[3][3].name, 'x')
+ eq_(diffs[0], ("add_table", c2))
+ eq_(diffs[1][0], "remove_table")
+ eq_(diffs[1][1].name, "b2")
+ eq_(diffs[2], ("add_column", None, "a", a.c.q))
+ eq_(diffs[3][0:3], ("remove_column", None, "c1"))
+ eq_(diffs[3][3].name, "x")
def test_empty_list(self):
# because they're going to do it....
@@ -1339,18 +1403,19 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
m2a = MetaData()
m2b = MetaData()
- Table('a', m1a, Column('id', Integer, primary_key=True))
- Table('b', m1b, Column('id', Integer, primary_key=True))
+ Table("a", m1a, Column("id", Integer, primary_key=True))
+ Table("b", m1b, Column("id", Integer, primary_key=True))
- Table('a', m2a, Column('id', Integer, primary_key=True))
- b = Table('b', m2b, Column('id', Integer, primary_key=True),
- Column('q', Integer))
+ Table("a", m2a, Column("id", Integer, primary_key=True))
+ b = Table(
+ "b",
+ m2b,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer),
+ )
diffs = self._fixture((m1a, m1b), (m2a, m2b))
- eq_(
- diffs,
- [('add_column', None, 'b', b.c.q)]
- )
+ eq_(diffs, [("add_column", None, "b", b.c.q)])
def test_raise_on_dupe(self):
m1a = MetaData()
@@ -1359,116 +1424,123 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
m2a = MetaData()
m2b = MetaData()
- Table('a', m1a, Column('id', Integer, primary_key=True))
- Table('b1', m1b, Column('id', Integer, primary_key=True))
- Table('b2', m1b, Column('id', Integer, primary_key=True))
- Table('b3', m1b, Column('id', Integer, primary_key=True))
+ Table("a", m1a, Column("id", Integer, primary_key=True))
+ Table("b1", m1b, Column("id", Integer, primary_key=True))
+ Table("b2", m1b, Column("id", Integer, primary_key=True))
+ Table("b3", m1b, Column("id", Integer, primary_key=True))
- Table('a', m2a, Column('id', Integer, primary_key=True))
- Table('a', m2b, Column('id', Integer, primary_key=True))
- Table('b1', m2b, Column('id', Integer, primary_key=True))
- Table('b2', m2a, Column('id', Integer, primary_key=True))
- Table('b2', m2b, Column('id', Integer, primary_key=True))
+ Table("a", m2a, Column("id", Integer, primary_key=True))
+ Table("a", m2b, Column("id", Integer, primary_key=True))
+ Table("b1", m2b, Column("id", Integer, primary_key=True))
+ Table("b2", m2a, Column("id", Integer, primary_key=True))
+ Table("b2", m2b, Column("id", Integer, primary_key=True))
assert_raises_message(
ValueError,
'Duplicate table keys across multiple MetaData objects: "a", "b2"',
self._fixture,
- [m1a, m1b], [m2a, m2b]
+ [m1a, m1b],
+ [m2a, m2b],
)
class AutoincrementTest(AutogenFixtureTest, TestBase):
__backend__ = True
- __requires__ = 'integer_subtype_comparisons',
+ __requires__ = ("integer_subtype_comparisons",)
def test_alter_column_autoincrement_none(self):
m1 = MetaData()
m2 = MetaData()
- Table('a', m1, Column('x', Integer, nullable=False))
- Table('a', m2, Column('x', Integer, nullable=True))
+ Table("a", m1, Column("x", Integer, nullable=False))
+ Table("a", m2, Column("x", Integer, nullable=True))
ops = self._fixture(m1, m2, return_ops=True)
- assert 'autoincrement' not in ops.ops[0].ops[0].kw
+ assert "autoincrement" not in ops.ops[0].ops[0].kw
def test_alter_column_autoincrement_pk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
- 'a', m1,
- Column('x', Integer, primary_key=True, autoincrement=False))
+ "a",
+ m1,
+ Column("x", Integer, primary_key=True, autoincrement=False),
+ )
Table(
- 'a', m2,
- Column('x', BigInteger, primary_key=True, autoincrement=False))
+ "a",
+ m2,
+ Column("x", BigInteger, primary_key=True, autoincrement=False),
+ )
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], False)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_pk_implicit_true(self):
m1 = MetaData()
m2 = MetaData()
- Table(
- 'a', m1,
- Column('x', Integer, primary_key=True))
- Table(
- 'a', m2,
- Column('x', BigInteger, primary_key=True))
+ Table("a", m1, Column("x", Integer, primary_key=True))
+ Table("a", m2, Column("x", BigInteger, primary_key=True))
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_pk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
- 'a', m1,
- Column('x', Integer, primary_key=True, autoincrement=True))
+ "a", m1, Column("x", Integer, primary_key=True, autoincrement=True)
+ )
Table(
- 'a', m2,
- Column('x', BigInteger, primary_key=True, autoincrement=True))
+ "a",
+ m2,
+ Column("x", BigInteger, primary_key=True, autoincrement=True),
+ )
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_nonpk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
- 'a', m1,
- Column('id', Integer, primary_key=True),
- Column('x', Integer, autoincrement=False)
+ "a",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer, autoincrement=False),
)
Table(
- 'a', m2,
- Column('id', Integer, primary_key=True),
- Column('x', BigInteger, autoincrement=False)
+ "a",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("x", BigInteger, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], False)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_nonpk_implicit_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
- 'a', m1,
- Column('id', Integer, primary_key=True),
- Column('x', Integer)
+ "a",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
)
Table(
- 'a', m2,
- Column('id', Integer, primary_key=True),
- Column('x', BigInteger)
+ "a",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("x", BigInteger),
)
ops = self._fixture(m1, m2, return_ops=True)
- assert 'autoincrement' not in ops.ops[0].ops[0].kw
+ assert "autoincrement" not in ops.ops[0].ops[0].kw
@config.requirements.fail_before_sqla_110
def test_alter_column_autoincrement_nonpk_explicit_true(self):
@@ -1476,54 +1548,60 @@ class AutoincrementTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
Table(
- 'a', m1,
- Column('id', Integer, primary_key=True),
- Column('x', Integer, autoincrement=True)
+ "a",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer, autoincrement=True),
)
Table(
- 'a', m2,
- Column('id', Integer, primary_key=True),
- Column('x', BigInteger, autoincrement=True)
+ "a",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("x", BigInteger, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_compositepk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
- 'a', m1,
- Column('id', Integer, primary_key=True),
- Column('x', Integer, primary_key=True, autoincrement=False)
+ "a",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer, primary_key=True, autoincrement=False),
)
Table(
- 'a', m2,
- Column('id', Integer, primary_key=True),
- Column('x', BigInteger, primary_key=True, autoincrement=False)
+ "a",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("x", BigInteger, primary_key=True, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], False)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_compositepk_implicit_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
- 'a', m1,
- Column('id', Integer, primary_key=True),
- Column('x', Integer, primary_key=True)
+ "a",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer, primary_key=True),
)
Table(
- 'a', m2,
- Column('id', Integer, primary_key=True),
- Column('x', BigInteger, primary_key=True)
+ "a",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("x", BigInteger, primary_key=True),
)
ops = self._fixture(m1, m2, return_ops=True)
- assert 'autoincrement' not in ops.ops[0].ops[0].kw
+ assert "autoincrement" not in ops.ops[0].ops[0].kw
@config.requirements.autoincrement_on_composite_pk
def test_alter_column_autoincrement_compositepk_explicit_true(self):
@@ -1531,20 +1609,22 @@ class AutoincrementTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
Table(
- 'a', m1,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('x', Integer, primary_key=True, autoincrement=True),
+ "a",
+ m1,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("x", Integer, primary_key=True, autoincrement=True),
# on SQLA 1.0 and earlier, this being present
# trips the "add KEY for the primary key" so that the
# AUTO_INCREMENT keyword is accepted by MySQL. SQLA 1.1 and
# greater the columns are just reorganized.
- mysql_engine='InnoDB'
+ mysql_engine="InnoDB",
)
Table(
- 'a', m2,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('x', BigInteger, primary_key=True, autoincrement=True)
+ "a",
+ m2,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("x", BigInteger, primary_key=True, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
- is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+ is_(ops.ops[0].ops[0].kw["autoincrement"], True)
diff --git a/tests/test_autogen_fks.py b/tests/test_autogen_fks.py
index 3dd66ae..66c5ac4 100644
--- a/tests/test_autogen_fks.py
+++ b/tests/test_autogen_fks.py
@@ -1,8 +1,14 @@
import sys
from alembic.testing import TestBase, config, mock
-from sqlalchemy import MetaData, Column, Table, Integer, String, \
- ForeignKeyConstraint
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Integer,
+ String,
+ ForeignKeyConstraint,
+)
from alembic.testing import eq_
py3k = sys.version_info.major >= 3
@@ -17,105 +23,141 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('test', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('test2', String(10)),
- ForeignKeyConstraint(['test2'], ['some_table.test']),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('test', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('test2', String(10)),
- mysql_engine='InnoDB'
- )
+ Table(
+ "some_table",
+ m1,
+ Column("test", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("test2", String(10)),
+ ForeignKeyConstraint(["test2"], ["some_table.test"]),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("test", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("test2", String(10)),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ['test2'],
- 'some_table', ['test'],
- conditional_name="servergenerated"
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["test2"],
+ "some_table",
+ ["test"],
+ conditional_name="servergenerated",
)
def test_add_fk(self):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id', Integer, primary_key=True),
- Column('test', String(10)),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('test2', String(10)),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id', Integer, primary_key=True),
- Column('test', String(10)),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('test2', String(10)),
- ForeignKeyConstraint(['test2'], ['some_table.test']),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("test", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("test2", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("test", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("test2", String(10)),
+ ForeignKeyConstraint(["test2"], ["some_table.test"]),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
self._assert_fk_diff(
- diffs[0], "add_fk",
- "user", ["test2"],
- "some_table", ["test"]
+ diffs[0], "add_fk", "user", ["test2"], "some_table", ["test"]
)
def test_no_change(self):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id', Integer, primary_key=True),
- Column('test', String(10)),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('test2', Integer),
- ForeignKeyConstraint(['test2'], ['some_table.id']),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id', Integer, primary_key=True),
- Column('test', String(10)),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('test2', Integer),
- ForeignKeyConstraint(['test2'], ['some_table.id']),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("test", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("test2", Integer),
+ ForeignKeyConstraint(["test2"], ["some_table.id"]),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("test", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("test2", Integer),
+ ForeignKeyConstraint(["test2"], ["some_table.id"]),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
@@ -125,36 +167,51 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- ForeignKeyConstraint(['other_id_1', 'other_id_2'],
- ['some_table.id_1', 'some_table.id_2']),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB'
- )
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- ForeignKeyConstraint(['other_id_1', 'other_id_2'],
- ['some_table.id_1', 'some_table.id_2']),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["some_table.id_1", "some_table.id_2"],
+ ),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["some_table.id_1", "some_table.id_2"],
+ ),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
@@ -164,42 +221,59 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- ForeignKeyConstraint(['other_id_1', 'other_id_2'],
- ['some_table.id_1', 'some_table.id_2'],
- name='fk_test_name'),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["some_table.id_1", "some_table.id_2"],
+ name="fk_test_name",
+ ),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
self._assert_fk_diff(
- diffs[0], "add_fk",
- "user", ['other_id_1', 'other_id_2'],
- 'some_table', ['id_1', 'id_2'],
- name="fk_test_name"
+ diffs[0],
+ "add_fk",
+ "user",
+ ["other_id_1", "other_id_2"],
+ "some_table",
+ ["id_1", "id_2"],
+ name="fk_test_name",
)
@config.requirements.no_name_normalize
@@ -207,111 +281,160 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- ForeignKeyConstraint(['other_id_1', 'other_id_2'],
- ['some_table.id_1', 'some_table.id_2'],
- name='fk_test_name'),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["some_table.id_1", "some_table.id_2"],
+ name="fk_test_name",
+ ),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ['other_id_1', 'other_id_2'],
- "some_table", ['id_1', 'id_2'],
- conditional_name="fk_test_name"
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["other_id_1", "other_id_2"],
+ "some_table",
+ ["id_1", "id_2"],
+ conditional_name="fk_test_name",
)
def test_add_fk_colkeys(self):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id_1', String(10), key='tid1', primary_key=True),
- Column('id_2', String(10), key='tid2', primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('other_id_1', String(10), key='oid1'),
- Column('other_id_2', String(10), key='oid2'),
- ForeignKeyConstraint(['oid1', 'oid2'],
- ['some_table.tid1', 'some_table.tid2'],
- name='fk_test_name'),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id_1", String(10), key="tid1", primary_key=True),
+ Column("id_2", String(10), key="tid2", primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("other_id_1", String(10), key="oid1"),
+ Column("other_id_2", String(10), key="oid2"),
+ ForeignKeyConstraint(
+ ["oid1", "oid2"],
+ ["some_table.tid1", "some_table.tid2"],
+ name="fk_test_name",
+ ),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
self._assert_fk_diff(
- diffs[0], "add_fk",
- "user", ['other_id_1', 'other_id_2'],
- 'some_table', ['id_1', 'id_2'],
- name="fk_test_name"
+ diffs[0],
+ "add_fk",
+ "user",
+ ["other_id_1", "other_id_2"],
+ "some_table",
+ ["id_1", "id_2"],
+ name="fk_test_name",
)
def test_no_change_colkeys(self):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id_1', String(10), primary_key=True),
- Column('id_2', String(10), primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('other_id_1', String(10)),
- Column('other_id_2', String(10)),
- ForeignKeyConstraint(['other_id_1', 'other_id_2'],
- ['some_table.id_1', 'some_table.id_2']),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id_1', String(10), key='tid1', primary_key=True),
- Column('id_2', String(10), key='tid2', primary_key=True),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('other_id_1', String(10), key='oid1'),
- Column('other_id_2', String(10), key='oid2'),
- ForeignKeyConstraint(['oid1', 'oid2'],
- ['some_table.tid1', 'some_table.tid2']),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id_1", String(10), primary_key=True),
+ Column("id_2", String(10), primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("other_id_1", String(10)),
+ Column("other_id_2", String(10)),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["some_table.id_1", "some_table.id_2"],
+ ),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id_1", String(10), key="tid1", primary_key=True),
+ Column("id_2", String(10), key="tid2", primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("other_id_1", String(10), key="oid1"),
+ Column("other_id_2", String(10), key="oid2"),
+ ForeignKeyConstraint(
+ ["oid1", "oid2"], ["some_table.tid1", "some_table.tid2"]
+ ),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
@@ -320,7 +443,7 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
class IncludeHooksTest(AutogenFixtureTest, TestBase):
__backend__ = True
- __requires__ = 'fk_names',
+ __requires__ = ("fk_names",)
@config.requirements.no_name_normalize
def test_remove_connection_fk(self):
@@ -328,11 +451,18 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
ref = Table(
- 'ref', m1, Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ "ref",
+ m1,
+ Column("id", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
t1 = Table(
- 't', m1, Column('x', Integer), Column('y', Integer),
- mysql_engine='InnoDB')
+ "t",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ mysql_engine="InnoDB",
+ )
t1.append_constraint(
ForeignKeyConstraint([t1.c.x], [ref.c.id], name="fk1")
)
@@ -341,24 +471,37 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
)
ref = Table(
- 'ref', m2, Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ "ref",
+ m2,
+ Column("id", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
Table(
- 't', m2, Column('x', Integer), Column('y', Integer),
- mysql_engine='InnoDB')
+ "t",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ mysql_engine="InnoDB",
+ )
def include_object(object_, name, type_, reflected, compare_to):
return not (
- isinstance(object_, ForeignKeyConstraint) and
- type_ == 'foreign_key_constraint'
- and reflected and name == 'fk1')
+ isinstance(object_, ForeignKeyConstraint)
+ and type_ == "foreign_key_constraint"
+ and reflected
+ and name == "fk1"
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
self._assert_fk_diff(
- diffs[0], "remove_fk",
- 't', ['y'], 'ref', ['id'],
- conditional_name='fk2'
+ diffs[0],
+ "remove_fk",
+ "t",
+ ["y"],
+ "ref",
+ ["id"],
+ conditional_name="fk2",
)
eq_(len(diffs), 1)
@@ -367,18 +510,32 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
Table(
- 'ref', m1,
- Column('id', Integer, primary_key=True), mysql_engine='InnoDB')
+ "ref",
+ m1,
+ Column("id", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
Table(
- 't', m1,
- Column('x', Integer), Column('y', Integer), mysql_engine='InnoDB')
+ "t",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ mysql_engine="InnoDB",
+ )
ref = Table(
- 'ref', m2, Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ "ref",
+ m2,
+ Column("id", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
t2 = Table(
- 't', m2, Column('x', Integer), Column('y', Integer),
- mysql_engine='InnoDB')
+ "t",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ mysql_engine="InnoDB",
+ )
t2.append_constraint(
ForeignKeyConstraint([t2.c.x], [ref.c.id], name="fk1")
)
@@ -388,16 +545,16 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
def include_object(object_, name, type_, reflected, compare_to):
return not (
- isinstance(object_, ForeignKeyConstraint) and
- type_ == 'foreign_key_constraint'
- and not reflected and name == 'fk1')
+ isinstance(object_, ForeignKeyConstraint)
+ and type_ == "foreign_key_constraint"
+ and not reflected
+ and name == "fk1"
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
self._assert_fk_diff(
- diffs[0], "add_fk",
- 't', ['y'], 'ref', ['id'],
- name='fk2'
+ diffs[0], "add_fk", "t", ["y"], "ref", ["id"], name="fk2"
)
eq_(len(diffs), 1)
@@ -407,20 +564,26 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
r1a = Table(
- 'ref_a', m1,
- Column('a', Integer, primary_key=True),
- mysql_engine='InnoDB'
+ "ref_a",
+ m1,
+ Column("a", Integer, primary_key=True),
+ mysql_engine="InnoDB",
)
Table(
- 'ref_b', m1,
- Column('a', Integer, primary_key=True),
- Column('b', Integer, primary_key=True),
- mysql_engine='InnoDB'
+ "ref_b",
+ m1,
+ Column("a", Integer, primary_key=True),
+ Column("b", Integer, primary_key=True),
+ mysql_engine="InnoDB",
)
t1 = Table(
- 't', m1, Column('x', Integer),
- Column('y', Integer), Column('z', Integer),
- mysql_engine='InnoDB')
+ "t",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
+ mysql_engine="InnoDB",
+ )
t1.append_constraint(
ForeignKeyConstraint([t1.c.x], [r1a.c.a], name="fk1")
)
@@ -429,82 +592,104 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
)
Table(
- 'ref_a', m2,
- Column('a', Integer, primary_key=True),
- mysql_engine='InnoDB'
+ "ref_a",
+ m2,
+ Column("a", Integer, primary_key=True),
+ mysql_engine="InnoDB",
)
r2b = Table(
- 'ref_b', m2,
- Column('a', Integer, primary_key=True),
- Column('b', Integer, primary_key=True),
- mysql_engine='InnoDB'
+ "ref_b",
+ m2,
+ Column("a", Integer, primary_key=True),
+ Column("b", Integer, primary_key=True),
+ mysql_engine="InnoDB",
)
t2 = Table(
- 't', m2, Column('x', Integer),
- Column('y', Integer), Column('z', Integer),
- mysql_engine='InnoDB')
+ "t",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
+ mysql_engine="InnoDB",
+ )
t2.append_constraint(
ForeignKeyConstraint(
- [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1")
+ [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1"
+ )
)
t2.append_constraint(
ForeignKeyConstraint(
- [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2")
+ [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2"
+ )
)
def include_object(object_, name, type_, reflected, compare_to):
return not (
- isinstance(object_, ForeignKeyConstraint) and
- type_ == 'foreign_key_constraint'
- and name == 'fk1'
+ isinstance(object_, ForeignKeyConstraint)
+ and type_ == "foreign_key_constraint"
+ and name == "fk1"
)
diffs = self._fixture(m1, m2, object_filters=include_object)
self._assert_fk_diff(
- diffs[0], "remove_fk",
- 't', ['y'], 'ref_a', ['a'],
- name='fk2'
+ diffs[0], "remove_fk", "t", ["y"], "ref_a", ["a"], name="fk2"
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- 't', ['y', 'z'], 'ref_b', ['a', 'b'],
- name='fk2'
+ diffs[1],
+ "add_fk",
+ "t",
+ ["y", "z"],
+ "ref_b",
+ ["a", "b"],
+ name="fk2",
)
eq_(len(diffs), 2)
class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
__backend__ = True
- __requires__ = ('flexible_fk_cascades', )
+ __requires__ = ("flexible_fk_cascades",)
def _fk_opts_fixture(self, old_opts, new_opts):
m1 = MetaData()
m2 = MetaData()
- Table('some_table', m1,
- Column('id', Integer, primary_key=True),
- Column('test', String(10)),
- mysql_engine='InnoDB')
-
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('tid', Integer),
- ForeignKeyConstraint(['tid'], ['some_table.id'], **old_opts),
- mysql_engine='InnoDB')
-
- Table('some_table', m2,
- Column('id', Integer, primary_key=True),
- Column('test', String(10)),
- mysql_engine='InnoDB')
-
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('tid', Integer),
- ForeignKeyConstraint(['tid'], ['some_table.id'], **new_opts),
- mysql_engine='InnoDB')
+ Table(
+ "some_table",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("test", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ ForeignKeyConstraint(["tid"], ["some_table.id"], **old_opts),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "some_table",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("test", String(10)),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ ForeignKeyConstraint(["tid"], ["some_table.id"], **new_opts),
+ mysql_engine="InnoDB",
+ )
return self._fixture(m1, m2)
@@ -521,47 +706,55 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
return True
def test_add_ondelete(self):
- diffs = self._fk_opts_fixture(
- {}, {"ondelete": "cascade"}
- )
+ diffs = self._fk_opts_fixture({}, {"ondelete": "cascade"})
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
ondelete=None,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- ondelete="cascade"
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ ondelete="cascade",
)
else:
eq_(diffs, [])
def test_remove_ondelete(self):
- diffs = self._fk_opts_fixture(
- {"ondelete": "CASCADE"}, {}
- )
+ diffs = self._fk_opts_fixture({"ondelete": "CASCADE"}, {})
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
ondelete="CASCADE",
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- ondelete=None
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ ondelete=None,
)
else:
eq_(diffs, [])
@@ -574,47 +767,55 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
eq_(diffs, [])
def test_add_onupdate(self):
- diffs = self._fk_opts_fixture(
- {}, {"onupdate": "cascade"}
- )
+ diffs = self._fk_opts_fixture({}, {"onupdate": "cascade"})
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate=None,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- onupdate="cascade"
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ onupdate="cascade",
)
else:
eq_(diffs, [])
def test_remove_onupdate(self):
- diffs = self._fk_opts_fixture(
- {"onupdate": "CASCADE"}, {}
- )
+ diffs = self._fk_opts_fixture({"onupdate": "CASCADE"}, {})
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate="CASCADE",
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- onupdate=None
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ onupdate=None,
)
else:
eq_(diffs, [])
@@ -668,20 +869,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
)
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate=None,
ondelete=mock.ANY, # MySQL reports None, PG reports RESTRICT
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate=None,
- ondelete="cascade"
+ ondelete="cascade",
)
else:
eq_(diffs, [])
@@ -696,20 +903,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
)
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate=mock.ANY, # MySQL reports None, PG reports RESTRICT
ondelete=None,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate="cascade",
- ondelete=None
+ ondelete=None,
)
else:
eq_(diffs, [])
@@ -717,70 +930,84 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
def test_ondelete_onupdate_combo(self):
diffs = self._fk_opts_fixture(
{"onupdate": "CASCADE", "ondelete": "SET NULL"},
- {"onupdate": "RESTRICT", "ondelete": "RESTRICT"}
+ {"onupdate": "RESTRICT", "ondelete": "RESTRICT"},
)
if self._expect_opts_supported():
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate="CASCADE",
ondelete="SET NULL",
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
onupdate="RESTRICT",
- ondelete="RESTRICT"
+ ondelete="RESTRICT",
)
else:
eq_(diffs, [])
@config.requirements.fk_initially
def test_add_initially_deferred(self):
- diffs = self._fk_opts_fixture(
- {}, {"initially": "deferred"}
- )
+ diffs = self._fk_opts_fixture({}, {"initially": "deferred"})
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
initially=None,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- initially="deferred"
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ initially="deferred",
)
@config.requirements.fk_initially
def test_remove_initially_deferred(self):
- diffs = self._fk_opts_fixture(
- {"initially": "deferred"}, {}
- )
+ diffs = self._fk_opts_fixture({"initially": "deferred"}, {})
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
initially="DEFERRED",
deferrable=True,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- initially=None
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ initially=None,
)
@config.requirements.fk_deferrable
@@ -791,19 +1018,25 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
)
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
initially=None,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
initially="immediate",
- deferrable=True
+ deferrable=True,
)
@config.requirements.fk_deferrable
@@ -814,20 +1047,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
)
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
initially=None, # immediate is the default
deferrable=True,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
initially=None,
- deferrable=None
+ deferrable=None,
)
@config.requirements.fk_initially
@@ -835,7 +1074,7 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
def test_add_initially_deferrable_nochange_one(self):
diffs = self._fk_opts_fixture(
{"deferrable": True, "initially": "immediate"},
- {"deferrable": True, "initially": "immediate"}
+ {"deferrable": True, "initially": "immediate"},
)
eq_(diffs, [])
@@ -845,7 +1084,7 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
def test_add_initially_deferrable_nochange_two(self):
diffs = self._fk_opts_fixture(
{"deferrable": True, "initially": "deferred"},
- {"deferrable": True, "initially": "deferred"}
+ {"deferrable": True, "initially": "deferred"},
)
eq_(diffs, [])
@@ -855,49 +1094,57 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
def test_add_initially_deferrable_nochange_three(self):
diffs = self._fk_opts_fixture(
{"deferrable": None, "initially": "deferred"},
- {"deferrable": None, "initially": "deferred"}
+ {"deferrable": None, "initially": "deferred"},
)
eq_(diffs, [])
@config.requirements.fk_deferrable
def test_add_deferrable(self):
- diffs = self._fk_opts_fixture(
- {}, {"deferrable": True}
- )
+ diffs = self._fk_opts_fixture({}, {"deferrable": True})
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
deferrable=None,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- deferrable=True
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ deferrable=True,
)
@config.requirements.fk_deferrable
def test_remove_deferrable(self):
- diffs = self._fk_opts_fixture(
- {"deferrable": True}, {}
- )
+ diffs = self._fk_opts_fixture({"deferrable": True}, {})
self._assert_fk_diff(
- diffs[0], "remove_fk",
- "user", ["tid"],
- "some_table", ["id"],
+ diffs[0],
+ "remove_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
deferrable=True,
- conditional_name="servergenerated"
+ conditional_name="servergenerated",
)
self._assert_fk_diff(
- diffs[1], "add_fk",
- "user", ["tid"],
- "some_table", ["id"],
- deferrable=None
+ diffs[1],
+ "add_fk",
+ "user",
+ ["tid"],
+ "some_table",
+ ["id"],
+ deferrable=None,
)
diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py
index b588cbe..f03155f 100644
--- a/tests/test_autogen_indexes.py
+++ b/tests/test_autogen_indexes.py
@@ -3,14 +3,24 @@ from alembic.testing import TestBase
from alembic.testing import config
from alembic.testing import assertions
-from sqlalchemy import MetaData, Column, Table, Integer, String, \
- Numeric, UniqueConstraint, Index, ForeignKeyConstraint,\
- ForeignKey, func
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Integer,
+ String,
+ Numeric,
+ UniqueConstraint,
+ Index,
+ ForeignKeyConstraint,
+ ForeignKey,
+ func,
+)
from alembic.testing import engines
from alembic.testing import eq_
from alembic.testing.env import staging_env
-py3k = sys.version_info >= (3, )
+py3k = sys.version_info >= (3,)
from ._autogen_fixtures import AutogenFixtureTest
@@ -24,6 +34,7 @@ class NoUqReflection(object):
def unimpl(*arg, **kw):
raise NotImplementedError()
+
eng.dialect.get_unique_constraints = unimpl
def test_add_ix_on_table_create(self):
@@ -37,25 +48,29 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
reports_unique_constraints = True
reports_unique_constraints_as_indexes = False
- __requires__ = ('unique_constraint_reflection', )
- __only_on__ = 'sqlite'
+ __requires__ = ("unique_constraint_reflection",)
+ __only_on__ = "sqlite"
def test_index_flag_becomes_named_unique_constraint(self):
m1 = MetaData()
m2 = MetaData()
- Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False, index=True),
- Column('a1', String(10), server_default="x")
- )
+ Table(
+ "user",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False, index=True),
+ Column("a1", String(10), server_default="x"),
+ )
- Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- UniqueConstraint("name", name="uq_user_name")
- )
+ Table(
+ "user",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("a1", String(10), server_default="x"),
+ UniqueConstraint("name", name="uq_user_name"),
+ )
diffs = self._fixture(m1, m2)
@@ -72,17 +87,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_add_unique_constraint(self):
m1 = MetaData()
m2 = MetaData()
- Table('address', m1,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('qpr', String(10), index=True),
- )
- Table('address', m2,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('qpr', String(10), index=True),
- UniqueConstraint("email_address", name="uq_email_address")
- )
+ Table(
+ "address",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("email_address", String(100), nullable=False),
+ Column("qpr", String(10), index=True),
+ )
+ Table(
+ "address",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("email_address", String(100), nullable=False),
+ Column("qpr", String(10), index=True),
+ UniqueConstraint("email_address", name="uq_email_address"),
+ )
diffs = self._fixture(m1, m2)
@@ -96,17 +115,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('unq_idx', m1,
- Column('id', Integer, primary_key=True),
- Column('x', String(20)),
- Index('x', 'x', unique=True)
- )
+ Table(
+ "unq_idx",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", String(20)),
+ Index("x", "x", unique=True),
+ )
- Table('unq_idx', m2,
- Column('id', Integer, primary_key=True),
- Column('x', String(20)),
- Index('x', 'x', unique=True)
- )
+ Table(
+ "unq_idx",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("x", String(20)),
+ Index("x", "x", unique=True),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -114,27 +137,31 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_index_becomes_unique(self):
m1 = MetaData()
m2 = MetaData()
- Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- UniqueConstraint('order_id', 'user_id',
- name='order_order_id_user_id_unique'
- ),
- Index('order_user_id_amount_idx', 'user_id', 'amount')
- )
-
- Table('order', m2,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- UniqueConstraint('order_id', 'user_id',
- name='order_order_id_user_id_unique'
- ),
- Index(
- 'order_user_id_amount_idx', 'user_id',
- 'amount', unique=True),
- )
+ Table(
+ "order",
+ m1,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ UniqueConstraint(
+ "order_id", "user_id", name="order_order_id_user_id_unique"
+ ),
+ Index("order_user_id_amount_idx", "user_id", "amount"),
+ )
+
+ Table(
+ "order",
+ m2,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ UniqueConstraint(
+ "order_id", "user_id", name="order_order_id_user_id_unique"
+ ),
+ Index(
+ "order_user_id_amount_idx", "user_id", "amount", unique=True
+ ),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_index")
@@ -148,16 +175,16 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_mismatch_db_named_col_flag(self):
m1 = MetaData()
m2 = MetaData()
- Table('item', m1,
- Column('x', Integer),
- UniqueConstraint('x', name="db_generated_name")
- )
+ Table(
+ "item",
+ m1,
+ Column("x", Integer),
+ UniqueConstraint("x", name="db_generated_name"),
+ )
# test mismatch between unique=True and
# named uq constraint
- Table('item', m2,
- Column('x', Integer, unique=True)
- )
+ Table("item", m2, Column("x", Integer, unique=True))
diffs = self._fixture(m1, m2)
@@ -166,11 +193,13 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_new_table_added(self):
m1 = MetaData()
m2 = MetaData()
- Table('extra', m2,
- Column('foo', Integer, index=True),
- Column('bar', Integer),
- Index('newtable_idx', 'bar')
- )
+ Table(
+ "extra",
+ m2,
+ Column("foo", Integer, index=True),
+ Column("bar", Integer),
+ Index("newtable_idx", "bar"),
+ )
diffs = self._fixture(m1, m2)
@@ -185,16 +214,20 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_named_cols_changed(self):
m1 = MetaData()
m2 = MetaData()
- Table('col_change', m1,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x', name="nochange")
- )
- Table('col_change', m2,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x', 'y', name="nochange")
- )
+ Table(
+ "col_change",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ UniqueConstraint("x", name="nochange"),
+ )
+ Table(
+ "col_change",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ UniqueConstraint("x", "y", name="nochange"),
+ )
diffs = self._fixture(m1, m2)
@@ -211,13 +244,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('nothing_changed', m1,
- Column('x', String(20), unique=True, index=True)
- )
+ Table(
+ "nothing_changed",
+ m1,
+ Column("x", String(20), unique=True, index=True),
+ )
- Table('nothing_changed', m2,
- Column('x', String(20), unique=True, index=True)
- )
+ Table(
+ "nothing_changed",
+ m2,
+ Column("x", String(20), unique=True, index=True),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -226,35 +263,43 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('nothing_changed', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20), unique=True),
- mysql_engine='InnoDB'
- )
- Table('nothing_changed_related', m1,
- Column('id1', Integer),
- Column('id2', Integer),
- ForeignKeyConstraint(
- ['id1', 'id2'],
- ['nothing_changed.id1', 'nothing_changed.id2']),
- mysql_engine='InnoDB'
- )
-
- Table('nothing_changed', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20), unique=True),
- mysql_engine='InnoDB'
- )
- Table('nothing_changed_related', m2,
- Column('id1', Integer),
- Column('id2', Integer),
- ForeignKeyConstraint(
- ['id1', 'id2'],
- ['nothing_changed.id1', 'nothing_changed.id2']),
- mysql_engine='InnoDB'
- )
+ Table(
+ "nothing_changed",
+ m1,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20), unique=True),
+ mysql_engine="InnoDB",
+ )
+ Table(
+ "nothing_changed_related",
+ m1,
+ Column("id1", Integer),
+ Column("id2", Integer),
+ ForeignKeyConstraint(
+ ["id1", "id2"], ["nothing_changed.id1", "nothing_changed.id2"]
+ ),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "nothing_changed",
+ m2,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20), unique=True),
+ mysql_engine="InnoDB",
+ )
+ Table(
+ "nothing_changed_related",
+ m2,
+ Column("id1", Integer),
+ Column("id2", Integer),
+ ForeignKeyConstraint(
+ ["id1", "id2"], ["nothing_changed.id1", "nothing_changed.id2"]
+ ),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -263,15 +308,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('nothing_changed', m1,
- Column('x', String(20), key='nx'),
- UniqueConstraint('nx')
- )
+ Table(
+ "nothing_changed",
+ m1,
+ Column("x", String(20), key="nx"),
+ UniqueConstraint("nx"),
+ )
- Table('nothing_changed', m2,
- Column('x', String(20), key='nx'),
- UniqueConstraint('nx')
- )
+ Table(
+ "nothing_changed",
+ m2,
+ Column("x", String(20), key="nx"),
+ UniqueConstraint("nx"),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -280,15 +329,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('nothing_changed', m1,
- Column('x', String(20), key='nx'),
- Index('foobar', 'nx')
- )
+ Table(
+ "nothing_changed",
+ m1,
+ Column("x", String(20), key="nx"),
+ Index("foobar", "nx"),
+ )
- Table('nothing_changed', m2,
- Column('x', String(20), key='nx'),
- Index('foobar', 'nx')
- )
+ Table(
+ "nothing_changed",
+ m2,
+ Column("x", String(20), key="nx"),
+ Index("foobar", "nx"),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -297,19 +350,23 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('nothing_changed', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- Index('x', 'x')
- )
+ Table(
+ "nothing_changed",
+ m1,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20)),
+ Index("x", "x"),
+ )
- Table('nothing_changed', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- Index('x', 'x')
- )
+ Table(
+ "nothing_changed",
+ m2,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20)),
+ Index("x", "x"),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -318,29 +375,43 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table("nothing_changed", m1,
- Column('id', Integer, primary_key=True),
- Column('other_id',
- ForeignKey('nc2.id',
- name='fk_my_table_other_table'
- ),
- nullable=False),
- Column('foo', Integer),
- mysql_engine='InnoDB')
- Table('nc2', m1,
- Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
-
- Table("nothing_changed", m2,
- Column('id', Integer, primary_key=True),
- Column('other_id', ForeignKey('nc2.id',
- name='fk_my_table_other_table'),
- nullable=False),
- Column('foo', Integer),
- mysql_engine='InnoDB')
- Table('nc2', m2,
- Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ Table(
+ "nothing_changed",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column(
+ "other_id",
+ ForeignKey("nc2.id", name="fk_my_table_other_table"),
+ nullable=False,
+ ),
+ Column("foo", Integer),
+ mysql_engine="InnoDB",
+ )
+ Table(
+ "nc2",
+ m1,
+ Column("id", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "nothing_changed",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column(
+ "other_id",
+ ForeignKey("nc2.id", name="fk_my_table_other_table"),
+ nullable=False,
+ ),
+ Column("foo", Integer),
+ mysql_engine="InnoDB",
+ )
+ Table(
+ "nc2",
+ m2,
+ Column("id", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -348,35 +419,49 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table("nothing_changed", m1,
- Column('id', Integer, primary_key=True),
- Column('other_id_1', Integer),
- Column('other_id_2', Integer),
- Column('foo', Integer),
- ForeignKeyConstraint(
- ['other_id_1', 'other_id_2'], ['nc2.id1', 'nc2.id2'],
- name='fk_my_table_other_table'
- ),
- mysql_engine='InnoDB')
- Table('nc2', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- mysql_engine='InnoDB')
-
- Table("nothing_changed", m2,
- Column('id', Integer, primary_key=True),
- Column('other_id_1', Integer),
- Column('other_id_2', Integer),
- Column('foo', Integer),
- ForeignKeyConstraint(
- ['other_id_1', 'other_id_2'], ['nc2.id1', 'nc2.id2'],
- name='fk_my_table_other_table'
- ),
- mysql_engine='InnoDB')
- Table('nc2', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ Table(
+ "nothing_changed",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("other_id_1", Integer),
+ Column("other_id_2", Integer),
+ Column("foo", Integer),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["nc2.id1", "nc2.id2"],
+ name="fk_my_table_other_table",
+ ),
+ mysql_engine="InnoDB",
+ )
+ Table(
+ "nc2",
+ m1,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
+
+ Table(
+ "nothing_changed",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("other_id_1", Integer),
+ Column("other_id_2", Integer),
+ Column("foo", Integer),
+ ForeignKeyConstraint(
+ ["other_id_1", "other_id_2"],
+ ["nc2.id1", "nc2.id2"],
+ name="fk_my_table_other_table",
+ ),
+ mysql_engine="InnoDB",
+ )
+ Table(
+ "nc2",
+ m2,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ mysql_engine="InnoDB",
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -384,65 +469,73 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('new_idx', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- )
+ Table(
+ "new_idx",
+ m1,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20)),
+ )
- idx = Index('x', 'x')
- Table('new_idx', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- idx
- )
+ idx = Index("x", "x")
+ Table(
+ "new_idx",
+ m2,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20)),
+ idx,
+ )
diffs = self._fixture(m1, m2)
- eq_(diffs, [('add_index', idx)])
+ eq_(diffs, [("add_index", idx)])
def test_removed_idx_index_named_as_column(self):
m1 = MetaData()
m2 = MetaData()
- idx = Index('x', 'x')
- Table('new_idx', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- idx
- )
+ idx = Index("x", "x")
+ Table(
+ "new_idx",
+ m1,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20)),
+ idx,
+ )
- Table('new_idx', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20))
- )
+ Table(
+ "new_idx",
+ m2,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True),
+ Column("x", String(20)),
+ )
diffs = self._fixture(m1, m2)
- eq_(diffs[0][0], 'remove_index')
+ eq_(diffs[0][0], "remove_index")
def test_drop_table_w_indexes(self):
m1 = MetaData()
m2 = MetaData()
t = Table(
- 'some_table', m1,
- Column('id', Integer, primary_key=True),
- Column('x', String(20)),
- Column('y', String(20)),
+ "some_table",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", String(20)),
+ Column("y", String(20)),
)
- Index('xy_idx', t.c.x, t.c.y)
- Index('y_idx', t.c.y)
+ Index("xy_idx", t.c.x, t.c.y)
+ Index("y_idx", t.c.y)
diffs = self._fixture(m1, m2)
- eq_(diffs[0][0], 'remove_index')
- eq_(diffs[1][0], 'remove_index')
- eq_(diffs[2][0], 'remove_table')
+ eq_(diffs[0][0], "remove_index")
+ eq_(diffs[1][0], "remove_index")
+ eq_(diffs[2][0], "remove_table")
eq_(
- set([diffs[0][1].name, diffs[1][1].name]),
- set(['xy_idx', 'y_idx'])
+ set([diffs[0][1].name, diffs[1][1].name]), set(["xy_idx", "y_idx"])
)
# this simply doesn't fully work before we had
@@ -453,11 +546,12 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
Table(
- 'some_table', m1,
- Column('id', Integer, primary_key=True),
- Column('x', String(20)),
- Column('y', String(20)),
- UniqueConstraint('y', name='uq_y')
+ "some_table",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("x", String(20)),
+ Column("y", String(20)),
+ UniqueConstraint("y", name="uq_y"),
)
diffs = self._fixture(m1, m2)
@@ -465,65 +559,80 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
if self.reports_unique_constraints_as_indexes:
# for MySQL this UQ will look like an index, so
# make sure it at least sets it up correctly
- eq_(diffs[0][0], 'remove_index')
- eq_(diffs[1][0], 'remove_table')
+ eq_(diffs[0][0], "remove_index")
+ eq_(diffs[1][0], "remove_table")
eq_(len(diffs), 2)
- constraints = [c for c in diffs[1][1].constraints
- if isinstance(c, UniqueConstraint)]
+ constraints = [
+ c
+ for c in diffs[1][1].constraints
+ if isinstance(c, UniqueConstraint)
+ ]
eq_(len(constraints), 0)
else:
- eq_(diffs[0][0], 'remove_table')
+ eq_(diffs[0][0], "remove_table")
eq_(len(diffs), 1)
- constraints = [c for c in diffs[0][1].constraints
- if isinstance(c, UniqueConstraint)]
+ constraints = [
+ c
+ for c in diffs[0][1].constraints
+ if isinstance(c, UniqueConstraint)
+ ]
if self.reports_unique_constraints:
eq_(len(constraints), 1)
def test_unnamed_cols_changed(self):
m1 = MetaData()
m2 = MetaData()
- Table('col_change', m1,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x')
- )
- Table('col_change', m2,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x', 'y')
- )
+ Table(
+ "col_change",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ UniqueConstraint("x"),
+ )
+ Table(
+ "col_change",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ UniqueConstraint("x", "y"),
+ )
diffs = self._fixture(m1, m2)
- diffs = set((cmd,
- ('x' in obj.name) if obj.name is not None else False)
- for cmd, obj in diffs)
+ diffs = set(
+ (cmd, ("x" in obj.name) if obj.name is not None else False)
+ for cmd, obj in diffs
+ )
if self.reports_unnamed_constraints:
if self.reports_unique_constraints_as_indexes:
eq_(
diffs,
- set([("remove_index", True), ("add_constraint", False)])
+ set([("remove_index", True), ("add_constraint", False)]),
)
else:
eq_(
diffs,
- set([("remove_constraint", True),
- ("add_constraint", False)])
+ set(
+ [
+ ("remove_constraint", True),
+ ("add_constraint", False),
+ ]
+ ),
)
def test_remove_named_unique_index(self):
m1 = MetaData()
m2 = MetaData()
- Table('remove_idx', m1,
- Column('x', Integer),
- Index('xidx', 'x', unique=True)
- )
- Table('remove_idx', m2,
- Column('x', Integer)
- )
+ Table(
+ "remove_idx",
+ m1,
+ Column("x", Integer),
+ Index("xidx", "x", unique=True),
+ )
+ Table("remove_idx", m2, Column("x", Integer))
diffs = self._fixture(m1, m2)
@@ -537,13 +646,13 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('remove_idx', m1,
- Column('x', Integer),
- UniqueConstraint('x', name='xidx')
- )
- Table('remove_idx', m2,
- Column('x', Integer),
- )
+ Table(
+ "remove_idx",
+ m1,
+ Column("x", Integer),
+ UniqueConstraint("x", name="xidx"),
+ )
+ Table("remove_idx", m2, Column("x", Integer))
diffs = self._fixture(m1, m2)
@@ -559,46 +668,49 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_dont_add_uq_on_table_create(self):
m1 = MetaData()
m2 = MetaData()
- Table('no_uq', m2, Column('x', String(50), unique=True))
+ Table("no_uq", m2, Column("x", String(50), unique=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table")
eq_(len(diffs), 1)
assert UniqueConstraint in set(
- type(c) for c in diffs[0][1].constraints)
+ type(c) for c in diffs[0][1].constraints
+ )
def test_add_uq_ix_on_table_create(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m2, Column('x', String(50), unique=True, index=True))
+ Table("add_ix", m2, Column("x", String(50), unique=True, index=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table")
eq_(len(diffs), 2)
assert UniqueConstraint not in set(
- type(c) for c in diffs[0][1].constraints)
+ type(c) for c in diffs[0][1].constraints
+ )
eq_(diffs[1][0], "add_index")
eq_(diffs[1][1].unique, True)
def test_add_ix_on_table_create(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m2, Column('x', String(50), index=True))
+ Table("add_ix", m2, Column("x", String(50), index=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table")
eq_(len(diffs), 2)
assert UniqueConstraint not in set(
- type(c) for c in diffs[0][1].constraints)
+ type(c) for c in diffs[0][1].constraints
+ )
eq_(diffs[1][0], "add_index")
eq_(diffs[1][1].unique, False)
def test_add_idx_non_col(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m1, Column('x', String(50)))
- t2 = Table('add_ix', m2, Column('x', String(50)))
- Index('foo_idx', t2.c.x.desc())
+ Table("add_ix", m1, Column("x", String(50)))
+ t2 = Table("add_ix", m2, Column("x", String(50)))
+ Index("foo_idx", t2.c.x.desc())
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_index")
@@ -606,10 +718,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_unchanged_idx_non_col(self):
m1 = MetaData()
m2 = MetaData()
- t1 = Table('add_ix', m1, Column('x', String(50)))
- Index('foo_idx', t1.c.x.desc())
- t2 = Table('add_ix', m2, Column('x', String(50)))
- Index('foo_idx', t2.c.x.desc())
+ t1 = Table("add_ix", m1, Column("x", String(50)))
+ Index("foo_idx", t1.c.x.desc())
+ t2 = Table("add_ix", m2, Column("x", String(50)))
+ Index("foo_idx", t2.c.x.desc())
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -622,8 +734,8 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_unchanged_case_sensitive_implicit_idx(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m1, Column('regNumber', String(50), index=True))
- Table('add_ix', m2, Column('regNumber', String(50), index=True))
+ Table("add_ix", m1, Column("regNumber", String(50), index=True))
+ Table("add_ix", m2, Column("regNumber", String(50), index=True))
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -631,10 +743,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
def test_unchanged_case_sensitive_explicit_idx(self):
m1 = MetaData()
m2 = MetaData()
- t1 = Table('add_ix', m1, Column('reg_number', String(50)))
- Index('regNumber_idx', t1.c.reg_number)
- t2 = Table('add_ix', m2, Column('reg_number', String(50)))
- Index('regNumber_idx', t2.c.reg_number)
+ t1 = Table("add_ix", m1, Column("reg_number", String(50)))
+ Index("regNumber_idx", t1.c.reg_number)
+ t2 = Table("add_ix", m2, Column("reg_number", String(50)))
+ Index("regNumber_idx", t2.c.reg_number)
diffs = self._fixture(m1, m2)
@@ -649,21 +761,36 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
def test_idx_added_schema(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m1, Column('x', String(50)), schema="test_schema")
- Table('add_ix', m2, Column('x', String(50)),
- Index('ix_1', 'x'), schema="test_schema")
+ Table("add_ix", m1, Column("x", String(50)), schema="test_schema")
+ Table(
+ "add_ix",
+ m2,
+ Column("x", String(50)),
+ Index("ix_1", "x"),
+ schema="test_schema",
+ )
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs[0][0], "add_index")
- eq_(diffs[0][1].name, 'ix_1')
+ eq_(diffs[0][1].name, "ix_1")
def test_idx_unchanged_schema(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'),
- schema="test_schema")
- Table('add_ix', m2, Column('x', String(50)),
- Index('ix_1', 'x'), schema="test_schema")
+ Table(
+ "add_ix",
+ m1,
+ Column("x", String(50)),
+ Index("ix_1", "x"),
+ schema="test_schema",
+ )
+ Table(
+ "add_ix",
+ m2,
+ Column("x", String(50)),
+ Index("ix_1", "x"),
+ schema="test_schema",
+ )
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs, [])
@@ -671,23 +798,36 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
def test_uq_added_schema(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_uq', m1, Column('x', String(50)), schema="test_schema")
- Table('add_uq', m2, Column('x', String(50)),
- UniqueConstraint('x', name='ix_1'), schema="test_schema")
+ Table("add_uq", m1, Column("x", String(50)), schema="test_schema")
+ Table(
+ "add_uq",
+ m2,
+ Column("x", String(50)),
+ UniqueConstraint("x", name="ix_1"),
+ schema="test_schema",
+ )
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs[0][0], "add_constraint")
- eq_(diffs[0][1].name, 'ix_1')
+ eq_(diffs[0][1].name, "ix_1")
def test_uq_unchanged_schema(self):
m1 = MetaData()
m2 = MetaData()
- Table('add_uq', m1, Column('x', String(50)),
- UniqueConstraint('x', name='ix_1'),
- schema="test_schema")
- Table('add_uq', m2, Column('x', String(50)),
- UniqueConstraint('x', name='ix_1'),
- schema="test_schema")
+ Table(
+ "add_uq",
+ m1,
+ Column("x", String(50)),
+ UniqueConstraint("x", name="ix_1"),
+ schema="test_schema",
+ )
+ Table(
+ "add_uq",
+ m2,
+ Column("x", String(50)),
+ UniqueConstraint("x", name="ix_1"),
+ schema="test_schema",
+ )
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs, [])
@@ -701,17 +841,19 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m2 = MetaData()
Table(
- 'add_excl', m1,
- Column('id', Integer, primary_key=True),
- Column('period', TSRANGE),
- ExcludeConstraint(('period', '&&'), name='quarters_period_excl')
+ "add_excl",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("period", TSRANGE),
+ ExcludeConstraint(("period", "&&"), name="quarters_period_excl"),
)
Table(
- 'add_excl', m2,
- Column('id', Integer, primary_key=True),
- Column('period', TSRANGE),
- ExcludeConstraint(('period', '&&'), name='quarters_period_excl')
+ "add_excl",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("period", TSRANGE),
+ ExcludeConstraint(("period", "&&"), name="quarters_period_excl"),
)
diffs = self._fixture(m1, m2)
@@ -721,10 +863,10 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m1 = MetaData()
m2 = MetaData()
- Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'))
+ Table("add_ix", m1, Column("x", String(50)), Index("ix_1", "x"))
- Table('add_ix', m2, Column('x', String(50)), Index('ix_1', 'x'))
- Table('add_ix', m2, Column('x', String(50)), schema="test_schema")
+ Table("add_ix", m2, Column("x", String(50)), Index("ix_1", "x"))
+ Table("add_ix", m2, Column("x", String(50)), schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs[0][0], "add_table")
@@ -734,15 +876,17 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m1 = MetaData()
m2 = MetaData()
Table(
- 'add_uq', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String),
- UniqueConstraint('name', name='uq_name')
+ "add_uq",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("name", String),
+ UniqueConstraint("name", name="uq_name"),
)
Table(
- 'add_uq', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String),
+ "add_uq",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("name", String),
)
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs[0][0], "remove_constraint")
@@ -754,22 +898,24 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m2 = MetaData()
t1 = Table(
- 'foo', m1,
- Column('id', Integer, primary_key=True),
- Column('email', String(50))
+ "foo",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("email", String(50)),
)
Index("email_idx", func.lower(t1.c.email), unique=True)
t2 = Table(
- 'foo', m2,
- Column('id', Integer, primary_key=True),
- Column('email', String(50))
+ "foo",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("email", String(50)),
)
Index("email_idx", func.lower(t2.c.email), unique=True)
with assertions.expect_warnings(
- "Skipped unsupported reflection",
- "autogenerate skipping functional index"
+ "Skipped unsupported reflection",
+ "autogenerate skipping functional index",
):
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -779,28 +925,34 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m2 = MetaData()
t1 = Table(
- 'foo', m1,
- Column('id', Integer, primary_key=True),
- Column('email', String(50)),
- Column('name', String(50))
+ "foo",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("email", String(50)),
+ Column("name", String(50)),
)
Index(
"email_idx",
- func.coalesce(t1.c.email, t1.c.name).desc(), unique=True)
+ func.coalesce(t1.c.email, t1.c.name).desc(),
+ unique=True,
+ )
t2 = Table(
- 'foo', m2,
- Column('id', Integer, primary_key=True),
- Column('email', String(50)),
- Column('name', String(50))
+ "foo",
+ m2,
+ Column("id", Integer, primary_key=True),
+ Column("email", String(50)),
+ Column("name", String(50)),
)
Index(
"email_idx",
- func.coalesce(t2.c.email, t2.c.name).desc(), unique=True)
+ func.coalesce(t2.c.email, t2.c.name).desc(),
+ unique=True,
+ )
with assertions.expect_warnings(
- "Skipped unsupported reflection",
- "autogenerate skipping functional index"
+ "Skipped unsupported reflection",
+ "autogenerate skipping functional index",
):
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -809,13 +961,14 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest):
reports_unnamed_constraints = True
reports_unique_constraints_as_indexes = True
- __only_on__ = 'mysql'
+ __only_on__ = "mysql"
__backend__ = True
def test_removed_idx_index_named_as_column(self):
try:
- super(MySQLUniqueIndexTest,
- self).test_removed_idx_index_named_as_column()
+ super(
+ MySQLUniqueIndexTest, self
+ ).test_removed_idx_index_named_as_column()
except IndexError:
assert True
else:
@@ -828,61 +981,70 @@ class OracleUniqueIndexTest(AutogenerateUniqueIndexTest):
__only_on__ = "oracle"
__backend__ = True
+
class NoUqReflectionIndexTest(NoUqReflection, AutogenerateUniqueIndexTest):
reports_unique_constraints = False
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
def test_unique_not_reported(self):
m1 = MetaData()
- Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- UniqueConstraint('order_id', 'user_id',
- name='order_order_id_user_id_unique'
- )
- )
+ Table(
+ "order",
+ m1,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ UniqueConstraint(
+ "order_id", "user_id", name="order_order_id_user_id_unique"
+ ),
+ )
diffs = self._fixture(m1, m1)
eq_(diffs, [])
def test_remove_unique_index_not_reported(self):
m1 = MetaData()
- Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- Index('oid_ix', 'order_id', 'user_id',
- unique=True
- )
- )
+ Table(
+ "order",
+ m1,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ Index("oid_ix", "order_id", "user_id", unique=True),
+ )
m2 = MetaData()
- Table('order', m2,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- )
+ Table(
+ "order",
+ m2,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_remove_plain_index_is_reported(self):
m1 = MetaData()
- Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- Index('oid_ix', 'order_id', 'user_id')
- )
+ Table(
+ "order",
+ m1,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ Index("oid_ix", "order_id", "user_id"),
+ )
m2 = MetaData()
- Table('order', m2,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- )
+ Table(
+ "order",
+ m2,
+ Column("order_id", Integer, primary_key=True),
+ Column("amount", Numeric(10, 2), nullable=True),
+ Column("user_id", Integer),
+ )
diffs = self._fixture(m1, m2)
- eq_(diffs[0][0], 'remove_index')
+ eq_(diffs[0][0], "remove_index")
class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
@@ -899,7 +1061,7 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
"""
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
@classmethod
def _get_bind(cls):
@@ -916,7 +1078,7 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
for uq in _get_unique_constraints(
self, connection, tablename, **kw
):
- uq['unique'] = True
+ uq["unique"] = True
indexes.append(uq)
return indexes
@@ -932,23 +1094,26 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- t1 = Table('t', m1, Column('x', Integer), Column('y', Integer))
- Index('ix1', t1.c.x)
- Index('ix2', t1.c.y)
+ t1 = Table("t", m1, Column("x", Integer), Column("y", Integer))
+ Index("ix1", t1.c.x)
+ Index("ix2", t1.c.y)
- Table('t', m2, Column('x', Integer), Column('y', Integer))
+ Table("t", m2, Column("x", Integer), Column("y", Integer))
def include_object(object_, name, type_, reflected, compare_to):
- if type_ == 'unique_constraint':
+ if type_ == "unique_constraint":
return False
return not (
- isinstance(object_, Index) and
- type_ == 'index' and reflected and name == 'ix1')
+ isinstance(object_, Index)
+ and type_ == "index"
+ and reflected
+ and name == "ix1"
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
- eq_(diffs[0][0], 'remove_index')
- eq_(diffs[0][1].name, 'ix2')
+ eq_(diffs[0][0], "remove_index")
+ eq_(diffs[0][1].name, "ix2")
eq_(len(diffs), 1)
@config.requirements.unique_constraint_reflection
@@ -958,45 +1123,54 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
Table(
- 't', m1, Column('x', Integer), Column('y', Integer),
- UniqueConstraint('x', name='uq1'),
- UniqueConstraint('y', name='uq2'),
+ "t",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ UniqueConstraint("x", name="uq1"),
+ UniqueConstraint("y", name="uq2"),
)
- Table('t', m2, Column('x', Integer), Column('y', Integer))
+ Table("t", m2, Column("x", Integer), Column("y", Integer))
def include_object(object_, name, type_, reflected, compare_to):
- if type_ == 'index':
+ if type_ == "index":
return False
return not (
- isinstance(object_, UniqueConstraint) and
- type_ == 'unique_constraint' and reflected and name == 'uq1')
+ isinstance(object_, UniqueConstraint)
+ and type_ == "unique_constraint"
+ and reflected
+ and name == "uq1"
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
- eq_(diffs[0][0], 'remove_constraint')
- eq_(diffs[0][1].name, 'uq2')
+ eq_(diffs[0][0], "remove_constraint")
+ eq_(diffs[0][1].name, "uq2")
eq_(len(diffs), 1)
def test_add_metadata_index(self):
m1 = MetaData()
m2 = MetaData()
- Table('t', m1, Column('x', Integer))
+ Table("t", m1, Column("x", Integer))
- t2 = Table('t', m2, Column('x', Integer))
- Index('ix1', t2.c.x)
- Index('ix2', t2.c.x)
+ t2 = Table("t", m2, Column("x", Integer))
+ Index("ix1", t2.c.x)
+ Index("ix2", t2.c.x)
def include_object(object_, name, type_, reflected, compare_to):
return not (
- isinstance(object_, Index) and
- type_ == 'index' and not reflected and name == 'ix1')
+ isinstance(object_, Index)
+ and type_ == "index"
+ and not reflected
+ and name == "ix1"
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
- eq_(diffs[0][0], 'add_index')
- eq_(diffs[0][1].name, 'ix2')
+ eq_(diffs[0][0], "add_index")
+ eq_(diffs[0][1].name, "ix2")
eq_(len(diffs), 1)
@config.requirements.unique_constraint_reflection
@@ -1004,24 +1178,28 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
m2 = MetaData()
- Table('t', m1, Column('x', Integer))
+ Table("t", m1, Column("x", Integer))
Table(
- 't', m2, Column('x', Integer),
- UniqueConstraint('x', name='uq1'),
- UniqueConstraint('x', name='uq2')
+ "t",
+ m2,
+ Column("x", Integer),
+ UniqueConstraint("x", name="uq1"),
+ UniqueConstraint("x", name="uq2"),
)
def include_object(object_, name, type_, reflected, compare_to):
return not (
- isinstance(object_, UniqueConstraint) and
- type_ == 'unique_constraint' and
- not reflected and name == 'uq1')
+ isinstance(object_, UniqueConstraint)
+ and type_ == "unique_constraint"
+ and not reflected
+ and name == "uq1"
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
- eq_(diffs[0][0], 'add_constraint')
- eq_(diffs[0][1].name, 'uq2')
+ eq_(diffs[0][0], "add_constraint")
+ eq_(diffs[0][1].name, "uq2")
eq_(len(diffs), 1)
def test_change_index(self):
@@ -1029,29 +1207,40 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
t1 = Table(
- 't', m1, Column('x', Integer),
- Column('y', Integer), Column('z', Integer))
- Index('ix1', t1.c.x)
- Index('ix2', t1.c.y)
+ "t",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
+ )
+ Index("ix1", t1.c.x)
+ Index("ix2", t1.c.y)
t2 = Table(
- 't', m2, Column('x', Integer),
- Column('y', Integer), Column('z', Integer))
- Index('ix1', t2.c.x, t2.c.y)
- Index('ix2', t2.c.x, t2.c.z)
+ "t",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
+ )
+ Index("ix1", t2.c.x, t2.c.y)
+ Index("ix2", t2.c.x, t2.c.z)
def include_object(object_, name, type_, reflected, compare_to):
return not (
- isinstance(object_, Index) and
- type_ == 'index' and not reflected and name == 'ix1'
- and isinstance(compare_to, Index))
+ isinstance(object_, Index)
+ and type_ == "index"
+ and not reflected
+ and name == "ix1"
+ and isinstance(compare_to, Index)
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
- eq_(diffs[0][0], 'remove_index')
- eq_(diffs[0][1].name, 'ix2')
- eq_(diffs[1][0], 'add_index')
- eq_(diffs[1][1].name, 'ix2')
+ eq_(diffs[0][0], "remove_index")
+ eq_(diffs[0][1].name, "ix2")
+ eq_(diffs[1][0], "add_index")
+ eq_(diffs[1][1].name, "ix2")
eq_(len(diffs), 2)
@config.requirements.unique_constraint_reflection
@@ -1060,39 +1249,46 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
m2 = MetaData()
Table(
- 't', m1, Column('x', Integer),
- Column('y', Integer), Column('z', Integer),
- UniqueConstraint('x', name='uq1'),
- UniqueConstraint('y', name='uq2')
+ "t",
+ m1,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
+ UniqueConstraint("x", name="uq1"),
+ UniqueConstraint("y", name="uq2"),
)
Table(
- 't', m2, Column('x', Integer), Column('y', Integer),
- Column('z', Integer),
- UniqueConstraint('x', 'z', name='uq1'),
- UniqueConstraint('y', 'z', name='uq2')
+ "t",
+ m2,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
+ UniqueConstraint("x", "z", name="uq1"),
+ UniqueConstraint("y", "z", name="uq2"),
)
def include_object(object_, name, type_, reflected, compare_to):
- if type_ == 'index':
+ if type_ == "index":
return False
return not (
- isinstance(object_, UniqueConstraint) and
- type_ == 'unique_constraint' and
- not reflected and name == 'uq1'
- and isinstance(compare_to, UniqueConstraint))
+ isinstance(object_, UniqueConstraint)
+ and type_ == "unique_constraint"
+ and not reflected
+ and name == "uq1"
+ and isinstance(compare_to, UniqueConstraint)
+ )
diffs = self._fixture(m1, m2, object_filters=include_object)
- eq_(diffs[0][0], 'remove_constraint')
- eq_(diffs[0][1].name, 'uq2')
- eq_(diffs[1][0], 'add_constraint')
- eq_(diffs[1][1].name, 'uq2')
+ eq_(diffs[0][0], "remove_constraint")
+ eq_(diffs[0][1].name, "uq2")
+ eq_(diffs[1][0], "add_constraint")
+ eq_(diffs[1][1].name, "uq2")
eq_(len(diffs), 2)
class TruncatedIdxTest(AutogenFixtureTest, TestBase):
-
def setUp(self):
self.bind = engines.testing_engine()
self.bind.dialect.max_identifier_length = 30
@@ -1102,12 +1298,13 @@ class TruncatedIdxTest(AutogenFixtureTest, TestBase):
m1 = MetaData()
Table(
- 'q', m1,
- Column('id', Integer, primary_key=True),
- Column('data', Integer),
+ "q",
+ m1,
+ Column("id", Integer, primary_key=True),
+ Column("data", Integer),
Index(
- conv("idx_q_table_this_is_more_than_thirty_characters"),
- "data")
+ conv("idx_q_table_this_is_more_than_thirty_characters"), "data"
+ ),
)
diffs = self._fixture(m1, m1)
diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py
index b32358f..37e1c61 100644
--- a/tests/test_autogen_render.py
+++ b/tests/test_autogen_render.py
@@ -4,11 +4,31 @@ from alembic.testing import TestBase, exclusions, assert_raises
from alembic.testing import assertions
from alembic.operations import ops
-from sqlalchemy import MetaData, Column, Table, String, \
- Numeric, CHAR, ForeignKey, DATETIME, Integer, BigInteger, \
- CheckConstraint, Unicode, Enum, cast,\
- DateTime, UniqueConstraint, Boolean, ForeignKeyConstraint,\
- PrimaryKeyConstraint, Index, func, text, DefaultClause
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ String,
+ Numeric,
+ CHAR,
+ ForeignKey,
+ DATETIME,
+ Integer,
+ BigInteger,
+ CheckConstraint,
+ Unicode,
+ Enum,
+ cast,
+ DateTime,
+ UniqueConstraint,
+ Boolean,
+ ForeignKeyConstraint,
+ PrimaryKeyConstraint,
+ Index,
+ func,
+ text,
+ DefaultClause,
+)
from sqlalchemy.types import TIMESTAMP
from sqlalchemy import types
@@ -28,7 +48,7 @@ from alembic.testing.fixtures import op_fixture
from alembic import op # noqa
import sqlalchemy as sa # noqa
-py3k = sys.version_info >= (3, )
+py3k = sys.version_info >= (3,)
class AutogenRenderTest(TestBase):
@@ -37,13 +57,12 @@ class AutogenRenderTest(TestBase):
def setUp(self):
ctx_opts = {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'target_metadata': MetaData()
+ "sqlalchemy_module_prefix": "sa.",
+ "alembic_module_prefix": "op.",
+ "target_metadata": MetaData(),
}
context = MigrationContext.configure(
- dialect=DefaultDialect(),
- opts=ctx_opts
+ dialect=DefaultDialect(), opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context)
@@ -53,17 +72,19 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_index
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_active_code_idx', 'test', "
- "['active', 'code'], unique=False)"
+ "['active', 'code'], unique=False)",
)
def test_render_add_index_batch(self):
@@ -71,18 +92,20 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_index
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.CreateIndexOp.from_index(idx)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"batch_op.create_index('test_active_code_idx', "
- "['active', 'code'], unique=False)"
+ "['active', 'code'], unique=False)",
)
def test_render_add_index_schema(self):
@@ -90,18 +113,20 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_index using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_active_code_idx', 'test', "
- "['active', 'code'], unique=False, schema='CamelSchema')"
+ "['active', 'code'], unique=False, schema='CamelSchema')",
)
def test_render_add_index_schema_batch(self):
@@ -109,73 +134,78 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_index using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.CreateIndexOp.from_index(idx)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"batch_op.create_index('test_active_code_idx', "
- "['active', 'code'], unique=False)"
+ "['active', 'code'], unique=False)",
)
def test_render_add_index_func(self):
m = MetaData()
t = Table(
- 'test', m,
- Column('id', Integer, primary_key=True),
- Column('code', String(255))
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("code", String(255)),
)
- idx = Index('test_lower_code_idx', func.lower(t.c.code))
+ idx = Index("test_lower_code_idx", func.lower(t.c.code))
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_lower_code_idx', 'test', "
- "[sa.text(!U'lower(code)')], unique=False)"
+ "[sa.text(!U'lower(code)')], unique=False)",
)
def test_render_add_index_cast(self):
m = MetaData()
t = Table(
- 'test', m,
- Column('id', Integer, primary_key=True),
- Column('code', String(255))
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("code", String(255)),
)
- idx = Index('test_lower_code_idx', cast(t.c.code, String))
+ idx = Index("test_lower_code_idx", cast(t.c.code, String))
op_obj = ops.CreateIndexOp.from_index(idx)
if config.requirements.sqlalchemy_110.enabled:
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_lower_code_idx', 'test', "
- "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)"
+ "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)",
)
else:
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_lower_code_idx', 'test', "
- "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)"
+ "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)",
)
def test_render_add_index_desc(self):
m = MetaData()
t = Table(
- 'test', m,
- Column('id', Integer, primary_key=True),
- Column('code', String(255))
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("code", String(255)),
)
- idx = Index('test_desc_code_idx', t.c.code.desc())
+ idx = Index("test_desc_code_idx", t.c.code.desc())
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_desc_code_idx', 'test', "
- "[sa.text(!U'code DESC')], unique=False)"
+ "[sa.text(!U'code DESC')], unique=False)",
)
def test_drop_index(self):
@@ -183,16 +213,18 @@ class AutogenRenderTest(TestBase):
autogenerate.render._drop_index
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.DropIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_index('test_active_code_idx', table_name='test')"
+ "op.drop_index('test_active_code_idx', table_name='test')",
)
def test_drop_index_batch(self):
@@ -200,17 +232,19 @@ class AutogenRenderTest(TestBase):
autogenerate.render._drop_index
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.DropIndexOp.from_index(idx)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "batch_op.drop_index('test_active_code_idx')"
+ "batch_op.drop_index('test_active_code_idx')",
)
def test_drop_index_schema(self):
@@ -218,18 +252,20 @@ class AutogenRenderTest(TestBase):
autogenerate.render._drop_index using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.DropIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_index('test_active_code_idx', " +
- "table_name='test', schema='CamelSchema')"
+ "op.drop_index('test_active_code_idx', "
+ + "table_name='test', schema='CamelSchema')",
)
def test_drop_index_schema_batch(self):
@@ -237,18 +273,20 @@ class AutogenRenderTest(TestBase):
autogenerate.render._drop_index using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- idx = Index('test_active_code_idx', t.c.active, t.c.code)
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ idx = Index("test_active_code_idx", t.c.active, t.c.code)
op_obj = ops.DropIndexOp.from_index(idx)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "batch_op.drop_index('test_active_code_idx')"
+ "batch_op.drop_index('test_active_code_idx')",
)
def test_add_unique_constraint(self):
@@ -256,16 +294,18 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_unique_constraint
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- uq = UniqueConstraint(t.c.code, name='uq_test_code')
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ uq = UniqueConstraint(t.c.code, name="uq_test_code")
op_obj = ops.AddConstraintOp.from_constraint(uq)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.create_unique_constraint('uq_test_code', 'test', ['code'])"
+ "op.create_unique_constraint('uq_test_code', 'test', ['code'])",
)
def test_add_unique_constraint_batch(self):
@@ -273,17 +313,19 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_unique_constraint
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- uq = UniqueConstraint(t.c.code, name='uq_test_code')
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ uq = UniqueConstraint(t.c.code, name="uq_test_code")
op_obj = ops.AddConstraintOp.from_constraint(uq)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "batch_op.create_unique_constraint('uq_test_code', ['code'])"
+ "batch_op.create_unique_constraint('uq_test_code', ['code'])",
)
def test_add_unique_constraint_schema(self):
@@ -291,18 +333,20 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_unique_constraint using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- uq = UniqueConstraint(t.c.code, name='uq_test_code')
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ uq = UniqueConstraint(t.c.code, name="uq_test_code")
op_obj = ops.AddConstraintOp.from_constraint(uq)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_unique_constraint('uq_test_code', 'test', "
- "['code'], schema='CamelSchema')"
+ "['code'], schema='CamelSchema')",
)
def test_add_unique_constraint_schema_batch(self):
@@ -310,19 +354,21 @@ class AutogenRenderTest(TestBase):
autogenerate.render._add_unique_constraint using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- uq = UniqueConstraint(t.c.code, name='uq_test_code')
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ uq = UniqueConstraint(t.c.code, name="uq_test_code")
op_obj = ops.AddConstraintOp.from_constraint(uq)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"batch_op.create_unique_constraint('uq_test_code', "
- "['code'])"
+ "['code'])",
)
def test_drop_unique_constraint(self):
@@ -330,16 +376,18 @@ class AutogenRenderTest(TestBase):
autogenerate.render._drop_constraint
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
- uq = UniqueConstraint(t.c.code, name='uq_test_code')
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
+ uq = UniqueConstraint(t.c.code, name="uq_test_code")
op_obj = ops.DropConstraintOp.from_constraint(uq)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_constraint('uq_test_code', 'test', type_='unique')"
+ "op.drop_constraint('uq_test_code', 'test', type_='unique')",
)
def test_drop_unique_constraint_schema(self):
@@ -347,67 +395,72 @@ class AutogenRenderTest(TestBase):
autogenerate.render._drop_constraint using schema
"""
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
- uq = UniqueConstraint(t.c.code, name='uq_test_code')
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
+ uq = UniqueConstraint(t.c.code, name="uq_test_code")
op_obj = ops.DropConstraintOp.from_constraint(uq)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.drop_constraint('uq_test_code', 'test', "
- "schema='CamelSchema', type_='unique')"
+ "schema='CamelSchema', type_='unique')",
)
def test_drop_unique_constraint_schema_reprobj(self):
"""
autogenerate.render._drop_constraint using schema
"""
+
class SomeObj(str):
def __repr__(self):
return "foo.camel_schema"
op_obj = ops.DropConstraintOp(
- "uq_test_code", "test", type_="unique",
- schema=SomeObj("CamelSchema")
+ "uq_test_code",
+ "test",
+ type_="unique",
+ schema=SomeObj("CamelSchema"),
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.drop_constraint('uq_test_code', 'test', "
- "schema=foo.camel_schema, type_='unique')"
+ "schema=foo.camel_schema, type_='unique')",
)
def test_add_fk_constraint(self):
m = MetaData()
- Table('a', m, Column('id', Integer, primary_key=True))
- b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
- fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+ Table("a", m, Column("id", Integer, primary_key=True))
+ b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+ fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
b.append_constraint(fk)
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])"
+ "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])",
)
def test_add_fk_constraint_batch(self):
m = MetaData()
- Table('a', m, Column('id', Integer, primary_key=True))
- b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
- fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+ Table("a", m, Column("id", Integer, primary_key=True))
+ b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+ fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
b.append_constraint(fk)
op_obj = ops.AddConstraintOp.from_constraint(fk)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'])"
+ "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'])",
)
def test_add_fk_constraint_kwarg(self):
m = MetaData()
- t1 = Table('t', m, Column('c', Integer))
- t2 = Table('t2', m, Column('c_rem', Integer))
+ t1 = Table("t", m, Column("c", Integer))
+ t2 = Table("t2", m, Column("c_rem", Integer))
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE")
@@ -417,11 +470,12 @@ class AutogenRenderTest(TestBase):
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render_op_text(self.autogen_context, op_obj),
),
"op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
- "onupdate='CASCADE')"
+ "onupdate='CASCADE')",
)
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], ondelete="CASCADE")
@@ -429,54 +483,62 @@ class AutogenRenderTest(TestBase):
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
- autogenerate.render_op_text(self.autogen_context, op_obj)
+ r"u'",
+ "'",
+ autogenerate.render_op_text(self.autogen_context, op_obj),
),
"op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
- "ondelete='CASCADE')"
+ "ondelete='CASCADE')",
)
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], deferrable=True)
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
- autogenerate.render_op_text(self.autogen_context, op_obj)
+ r"u'",
+ "'",
+ autogenerate.render_op_text(self.autogen_context, op_obj),
),
"op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
- "deferrable=True)"
+ "deferrable=True)",
)
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], initially="XYZ")
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render_op_text(self.autogen_context, op_obj),
),
"op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
- "initially='XYZ')"
+ "initially='XYZ')",
)
fk = ForeignKeyConstraint(
- [t1.c.c], [t2.c.c_rem],
- initially="XYZ", ondelete="CASCADE", deferrable=True)
+ [t1.c.c],
+ [t2.c.c_rem],
+ initially="XYZ",
+ ondelete="CASCADE",
+ deferrable=True,
+ )
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
- autogenerate.render_op_text(self.autogen_context, op_obj)
+ r"u'",
+ "'",
+ autogenerate.render_op_text(self.autogen_context, op_obj),
),
"op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
- "ondelete='CASCADE', initially='XYZ', deferrable=True)"
+ "ondelete='CASCADE', initially='XYZ', deferrable=True)",
)
def test_add_fk_constraint_inline_colkeys(self):
m = MetaData()
- Table('a', m, Column('id', Integer, key='aid', primary_key=True))
+ Table("a", m, Column("id", Integer, key="aid", primary_key=True))
b = Table(
- 'b', m,
- Column('a_id', Integer, ForeignKey('a.aid'), key='baid'))
+ "b", m, Column("a_id", Integer, ForeignKey("a.aid"), key="baid")
+ )
op_obj = ops.CreateTableOp.from_table(b)
py_code = autogenerate.render_op_text(self.autogen_context, op_obj)
@@ -485,20 +547,21 @@ class AutogenRenderTest(TestBase):
py_code,
"op.create_table('b',"
"sa.Column('a_id', sa.Integer(), nullable=True),"
- "sa.ForeignKeyConstraint(['a_id'], ['a.id'], ))"
+ "sa.ForeignKeyConstraint(['a_id'], ['a.id'], ))",
)
context = op_fixture()
eval(py_code)
context.assert_(
"CREATE TABLE b (a_id INTEGER, "
- "FOREIGN KEY(a_id) REFERENCES a (id))")
+ "FOREIGN KEY(a_id) REFERENCES a (id))"
+ )
def test_add_fk_constraint_separate_colkeys(self):
m = MetaData()
- Table('a', m, Column('id', Integer, key='aid', primary_key=True))
- b = Table('b', m, Column('a_id', Integer, key='baid'))
- fk = ForeignKeyConstraint(['baid'], ['a.aid'], name='fk_a_id')
+ Table("a", m, Column("id", Integer, key="aid", primary_key=True))
+ b = Table("b", m, Column("a_id", Integer, key="baid"))
+ fk = ForeignKeyConstraint(["baid"], ["a.aid"], name="fk_a_id")
b.append_constraint(fk)
op_obj = ops.CreateTableOp.from_table(b)
@@ -508,14 +571,15 @@ class AutogenRenderTest(TestBase):
py_code,
"op.create_table('b',"
"sa.Column('a_id', sa.Integer(), nullable=True),"
- "sa.ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id'))"
+ "sa.ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id'))",
)
context = op_fixture()
eval(py_code)
context.assert_(
"CREATE TABLE b (a_id INTEGER, CONSTRAINT "
- "fk_a_id FOREIGN KEY(a_id) REFERENCES a (id))")
+ "fk_a_id FOREIGN KEY(a_id) REFERENCES a (id))"
+ )
context = op_fixture()
@@ -523,7 +587,7 @@ class AutogenRenderTest(TestBase):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])"
+ "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])",
)
py_code = autogenerate.render_op_text(self.autogen_context, op_obj)
@@ -531,124 +595,151 @@ class AutogenRenderTest(TestBase):
eval(py_code)
context.assert_(
"ALTER TABLE b ADD CONSTRAINT fk_a_id "
- "FOREIGN KEY(a_id) REFERENCES a (id)")
+ "FOREIGN KEY(a_id) REFERENCES a (id)"
+ )
def test_add_fk_constraint_schema(self):
m = MetaData()
Table(
- 'a', m, Column('id', Integer, primary_key=True),
- schema="CamelSchemaTwo")
+ "a",
+ m,
+ Column("id", Integer, primary_key=True),
+ schema="CamelSchemaTwo",
+ )
b = Table(
- 'b', m, Column('a_id', Integer, ForeignKey('a.id')),
- schema="CamelSchemaOne")
+ "b",
+ m,
+ Column("a_id", Integer, ForeignKey("a.id")),
+ schema="CamelSchemaOne",
+ )
fk = ForeignKeyConstraint(
- ["a_id"],
- ["CamelSchemaTwo.a.id"], name='fk_a_id')
+ ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+ )
b.append_constraint(fk)
op_obj = ops.AddConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'],"
" source_schema='CamelSchemaOne', "
- "referent_schema='CamelSchemaTwo')"
+ "referent_schema='CamelSchemaTwo')",
)
def test_add_fk_constraint_schema_batch(self):
m = MetaData()
Table(
- 'a', m, Column('id', Integer, primary_key=True),
- schema="CamelSchemaTwo")
+ "a",
+ m,
+ Column("id", Integer, primary_key=True),
+ schema="CamelSchemaTwo",
+ )
b = Table(
- 'b', m, Column('a_id', Integer, ForeignKey('a.id')),
- schema="CamelSchemaOne")
+ "b",
+ m,
+ Column("a_id", Integer, ForeignKey("a.id")),
+ schema="CamelSchemaOne",
+ )
fk = ForeignKeyConstraint(
- ["a_id"],
- ["CamelSchemaTwo.a.id"], name='fk_a_id')
+ ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+ )
b.append_constraint(fk)
op_obj = ops.AddConstraintOp.from_constraint(fk)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'],"
- " referent_schema='CamelSchemaTwo')"
+ " referent_schema='CamelSchemaTwo')",
)
def test_drop_fk_constraint(self):
m = MetaData()
- Table('a', m, Column('id', Integer, primary_key=True))
- b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
- fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+ Table("a", m, Column("id", Integer, primary_key=True))
+ b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+ fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
b.append_constraint(fk)
op_obj = ops.DropConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_constraint('fk_a_id', 'b', type_='foreignkey')"
+ "op.drop_constraint('fk_a_id', 'b', type_='foreignkey')",
)
def test_drop_fk_constraint_batch(self):
m = MetaData()
- Table('a', m, Column('id', Integer, primary_key=True))
- b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
- fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+ Table("a", m, Column("id", Integer, primary_key=True))
+ b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+ fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
b.append_constraint(fk)
op_obj = ops.DropConstraintOp.from_constraint(fk)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "batch_op.drop_constraint('fk_a_id', type_='foreignkey')"
+ "batch_op.drop_constraint('fk_a_id', type_='foreignkey')",
)
def test_drop_fk_constraint_schema(self):
m = MetaData()
Table(
- 'a', m, Column('id', Integer, primary_key=True),
- schema="CamelSchemaTwo")
+ "a",
+ m,
+ Column("id", Integer, primary_key=True),
+ schema="CamelSchemaTwo",
+ )
b = Table(
- 'b', m, Column('a_id', Integer, ForeignKey('a.id')),
- schema="CamelSchemaOne")
+ "b",
+ m,
+ Column("a_id", Integer, ForeignKey("a.id")),
+ schema="CamelSchemaOne",
+ )
fk = ForeignKeyConstraint(
- ["a_id"],
- ["CamelSchemaTwo.a.id"], name='fk_a_id')
+ ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+ )
b.append_constraint(fk)
op_obj = ops.DropConstraintOp.from_constraint(fk)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.drop_constraint('fk_a_id', 'b', schema='CamelSchemaOne', "
- "type_='foreignkey')"
+ "type_='foreignkey')",
)
def test_drop_fk_constraint_batch_schema(self):
m = MetaData()
Table(
- 'a', m, Column('id', Integer, primary_key=True),
- schema="CamelSchemaTwo")
+ "a",
+ m,
+ Column("id", Integer, primary_key=True),
+ schema="CamelSchemaTwo",
+ )
b = Table(
- 'b', m, Column('a_id', Integer, ForeignKey('a.id')),
- schema="CamelSchemaOne")
+ "b",
+ m,
+ Column("a_id", Integer, ForeignKey("a.id")),
+ schema="CamelSchemaOne",
+ )
fk = ForeignKeyConstraint(
- ["a_id"],
- ["CamelSchemaTwo.a.id"], name='fk_a_id')
+ ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+ )
b.append_constraint(fk)
op_obj = ops.DropConstraintOp.from_constraint(fk)
with self.autogen_context._within_batch():
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "batch_op.drop_constraint('fk_a_id', type_='foreignkey')"
+ "batch_op.drop_constraint('fk_a_id', type_='foreignkey')",
)
def test_render_table_upgrade(self):
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('name', Unicode(255)),
- Column("address_id", Integer, ForeignKey("address.id")),
- Column("timestamp", DATETIME, server_default="NOW()"),
- Column("amount", Numeric(5, 2)),
- UniqueConstraint("name", name="uq_name"),
- UniqueConstraint("timestamp"),
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("name", Unicode(255)),
+ Column("address_id", Integer, ForeignKey("address.id")),
+ Column("timestamp", DATETIME, server_default="NOW()"),
+ Column("amount", Numeric(5, 2)),
+ UniqueConstraint("name", name="uq_name"),
+ UniqueConstraint("timestamp"),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
@@ -666,16 +757,18 @@ class AutogenRenderTest(TestBase):
"sa.PrimaryKeyConstraint('id'),"
"sa.UniqueConstraint('name', name='uq_name'),"
"sa.UniqueConstraint('timestamp')"
- ")"
+ ")",
)
def test_render_table_w_schema(self):
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('address.id')),
- schema='foo'
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer, ForeignKey("address.id")),
+ schema="foo",
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -685,84 +778,89 @@ class AutogenRenderTest(TestBase):
"sa.ForeignKeyConstraint(['q'], ['address.id'], ),"
"sa.PrimaryKeyConstraint('id'),"
"schema='foo'"
- ")"
+ ")",
)
def test_render_table_w_system(self):
m = MetaData()
- t = Table('sometable', m,
- Column('id', Integer, primary_key=True),
- Column('xmin', Integer, system=True, nullable=False)
- )
+ t = Table(
+ "sometable",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("xmin", Integer, system=True, nullable=False),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('sometable',"
"sa.Column('id', sa.Integer(), nullable=False),"
"sa.Column('xmin', sa.Integer(), nullable=False, system=True),"
- "sa.PrimaryKeyConstraint('id'))"
+ "sa.PrimaryKeyConstraint('id'))",
)
def test_render_table_w_unicode_name(self):
m = MetaData()
- t = Table(compat.ue('\u0411\u0435\u0437'), m,
- Column('id', Integer, primary_key=True),
- )
+ t = Table(
+ compat.ue("\u0411\u0435\u0437"),
+ m,
+ Column("id", Integer, primary_key=True),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table(%r,"
"sa.Column('id', sa.Integer(), nullable=False),"
- "sa.PrimaryKeyConstraint('id'))" % compat.ue('\u0411\u0435\u0437')
+ "sa.PrimaryKeyConstraint('id'))" % compat.ue("\u0411\u0435\u0437"),
)
def test_render_table_w_unicode_schema(self):
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- schema=compat.ue('\u0411\u0435\u0437')
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ schema=compat.ue("\u0411\u0435\u0437"),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('test',"
"sa.Column('id', sa.Integer(), nullable=False),"
"sa.PrimaryKeyConstraint('id'),"
- "schema=%r)" % compat.ue('\u0411\u0435\u0437')
+ "schema=%r)" % compat.ue("\u0411\u0435\u0437"),
)
def test_render_table_w_unsupported_constraint(self):
from sqlalchemy.sql.schema import ColumnCollectionConstraint
class SomeCustomConstraint(ColumnCollectionConstraint):
- __visit_name__ = 'some_custom'
+ __visit_name__ = "some_custom"
m = MetaData()
- t = Table(
- 't', m, Column('id', Integer),
- SomeCustomConstraint('id'),
- )
+ t = Table("t", m, Column("id", Integer), SomeCustomConstraint("id"))
op_obj = ops.CreateTableOp.from_table(t)
with assertions.expect_warnings(
- "No renderer is established for object SomeCustomConstraint"):
+ "No renderer is established for object SomeCustomConstraint"
+ ):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t',"
"sa.Column('id', sa.Integer(), nullable=True),"
"[Unknown Python object "
- "SomeCustomConstraint(Column('id', Integer(), table=<t>))])"
+ "SomeCustomConstraint(Column('id', Integer(), table=<t>))])",
)
@patch("alembic.autogenerate.render.MAX_PYTHON_ARGS", 3)
def test_render_table_max_cols(self):
m = MetaData()
t = Table(
- 'test', m,
- Column('a', Integer),
- Column('b', Integer),
- Column('c', Integer),
- Column('d', Integer),
+ "test",
+ m,
+ Column("a", Integer),
+ Column("b", Integer),
+ Column("c", Integer),
+ Column("d", Integer),
)
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
@@ -771,14 +869,15 @@ class AutogenRenderTest(TestBase):
"*[sa.Column('a', sa.Integer(), nullable=True),"
"sa.Column('b', sa.Integer(), nullable=True),"
"sa.Column('c', sa.Integer(), nullable=True),"
- "sa.Column('d', sa.Integer(), nullable=True)])"
+ "sa.Column('d', sa.Integer(), nullable=True)])",
)
t2 = Table(
- 'test2', m,
- Column('a', Integer),
- Column('b', Integer),
- Column('c', Integer),
+ "test2",
+ m,
+ Column("a", Integer),
+ Column("b", Integer),
+ Column("c", Integer),
)
op_obj = ops.CreateTableOp.from_table(t2)
@@ -787,15 +886,17 @@ class AutogenRenderTest(TestBase):
"op.create_table('test2',"
"sa.Column('a', sa.Integer(), nullable=True),"
"sa.Column('b', sa.Integer(), nullable=True),"
- "sa.Column('c', sa.Integer(), nullable=True))"
+ "sa.Column('c', sa.Integer(), nullable=True))",
)
def test_render_table_w_fk_schema(self):
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('foo.address.id')),
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer, ForeignKey("foo.address.id")),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -804,20 +905,23 @@ class AutogenRenderTest(TestBase):
"sa.Column('q', sa.Integer(), nullable=True),"
"sa.ForeignKeyConstraint(['q'], ['foo.address.id'], ),"
"sa.PrimaryKeyConstraint('id')"
- ")"
+ ")",
)
def test_render_table_w_metadata_schema(self):
m = MetaData(schema="foo")
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('address.id')),
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer, ForeignKey("address.id")),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
- autogenerate.render_op_text(self.autogen_context, op_obj)
+ r"u'",
+ "'",
+ autogenerate.render_op_text(self.autogen_context, op_obj),
),
"op.create_table('test',"
"sa.Column('id', sa.Integer(), nullable=False),"
@@ -825,15 +929,17 @@ class AutogenRenderTest(TestBase):
"sa.ForeignKeyConstraint(['q'], ['foo.address.id'], ),"
"sa.PrimaryKeyConstraint('id'),"
"schema='foo'"
- ")"
+ ")",
)
def test_render_table_w_metadata_schema_override(self):
m = MetaData(schema="foo")
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('bar.address.id')),
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer, ForeignKey("bar.address.id")),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -843,16 +949,19 @@ class AutogenRenderTest(TestBase):
"sa.ForeignKeyConstraint(['q'], ['bar.address.id'], ),"
"sa.PrimaryKeyConstraint('id'),"
"schema='foo'"
- ")"
+ ")",
)
def test_render_addtl_args(self):
m = MetaData()
- t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('bar.address.id')),
- sqlite_autoincrement=True, mysql_engine="InnoDB"
- )
+ t = Table(
+ "test",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("q", Integer, ForeignKey("bar.address.id")),
+ sqlite_autoincrement=True,
+ mysql_engine="InnoDB",
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -861,60 +970,58 @@ class AutogenRenderTest(TestBase):
"sa.Column('q', sa.Integer(), nullable=True),"
"sa.ForeignKeyConstraint(['q'], ['bar.address.id'], ),"
"sa.PrimaryKeyConstraint('id'),"
- "mysql_engine='InnoDB',sqlite_autoincrement=True)"
+ "mysql_engine='InnoDB',sqlite_autoincrement=True)",
)
def test_render_drop_table(self):
- op_obj = ops.DropTableOp.from_table(
- Table("sometable", MetaData())
- )
+ op_obj = ops.DropTableOp.from_table(Table("sometable", MetaData()))
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_table('sometable')"
+ "op.drop_table('sometable')",
)
def test_render_drop_table_w_schema(self):
op_obj = ops.DropTableOp.from_table(
- Table("sometable", MetaData(), schema='foo')
+ Table("sometable", MetaData(), schema="foo")
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_table('sometable', schema='foo')"
+ "op.drop_table('sometable', schema='foo')",
)
def test_render_table_no_implicit_check(self):
m = MetaData()
- t = Table('test', m, Column('x', Boolean()))
+ t = Table("test", m, Column("x", Boolean()))
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('test',"
- "sa.Column('x', sa.Boolean(), nullable=True))"
+ "sa.Column('x', sa.Boolean(), nullable=True))",
)
def test_render_pk_with_col_name_vs_col_key(self):
m = MetaData()
- t1 = Table('t1', m, Column('x', Integer, key='y', primary_key=True))
+ t1 = Table("t1", m, Column("x", Integer, key="y", primary_key=True))
op_obj = ops.CreateTableOp.from_table(t1)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t1',"
"sa.Column('x', sa.Integer(), nullable=False),"
- "sa.PrimaryKeyConstraint('x'))"
+ "sa.PrimaryKeyConstraint('x'))",
)
def test_render_empty_pk_vs_nonempty_pk(self):
m = MetaData()
- t1 = Table('t1', m, Column('x', Integer))
- t2 = Table('t2', m, Column('x', Integer, primary_key=True))
+ t1 = Table("t1", m, Column("x", Integer))
+ t2 = Table("t2", m, Column("x", Integer, primary_key=True))
op_obj = ops.CreateTableOp.from_table(t1)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t1',"
- "sa.Column('x', sa.Integer(), nullable=True))"
+ "sa.Column('x', sa.Integer(), nullable=True))",
)
op_obj = ops.CreateTableOp.from_table(t2)
@@ -922,16 +1029,18 @@ class AutogenRenderTest(TestBase):
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t2',"
"sa.Column('x', sa.Integer(), nullable=False),"
- "sa.PrimaryKeyConstraint('x'))"
+ "sa.PrimaryKeyConstraint('x'))",
)
@config.requirements.fail_before_sqla_110
def test_render_table_w_autoincrement(self):
m = MetaData()
t = Table(
- 'test', m,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True, autoincrement=True))
+ "test",
+ m,
+ Column("id1", Integer, primary_key=True),
+ Column("id2", Integer, primary_key=True, autoincrement=True),
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -940,111 +1049,109 @@ class AutogenRenderTest(TestBase):
"sa.Column('id2', sa.Integer(), autoincrement=True, "
"nullable=False),"
"sa.PrimaryKeyConstraint('id1', 'id2')"
- ")"
+ ")",
)
def test_render_add_column(self):
op_obj = ops.AddColumnOp(
- "foo", Column("x", Integer, server_default="5"))
+ "foo", Column("x", Integer, server_default="5")
+ )
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.add_column('foo', sa.Column('x', sa.Integer(), "
- "server_default='5', nullable=True))"
+ "server_default='5', nullable=True))",
)
def test_render_add_column_system(self):
# this would never actually happen since "system" columns
# can't be added in any case. Howver it will render as
# part of op.CreateTableOp.
- op_obj = ops.AddColumnOp(
- "foo", Column("xmin", Integer, system=True))
+ op_obj = ops.AddColumnOp("foo", Column("xmin", Integer, system=True))
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.add_column('foo', sa.Column('xmin', sa.Integer(), "
- "nullable=True, system=True))"
+ "nullable=True, system=True))",
)
def test_render_add_column_w_schema(self):
op_obj = ops.AddColumnOp(
- "bar", Column("x", Integer, server_default="5"),
- schema="foo")
+ "bar", Column("x", Integer, server_default="5"), schema="foo"
+ )
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.add_column('bar', sa.Column('x', sa.Integer(), "
- "server_default='5', nullable=True), schema='foo')"
+ "server_default='5', nullable=True), schema='foo')",
)
def test_render_drop_column(self):
op_obj = ops.DropColumnOp.from_column_and_tablename(
- None, "foo", Column("x", Integer, server_default="5"))
+ None, "foo", Column("x", Integer, server_default="5")
+ )
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_column('foo', 'x')"
+ "op.drop_column('foo', 'x')",
)
def test_render_drop_column_w_schema(self):
op_obj = ops.DropColumnOp.from_column_and_tablename(
- "foo", "bar", Column("x", Integer, server_default="5"))
+ "foo", "bar", Column("x", Integer, server_default="5")
+ )
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_column('bar', 'x', schema='foo')"
+ "op.drop_column('bar', 'x', schema='foo')",
)
def test_render_quoted_server_default(self):
eq_(
autogenerate.render._render_server_default(
"nextval('group_to_perm_group_to_perm_id_seq'::regclass)",
- self.autogen_context),
- '"nextval(\'group_to_perm_group_to_perm_id_seq\'::regclass)"'
+ self.autogen_context,
+ ),
+ "\"nextval('group_to_perm_group_to_perm_id_seq'::regclass)\"",
)
def test_render_unicode_server_default(self):
default = compat.ue(
- '\u0411\u0435\u0437 '
- '\u043d\u0430\u0437\u0432\u0430\u043d\u0438\u044f'
+ "\u0411\u0435\u0437 "
+ "\u043d\u0430\u0437\u0432\u0430\u043d\u0438\u044f"
)
- c = Column(
- 'x', Unicode,
- server_default=text(default)
- )
+ c = Column("x", Unicode, server_default=text(default))
eq_ignore_whitespace(
autogenerate.render._render_server_default(
c.server_default, self.autogen_context
),
- "sa.text(%r)" % default
+ "sa.text(%r)" % default,
)
def test_render_col_with_server_default(self):
- c = Column('updated_at', TIMESTAMP(),
- server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)',
- nullable=False)
- result = autogenerate.render._render_column(
- c, self.autogen_context
+ c = Column(
+ "updated_at",
+ TIMESTAMP(),
+ server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)',
+ nullable=False,
)
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_ignore_whitespace(
result,
- 'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
- 'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', '
- 'nullable=False)'
+ "sa.Column('updated_at', sa.TIMESTAMP(), "
+ "server_default='TIMEZONE(\"utc\", CURRENT_TIMESTAMP)', "
+ "nullable=False)",
)
def test_render_col_autoinc_false_mysql(self):
- c = Column('some_key', Integer, primary_key=True, autoincrement=False)
- Table('some_table', MetaData(), c)
- result = autogenerate.render._render_column(
- c, self.autogen_context
- )
+ c = Column("some_key", Integer, primary_key=True, autoincrement=False)
+ Table("some_table", MetaData(), c)
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_ignore_whitespace(
result,
- 'sa.Column(\'some_key\', sa.Integer(), '
- 'autoincrement=False, '
- 'nullable=False)'
+ "sa.Column('some_key', sa.Integer(), "
+ "autoincrement=False, "
+ "nullable=False)",
)
def test_render_custom(self):
-
class MySpecialType(Integer):
pass
@@ -1065,17 +1172,18 @@ class AutogenRenderTest(TestBase):
return "render:%s" % type_
self.autogen_context.opts.update(
- render_item=render,
- alembic_module_prefix='sa.'
+ render_item=render, alembic_module_prefix="sa."
)
- t = Table('t', MetaData(),
- Column('x', Integer),
- Column('y', Integer),
- Column('q', MySpecialType()),
- PrimaryKeyConstraint('x'),
- ForeignKeyConstraint(['x'], ['y'])
- )
+ t = Table(
+ "t",
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", MySpecialType()),
+ PrimaryKeyConstraint("x"),
+ ForeignKeyConstraint(["x"], ["y"]),
+ )
op_obj = ops.CreateTableOp.from_table(t)
result = autogenerate.render_op_text(self.autogen_context, op_obj)
eq_ignore_whitespace(
@@ -1083,89 +1191,97 @@ class AutogenRenderTest(TestBase):
"sa.create_table('t',"
"col(x),"
"sa.Column('q', MySpecialType(), nullable=True),"
- "render:primary_key)"
+ "render:primary_key)",
)
eq_(
self.autogen_context.imports,
- set(['from mypackage import MySpecialType'])
+ set(["from mypackage import MySpecialType"]),
)
def test_render_modify_type(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
- modify_type=CHAR(10), existing_type=CHAR(20)
+ "sometable",
+ "somecolumn",
+ modify_type=CHAR(10),
+ existing_type=CHAR(20),
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
- "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))"
+ "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))",
)
def test_render_modify_type_w_schema(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
- modify_type=CHAR(10), existing_type=CHAR(20),
- schema='foo'
+ "sometable",
+ "somecolumn",
+ modify_type=CHAR(10),
+ existing_type=CHAR(20),
+ schema="foo",
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), "
- "schema='foo')"
+ "schema='foo')",
)
def test_render_modify_nullable(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
+ "sometable",
+ "somecolumn",
existing_type=Integer(),
- modify_nullable=True
+ modify_nullable=True,
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
- "existing_type=sa.Integer(), nullable=True)"
+ "existing_type=sa.Integer(), nullable=True)",
)
def test_render_modify_nullable_no_existing_type(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
- modify_nullable=True
+ "sometable", "somecolumn", modify_nullable=True
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.alter_column('sometable', 'somecolumn', nullable=True)"
+ "op.alter_column('sometable', 'somecolumn', nullable=True)",
)
def test_render_modify_nullable_w_schema(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
+ "sometable",
+ "somecolumn",
existing_type=Integer(),
- modify_nullable=True, schema='foo'
+ modify_nullable=True,
+ schema="foo",
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
- "existing_type=sa.Integer(), nullable=True, schema='foo')"
+ "existing_type=sa.Integer(), nullable=True, schema='foo')",
)
def test_render_modify_type_w_autoincrement(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
- modify_type=Integer(), existing_type=BigInteger(),
- autoincrement=True
+ "sometable",
+ "somecolumn",
+ modify_type=Integer(),
+ existing_type=BigInteger(),
+ autoincrement=True,
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.BigInteger(), type_=sa.Integer(), "
- "autoincrement=True)"
+ "autoincrement=True)",
)
def test_render_fk_constraint_kwarg(self):
m = MetaData()
- t1 = Table('t', m, Column('c', Integer))
- t2 = Table('t2', m, Column('c_rem', Integer))
+ t1 = Table("t", m, Column("c", Integer))
+ t2 = Table("t2", m, Column("c_rem", Integer))
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE")
@@ -1174,239 +1290,274 @@ class AutogenRenderTest(TestBase):
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')"
+ fk, self.autogen_context
+ ),
+ ),
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')",
)
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], ondelete="CASCADE")
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')"
+ fk, self.autogen_context
+ ),
+ ),
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')",
)
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], deferrable=True)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context),
+ fk, self.autogen_context
+ ),
),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)"
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)",
)
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], initially="XYZ")
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)
+ fk, self.autogen_context
+ ),
),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')"
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')",
)
fk = ForeignKeyConstraint(
- [t1.c.c], [t2.c.c_rem],
- initially="XYZ", ondelete="CASCADE", deferrable=True)
+ [t1.c.c],
+ [t2.c.c_rem],
+ initially="XYZ",
+ ondelete="CASCADE",
+ deferrable=True,
+ )
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)
+ fk, self.autogen_context
+ ),
),
"sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], "
- "ondelete='CASCADE', initially='XYZ', deferrable=True)"
+ "ondelete='CASCADE', initially='XYZ', deferrable=True)",
)
def test_render_fk_constraint_resolve_key(self):
m = MetaData()
- t1 = Table('t', m, Column('c', Integer))
- t2 = Table('t2', m, Column('c_rem', Integer, key='c_remkey'))
+ t1 = Table("t", m, Column("c", Integer))
+ t2 = Table("t2", m, Column("c_rem", Integer, key="c_remkey"))
- fk = ForeignKeyConstraint(['c'], ['t2.c_remkey'])
+ fk = ForeignKeyConstraint(["c"], ["t2.c_remkey"])
t1.append_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )"
+ fk, self.autogen_context
+ ),
+ ),
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
)
def test_render_fk_constraint_bad_table_resolve(self):
m = MetaData()
- t1 = Table('t', m, Column('c', Integer))
- t2 = Table('t2', m, Column('c_rem', Integer))
+ t1 = Table("t", m, Column("c", Integer))
+ t2 = Table("t2", m, Column("c_rem", Integer))
- fk = ForeignKeyConstraint(['c'], ['t2.nonexistent'])
+ fk = ForeignKeyConstraint(["c"], ["t2.nonexistent"])
t1.append_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)),
- "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )"
+ fk, self.autogen_context
+ ),
+ ),
+ "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )",
)
def test_render_fk_constraint_bad_table_resolve_dont_get_confused(self):
m = MetaData()
- t1 = Table('t', m, Column('c', Integer))
+ t1 = Table("t", m, Column("c", Integer))
t2 = Table(
- 't2', m,
- Column('c_rem', Integer, key='cr_key'),
- Column('c_rem_2', Integer, key='c_rem')
-
+ "t2",
+ m,
+ Column("c_rem", Integer, key="cr_key"),
+ Column("c_rem_2", Integer, key="c_rem"),
)
- fk = ForeignKeyConstraint(['c'], ['t2.c_rem'], link_to_name=True)
+ fk = ForeignKeyConstraint(["c"], ["t2.c_rem"], link_to_name=True)
t1.append_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )"
+ fk, self.autogen_context
+ ),
+ ),
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
)
def test_render_fk_constraint_link_to_name(self):
m = MetaData()
- t1 = Table('t', m, Column('c', Integer))
- t2 = Table('t2', m, Column('c_rem', Integer, key='c_remkey'))
+ t1 = Table("t", m, Column("c", Integer))
+ t2 = Table("t2", m, Column("c_rem", Integer, key="c_remkey"))
- fk = ForeignKeyConstraint(['c'], ['t2.c_rem'], link_to_name=True)
+ fk = ForeignKeyConstraint(["c"], ["t2.c_rem"], link_to_name=True)
t1.append_constraint(fk)
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)),
- "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )"
+ fk, self.autogen_context
+ ),
+ ),
+ "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
)
def test_render_fk_constraint_use_alter(self):
m = MetaData()
- Table('t', m, Column('c', Integer))
+ Table("t", m, Column("c", Integer))
t2 = Table(
- 't2', m,
+ "t2",
+ m,
Column(
- 'c_rem', Integer,
- ForeignKey('t.c', name="fk1", use_alter=True)))
+ "c_rem", Integer, ForeignKey("t.c", name="fk1", use_alter=True)
+ ),
+ )
const = list(t2.foreign_keys)[0].constraint
eq_ignore_whitespace(
autogenerate.render._render_constraint(
- const, self.autogen_context),
+ const, self.autogen_context
+ ),
"sa.ForeignKeyConstraint(['c_rem'], ['t.c'], "
- "name='fk1', use_alter=True)"
+ "name='fk1', use_alter=True)",
)
def test_render_fk_constraint_w_metadata_schema(self):
m = MetaData(schema="foo")
- t1 = Table('t', m, Column('c', Integer))
- t2 = Table('t2', m, Column('c_rem', Integer))
+ t1 = Table("t", m, Column("c", Integer))
+ t2 = Table("t2", m, Column("c_rem", Integer))
fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE")
eq_ignore_whitespace(
re.sub(
- r"u'", "'",
+ r"u'",
+ "'",
autogenerate.render._render_constraint(
- fk, self.autogen_context)
+ fk, self.autogen_context
+ ),
),
"sa.ForeignKeyConstraint(['c'], ['foo.t2.c_rem'], "
- "onupdate='CASCADE')"
+ "onupdate='CASCADE')",
)
def test_render_check_constraint_literal(self):
eq_ignore_whitespace(
autogenerate.render._render_check_constraint(
- CheckConstraint("im a constraint", name='cc1'),
- self.autogen_context
+ CheckConstraint("im a constraint", name="cc1"),
+ self.autogen_context,
),
- "sa.CheckConstraint(!U'im a constraint', name='cc1')"
+ "sa.CheckConstraint(!U'im a constraint', name='cc1')",
)
def test_render_check_constraint_sqlexpr(self):
- c = column('c')
- five = literal_column('5')
- ten = literal_column('10')
+ c = column("c")
+ five = literal_column("5")
+ ten = literal_column("10")
eq_ignore_whitespace(
autogenerate.render._render_check_constraint(
- CheckConstraint(and_(c > five, c < ten)),
- self.autogen_context
+ CheckConstraint(and_(c > five, c < ten)), self.autogen_context
),
- "sa.CheckConstraint(!U'c > 5 AND c < 10')"
+ "sa.CheckConstraint(!U'c > 5 AND c < 10')",
)
def test_render_check_constraint_literal_binds(self):
- c = column('c')
+ c = column("c")
eq_ignore_whitespace(
autogenerate.render._render_check_constraint(
- CheckConstraint(and_(c > 5, c < 10)),
- self.autogen_context
+ CheckConstraint(and_(c > 5, c < 10)), self.autogen_context
),
- "sa.CheckConstraint(!U'c > 5 AND c < 10')"
+ "sa.CheckConstraint(!U'c > 5 AND c < 10')",
)
def test_render_unique_constraint_opts(self):
m = MetaData()
- t = Table('t', m, Column('c', Integer))
+ t = Table("t", m, Column("c", Integer))
eq_ignore_whitespace(
autogenerate.render._render_unique_constraint(
- UniqueConstraint(t.c.c, name='uq_1', deferrable='XYZ'),
- self.autogen_context
+ UniqueConstraint(t.c.c, name="uq_1", deferrable="XYZ"),
+ self.autogen_context,
),
- "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')"
+ "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')",
)
def test_add_unique_constraint_unicode_schema(self):
m = MetaData()
t = Table(
- 't', m, Column('c', Integer),
- schema=compat.ue('\u0411\u0435\u0437')
+ "t",
+ m,
+ Column("c", Integer),
+ schema=compat.ue("\u0411\u0435\u0437"),
)
op_obj = ops.AddConstraintOp.from_constraint(UniqueConstraint(t.c.c))
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_unique_constraint(None, 't', ['c'], "
- "schema=%r)" % compat.ue('\u0411\u0435\u0437')
+ "schema=%r)" % compat.ue("\u0411\u0435\u0437"),
)
def test_render_modify_nullable_w_default(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
+ "sometable",
+ "somecolumn",
existing_type=Integer(),
existing_server_default="5",
- modify_nullable=True
+ modify_nullable=True,
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.Integer(), nullable=True, "
- "existing_server_default='5')"
+ "existing_server_default='5')",
)
def test_render_enum(self):
eq_ignore_whitespace(
autogenerate.render._repr_type(
Enum("one", "two", "three", name="myenum"),
- self.autogen_context),
- "sa.Enum('one', 'two', 'three', name='myenum')"
+ self.autogen_context,
+ ),
+ "sa.Enum('one', 'two', 'three', name='myenum')",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- Enum("one", "two", "three"),
- self.autogen_context),
- "sa.Enum('one', 'two', 'three')"
+ Enum("one", "two", "three"), self.autogen_context
+ ),
+ "sa.Enum('one', 'two', 'three')",
)
@config.requirements.sqlalchemy_099
@@ -1414,15 +1565,16 @@ class AutogenRenderTest(TestBase):
eq_ignore_whitespace(
autogenerate.render._repr_type(
Enum("one", "two", "three", native_enum=False),
- self.autogen_context),
- "sa.Enum('one', 'two', 'three', native_enum=False)"
+ self.autogen_context,
+ ),
+ "sa.Enum('one', 'two', 'three', native_enum=False)",
)
def test_repr_plain_sqla_type(self):
type_ = Integer()
eq_ignore_whitespace(
autogenerate.render._repr_type(type_, self.autogen_context),
- "sa.Integer()"
+ "sa.Integer()",
)
@config.requirements.sqlalchemy_110
@@ -1430,24 +1582,27 @@ class AutogenRenderTest(TestBase):
eq_ignore_whitespace(
autogenerate.render._repr_type(
- types.ARRAY(Integer), self.autogen_context),
- "sa.ARRAY(sa.Integer())"
+ types.ARRAY(Integer), self.autogen_context
+ ),
+ "sa.ARRAY(sa.Integer())",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- types.ARRAY(DateTime(timezone=True)), self.autogen_context),
- "sa.ARRAY(sa.DateTime(timezone=True))"
+ types.ARRAY(DateTime(timezone=True)), self.autogen_context
+ ),
+ "sa.ARRAY(sa.DateTime(timezone=True))",
)
@config.requirements.sqlalchemy_110
def test_render_array_no_context(self):
- uo = ops.UpgradeOps(ops=[
- ops.CreateTableOp(
- "sometable",
- [Column('x', types.ARRAY(Integer))]
- )
- ])
+ uo = ops.UpgradeOps(
+ ops=[
+ ops.CreateTableOp(
+ "sometable", [Column("x", types.ARRAY(Integer))]
+ )
+ ]
+ )
eq_(
autogenerate.render_python_code(uo),
@@ -1455,11 +1610,11 @@ class AutogenRenderTest(TestBase):
" op.create_table('sometable',\n"
" sa.Column('x', sa.ARRAY(sa.Integer()), nullable=True)\n"
" )\n"
- " # ### end Alembic commands ###"
+ " # ### end Alembic commands ###",
)
def test_repr_custom_type_w_sqla_prefix(self):
- self.autogen_context.opts['user_module_prefix'] = None
+ self.autogen_context.opts["user_module_prefix"] = None
class MyType(UserDefinedType):
pass
@@ -1470,283 +1625,280 @@ class AutogenRenderTest(TestBase):
eq_ignore_whitespace(
autogenerate.render._repr_type(type_, self.autogen_context),
- "sqlalchemy_util.types.MyType()"
+ "sqlalchemy_util.types.MyType()",
)
def test_repr_user_type_user_prefix_None(self):
class MyType(UserDefinedType):
-
def get_col_spec(self):
return "MYTYPE"
type_ = MyType()
- self.autogen_context.opts['user_module_prefix'] = None
+ self.autogen_context.opts["user_module_prefix"] = None
eq_ignore_whitespace(
autogenerate.render._repr_type(type_, self.autogen_context),
- "tests.test_autogen_render.MyType()"
+ "tests.test_autogen_render.MyType()",
)
def test_repr_user_type_user_prefix_present(self):
from sqlalchemy.types import UserDefinedType
class MyType(UserDefinedType):
-
def get_col_spec(self):
return "MYTYPE"
type_ = MyType()
- self.autogen_context.opts['user_module_prefix'] = 'user.'
+ self.autogen_context.opts["user_module_prefix"] = "user."
eq_ignore_whitespace(
autogenerate.render._repr_type(type_, self.autogen_context),
- "user.MyType()"
+ "user.MyType()",
)
def test_repr_dialect_type(self):
from sqlalchemy.dialects.mysql import VARCHAR
- type_ = VARCHAR(20, charset='utf8', national=True)
+ type_ = VARCHAR(20, charset="utf8", national=True)
- self.autogen_context.opts['user_module_prefix'] = None
+ self.autogen_context.opts["user_module_prefix"] = None
eq_ignore_whitespace(
autogenerate.render._repr_type(type_, self.autogen_context),
- "mysql.VARCHAR(charset='utf8', national=True, length=20)"
+ "mysql.VARCHAR(charset='utf8', national=True, length=20)",
+ )
+ eq_(
+ self.autogen_context.imports,
+ set(["from sqlalchemy.dialects import mysql"]),
)
- eq_(self.autogen_context.imports,
- set(['from sqlalchemy.dialects import mysql'])
- )
def test_render_server_default_text(self):
c = Column(
- 'updated_at', TIMESTAMP(),
- server_default=text('now()'),
- nullable=False)
- result = autogenerate.render._render_column(
- c, self.autogen_context
+ "updated_at",
+ TIMESTAMP(),
+ server_default=text("now()"),
+ nullable=False,
)
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_ignore_whitespace(
result,
- 'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
- 'server_default=sa.text(!U\'now()\'), '
- 'nullable=False)'
+ "sa.Column('updated_at', sa.TIMESTAMP(), "
+ "server_default=sa.text(!U'now()'), "
+ "nullable=False)",
)
def test_render_server_default_non_native_boolean(self):
c = Column(
- 'updated_at', Boolean(),
- server_default=false(),
- nullable=False)
-
- result = autogenerate.render._render_column(
- c, self.autogen_context
+ "updated_at", Boolean(), server_default=false(), nullable=False
)
+
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_ignore_whitespace(
result,
- 'sa.Column(\'updated_at\', sa.Boolean(), '
- 'server_default=sa.text(!U\'0\'), '
- 'nullable=False)'
+ "sa.Column('updated_at', sa.Boolean(), "
+ "server_default=sa.text(!U'0'), "
+ "nullable=False)",
)
def test_render_server_default_func(self):
c = Column(
- 'updated_at', TIMESTAMP(),
+ "updated_at",
+ TIMESTAMP(),
server_default=func.now(),
- nullable=False)
- result = autogenerate.render._render_column(
- c, self.autogen_context
+ nullable=False,
)
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_ignore_whitespace(
result,
- 'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
- 'server_default=sa.text(!U\'now()\'), '
- 'nullable=False)'
+ "sa.Column('updated_at', sa.TIMESTAMP(), "
+ "server_default=sa.text(!U'now()'), "
+ "nullable=False)",
)
def test_render_server_default_int(self):
- c = Column(
- 'value', Integer,
- server_default="0")
- result = autogenerate.render._render_column(
- c, self.autogen_context
- )
+ c = Column("value", Integer, server_default="0")
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_(
result,
"sa.Column('value', sa.Integer(), "
- "server_default='0', nullable=True)"
+ "server_default='0', nullable=True)",
)
def test_render_modify_reflected_int_server_default(self):
op_obj = ops.AlterColumnOp(
- "sometable", "somecolumn",
+ "sometable",
+ "somecolumn",
existing_type=Integer(),
existing_server_default=DefaultClause(text("5")),
- modify_nullable=True
+ modify_nullable=True,
)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.Integer(), nullable=True, "
- "existing_server_default=sa.text(!U'5'))"
+ "existing_server_default=sa.text(!U'5'))",
)
def test_render_executesql_plaintext(self):
op_obj = ops.ExecuteSQLOp("drop table foo")
eq_(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.execute('drop table foo')"
+ "op.execute('drop table foo')",
)
def test_render_executesql_sqlexpr_notimplemented(self):
- sql = table('x', column('q')).insert()
+ sql = table("x", column("q")).insert()
op_obj = ops.ExecuteSQLOp(sql)
assert_raises(
NotImplementedError,
- autogenerate.render_op_text, self.autogen_context, op_obj
+ autogenerate.render_op_text,
+ self.autogen_context,
+ op_obj,
)
class RenderNamingConventionTest(TestBase):
- __requires__ = ('sqlalchemy_094',)
+ __requires__ = ("sqlalchemy_094",)
def setUp(self):
convention = {
- "ix": 'ix_%(custom)s_%(column_0_label)s',
+ "ix": "ix_%(custom)s_%(column_0_label)s",
"uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(custom)s_%(table_name)s",
"fk": "fk_%(custom)s_%(table_name)s_"
"%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(custom)s_%(table_name)s",
- "custom": lambda const, table: "ct"
+ "custom": lambda const, table: "ct",
}
- self.metadata = MetaData(
- naming_convention=convention
- )
+ self.metadata = MetaData(naming_convention=convention)
ctx_opts = {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'target_metadata': MetaData()
+ "sqlalchemy_module_prefix": "sa.",
+ "alembic_module_prefix": "op.",
+ "target_metadata": MetaData(),
}
context = MigrationContext.configure(
- dialect_name="postgresql",
- opts=ctx_opts
+ dialect_name="postgresql", opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context)
def test_schema_type_boolean(self):
- t = Table('t', self.metadata, Column('c', Boolean(name='xyz')))
+ t = Table("t", self.metadata, Column("c", Boolean(name="xyz")))
op_obj = ops.AddColumnOp.from_column(t.c.c)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.add_column('t', "
- "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))"
+ "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))",
)
def test_explicit_unique_constraint(self):
- t = Table('t', self.metadata, Column('c', Integer))
+ t = Table("t", self.metadata, Column("c", Integer))
eq_ignore_whitespace(
autogenerate.render._render_unique_constraint(
- UniqueConstraint(t.c.c, deferrable='XYZ'),
- self.autogen_context
+ UniqueConstraint(t.c.c, deferrable="XYZ"), self.autogen_context
),
"sa.UniqueConstraint('c', deferrable='XYZ', "
- "name=op.f('uq_ct_t_c'))"
+ "name=op.f('uq_ct_t_c'))",
)
def test_explicit_named_unique_constraint(self):
- t = Table('t', self.metadata, Column('c', Integer))
+ t = Table("t", self.metadata, Column("c", Integer))
eq_ignore_whitespace(
autogenerate.render._render_unique_constraint(
- UniqueConstraint(t.c.c, name='q'),
- self.autogen_context
+ UniqueConstraint(t.c.c, name="q"), self.autogen_context
),
- "sa.UniqueConstraint('c', name='q')"
+ "sa.UniqueConstraint('c', name='q')",
)
def test_render_add_index(self):
- t = Table('test', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ t = Table(
+ "test",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
idx = Index(None, t.c.active, t.c.code)
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index(op.f('ix_ct_test_active'), 'test', "
- "['active', 'code'], unique=False)"
+ "['active', 'code'], unique=False)",
)
def test_render_drop_index(self):
- t = Table('test', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ t = Table(
+ "test",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ )
idx = Index(None, t.c.active, t.c.code)
op_obj = ops.DropIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
- "op.drop_index(op.f('ix_ct_test_active'), table_name='test')"
+ "op.drop_index(op.f('ix_ct_test_active'), table_name='test')",
)
def test_render_add_index_schema(self):
- t = Table('test', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
+ t = Table(
+ "test",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("active", Boolean()),
+ Column("code", String(255)),
+ schema="CamelSchema",
+ )
idx = Index(None, t.c.active, t.c.code)
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index(op.f('ix_ct_CamelSchema_test_active'), 'test', "
- "['active', 'code'], unique=False, schema='CamelSchema')"
+ "['active', 'code'], unique=False, schema='CamelSchema')",
)
def test_implicit_unique_constraint(self):
- t = Table('t', self.metadata, Column('c', Integer, unique=True))
+ t = Table("t", self.metadata, Column("c", Integer, unique=True))
uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0]
eq_ignore_whitespace(
- autogenerate.render._render_unique_constraint(uq,
- self.autogen_context
- ),
- "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))"
+ autogenerate.render._render_unique_constraint(
+ uq, self.autogen_context
+ ),
+ "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))",
)
def test_inline_pk_constraint(self):
- t = Table('t', self.metadata, Column('c', Integer, primary_key=True))
+ t = Table("t", self.metadata, Column("c", Integer, primary_key=True))
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t',sa.Column('c', sa.Integer(), nullable=False),"
- "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))"
+ "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))",
)
def test_inline_ck_constraint(self):
t = Table(
- 't', self.metadata, Column('c', Integer), CheckConstraint("c > 5"))
+ "t", self.metadata, Column("c", Integer), CheckConstraint("c > 5")
+ )
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t',sa.Column('c', sa.Integer(), nullable=True),"
- "sa.CheckConstraint(!U'c > 5', name=op.f('ck_ct_t')))"
+ "sa.CheckConstraint(!U'c > 5', name=op.f('ck_ct_t')))",
)
def test_inline_fk(self):
- t = Table('t', self.metadata, Column('c', Integer, ForeignKey('q.id')))
+ t = Table("t", self.metadata, Column("c", Integer, ForeignKey("q.id")))
op_obj = ops.CreateTableOp.from_table(t)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_table('t',sa.Column('c', sa.Integer(), nullable=True),"
"sa.ForeignKeyConstraint(['c'], ['q.id'], "
- "name=op.f('fk_ct_t_c_q')))"
+ "name=op.f('fk_ct_t_c_q')))",
)
def test_render_check_constraint_renamed(self):
@@ -1760,31 +1912,31 @@ class RenderNamingConventionTest(TestBase):
used.
"""
- m1 = MetaData(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"})
+ m1 = MetaData(
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
ck = CheckConstraint("im a constraint", name="cc1")
- Table('t', m1, Column('x'), ck)
+ Table("t", m1, Column("x"), ck)
eq_ignore_whitespace(
autogenerate.render._render_check_constraint(
- ck,
- self.autogen_context
+ ck, self.autogen_context
),
- "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))"
+ "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))",
)
def test_create_table_plus_add_index_in_modify(self):
- uo = ops.UpgradeOps(ops=[
- ops.CreateTableOp(
- "sometable",
- [Column('x', Integer), Column('y', Integer)]
- ),
- ops.ModifyTableOps(
- "sometable", ops=[
- ops.CreateIndexOp('ix1', 'sometable', ['x', 'y'])
- ]
- )
- ])
+ uo = ops.UpgradeOps(
+ ops=[
+ ops.CreateTableOp(
+ "sometable", [Column("x", Integer), Column("y", Integer)]
+ ),
+ ops.ModifyTableOps(
+ "sometable",
+ ops=[ops.CreateIndexOp("ix1", "sometable", ["x", "y"])],
+ ),
+ ]
+ )
eq_(
autogenerate.render_python_code(uo, render_as_batch=True),
@@ -1797,5 +1949,5 @@ class RenderNamingConventionTest(TestBase):
"as batch_op:\n"
" batch_op.create_index("
"'ix1', ['x', 'y'], unique=False)\n\n"
- " # ### end Alembic commands ###"
+ " # ### end Alembic commands ###",
)
diff --git a/tests/test_batch.py b/tests/test_batch.py
index 99605d0..2a7c52e 100644
--- a/tests/test_batch.py
+++ b/tests/test_batch.py
@@ -11,9 +11,22 @@ from alembic.operations.batch import ApplyBatchImpl
from alembic.runtime.migration import MigrationContext
-from sqlalchemy import Integer, Table, Column, String, MetaData, ForeignKey, \
- UniqueConstraint, ForeignKeyConstraint, Index, Boolean, CheckConstraint, \
- Enum, DateTime, PrimaryKeyConstraint
+from sqlalchemy import (
+ Integer,
+ Table,
+ Column,
+ String,
+ MetaData,
+ ForeignKey,
+ UniqueConstraint,
+ ForeignKeyConstraint,
+ Index,
+ Boolean,
+ CheckConstraint,
+ Enum,
+ DateTime,
+ PrimaryKeyConstraint,
+)
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import column, text, select
from sqlalchemy.schema import CreateTable, CreateIndex
@@ -21,84 +34,91 @@ from sqlalchemy import exc
class BatchApplyTest(TestBase):
-
def setUp(self):
self.op = Operations(mock.Mock(opts={}))
def _simple_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('x', String(10)),
- Column('y', Integer)
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("x", String(10)),
+ Column("y", Integer),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _uq_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('x', String()),
- Column('y', Integer),
- UniqueConstraint('y', name='uq1')
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("x", String()),
+ Column("y", Integer),
+ UniqueConstraint("y", name="uq1"),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _ix_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('x', String()),
- Column('y', Integer),
- Index('ix1', 'y')
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("x", String()),
+ Column("y", Integer),
+ Index("ix1", "y"),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _pk_fixture(self):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer),
- Column('x', String()),
- Column('y', Integer),
- PrimaryKeyConstraint('id', name="mypk")
+ "tname",
+ m,
+ Column("id", Integer),
+ Column("x", String()),
+ Column("y", Integer),
+ PrimaryKeyConstraint("id", name="mypk"),
)
return ApplyBatchImpl(t, (), {}, False)
def _literal_ck_fixture(
- self, copy_from=None, table_args=(), table_kwargs={}):
+ self, copy_from=None, table_args=(), table_kwargs={}
+ ):
m = MetaData()
if copy_from is not None:
t = copy_from
else:
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('email', String()),
- CheckConstraint("email LIKE '%@%'")
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email", String()),
+ CheckConstraint("email LIKE '%@%'"),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _sql_ck_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('email', String())
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email", String()),
)
- t.append_constraint(CheckConstraint(t.c.email.like('%@%')))
+ t.append_constraint(CheckConstraint(t.c.email.like("%@%")))
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _fk_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('email', String()),
- Column('user_id', Integer, ForeignKey('user.id'))
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email", String()),
+ Column("user_id", Integer, ForeignKey("user.id")),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
@@ -110,93 +130,108 @@ class BatchApplyTest(TestBase):
schemaarg = ""
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('email', String()),
- Column('user_id_1', Integer, ForeignKey('%suser.id' % schemaarg)),
- Column('user_id_2', Integer, ForeignKey('%suser.id' % schemaarg)),
- Column('user_id_3', Integer),
- Column('user_id_version', Integer),
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email", String()),
+ Column("user_id_1", Integer, ForeignKey("%suser.id" % schemaarg)),
+ Column("user_id_2", Integer, ForeignKey("%suser.id" % schemaarg)),
+ Column("user_id_3", Integer),
+ Column("user_id_version", Integer),
ForeignKeyConstraint(
- ['user_id_3', 'user_id_version'],
- ['%suser.id' % schemaarg, '%suser.id_version' % schemaarg]),
- schema=schema
+ ["user_id_3", "user_id_version"],
+ ["%suser.id" % schemaarg, "%suser.id_version" % schemaarg],
+ ),
+ schema=schema,
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _named_fk_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('email', String()),
- Column('user_id', Integer, ForeignKey('user.id', name='ufk'))
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("email", String()),
+ Column("user_id", Integer, ForeignKey("user.id", name="ufk")),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _selfref_fk_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('parent_id', Integer, ForeignKey('tname.id')),
- Column('data', String)
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("parent_id", Integer, ForeignKey("tname.id")),
+ Column("data", String),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _boolean_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('flag', Boolean)
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("flag", Boolean),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _boolean_no_ck_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('flag', Boolean(create_constraint=False))
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("flag", Boolean(create_constraint=False)),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _enum_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('thing', Enum('a', 'b', 'c'))
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("thing", Enum("a", "b", "c")),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
def _server_default_fixture(self, table_args=(), table_kwargs={}):
m = MetaData()
t = Table(
- 'tname', m,
- Column('id', Integer, primary_key=True),
- Column('thing', String(), server_default='')
+ "tname",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("thing", String(), server_default=""),
)
return ApplyBatchImpl(t, table_args, table_kwargs, False)
- def _assert_impl(self, impl, colnames=None,
- ddl_contains=None, ddl_not_contains=None,
- dialect='default', schema=None):
+ def _assert_impl(
+ self,
+ impl,
+ colnames=None,
+ ddl_contains=None,
+ ddl_not_contains=None,
+ dialect="default",
+ schema=None,
+ ):
context = op_fixture(dialect=dialect)
impl._create(context.impl)
if colnames is None:
- colnames = ['id', 'x', 'y']
+ colnames = ["id", "x", "y"]
eq_(impl.new_table.c.keys(), colnames)
pk_cols = [col for col in impl.new_table.c if col.primary_key]
eq_(list(impl.new_table.primary_key), pk_cols)
create_stmt = str(
- CreateTable(impl.new_table).compile(dialect=context.dialect))
- create_stmt = re.sub(r'[\n\t]', '', create_stmt)
+ CreateTable(impl.new_table).compile(dialect=context.dialect)
+ )
+ create_stmt = re.sub(r"[\n\t]", "", create_stmt)
idx_stmt = ""
for idx in impl.indexes.values():
@@ -205,17 +240,16 @@ class BatchApplyTest(TestBase):
impl.new_table.name = impl.table.name
idx_stmt += str(CreateIndex(idx).compile(dialect=context.dialect))
impl.new_table.name = ApplyBatchImpl._calc_temp_name(
- impl.table.name)
- idx_stmt = re.sub(r'[\n\t]', '', idx_stmt)
+ impl.table.name
+ )
+ idx_stmt = re.sub(r"[\n\t]", "", idx_stmt)
if ddl_contains:
assert ddl_contains in create_stmt + idx_stmt
if ddl_not_contains:
assert ddl_not_contains not in create_stmt + idx_stmt
- expected = [
- create_stmt,
- ]
+ expected = [create_stmt]
if schema:
args = {"schema": "%s." % schema}
@@ -224,32 +258,40 @@ class BatchApplyTest(TestBase):
args["temp_name"] = impl.new_table.name
- args['colnames'] = ", ".join([
- impl.new_table.c[name].name
- for name in colnames
- if name in impl.table.c])
+ args["colnames"] = ", ".join(
+ [
+ impl.new_table.c[name].name
+ for name in colnames
+ if name in impl.table.c
+ ]
+ )
- args['tname_colnames'] = ", ".join(
- "CAST(%(schema)stname.%(name)s AS %(type)s) AS anon_1" % {
- 'schema': args['schema'],
- 'name': name,
- 'type': impl.new_table.c[name].type
+ args["tname_colnames"] = ", ".join(
+ "CAST(%(schema)stname.%(name)s AS %(type)s) AS anon_1"
+ % {
+ "schema": args["schema"],
+ "name": name,
+ "type": impl.new_table.c[name].type,
}
if (
impl.new_table.c[name].type._type_affinity
- is not impl.table.c[name].type._type_affinity)
- else "%(schema)stname.%(name)s" % {
- 'schema': args['schema'], 'name': name}
- for name in colnames if name in impl.table.c
- )
-
- expected.extend([
- 'INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) '
- 'SELECT %(tname_colnames)s FROM %(schema)stname' % args,
- 'DROP TABLE %(schema)stname' % args,
- 'ALTER TABLE %(schema)s%(temp_name)s '
- 'RENAME TO %(schema)stname' % args
- ])
+ is not impl.table.c[name].type._type_affinity
+ )
+ else "%(schema)stname.%(name)s"
+ % {"schema": args["schema"], "name": name}
+ for name in colnames
+ if name in impl.table.c
+ )
+
+ expected.extend(
+ [
+ "INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) "
+ "SELECT %(tname_colnames)s FROM %(schema)stname" % args,
+ "DROP TABLE %(schema)stname" % args,
+ "ALTER TABLE %(schema)s%(temp_name)s "
+ "RENAME TO %(schema)stname" % args,
+ ]
+ )
if idx_stmt:
expected.append(idx_stmt)
context.assert_(*expected)
@@ -257,36 +299,42 @@ class BatchApplyTest(TestBase):
def test_change_type(self):
impl = self._simple_fixture()
- impl.alter_column('tname', 'x', type_=String)
+ impl.alter_column("tname", "x", type_=String)
new_table = self._assert_impl(impl)
assert new_table.c.x.type._type_affinity is String
def test_rename_col(self):
impl = self._simple_fixture()
- impl.alter_column('tname', 'x', name='q')
+ impl.alter_column("tname", "x", name="q")
new_table = self._assert_impl(impl)
- eq_(new_table.c.x.name, 'q')
+ eq_(new_table.c.x.name, "q")
def test_rename_col_boolean(self):
impl = self._boolean_fixture()
- impl.alter_column('tname', 'flag', name='bflag')
+ impl.alter_column("tname", "flag", name="bflag")
new_table = self._assert_impl(
- impl, ddl_contains="CHECK (bflag IN (0, 1)",
- colnames=["id", "flag"])
- eq_(new_table.c.flag.name, 'bflag')
+ impl,
+ ddl_contains="CHECK (bflag IN (0, 1)",
+ colnames=["id", "flag"],
+ )
+ eq_(new_table.c.flag.name, "bflag")
eq_(
- len([
- const for const
- in new_table.constraints
- if isinstance(const, CheckConstraint)]),
- 1)
+ len(
+ [
+ const
+ for const in new_table.constraints
+ if isinstance(const, CheckConstraint)
+ ]
+ ),
+ 1,
+ )
def test_change_type_schematype_to_non(self):
impl = self._boolean_fixture()
- impl.alter_column('tname', 'flag', type_=Integer)
+ impl.alter_column("tname", "flag", type_=Integer)
new_table = self._assert_impl(
- impl, colnames=['id', 'flag'],
- ddl_not_contains="CHECK")
+ impl, colnames=["id", "flag"], ddl_not_contains="CHECK"
+ )
assert new_table.c.flag.type._type_affinity is Integer
# NOTE: we can't do test_change_type_non_to_schematype
@@ -295,254 +343,310 @@ class BatchApplyTest(TestBase):
def test_rename_col_boolean_no_ck(self):
impl = self._boolean_no_ck_fixture()
- impl.alter_column('tname', 'flag', name='bflag')
+ impl.alter_column("tname", "flag", name="bflag")
new_table = self._assert_impl(
- impl, ddl_not_contains="CHECK",
- colnames=["id", "flag"])
- eq_(new_table.c.flag.name, 'bflag')
+ impl, ddl_not_contains="CHECK", colnames=["id", "flag"]
+ )
+ eq_(new_table.c.flag.name, "bflag")
eq_(
- len([
- const for const
- in new_table.constraints
- if isinstance(const, CheckConstraint)]),
- 0)
+ len(
+ [
+ const
+ for const in new_table.constraints
+ if isinstance(const, CheckConstraint)
+ ]
+ ),
+ 0,
+ )
def test_rename_col_enum(self):
impl = self._enum_fixture()
- impl.alter_column('tname', 'thing', name='thang')
+ impl.alter_column("tname", "thing", name="thang")
new_table = self._assert_impl(
- impl, ddl_contains="CHECK (thang IN ('a', 'b', 'c')",
- colnames=["id", "thing"])
- eq_(new_table.c.thing.name, 'thang')
+ impl,
+ ddl_contains="CHECK (thang IN ('a', 'b', 'c')",
+ colnames=["id", "thing"],
+ )
+ eq_(new_table.c.thing.name, "thang")
eq_(
- len([
- const for const
- in new_table.constraints
- if isinstance(const, CheckConstraint)]),
- 1)
+ len(
+ [
+ const
+ for const in new_table.constraints
+ if isinstance(const, CheckConstraint)
+ ]
+ ),
+ 1,
+ )
def test_rename_col_literal_ck(self):
impl = self._literal_ck_fixture()
- impl.alter_column('tname', 'email', name='emol')
+ impl.alter_column("tname", "email", name="emol")
new_table = self._assert_impl(
# note this is wrong, we don't dig into the SQL
- impl, ddl_contains="CHECK (email LIKE '%@%')",
- colnames=["id", "email"])
+ impl,
+ ddl_contains="CHECK (email LIKE '%@%')",
+ colnames=["id", "email"],
+ )
eq_(
- len([c for c in new_table.constraints
- if isinstance(c, CheckConstraint)]), 1)
+ len(
+ [
+ c
+ for c in new_table.constraints
+ if isinstance(c, CheckConstraint)
+ ]
+ ),
+ 1,
+ )
- eq_(new_table.c.email.name, 'emol')
+ eq_(new_table.c.email.name, "emol")
def test_rename_col_literal_ck_workaround(self):
impl = self._literal_ck_fixture(
copy_from=Table(
- 'tname', MetaData(),
- Column('id', Integer, primary_key=True),
- Column('email', String),
+ "tname",
+ MetaData(),
+ Column("id", Integer, primary_key=True),
+ Column("email", String),
),
- table_args=[CheckConstraint("emol LIKE '%@%'")])
+ table_args=[CheckConstraint("emol LIKE '%@%'")],
+ )
- impl.alter_column('tname', 'email', name='emol')
+ impl.alter_column("tname", "email", name="emol")
new_table = self._assert_impl(
- impl, ddl_contains="CHECK (emol LIKE '%@%')",
- colnames=["id", "email"])
+ impl,
+ ddl_contains="CHECK (emol LIKE '%@%')",
+ colnames=["id", "email"],
+ )
eq_(
- len([c for c in new_table.constraints
- if isinstance(c, CheckConstraint)]), 1)
- eq_(new_table.c.email.name, 'emol')
+ len(
+ [
+ c
+ for c in new_table.constraints
+ if isinstance(c, CheckConstraint)
+ ]
+ ),
+ 1,
+ )
+ eq_(new_table.c.email.name, "emol")
def test_rename_col_sql_ck(self):
impl = self._sql_ck_fixture()
- impl.alter_column('tname', 'email', name='emol')
+ impl.alter_column("tname", "email", name="emol")
new_table = self._assert_impl(
- impl, ddl_contains="CHECK (emol LIKE '%@%')",
- colnames=["id", "email"])
+ impl,
+ ddl_contains="CHECK (emol LIKE '%@%')",
+ colnames=["id", "email"],
+ )
eq_(
- len([c for c in new_table.constraints
- if isinstance(c, CheckConstraint)]), 1)
+ len(
+ [
+ c
+ for c in new_table.constraints
+ if isinstance(c, CheckConstraint)
+ ]
+ ),
+ 1,
+ )
- eq_(new_table.c.email.name, 'emol')
+ eq_(new_table.c.email.name, "emol")
def test_add_col(self):
impl = self._simple_fixture()
- col = Column('g', Integer)
+ col = Column("g", Integer)
# operations.add_column produces a table
- t = self.op.schema_obj.table('tname', col) # noqa
- impl.add_column('tname', col)
- new_table = self._assert_impl(impl, colnames=['id', 'x', 'y', 'g'])
- eq_(new_table.c.g.name, 'g')
+ t = self.op.schema_obj.table("tname", col) # noqa
+ impl.add_column("tname", col)
+ new_table = self._assert_impl(impl, colnames=["id", "x", "y", "g"])
+ eq_(new_table.c.g.name, "g")
def test_add_server_default(self):
impl = self._simple_fixture()
- impl.alter_column('tname', 'y', server_default="10")
- new_table = self._assert_impl(
- impl, ddl_contains="DEFAULT '10'")
- eq_(
- new_table.c.y.server_default.arg, "10"
- )
+ impl.alter_column("tname", "y", server_default="10")
+ new_table = self._assert_impl(impl, ddl_contains="DEFAULT '10'")
+ eq_(new_table.c.y.server_default.arg, "10")
def test_drop_server_default(self):
impl = self._server_default_fixture()
- impl.alter_column('tname', 'thing', server_default=None)
+ impl.alter_column("tname", "thing", server_default=None)
new_table = self._assert_impl(
- impl, colnames=['id', 'thing'], ddl_not_contains="DEFAULT")
+ impl, colnames=["id", "thing"], ddl_not_contains="DEFAULT"
+ )
eq_(new_table.c.thing.server_default, None)
def test_rename_col_pk(self):
impl = self._simple_fixture()
- impl.alter_column('tname', 'id', name='foobar')
+ impl.alter_column("tname", "id", name="foobar")
new_table = self._assert_impl(
- impl, ddl_contains="PRIMARY KEY (foobar)")
- eq_(new_table.c.id.name, 'foobar')
+ impl, ddl_contains="PRIMARY KEY (foobar)"
+ )
+ eq_(new_table.c.id.name, "foobar")
eq_(list(new_table.primary_key), [new_table.c.id])
def test_rename_col_fk(self):
impl = self._fk_fixture()
- impl.alter_column('tname', 'user_id', name='foobar')
+ impl.alter_column("tname", "user_id", name="foobar")
new_table = self._assert_impl(
- impl, colnames=['id', 'email', 'user_id'],
- ddl_contains='FOREIGN KEY(foobar) REFERENCES "user" (id)')
- eq_(new_table.c.user_id.name, 'foobar')
+ impl,
+ colnames=["id", "email", "user_id"],
+ ddl_contains='FOREIGN KEY(foobar) REFERENCES "user" (id)',
+ )
+ eq_(new_table.c.user_id.name, "foobar")
eq_(
- list(new_table.c.user_id.foreign_keys)[0]._get_colspec(),
- "user.id"
+ list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id"
)
def test_regen_multi_fk(self):
impl = self._multi_fk_fixture()
self._assert_impl(
- impl, colnames=[
- 'id', 'email', 'user_id_1', 'user_id_2',
- 'user_id_3', 'user_id_version'],
- ddl_contains='FOREIGN KEY(user_id_3, user_id_version) '
- 'REFERENCES "user" (id, id_version)')
+ impl,
+ colnames=[
+ "id",
+ "email",
+ "user_id_1",
+ "user_id_2",
+ "user_id_3",
+ "user_id_version",
+ ],
+ ddl_contains="FOREIGN KEY(user_id_3, user_id_version) "
+ 'REFERENCES "user" (id, id_version)',
+ )
def test_regen_multi_fk_schema(self):
- impl = self._multi_fk_fixture(schema='foo_schema')
+ impl = self._multi_fk_fixture(schema="foo_schema")
self._assert_impl(
- impl, colnames=[
- 'id', 'email', 'user_id_1', 'user_id_2',
- 'user_id_3', 'user_id_version'],
- ddl_contains='FOREIGN KEY(user_id_3, user_id_version) '
+ impl,
+ colnames=[
+ "id",
+ "email",
+ "user_id_1",
+ "user_id_2",
+ "user_id_3",
+ "user_id_version",
+ ],
+ ddl_contains="FOREIGN KEY(user_id_3, user_id_version) "
'REFERENCES foo_schema."user" (id, id_version)',
- schema='foo_schema')
+ schema="foo_schema",
+ )
def test_drop_col(self):
impl = self._simple_fixture()
- impl.drop_column('tname', column('x'))
- new_table = self._assert_impl(impl, colnames=['id', 'y'])
- assert 'y' in new_table.c
- assert 'x' not in new_table.c
+ impl.drop_column("tname", column("x"))
+ new_table = self._assert_impl(impl, colnames=["id", "y"])
+ assert "y" in new_table.c
+ assert "x" not in new_table.c
def test_drop_col_remove_pk(self):
impl = self._simple_fixture()
- impl.drop_column('tname', column('id'))
+ impl.drop_column("tname", column("id"))
new_table = self._assert_impl(
- impl, colnames=['x', 'y'], ddl_not_contains="PRIMARY KEY")
- assert 'y' in new_table.c
- assert 'id' not in new_table.c
+ impl, colnames=["x", "y"], ddl_not_contains="PRIMARY KEY"
+ )
+ assert "y" in new_table.c
+ assert "id" not in new_table.c
assert not new_table.primary_key
def test_drop_col_remove_fk(self):
impl = self._fk_fixture()
- impl.drop_column('tname', column('user_id'))
+ impl.drop_column("tname", column("user_id"))
new_table = self._assert_impl(
- impl, colnames=['id', 'email'], ddl_not_contains="FOREIGN KEY")
- assert 'user_id' not in new_table.c
+ impl, colnames=["id", "email"], ddl_not_contains="FOREIGN KEY"
+ )
+ assert "user_id" not in new_table.c
assert not new_table.foreign_keys
def test_drop_col_retain_fk(self):
impl = self._fk_fixture()
- impl.drop_column('tname', column('email'))
+ impl.drop_column("tname", column("email"))
new_table = self._assert_impl(
- impl, colnames=['id', 'user_id'],
- ddl_contains='FOREIGN KEY(user_id) REFERENCES "user" (id)')
- assert 'email' not in new_table.c
+ impl,
+ colnames=["id", "user_id"],
+ ddl_contains='FOREIGN KEY(user_id) REFERENCES "user" (id)',
+ )
+ assert "email" not in new_table.c
assert new_table.c.user_id.foreign_keys
def test_drop_col_retain_fk_selfref(self):
impl = self._selfref_fk_fixture()
- impl.drop_column('tname', column('data'))
- new_table = self._assert_impl(impl, colnames=['id', 'parent_id'])
- assert 'data' not in new_table.c
+ impl.drop_column("tname", column("data"))
+ new_table = self._assert_impl(impl, colnames=["id", "parent_id"])
+ assert "data" not in new_table.c
assert new_table.c.parent_id.foreign_keys
def test_add_fk(self):
impl = self._simple_fixture()
- impl.add_column('tname', Column('user_id', Integer))
+ impl.add_column("tname", Column("user_id", Integer))
fk = self.op.schema_obj.foreign_key_constraint(
- 'fk1', 'tname', 'user',
- ['user_id'], ['id'])
+ "fk1", "tname", "user", ["user_id"], ["id"]
+ )
impl.add_constraint(fk)
new_table = self._assert_impl(
- impl, colnames=['id', 'x', 'y', 'user_id'],
- ddl_contains='CONSTRAINT fk1 FOREIGN KEY(user_id) '
- 'REFERENCES "user" (id)')
+ impl,
+ colnames=["id", "x", "y", "user_id"],
+ ddl_contains="CONSTRAINT fk1 FOREIGN KEY(user_id) "
+ 'REFERENCES "user" (id)',
+ )
eq_(
- list(new_table.c.user_id.foreign_keys)[0]._get_colspec(),
- 'user.id'
+ list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id"
)
def test_drop_fk(self):
impl = self._named_fk_fixture()
- fk = ForeignKeyConstraint([], [], name='ufk')
+ fk = ForeignKeyConstraint([], [], name="ufk")
impl.drop_constraint(fk)
new_table = self._assert_impl(
- impl, colnames=['id', 'email', 'user_id'],
- ddl_not_contains="CONSTRANT fk1")
- eq_(
- list(new_table.foreign_keys),
- []
+ impl,
+ colnames=["id", "email", "user_id"],
+ ddl_not_contains="CONSTRANT fk1",
)
+ eq_(list(new_table.foreign_keys), [])
def test_add_uq(self):
impl = self._simple_fixture()
- uq = self.op.schema_obj.unique_constraint(
- 'uq1', 'tname', ['y']
- )
+ uq = self.op.schema_obj.unique_constraint("uq1", "tname", ["y"])
impl.add_constraint(uq)
self._assert_impl(
- impl, colnames=['id', 'x', 'y'],
- ddl_contains="CONSTRAINT uq1 UNIQUE")
+ impl,
+ colnames=["id", "x", "y"],
+ ddl_contains="CONSTRAINT uq1 UNIQUE",
+ )
def test_drop_uq(self):
impl = self._uq_fixture()
- uq = self.op.schema_obj.unique_constraint(
- 'uq1', 'tname', ['y']
- )
+ uq = self.op.schema_obj.unique_constraint("uq1", "tname", ["y"])
impl.drop_constraint(uq)
self._assert_impl(
- impl, colnames=['id', 'x', 'y'],
- ddl_not_contains="CONSTRAINT uq1 UNIQUE")
+ impl,
+ colnames=["id", "x", "y"],
+ ddl_not_contains="CONSTRAINT uq1 UNIQUE",
+ )
def test_create_index(self):
impl = self._simple_fixture()
- ix = self.op.schema_obj.index('ix1', 'tname', ['y'])
+ ix = self.op.schema_obj.index("ix1", "tname", ["y"])
impl.create_index(ix)
self._assert_impl(
- impl, colnames=['id', 'x', 'y'],
- ddl_contains="CREATE INDEX ix1")
+ impl, colnames=["id", "x", "y"], ddl_contains="CREATE INDEX ix1"
+ )
def test_drop_index(self):
impl = self._ix_fixture()
- ix = self.op.schema_obj.index('ix1', 'tname', ['y'])
+ ix = self.op.schema_obj.index("ix1", "tname", ["y"])
impl.drop_index(ix)
self._assert_impl(
- impl, colnames=['id', 'x', 'y'],
- ddl_not_contains="CONSTRAINT uq1 UNIQUE")
+ impl,
+ colnames=["id", "x", "y"],
+ ddl_not_contains="CONSTRAINT uq1 UNIQUE",
+ )
def test_add_table_opts(self):
- impl = self._simple_fixture(table_kwargs={'mysql_engine': 'InnoDB'})
- self._assert_impl(
- impl, ddl_contains="ENGINE=InnoDB",
- dialect='mysql'
- )
+ impl = self._simple_fixture(table_kwargs={"mysql_engine": "InnoDB"})
+ self._assert_impl(impl, ddl_contains="ENGINE=InnoDB", dialect="mysql")
def test_drop_pk(self):
impl = self._pk_fixture()
@@ -554,14 +658,15 @@ class BatchApplyTest(TestBase):
class BatchAPITest(TestBase):
-
@contextmanager
def _fixture(self, schema=None):
migration_context = mock.Mock(
- opts={}, impl=mock.MagicMock(__dialect__='sqlite'))
+ opts={}, impl=mock.MagicMock(__dialect__="sqlite")
+ )
op = Operations(migration_context)
batch = op.batch_alter_table(
- 'tname', recreate='never', schema=schema).__enter__()
+ "tname", recreate="never", schema=schema
+ ).__enter__()
mock_schema = mock.MagicMock()
with mock.patch("alembic.operations.schemaobj.sa_schema", mock_schema):
@@ -571,105 +676,131 @@ class BatchAPITest(TestBase):
def test_drop_col(self):
with self._fixture() as batch:
- batch.drop_column('q')
+ batch.drop_column("q")
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.drop_column(
- 'tname', self.mock_schema.Column(), schema=None)]
+ [
+ mock.call.drop_column(
+ "tname", self.mock_schema.Column(), schema=None
+ )
+ ],
)
def test_add_col(self):
- column = Column('w', String(50))
+ column = Column("w", String(50))
with self._fixture() as batch:
batch.add_column(column)
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.add_column(
- 'tname', column, schema=None)]
+ [mock.call.add_column("tname", column, schema=None)],
)
def test_create_fk(self):
with self._fixture() as batch:
- batch.create_foreign_key('myfk', 'user', ['x'], ['y'])
+ batch.create_foreign_key("myfk", "user", ["x"], ["y"])
eq_(
self.mock_schema.ForeignKeyConstraint.mock_calls,
[
mock.call(
- ['x'], ['user.y'],
- onupdate=None, ondelete=None, name='myfk',
- initially=None, deferrable=None, match=None)
- ]
+ ["x"],
+ ["user.y"],
+ onupdate=None,
+ ondelete=None,
+ name="myfk",
+ initially=None,
+ deferrable=None,
+ match=None,
+ )
+ ],
)
eq_(
self.mock_schema.Table.mock_calls,
[
mock.call(
- 'user', self.mock_schema.MetaData(),
+ "user",
+ self.mock_schema.MetaData(),
self.mock_schema.Column(),
- schema=None
+ schema=None,
),
mock.call(
- 'tname', self.mock_schema.MetaData(),
+ "tname",
+ self.mock_schema.MetaData(),
self.mock_schema.Column(),
- schema=None
+ schema=None,
),
mock.call().append_constraint(
- self.mock_schema.ForeignKeyConstraint())
- ]
+ self.mock_schema.ForeignKeyConstraint()
+ ),
+ ],
)
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.add_constraint(
- self.mock_schema.ForeignKeyConstraint())]
+ [
+ mock.call.add_constraint(
+ self.mock_schema.ForeignKeyConstraint()
+ )
+ ],
)
def test_create_fk_schema(self):
- with self._fixture(schema='foo') as batch:
- batch.create_foreign_key('myfk', 'user', ['x'], ['y'])
+ with self._fixture(schema="foo") as batch:
+ batch.create_foreign_key("myfk", "user", ["x"], ["y"])
eq_(
self.mock_schema.ForeignKeyConstraint.mock_calls,
[
mock.call(
- ['x'], ['user.y'],
- onupdate=None, ondelete=None, name='myfk',
- initially=None, deferrable=None, match=None)
- ]
+ ["x"],
+ ["user.y"],
+ onupdate=None,
+ ondelete=None,
+ name="myfk",
+ initially=None,
+ deferrable=None,
+ match=None,
+ )
+ ],
)
eq_(
self.mock_schema.Table.mock_calls,
[
mock.call(
- 'user', self.mock_schema.MetaData(),
+ "user",
+ self.mock_schema.MetaData(),
self.mock_schema.Column(),
- schema=None
+ schema=None,
),
mock.call(
- 'tname', self.mock_schema.MetaData(),
+ "tname",
+ self.mock_schema.MetaData(),
self.mock_schema.Column(),
- schema='foo'
+ schema="foo",
),
mock.call().append_constraint(
- self.mock_schema.ForeignKeyConstraint())
- ]
+ self.mock_schema.ForeignKeyConstraint()
+ ),
+ ],
)
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.add_constraint(
- self.mock_schema.ForeignKeyConstraint())]
+ [
+ mock.call.add_constraint(
+ self.mock_schema.ForeignKeyConstraint()
+ )
+ ],
)
def test_create_uq(self):
with self._fixture() as batch:
- batch.create_unique_constraint('uq1', ['a', 'b'])
+ batch.create_unique_constraint("uq1", ["a", "b"])
eq_(
self.mock_schema.Table().c.__getitem__.mock_calls,
- [mock.call('a'), mock.call('b')]
+ [mock.call("a"), mock.call("b")],
)
eq_(
@@ -678,23 +809,22 @@ class BatchAPITest(TestBase):
mock.call(
self.mock_schema.Table().c.__getitem__(),
self.mock_schema.Table().c.__getitem__(),
- name='uq1'
+ name="uq1",
)
- ]
+ ],
)
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.add_constraint(
- self.mock_schema.UniqueConstraint())]
+ [mock.call.add_constraint(self.mock_schema.UniqueConstraint())],
)
def test_create_pk(self):
with self._fixture() as batch:
- batch.create_primary_key('pk1', ['a', 'b'])
+ batch.create_primary_key("pk1", ["a", "b"])
eq_(
self.mock_schema.Table().c.__getitem__.mock_calls,
- [mock.call('a'), mock.call('b')]
+ [mock.call("a"), mock.call("b")],
)
eq_(
@@ -703,60 +833,53 @@ class BatchAPITest(TestBase):
mock.call(
self.mock_schema.Table().c.__getitem__(),
self.mock_schema.Table().c.__getitem__(),
- name='pk1'
+ name="pk1",
)
- ]
+ ],
)
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.add_constraint(
- self.mock_schema.PrimaryKeyConstraint())]
+ [
+ mock.call.add_constraint(
+ self.mock_schema.PrimaryKeyConstraint()
+ )
+ ],
)
def test_create_check(self):
expr = text("a > b")
with self._fixture() as batch:
- batch.create_check_constraint('ck1', expr)
+ batch.create_check_constraint("ck1", expr)
eq_(
self.mock_schema.CheckConstraint.mock_calls,
- [
- mock.call(
- expr, name="ck1"
- )
- ]
+ [mock.call(expr, name="ck1")],
)
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.add_constraint(
- self.mock_schema.CheckConstraint())]
+ [mock.call.add_constraint(self.mock_schema.CheckConstraint())],
)
def test_drop_constraint(self):
with self._fixture() as batch:
- batch.drop_constraint('uq1')
+ batch.drop_constraint("uq1")
- eq_(
- self.mock_schema.Constraint.mock_calls,
- [
- mock.call(name='uq1')
- ]
- )
+ eq_(self.mock_schema.Constraint.mock_calls, [mock.call(name="uq1")])
eq_(
batch.impl.operations.impl.mock_calls,
- [mock.call.drop_constraint(self.mock_schema.Constraint())]
+ [mock.call.drop_constraint(self.mock_schema.Constraint())],
)
class CopyFromTest(TestBase):
-
def _fixture(self):
self.metadata = MetaData()
self.table = Table(
- 'foo', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column('x', Integer),
+ "foo",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("x", Integer),
)
context = op_fixture(dialect="sqlite", as_sql=True)
@@ -766,148 +889,151 @@ class CopyFromTest(TestBase):
def test_change_type(self):
context = self._fixture()
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
- batch_op.alter_column('data', type_=Integer)
+ "foo", copy_from=self.table
+ ) as batch_op:
+ batch_op.alter_column("data", type_=Integer)
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data INTEGER, x INTEGER, PRIMARY KEY (id))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, '
- 'CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data INTEGER, x INTEGER, PRIMARY KEY (id))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, "
+ "CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
)
def test_change_type_from_schematype(self):
context = self._fixture()
self.table.append_column(
- Column('y', Boolean(
- create_constraint=True, name="ck1")))
+ Column("y", Boolean(create_constraint=True, name="ck1"))
+ )
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
+ "foo", copy_from=self.table
+ ) as batch_op:
batch_op.alter_column(
- 'y', type_=Integer,
- existing_type=Boolean(
- create_constraint=True, name="ck1"))
+ "y",
+ type_=Integer,
+ existing_type=Boolean(create_constraint=True, name="ck1"),
+ )
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data VARCHAR(50), x INTEGER, y INTEGER, PRIMARY KEY (id))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, '
- 'foo.data, foo.x, CAST(foo.y AS INTEGER) AS anon_1 FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data VARCHAR(50), x INTEGER, y INTEGER, PRIMARY KEY (id))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, "
+ "foo.data, foo.x, CAST(foo.y AS INTEGER) AS anon_1 FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
)
def test_change_type_to_schematype(self):
context = self._fixture()
- self.table.append_column(
- Column('y', Integer))
+ self.table.append_column(Column("y", Integer))
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
+ "foo", copy_from=self.table
+ ) as batch_op:
batch_op.alter_column(
- 'y', existing_type=Integer,
- type_=Boolean(
- create_constraint=True, name="ck1"))
+ "y",
+ existing_type=Integer,
+ type_=Boolean(create_constraint=True, name="ck1"),
+ )
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data VARCHAR(50), x INTEGER, y BOOLEAN, PRIMARY KEY (id), '
- 'CONSTRAINT ck1 CHECK (y IN (0, 1)))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, '
- 'foo.data, foo.x, CAST(foo.y AS BOOLEAN) AS anon_1 FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data VARCHAR(50), x INTEGER, y BOOLEAN, PRIMARY KEY (id), "
+ "CONSTRAINT ck1 CHECK (y IN (0, 1)))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, "
+ "foo.data, foo.x, CAST(foo.y AS BOOLEAN) AS anon_1 FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
)
def test_create_drop_index_w_always(self):
context = self._fixture()
with self.op.batch_alter_table(
- "foo", copy_from=self.table, recreate='always') as batch_op:
- batch_op.create_index(
- 'ix_data', ['data'], unique=True)
+ "foo", copy_from=self.table, recreate="always"
+ ) as batch_op:
+ batch_op.create_index("ix_data", ["data"], unique=True)
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data VARCHAR(50), '
- 'x INTEGER, PRIMARY KEY (id))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x) '
- 'SELECT foo.id, foo.data, foo.x FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo',
- 'CREATE UNIQUE INDEX ix_data ON foo (data)',
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data VARCHAR(50), "
+ "x INTEGER, PRIMARY KEY (id))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x) "
+ "SELECT foo.id, foo.data, foo.x FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
+ "CREATE UNIQUE INDEX ix_data ON foo (data)",
)
context.clear_assertions()
- Index('ix_data', self.table.c.data, unique=True)
+ Index("ix_data", self.table.c.data, unique=True)
with self.op.batch_alter_table(
- "foo", copy_from=self.table, recreate='always') as batch_op:
- batch_op.drop_index('ix_data')
+ "foo", copy_from=self.table, recreate="always"
+ ) as batch_op:
+ batch_op.drop_index("ix_data")
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data VARCHAR(50), x INTEGER, PRIMARY KEY (id))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x) '
- 'SELECT foo.id, foo.data, foo.x FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data VARCHAR(50), x INTEGER, PRIMARY KEY (id))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x) "
+ "SELECT foo.id, foo.data, foo.x FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
)
def test_create_drop_index_wo_always(self):
context = self._fixture()
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
- batch_op.create_index(
- 'ix_data', ['data'], unique=True)
+ "foo", copy_from=self.table
+ ) as batch_op:
+ batch_op.create_index("ix_data", ["data"], unique=True)
- context.assert_(
- 'CREATE UNIQUE INDEX ix_data ON foo (data)'
- )
+ context.assert_("CREATE UNIQUE INDEX ix_data ON foo (data)")
context.clear_assertions()
- Index('ix_data', self.table.c.data, unique=True)
+ Index("ix_data", self.table.c.data, unique=True)
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
- batch_op.drop_index('ix_data')
+ "foo", copy_from=self.table
+ ) as batch_op:
+ batch_op.drop_index("ix_data")
- context.assert_(
- 'DROP INDEX ix_data'
- )
+ context.assert_("DROP INDEX ix_data")
def test_create_drop_index_w_other_ops(self):
context = self._fixture()
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
- batch_op.alter_column('data', type_=Integer)
- batch_op.create_index(
- 'ix_data', ['data'], unique=True)
+ "foo", copy_from=self.table
+ ) as batch_op:
+ batch_op.alter_column("data", type_=Integer)
+ batch_op.create_index("ix_data", ["data"], unique=True)
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data INTEGER, x INTEGER, PRIMARY KEY (id))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, '
- 'CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo',
- 'CREATE UNIQUE INDEX ix_data ON foo (data)',
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data INTEGER, x INTEGER, PRIMARY KEY (id))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, "
+ "CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
+ "CREATE UNIQUE INDEX ix_data ON foo (data)",
)
context.clear_assertions()
- Index('ix_data', self.table.c.data, unique=True)
+ Index("ix_data", self.table.c.data, unique=True)
with self.op.batch_alter_table(
- "foo", copy_from=self.table) as batch_op:
- batch_op.drop_index('ix_data')
- batch_op.alter_column('data', type_=String)
+ "foo", copy_from=self.table
+ ) as batch_op:
+ batch_op.drop_index("ix_data")
+ batch_op.alter_column("data", type_=String)
context.assert_(
- 'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
- 'data VARCHAR, x INTEGER, PRIMARY KEY (id))',
- 'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, '
- 'foo.data, foo.x FROM foo',
- 'DROP TABLE foo',
- 'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+ "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+ "data VARCHAR, x INTEGER, PRIMARY KEY (id))",
+ "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, "
+ "foo.data, foo.x FROM foo",
+ "DROP TABLE foo",
+ "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
)
@@ -918,11 +1044,12 @@ class BatchRoundTripTest(TestBase):
self.conn = config.db.connect()
self.metadata = MetaData()
t1 = Table(
- 'foo', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column('x', Integer),
- mysql_engine='InnoDB'
+ "foo",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("x", Integer),
+ mysql_engine="InnoDB",
)
t1.create(self.conn)
@@ -933,8 +1060,8 @@ class BatchRoundTripTest(TestBase):
{"id": 2, "data": "22", "x": 6},
{"id": 3, "data": "8.5", "x": 7},
{"id": 4, "data": "9.46", "x": 8},
- {"id": 5, "data": "d5", "x": 9}
- ]
+ {"id": 5, "data": "d5", "x": 9},
+ ],
)
context = MigrationContext.configure(self.conn)
self.op = Operations(context)
@@ -949,80 +1076,75 @@ class BatchRoundTripTest(TestBase):
def _no_pk_fixture(self):
nopk = Table(
- 'nopk', self.metadata,
- Column('a', Integer),
- Column('b', Integer),
- Column('c', Integer),
- mysql_engine='InnoDB'
+ "nopk",
+ self.metadata,
+ Column("a", Integer),
+ Column("b", Integer),
+ Column("c", Integer),
+ mysql_engine="InnoDB",
)
nopk.create(self.conn)
self.conn.execute(
- nopk.insert(),
- [
- {"a": 1, "b": 2, "c": 3},
- {"a": 2, "b": 4, "c": 5},
- ]
-
+ nopk.insert(), [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}]
)
return nopk
def _table_w_index_fixture(self):
t = Table(
- 't_w_ix', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('thing', Integer),
- Column('data', String(20)),
+ "t_w_ix",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("thing", Integer),
+ Column("data", String(20)),
)
- Index('ix_thing', t.c.thing)
+ Index("ix_thing", t.c.thing)
t.create(self.conn)
return t
def _boolean_fixture(self):
t = Table(
- 'hasbool', self.metadata,
- Column('x', Boolean(create_constraint=True, name='ck1')),
- Column('y', Integer)
+ "hasbool",
+ self.metadata,
+ Column("x", Boolean(create_constraint=True, name="ck1")),
+ Column("y", Integer),
)
t.create(self.conn)
def _timestamp_fixture(self):
- t = Table(
- 'hasts', self.metadata,
- Column('x', DateTime()),
- )
+ t = Table("hasts", self.metadata, Column("x", DateTime()))
t.create(self.conn)
return t
def _int_to_boolean_fixture(self):
- t = Table(
- 'hasbool', self.metadata,
- Column('x', Integer)
- )
+ t = Table("hasbool", self.metadata, Column("x", Integer))
t.create(self.conn)
def test_change_type_boolean_to_int(self):
self._boolean_fixture()
- with self.op.batch_alter_table(
- "hasbool"
- ) as batch_op:
+ with self.op.batch_alter_table("hasbool") as batch_op:
batch_op.alter_column(
- 'x', type_=Integer, existing_type=Boolean(
- create_constraint=True, name='ck1'))
+ "x",
+ type_=Integer,
+ existing_type=Boolean(create_constraint=True, name="ck1"),
+ )
insp = Inspector.from_engine(config.db)
eq_(
- [c['type']._type_affinity for c in insp.get_columns('hasbool')
- if c['name'] == 'x'],
- [Integer]
+ [
+ c["type"]._type_affinity
+ for c in insp.get_columns("hasbool")
+ if c["name"] == "x"
+ ],
+ [Integer],
)
def test_no_net_change_timestamp(self):
t = self._timestamp_fixture()
import datetime
+
self.conn.execute(
- t.insert(),
- {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
+ t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
)
with self.op.batch_alter_table("hasts") as batch_op:
@@ -1030,69 +1152,71 @@ class BatchRoundTripTest(TestBase):
eq_(
self.conn.execute(select([t.c.x])).fetchall(),
- [(datetime.datetime(2012, 5, 18, 15, 32, 5),)]
+ [(datetime.datetime(2012, 5, 18, 15, 32, 5),)],
)
def test_drop_col_schematype(self):
self._boolean_fixture()
- with self.op.batch_alter_table(
- "hasbool"
- ) as batch_op:
- batch_op.drop_column('x')
+ with self.op.batch_alter_table("hasbool") as batch_op:
+ batch_op.drop_column("x")
insp = Inspector.from_engine(config.db)
- assert 'x' not in (c['name'] for c in insp.get_columns('hasbool'))
+ assert "x" not in (c["name"] for c in insp.get_columns("hasbool"))
def test_change_type_int_to_boolean(self):
self._int_to_boolean_fixture()
- with self.op.batch_alter_table(
- "hasbool"
- ) as batch_op:
+ with self.op.batch_alter_table("hasbool") as batch_op:
batch_op.alter_column(
- 'x', type_=Boolean(create_constraint=True, name='ck1'))
+ "x", type_=Boolean(create_constraint=True, name="ck1")
+ )
insp = Inspector.from_engine(config.db)
if exclusions.against(config, "sqlite"):
eq_(
- [c['type']._type_affinity for
- c in insp.get_columns('hasbool') if c['name'] == 'x'],
- [Boolean]
+ [
+ c["type"]._type_affinity
+ for c in insp.get_columns("hasbool")
+ if c["name"] == "x"
+ ],
+ [Boolean],
)
elif exclusions.against(config, "mysql"):
eq_(
- [c['type']._type_affinity for
- c in insp.get_columns('hasbool') if c['name'] == 'x'],
- [Integer]
+ [
+ c["type"]._type_affinity
+ for c in insp.get_columns("hasbool")
+ if c["name"] == "x"
+ ],
+ [Integer],
)
def tearDown(self):
self.metadata.drop_all(self.conn)
self.conn.close()
- def _assert_data(self, data, tablename='foo'):
+ def _assert_data(self, data, tablename="foo"):
eq_(
- [dict(row) for row
- in self.conn.execute("select * from %s" % tablename)],
- data
+ [
+ dict(row)
+ for row in self.conn.execute("select * from %s" % tablename)
+ ],
+ data,
)
def test_ix_existing(self):
self._table_w_index_fixture()
with self.op.batch_alter_table("t_w_ix") as batch_op:
- batch_op.alter_column('data', type_=String(30))
+ batch_op.alter_column("data", type_=String(30))
batch_op.create_index("ix_data", ["data"])
insp = Inspector.from_engine(config.db)
eq_(
set(
- (ix['name'], tuple(ix['column_names'])) for ix in
- insp.get_indexes('t_w_ix')
+ (ix["name"], tuple(ix["column_names"]))
+ for ix in insp.get_indexes("t_w_ix")
),
- set([
- ('ix_data', ('data',)),
- ('ix_thing', ('thing', ))
- ])
+ set([("ix_data", ("data",)), ("ix_thing", ("thing",))]),
)
def test_fk_points_to_me_auto(self):
@@ -1108,31 +1232,39 @@ class BatchRoundTripTest(TestBase):
@exclusions.only_on("sqlite")
@exclusions.fails(
"intentionally asserting that this "
- "doesn't work w/ pragma foreign keys")
+ "doesn't work w/ pragma foreign keys"
+ )
def test_fk_points_to_me_sqlite_refinteg(self):
with self._sqlite_referential_integrity():
self._test_fk_points_to_me("auto")
def _test_fk_points_to_me(self, recreate):
bar = Table(
- 'bar', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('foo_id', Integer, ForeignKey('foo.id')),
- mysql_engine='InnoDB'
+ "bar",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("foo_id", Integer, ForeignKey("foo.id")),
+ mysql_engine="InnoDB",
)
bar.create(self.conn)
- self.conn.execute(bar.insert(), {'id': 1, 'foo_id': 3})
+ self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
with self.op.batch_alter_table("foo", recreate=recreate) as batch_op:
batch_op.alter_column(
- 'data', new_column_name='newdata', existing_type=String(50))
+ "data", new_column_name="newdata", existing_type=String(50)
+ )
insp = Inspector.from_engine(self.conn)
eq_(
- [(key['referred_table'],
- key['referred_columns'], key['constrained_columns'])
- for key in insp.get_foreign_keys('bar')],
- [('foo', ['id'], ['foo_id'])]
+ [
+ (
+ key["referred_table"],
+ key["referred_columns"],
+ key["constrained_columns"],
+ )
+ for key in insp.get_foreign_keys("bar")
+ ],
+ [("foo", ["id"], ["foo_id"])],
)
def test_selfref_fk_auto(self):
@@ -1145,100 +1277,112 @@ class BatchRoundTripTest(TestBase):
@exclusions.only_on("sqlite")
@exclusions.fails(
"intentionally asserting that this "
- "doesn't work w/ pragma foreign keys")
+ "doesn't work w/ pragma foreign keys"
+ )
def test_selfref_fk_sqlite_refinteg(self):
with self._sqlite_referential_integrity():
self._test_selfref_fk("auto")
def _test_selfref_fk(self, recreate):
bar = Table(
- 'bar', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('bar_id', Integer, ForeignKey('bar.id')),
- Column('data', String(50)),
- mysql_engine='InnoDB'
+ "bar",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("bar_id", Integer, ForeignKey("bar.id")),
+ Column("data", String(50)),
+ mysql_engine="InnoDB",
)
bar.create(self.conn)
- self.conn.execute(bar.insert(), {'id': 1, 'data': 'x', 'bar_id': None})
- self.conn.execute(bar.insert(), {'id': 2, 'data': 'y', 'bar_id': 1})
+ self.conn.execute(bar.insert(), {"id": 1, "data": "x", "bar_id": None})
+ self.conn.execute(bar.insert(), {"id": 2, "data": "y", "bar_id": 1})
with self.op.batch_alter_table("bar", recreate=recreate) as batch_op:
batch_op.alter_column(
- 'data', new_column_name='newdata', existing_type=String(50))
+ "data", new_column_name="newdata", existing_type=String(50)
+ )
insp = Inspector.from_engine(self.conn)
insp = Inspector.from_engine(self.conn)
eq_(
- [(key['referred_table'],
- key['referred_columns'], key['constrained_columns'])
- for key in insp.get_foreign_keys('bar')],
- [('bar', ['id'], ['bar_id'])]
+ [
+ (
+ key["referred_table"],
+ key["referred_columns"],
+ key["constrained_columns"],
+ )
+ for key in insp.get_foreign_keys("bar")
+ ],
+ [("bar", ["id"], ["bar_id"])],
)
def test_change_type(self):
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.alter_column('data', type_=Integer)
+ batch_op.alter_column("data", type_=Integer)
- self._assert_data([
- {"id": 1, "data": 0, "x": 5},
- {"id": 2, "data": 22, "x": 6},
- {"id": 3, "data": 8, "x": 7},
- {"id": 4, "data": 9, "x": 8},
- {"id": 5, "data": 0, "x": 9}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "data": 0, "x": 5},
+ {"id": 2, "data": 22, "x": 6},
+ {"id": 3, "data": 8, "x": 7},
+ {"id": 4, "data": 9, "x": 8},
+ {"id": 5, "data": 0, "x": 9},
+ ]
+ )
def test_drop_column(self):
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.drop_column('data')
+ batch_op.drop_column("data")
- self._assert_data([
- {"id": 1, "x": 5},
- {"id": 2, "x": 6},
- {"id": 3, "x": 7},
- {"id": 4, "x": 8},
- {"id": 5, "x": 9}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "x": 5},
+ {"id": 2, "x": 6},
+ {"id": 3, "x": 7},
+ {"id": 4, "x": 8},
+ {"id": 5, "x": 9},
+ ]
+ )
def test_drop_pk_col_readd_col(self):
# drop a column, add it back without primary_key=True, should no
# longer be in the constraint
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.drop_column('id')
- batch_op.add_column(Column('id', Integer))
+ batch_op.drop_column("id")
+ batch_op.add_column(Column("id", Integer))
- pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo')
- eq_(pk_const['constrained_columns'], [])
+ pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo")
+ eq_(pk_const["constrained_columns"], [])
def test_drop_pk_col_readd_pk_col(self):
# drop a column, add it back with primary_key=True, should remain
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.drop_column('id')
- batch_op.add_column(Column('id', Integer, primary_key=True))
+ batch_op.drop_column("id")
+ batch_op.add_column(Column("id", Integer, primary_key=True))
- pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo')
- eq_(pk_const['constrained_columns'], ['id'])
+ pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo")
+ eq_(pk_const["constrained_columns"], ["id"])
def test_drop_pk_col_readd_col_also_pk_const(self):
# drop a column, add it back without primary_key=True, but then
# also make anew PK constraint that includes it, should remain
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.drop_column('id')
- batch_op.add_column(Column('id', Integer))
- batch_op.create_primary_key('newpk', ['id'])
+ batch_op.drop_column("id")
+ batch_op.add_column(Column("id", Integer))
+ batch_op.create_primary_key("newpk", ["id"])
- pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo')
- eq_(pk_const['constrained_columns'], ['id'])
+ pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo")
+ eq_(pk_const["constrained_columns"], ["id"])
def test_add_pk_constraint(self):
self._no_pk_fixture()
with self.op.batch_alter_table("nopk", recreate="always") as batch_op:
- batch_op.create_primary_key('newpk', ['a', 'b'])
+ batch_op.create_primary_key("newpk", ["a", "b"])
- pk_const = Inspector.from_engine(self.conn).get_pk_constraint('nopk')
+ pk_const = Inspector.from_engine(self.conn).get_pk_constraint("nopk")
with config.requirements.reflects_pk_names.fail_if():
- eq_(pk_const['name'], 'newpk')
- eq_(pk_const['constrained_columns'], ['a', 'b'])
+ eq_(pk_const["name"], "newpk")
+ eq_(pk_const["constrained_columns"], ["a", "b"])
@config.requirements.check_constraints_w_enforcement
def test_add_ck_constraint(self):
@@ -1247,203 +1391,219 @@ class BatchRoundTripTest(TestBase):
# we dont support reflection of CHECK constraints
# so test this by just running invalid data in
- foo = self.metadata.tables['foo']
+ foo = self.metadata.tables["foo"]
assert_raises_message(
exc.IntegrityError,
"newck",
self.conn.execute,
- foo.insert(), {"id": 6, "data": 5, "x": -2}
+ foo.insert(),
+ {"id": 6, "data": 5, "x": -2},
)
@config.requirements.sqlalchemy_094
@config.requirements.unnamed_constraints
def test_drop_foreign_key(self):
bar = Table(
- 'bar', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('foo_id', Integer, ForeignKey('foo.id')),
- mysql_engine='InnoDB'
+ "bar",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("foo_id", Integer, ForeignKey("foo.id")),
+ mysql_engine="InnoDB",
)
bar.create(self.conn)
- self.conn.execute(bar.insert(), {'id': 1, 'foo_id': 3})
+ self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
naming_convention = {
- "fk":
- "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"
}
with self.op.batch_alter_table(
- "bar", naming_convention=naming_convention) as batch_op:
- batch_op.drop_constraint(
- "fk_bar_foo_id_foo", type_="foreignkey")
- eq_(
- Inspector.from_engine(self.conn).get_foreign_keys('bar'),
- []
- )
+ "bar", naming_convention=naming_convention
+ ) as batch_op:
+ batch_op.drop_constraint("fk_bar_foo_id_foo", type_="foreignkey")
+ eq_(Inspector.from_engine(self.conn).get_foreign_keys("bar"), [])
def test_drop_column_fk_recreate(self):
- with self.op.batch_alter_table("foo", recreate='always') as batch_op:
- batch_op.drop_column('data')
+ with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+ batch_op.drop_column("data")
- self._assert_data([
- {"id": 1, "x": 5},
- {"id": 2, "x": 6},
- {"id": 3, "x": 7},
- {"id": 4, "x": 8},
- {"id": 5, "x": 9}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "x": 5},
+ {"id": 2, "x": 6},
+ {"id": 3, "x": 7},
+ {"id": 4, "x": 8},
+ {"id": 5, "x": 9},
+ ]
+ )
def test_rename_column(self):
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.alter_column('x', new_column_name='y')
+ batch_op.alter_column("x", new_column_name="y")
- self._assert_data([
- {"id": 1, "data": "d1", "y": 5},
- {"id": 2, "data": "22", "y": 6},
- {"id": 3, "data": "8.5", "y": 7},
- {"id": 4, "data": "9.46", "y": 8},
- {"id": 5, "data": "d5", "y": 9}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "data": "d1", "y": 5},
+ {"id": 2, "data": "22", "y": 6},
+ {"id": 3, "data": "8.5", "y": 7},
+ {"id": 4, "data": "9.46", "y": 8},
+ {"id": 5, "data": "d5", "y": 9},
+ ]
+ )
def test_rename_column_boolean(self):
bar = Table(
- 'bar', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('flag', Boolean()),
- mysql_engine='InnoDB'
+ "bar",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("flag", Boolean()),
+ mysql_engine="InnoDB",
)
bar.create(self.conn)
- self.conn.execute(bar.insert(), {'id': 1, 'flag': True})
- self.conn.execute(bar.insert(), {'id': 2, 'flag': False})
+ self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+ self.conn.execute(bar.insert(), {"id": 2, "flag": False})
- with self.op.batch_alter_table(
- "bar"
- ) as batch_op:
+ with self.op.batch_alter_table("bar") as batch_op:
batch_op.alter_column(
- 'flag', new_column_name='bflag', existing_type=Boolean)
+ "flag", new_column_name="bflag", existing_type=Boolean
+ )
- self._assert_data([
- {"id": 1, 'bflag': True},
- {"id": 2, 'bflag': False},
- ], 'bar')
+ self._assert_data(
+ [{"id": 1, "bflag": True}, {"id": 2, "bflag": False}], "bar"
+ )
@config.requirements.non_native_boolean
def test_rename_column_non_native_boolean_no_ck(self):
bar = Table(
- 'bar', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('flag', Boolean(create_constraint=False)),
- mysql_engine='InnoDB'
+ "bar",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("flag", Boolean(create_constraint=False)),
+ mysql_engine="InnoDB",
)
bar.create(self.conn)
- self.conn.execute(bar.insert(), {'id': 1, 'flag': True})
- self.conn.execute(bar.insert(), {'id': 2, 'flag': False})
+ self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+ self.conn.execute(bar.insert(), {"id": 2, "flag": False})
self.conn.execute(
# override Boolean type which as of 1.1 coerces numerics
# to 1/0
text("insert into bar (id, flag) values (:id, :flag)"),
- {'id': 3, 'flag': 5})
+ {"id": 3, "flag": 5},
+ )
with self.op.batch_alter_table(
"bar",
- reflect_args=[Column('flag', Boolean(create_constraint=False))]
+ reflect_args=[Column("flag", Boolean(create_constraint=False))],
) as batch_op:
batch_op.alter_column(
- 'flag', new_column_name='bflag', existing_type=Boolean)
+ "flag", new_column_name="bflag", existing_type=Boolean
+ )
- self._assert_data([
- {"id": 1, 'bflag': True},
- {"id": 2, 'bflag': False},
- {'id': 3, 'bflag': 5}
- ], 'bar')
+ self._assert_data(
+ [
+ {"id": 1, "bflag": True},
+ {"id": 2, "bflag": False},
+ {"id": 3, "bflag": 5},
+ ],
+ "bar",
+ )
def test_drop_column_pk(self):
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.drop_column('id')
+ batch_op.drop_column("id")
- self._assert_data([
- {"data": "d1", "x": 5},
- {"data": "22", "x": 6},
- {"data": "8.5", "x": 7},
- {"data": "9.46", "x": 8},
- {"data": "d5", "x": 9}
- ])
+ self._assert_data(
+ [
+ {"data": "d1", "x": 5},
+ {"data": "22", "x": 6},
+ {"data": "8.5", "x": 7},
+ {"data": "9.46", "x": 8},
+ {"data": "d5", "x": 9},
+ ]
+ )
def test_rename_column_pk(self):
with self.op.batch_alter_table("foo") as batch_op:
- batch_op.alter_column('id', new_column_name='ident')
+ batch_op.alter_column("id", new_column_name="ident")
- self._assert_data([
- {"ident": 1, "data": "d1", "x": 5},
- {"ident": 2, "data": "22", "x": 6},
- {"ident": 3, "data": "8.5", "x": 7},
- {"ident": 4, "data": "9.46", "x": 8},
- {"ident": 5, "data": "d5", "x": 9}
- ])
+ self._assert_data(
+ [
+ {"ident": 1, "data": "d1", "x": 5},
+ {"ident": 2, "data": "22", "x": 6},
+ {"ident": 3, "data": "8.5", "x": 7},
+ {"ident": 4, "data": "9.46", "x": 8},
+ {"ident": 5, "data": "d5", "x": 9},
+ ]
+ )
def test_add_column_auto(self):
# note this uses ALTER
with self.op.batch_alter_table("foo") as batch_op:
batch_op.add_column(
- Column('data2', String(50), server_default='hi'))
+ Column("data2", String(50), server_default="hi")
+ )
- self._assert_data([
- {"id": 1, "data": "d1", "x": 5, 'data2': 'hi'},
- {"id": 2, "data": "22", "x": 6, 'data2': 'hi'},
- {"id": 3, "data": "8.5", "x": 7, 'data2': 'hi'},
- {"id": 4, "data": "9.46", "x": 8, 'data2': 'hi'},
- {"id": 5, "data": "d5", "x": 9, 'data2': 'hi'}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "data": "d1", "x": 5, "data2": "hi"},
+ {"id": 2, "data": "22", "x": 6, "data2": "hi"},
+ {"id": 3, "data": "8.5", "x": 7, "data2": "hi"},
+ {"id": 4, "data": "9.46", "x": 8, "data2": "hi"},
+ {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
+ ]
+ )
def test_add_column_recreate(self):
- with self.op.batch_alter_table("foo", recreate='always') as batch_op:
+ with self.op.batch_alter_table("foo", recreate="always") as batch_op:
batch_op.add_column(
- Column('data2', String(50), server_default='hi'))
+ Column("data2", String(50), server_default="hi")
+ )
- self._assert_data([
- {"id": 1, "data": "d1", "x": 5, 'data2': 'hi'},
- {"id": 2, "data": "22", "x": 6, 'data2': 'hi'},
- {"id": 3, "data": "8.5", "x": 7, 'data2': 'hi'},
- {"id": 4, "data": "9.46", "x": 8, 'data2': 'hi'},
- {"id": 5, "data": "d5", "x": 9, 'data2': 'hi'}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "data": "d1", "x": 5, "data2": "hi"},
+ {"id": 2, "data": "22", "x": 6, "data2": "hi"},
+ {"id": 3, "data": "8.5", "x": 7, "data2": "hi"},
+ {"id": 4, "data": "9.46", "x": 8, "data2": "hi"},
+ {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
+ ]
+ )
def test_create_drop_index(self):
insp = Inspector.from_engine(config.db)
- eq_(
- insp.get_indexes('foo'), []
- )
+ eq_(insp.get_indexes("foo"), [])
- with self.op.batch_alter_table("foo", recreate='always') as batch_op:
- batch_op.create_index(
- 'ix_data', ['data'], unique=True)
+ with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+ batch_op.create_index("ix_data", ["data"], unique=True)
- self._assert_data([
- {"id": 1, "data": "d1", "x": 5},
- {"id": 2, "data": "22", "x": 6},
- {"id": 3, "data": "8.5", "x": 7},
- {"id": 4, "data": "9.46", "x": 8},
- {"id": 5, "data": "d5", "x": 9}
- ])
+ self._assert_data(
+ [
+ {"id": 1, "data": "d1", "x": 5},
+ {"id": 2, "data": "22", "x": 6},
+ {"id": 3, "data": "8.5", "x": 7},
+ {"id": 4, "data": "9.46", "x": 8},
+ {"id": 5, "data": "d5", "x": 9},
+ ]
+ )
insp = Inspector.from_engine(config.db)
eq_(
[
- dict(unique=ix['unique'],
- name=ix['name'],
- column_names=ix['column_names'])
- for ix in insp.get_indexes('foo')
+ dict(
+ unique=ix["unique"],
+ name=ix["name"],
+ column_names=ix["column_names"],
+ )
+ for ix in insp.get_indexes("foo")
],
- [{'unique': True, 'name': 'ix_data', 'column_names': ['data']}]
+ [{"unique": True, "name": "ix_data", "column_names": ["data"]}],
)
- with self.op.batch_alter_table("foo", recreate='always') as batch_op:
- batch_op.drop_index('ix_data')
+ with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+ batch_op.drop_index("ix_data")
insp = Inspector.from_engine(config.db)
- eq_(
- insp.get_indexes('foo'), []
- )
+ eq_(insp.get_indexes("foo"), [])
class BatchRoundTripMySQLTest(BatchRoundTripTest):
@@ -1496,7 +1656,8 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
@exclusions.fails()
def test_drop_pk_col_readd_pk_col(self):
super(
- BatchRoundTripPostgresqlTest, self).test_drop_pk_col_readd_pk_col()
+ BatchRoundTripPostgresqlTest, self
+ ).test_drop_pk_col_readd_pk_col()
@exclusions.fails()
def test_drop_pk_col_readd_col_also_pk_const(self):
@@ -1513,10 +1674,12 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
@exclusions.fails()
def test_change_type_int_to_boolean(self):
- super(BatchRoundTripPostgresqlTest, self).\
- test_change_type_int_to_boolean()
+ super(
+ BatchRoundTripPostgresqlTest, self
+ ).test_change_type_int_to_boolean()
@exclusions.fails()
def test_change_type_boolean_to_int(self):
- super(BatchRoundTripPostgresqlTest, self).\
- test_change_type_boolean_to_int()
+ super(
+ BatchRoundTripPostgresqlTest, self
+ ).test_change_type_boolean_to_int()
diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py
index 2655630..220719a 100644
--- a/tests/test_bulk_insert.py
+++ b/tests/test_bulk_insert.py
@@ -13,127 +13,131 @@ from alembic.testing import eq_, assert_raises_message, config
class BulkInsertTest(TestBase):
def _table_fixture(self, dialect, as_sql):
context = op_fixture(dialect, as_sql)
- t1 = table("ins_table",
- column('id', Integer),
- column('v1', String()),
- column('v2', String()),
- )
+ t1 = table(
+ "ins_table",
+ column("id", Integer),
+ column("v1", String()),
+ column("v2", String()),
+ )
return context, t1
def _big_t_table_fixture(self, dialect, as_sql):
context = op_fixture(dialect, as_sql)
- t1 = Table("ins_table", MetaData(),
- Column('id', Integer, primary_key=True),
- Column('v1', String()),
- Column('v2', String()),
- )
+ t1 = Table(
+ "ins_table",
+ MetaData(),
+ Column("id", Integer, primary_key=True),
+ Column("v1", String()),
+ Column("v2", String()),
+ )
return context, t1
def _test_bulk_insert(self, dialect, as_sql):
context, t1 = self._table_fixture(dialect, as_sql)
- op.bulk_insert(t1, [
- {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
- {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
- {'id': 3, 'v1': 'row v3', 'v2': 'row v7'},
- {'id': 4, 'v1': 'row v4', 'v2': 'row v8'},
- ])
+ op.bulk_insert(
+ t1,
+ [
+ {"id": 1, "v1": "row v1", "v2": "row v5"},
+ {"id": 2, "v1": "row v2", "v2": "row v6"},
+ {"id": 3, "v1": "row v3", "v2": "row v7"},
+ {"id": 4, "v1": "row v4", "v2": "row v8"},
+ ],
+ )
return context
def _test_bulk_insert_single(self, dialect, as_sql):
context, t1 = self._table_fixture(dialect, as_sql)
- op.bulk_insert(t1, [
- {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
- ])
+ op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}])
return context
def _test_bulk_insert_single_bigt(self, dialect, as_sql):
context, t1 = self._big_t_table_fixture(dialect, as_sql)
- op.bulk_insert(t1, [
- {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
- ])
+ op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}])
return context
def test_bulk_insert(self):
- context = self._test_bulk_insert('default', False)
+ context = self._test_bulk_insert("default", False)
context.assert_(
- 'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
+ "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)"
)
def test_bulk_insert_wrong_cols(self):
- context = op_fixture('postgresql')
- t1 = table("ins_table",
- column('id', Integer),
- column('v1', String()),
- column('v2', String()),
- )
- op.bulk_insert(t1, [
- {'v1': 'row v1', },
- ])
+ context = op_fixture("postgresql")
+ t1 = table(
+ "ins_table",
+ column("id", Integer),
+ column("v1", String()),
+ column("v2", String()),
+ )
+ op.bulk_insert(t1, [{"v1": "row v1"}])
context.assert_(
- 'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
+ "INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)"
)
def test_bulk_insert_no_rows(self):
- context, t1 = self._table_fixture('default', False)
+ context, t1 = self._table_fixture("default", False)
op.bulk_insert(t1, [])
context.assert_()
def test_bulk_insert_pg(self):
- context = self._test_bulk_insert('postgresql', False)
+ context = self._test_bulk_insert("postgresql", False)
context.assert_(
- 'INSERT INTO ins_table (id, v1, v2) '
- 'VALUES (%(id)s, %(v1)s, %(v2)s)'
+ "INSERT INTO ins_table (id, v1, v2) "
+ "VALUES (%(id)s, %(v1)s, %(v2)s)"
)
def test_bulk_insert_pg_single(self):
- context = self._test_bulk_insert_single('postgresql', False)
+ context = self._test_bulk_insert_single("postgresql", False)
context.assert_(
- 'INSERT INTO ins_table (id, v1, v2) '
- 'VALUES (%(id)s, %(v1)s, %(v2)s)'
+ "INSERT INTO ins_table (id, v1, v2) "
+ "VALUES (%(id)s, %(v1)s, %(v2)s)"
)
def test_bulk_insert_pg_single_as_sql(self):
- context = self._test_bulk_insert_single('postgresql', True)
+ context = self._test_bulk_insert_single("postgresql", True)
context.assert_(
"INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')"
)
def test_bulk_insert_pg_single_big_t_as_sql(self):
- context = self._test_bulk_insert_single_bigt('postgresql', True)
+ context = self._test_bulk_insert_single_bigt("postgresql", True)
context.assert_(
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (1, 'row v1', 'row v5')"
)
def test_bulk_insert_mssql(self):
- context = self._test_bulk_insert('mssql', False)
+ context = self._test_bulk_insert("mssql", False)
context.assert_(
- 'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
+ "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)"
)
def test_bulk_insert_inline_literal_as_sql(self):
- context = op_fixture('postgresql', True)
+ context = op_fixture("postgresql", True)
class MyType(TypeEngine):
pass
- t1 = table('t', column('id', Integer), column('data', MyType()))
+ t1 = table("t", column("id", Integer), column("data", MyType()))
- op.bulk_insert(t1, [
- {'id': 1, 'data': op.inline_literal('d1')},
- {'id': 2, 'data': op.inline_literal('d2')},
- ])
+ op.bulk_insert(
+ t1,
+ [
+ {"id": 1, "data": op.inline_literal("d1")},
+ {"id": 2, "data": op.inline_literal("d2")},
+ ],
+ )
context.assert_(
"INSERT INTO t (id, data) VALUES (1, 'd1')",
- "INSERT INTO t (id, data) VALUES (2, 'd2')"
+ "INSERT INTO t (id, data) VALUES (2, 'd2')",
)
def test_bulk_insert_as_sql(self):
- context = self._test_bulk_insert('default', True)
+ context = self._test_bulk_insert("default", True)
context.assert_(
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (1, 'row v1', 'row v5')",
@@ -142,11 +146,11 @@ class BulkInsertTest(TestBase):
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (3, 'row v3', 'row v7')",
"INSERT INTO ins_table (id, v1, v2) "
- "VALUES (4, 'row v4', 'row v8')"
+ "VALUES (4, 'row v4', 'row v8')",
)
def test_bulk_insert_as_sql_pg(self):
- context = self._test_bulk_insert('postgresql', True)
+ context = self._test_bulk_insert("postgresql", True)
context.assert_(
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (1, 'row v1', 'row v5')",
@@ -155,65 +159,68 @@ class BulkInsertTest(TestBase):
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (3, 'row v3', 'row v7')",
"INSERT INTO ins_table (id, v1, v2) "
- "VALUES (4, 'row v4', 'row v8')"
+ "VALUES (4, 'row v4', 'row v8')",
)
def test_bulk_insert_as_sql_mssql(self):
- context = self._test_bulk_insert('mssql', True)
+ context = self._test_bulk_insert("mssql", True)
# SQL server requires IDENTITY_INSERT
# TODO: figure out if this is safe to enable for a table that
# doesn't have an IDENTITY column
context.assert_(
- 'SET IDENTITY_INSERT ins_table ON',
- 'GO',
+ "SET IDENTITY_INSERT ins_table ON",
+ "GO",
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (1, 'row v1', 'row v5')",
- 'GO',
+ "GO",
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (2, 'row v2', 'row v6')",
- 'GO',
+ "GO",
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (3, 'row v3', 'row v7')",
- 'GO',
+ "GO",
"INSERT INTO ins_table (id, v1, v2) "
"VALUES (4, 'row v4', 'row v8')",
- 'GO',
- 'SET IDENTITY_INSERT ins_table OFF',
- 'GO',
+ "GO",
+ "SET IDENTITY_INSERT ins_table OFF",
+ "GO",
)
def test_bulk_insert_from_new_table(self):
context = op_fixture("postgresql", True)
t1 = op.create_table(
"ins_table",
- Column('id', Integer),
- Column('v1', String()),
- Column('v2', String()),
+ Column("id", Integer),
+ Column("v1", String()),
+ Column("v2", String()),
+ )
+ op.bulk_insert(
+ t1,
+ [
+ {"id": 1, "v1": "row v1", "v2": "row v5"},
+ {"id": 2, "v1": "row v2", "v2": "row v6"},
+ ],
)
- op.bulk_insert(t1, [
- {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
- {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
- ])
context.assert_(
- 'CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)',
+ "CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)",
"INSERT INTO ins_table (id, v1, v2) VALUES "
"(1, 'row v1', 'row v5')",
"INSERT INTO ins_table (id, v1, v2) VALUES "
- "(2, 'row v2', 'row v6')"
+ "(2, 'row v2', 'row v6')",
)
def test_invalid_format(self):
context, t1 = self._table_fixture("sqlite", False)
assert_raises_message(
- TypeError,
- "List expected",
- op.bulk_insert, t1, {"id": 5}
+ TypeError, "List expected", op.bulk_insert, t1, {"id": 5}
)
assert_raises_message(
TypeError,
"List of dictionaries expected",
- op.bulk_insert, t1, [(5, )]
+ op.bulk_insert,
+ t1,
+ [(5,)],
)
@@ -223,86 +230,85 @@ class RoundTripTest(TestBase):
def setUp(self):
from sqlalchemy import create_engine
from alembic.migration import MigrationContext
+
self.conn = config.db.connect()
- self.conn.execute("""
+ self.conn.execute(
+ """
create table foo(
id integer primary key,
data varchar(50),
x integer
)
- """)
+ """
+ )
context = MigrationContext.configure(self.conn)
self.op = op.Operations(context)
- self.t1 = table('foo',
- column('id'),
- column('data'),
- column('x')
- )
+ self.t1 = table("foo", column("id"), column("data"), column("x"))
def tearDown(self):
self.conn.execute("drop table foo")
self.conn.close()
def test_single_insert_round_trip(self):
- self.op.bulk_insert(self.t1,
- [{'data': "d1", "x": "x1"}]
- )
+ self.op.bulk_insert(self.t1, [{"data": "d1", "x": "x1"}])
eq_(
self.conn.execute("select id, data, x from foo").fetchall(),
- [
- (1, "d1", "x1"),
- ]
+ [(1, "d1", "x1")],
)
def test_bulk_insert_round_trip(self):
- self.op.bulk_insert(self.t1, [
- {'data': "d1", "x": "x1"},
- {'data': "d2", "x": "x2"},
- {'data': "d3", "x": "x3"},
- ])
+ self.op.bulk_insert(
+ self.t1,
+ [
+ {"data": "d1", "x": "x1"},
+ {"data": "d2", "x": "x2"},
+ {"data": "d3", "x": "x3"},
+ ],
+ )
eq_(
self.conn.execute("select id, data, x from foo").fetchall(),
- [
- (1, "d1", "x1"),
- (2, "d2", "x2"),
- (3, "d3", "x3")
- ]
+ [(1, "d1", "x1"), (2, "d2", "x2"), (3, "d3", "x3")],
)
def test_bulk_insert_inline_literal(self):
class MyType(TypeEngine):
pass
- t1 = table('foo', column('id', Integer), column('data', MyType()))
+ t1 = table("foo", column("id", Integer), column("data", MyType()))
- self.op.bulk_insert(t1, [
- {'id': 1, 'data': self.op.inline_literal('d1')},
- {'id': 2, 'data': self.op.inline_literal('d2')},
- ], multiinsert=False)
+ self.op.bulk_insert(
+ t1,
+ [
+ {"id": 1, "data": self.op.inline_literal("d1")},
+ {"id": 2, "data": self.op.inline_literal("d2")},
+ ],
+ multiinsert=False,
+ )
eq_(
self.conn.execute("select id, data from foo").fetchall(),
- [
- (1, "d1"),
- (2, "d2"),
- ]
+ [(1, "d1"), (2, "d2")],
)
def test_bulk_insert_from_new_table(self):
t1 = self.op.create_table(
"ins_table",
- Column('id', Integer),
- Column('v1', String()),
- Column('v2', String()),
+ Column("id", Integer),
+ Column("v1", String()),
+ Column("v2", String()),
+ )
+ self.op.bulk_insert(
+ t1,
+ [
+ {"id": 1, "v1": "row v1", "v2": "row v5"},
+ {"id": 2, "v1": "row v2", "v2": "row v6"},
+ ],
)
- self.op.bulk_insert(t1, [
- {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
- {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
- ])
eq_(
self.conn.execute(
- "select id, v1, v2 from ins_table order by id").fetchall(),
- [(1, u'row v1', u'row v5'), (2, u'row v2', u'row v6')]
- ) \ No newline at end of file
+ "select id, v1, v2 from ins_table order by id"
+ ).fetchall(),
+ [(1, u"row v1", u"row v5"), (2, u"row v2", u"row v6")],
+ )
diff --git a/tests/test_command.py b/tests/test_command.py
index 3f3daf5..a9f0e5d 100644
--- a/tests/test_command.py
+++ b/tests/test_command.py
@@ -3,9 +3,16 @@ from io import TextIOWrapper, BytesIO
from alembic.script import ScriptDirectory
from alembic import config
from alembic.testing.fixtures import TestBase, capture_context_buffer
-from alembic.testing.env import staging_env, _sqlite_testing_config, \
- three_rev_fixture, clear_staging_env, _no_sql_testing_config, \
- _sqlite_file_db, write_script, env_file_fixture
+from alembic.testing.env import (
+ staging_env,
+ _sqlite_testing_config,
+ three_rev_fixture,
+ clear_staging_env,
+ _no_sql_testing_config,
+ _sqlite_file_db,
+ write_script,
+ env_file_fixture,
+)
from alembic.testing import eq_, assert_raises_message, mock, assert_raises
from alembic import util
from contextlib import contextmanager
@@ -18,13 +25,12 @@ class _BufMixin(object):
# try to simulate how sys.stdout looks - we send it u''
# but then it's trying to encode to something.
buf = BytesIO()
- wrapper = TextIOWrapper(buf, encoding='ascii', line_buffering=True)
+ wrapper = TextIOWrapper(buf, encoding="ascii", line_buffering=True)
wrapper.getvalue = buf.getvalue
return wrapper
class HistoryTest(_BufMixin, TestBase):
-
@classmethod
def setup_class(cls):
cls.env = staging_env()
@@ -41,7 +47,8 @@ class HistoryTest(_BufMixin, TestBase):
@classmethod
def _setup_env_file(self):
- env_file_fixture(r"""
+ env_file_fixture(
+ r"""
from sqlalchemy import MetaData, engine_from_config
target_metadata = MetaData()
@@ -63,7 +70,8 @@ try:
finally:
connection.close()
-""")
+"""
+ )
def _eq_cmd_output(self, buf, expected, env_token=False, currents=()):
script = ScriptDirectory.from_config(self.cfg)
@@ -82,9 +90,11 @@ finally:
assert_lines.insert(0, "environment included OK")
eq_(
- buf.getvalue().decode("ascii", 'replace').strip(),
- "\n".join(assert_lines).
- encode("ascii", "replace").decode("ascii").strip()
+ buf.getvalue().decode("ascii", "replace").strip(),
+ "\n".join(assert_lines)
+ .encode("ascii", "replace")
+ .decode("ascii")
+ .strip(),
)
def test_history_full(self):
@@ -163,11 +173,11 @@ finally:
self.cfg.stdout = buf = self._buf_fixture()
command.history(self.cfg, indicate_current=True, verbose=True)
self._eq_cmd_output(
- buf, [self.c, self.b, self.a], currents=(self.b,), env_token=True)
+ buf, [self.c, self.b, self.a], currents=(self.b,), env_token=True
+ )
class CurrentTest(_BufMixin, TestBase):
-
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
@@ -189,11 +199,15 @@ class CurrentTest(_BufMixin, TestBase):
yield
- lines = set([
- re.match(r'(^.\w)', elem).group(1)
- for elem in re.split(
- "\n",
- buf.getvalue().decode('ascii', 'replace').strip()) if elem])
+ lines = set(
+ [
+ re.match(r"(^.\w)", elem).group(1)
+ for elem in re.split(
+ "\n", buf.getvalue().decode("ascii", "replace").strip()
+ )
+ if elem
+ ]
+ )
eq_(lines, set(revs))
@@ -205,25 +219,25 @@ class CurrentTest(_BufMixin, TestBase):
def test_plain_current(self):
command.stamp(self.cfg, ())
command.stamp(self.cfg, self.a3.revision)
- with self._assert_lines(['a3']):
+ with self._assert_lines(["a3"]):
command.current(self.cfg)
def test_two_heads(self):
command.stamp(self.cfg, ())
command.stamp(self.cfg, (self.a1.revision, self.b1.revision))
- with self._assert_lines(['a1', 'b1']):
+ with self._assert_lines(["a1", "b1"]):
command.current(self.cfg)
def test_heads_one_is_dependent(self):
command.stamp(self.cfg, ())
- command.stamp(self.cfg, (self.b2.revision, ))
- with self._assert_lines(['a2', 'b2']):
+ command.stamp(self.cfg, (self.b2.revision,))
+ with self._assert_lines(["a2", "b2"]):
command.current(self.cfg)
def test_heads_upg(self):
- command.stamp(self.cfg, (self.b2.revision, ))
+ command.stamp(self.cfg, (self.b2.revision,))
command.upgrade(self.cfg, (self.b3.revision))
- with self._assert_lines(['a2', 'b3']):
+ with self._assert_lines(["a2", "b3"]):
command.current(self.cfg)
@@ -236,7 +250,8 @@ class RevisionTest(TestBase):
clear_staging_env()
def _env_fixture(self, version_table_pk=True):
- env_file_fixture("""
+ env_file_fixture(
+ """
from sqlalchemy import MetaData, engine_from_config
target_metadata = MetaData()
@@ -258,7 +273,9 @@ try:
finally:
connection.close()
-""" % (version_table_pk, ))
+"""
+ % (version_table_pk,)
+ )
def test_create_rev_plain_db_not_up_to_date(self):
self._env_fixture()
@@ -275,7 +292,9 @@ finally:
assert_raises_message(
util.CommandError,
"Target database is not up to date.",
- command.revision, self.cfg, autogenerate=True
+ command.revision,
+ self.cfg,
+ autogenerate=True,
)
def test_create_rev_autogen_db_not_up_to_date_multi_heads(self):
@@ -290,7 +309,9 @@ finally:
assert_raises_message(
util.CommandError,
"Target database is not up to date.",
- command.revision, self.cfg, autogenerate=True
+ command.revision,
+ self.cfg,
+ autogenerate=True,
)
def test_create_rev_plain_db_not_up_to_date_multi_heads(self):
@@ -306,7 +327,8 @@ finally:
util.CommandError,
"Multiple heads are present; please specify the head revision "
"on which the new revision should be based, or perform a merge.",
- command.revision, self.cfg
+ command.revision,
+ self.cfg,
)
def test_create_rev_autogen_need_to_select_head(self):
@@ -321,7 +343,9 @@ finally:
util.CommandError,
"Multiple heads are present; please specify the head revision "
"on which the new revision should be based, or perform a merge.",
- command.revision, self.cfg, autogenerate=True
+ command.revision,
+ self.cfg,
+ autogenerate=True,
)
def test_pk_constraint_normally_prevents_dupe_rows(self):
@@ -333,7 +357,7 @@ finally:
assert_raises(
sqla_exc.IntegrityError,
db.execute,
- "insert into alembic_version values ('%s')" % r2.revision
+ "insert into alembic_version values ('%s')" % r2.revision,
)
def test_err_correctly_raised_on_dupe_rows_no_pk(self):
@@ -347,7 +371,9 @@ finally:
util.CommandError,
"Online migration expected to match one row when "
"updating .* in 'alembic_version'; 2 found",
- command.downgrade, self.cfg, "-1"
+ command.downgrade,
+ self.cfg,
+ "-1",
)
def test_create_rev_plain_need_to_select_head(self):
@@ -362,7 +388,8 @@ finally:
util.CommandError,
"Multiple heads are present; please specify the head revision "
"on which the new revision should be based, or perform a merge.",
- command.revision, self.cfg
+ command.revision,
+ self.cfg,
)
def test_create_rev_plain_post_merge(self):
@@ -389,27 +416,20 @@ finally:
command.revision(self.cfg)
rev2 = command.revision(self.cfg)
rev3 = command.revision(self.cfg, depends_on=rev2.revision)
- eq_(
- rev3._resolved_dependencies, (rev2.revision, )
- )
+ eq_(rev3._resolved_dependencies, (rev2.revision,))
rev4 = command.revision(
- self.cfg, depends_on=[rev2.revision, rev3.revision])
- eq_(
- rev4._resolved_dependencies, (rev2.revision, rev3.revision)
+ self.cfg, depends_on=[rev2.revision, rev3.revision]
)
+ eq_(rev4._resolved_dependencies, (rev2.revision, rev3.revision))
def test_create_rev_depends_on_branch_label(self):
self._env_fixture()
command.revision(self.cfg)
- rev2 = command.revision(self.cfg, branch_label='foobar')
- rev3 = command.revision(self.cfg, depends_on='foobar')
- eq_(
- rev3.dependencies, 'foobar'
- )
- eq_(
- rev3._resolved_dependencies, (rev2.revision, )
- )
+ rev2 = command.revision(self.cfg, branch_label="foobar")
+ rev3 = command.revision(self.cfg, depends_on="foobar")
+ eq_(rev3.dependencies, "foobar")
+ eq_(rev3._resolved_dependencies, (rev2.revision,))
def test_create_rev_depends_on_partial_revid(self):
self._env_fixture()
@@ -417,12 +437,8 @@ finally:
rev2 = command.revision(self.cfg)
assert len(rev2.revision) > 7
rev3 = command.revision(self.cfg, depends_on=rev2.revision[0:4])
- eq_(
- rev3.dependencies, rev2.revision
- )
- eq_(
- rev3._resolved_dependencies, (rev2.revision, )
- )
+ eq_(rev3.dependencies, rev2.revision)
+ eq_(rev3._resolved_dependencies, (rev2.revision,))
def test_create_rev_invalid_depends_on(self):
self._env_fixture()
@@ -430,7 +446,9 @@ finally:
assert_raises_message(
util.CommandError,
"Can't locate revision identified by 'invalid'",
- command.revision, self.cfg, depends_on='invalid'
+ command.revision,
+ self.cfg,
+ depends_on="invalid",
)
def test_create_rev_autogenerate_db_not_up_to_date_post_merge(self):
@@ -444,7 +462,9 @@ finally:
assert_raises_message(
util.CommandError,
"Target database is not up to date.",
- command.revision, self.cfg, autogenerate=True
+ command.revision,
+ self.cfg,
+ autogenerate=True,
)
def test_nonsensical_sql_mode_autogen(self):
@@ -452,7 +472,10 @@ finally:
assert_raises_message(
util.CommandError,
"Using --sql with --autogenerate does not make any sense",
- command.revision, self.cfg, autogenerate=True, sql=True
+ command.revision,
+ self.cfg,
+ autogenerate=True,
+ sql=True,
)
def test_nonsensical_sql_no_env(self):
@@ -461,7 +484,9 @@ finally:
util.CommandError,
"Using --sql with the revision command when revision_environment "
"is not configured does not make any sense",
- command.revision, self.cfg, sql=True
+ command.revision,
+ self.cfg,
+ sql=True,
)
def test_sensical_sql_w_env(self):
@@ -471,12 +496,11 @@ finally:
class UpgradeDowngradeStampTest(TestBase):
-
def setUp(self):
self.env = staging_env()
self.cfg = cfg = _no_sql_testing_config()
- cfg.set_main_option('dialect_name', 'sqlite')
- cfg.remove_main_option('url')
+ cfg.set_main_option("dialect_name", "sqlite")
+ cfg.remove_main_option("url")
self.a, self.b, self.c = three_rev_fixture(cfg)
@@ -559,7 +583,7 @@ class UpgradeDowngradeStampTest(TestBase):
class LiveStampTest(TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
def setUp(self):
self.bind = _sqlite_file_db()
@@ -569,15 +593,25 @@ class LiveStampTest(TestBase):
self.b = b = util.rev_id()
script = ScriptDirectory.from_config(self.cfg)
script.generate_revision(a, None, refresh=True)
- write_script(script, a, """
+ write_script(
+ script,
+ a,
+ """
revision = '%s'
down_revision = None
-""" % a)
+"""
+ % a,
+ )
script.generate_revision(b, None, refresh=True)
- write_script(script, b, """
+ write_script(
+ script,
+ b,
+ """
revision = '%s'
down_revision = '%s'
-""" % (b, a))
+"""
+ % (b, a),
+ )
def tearDown(self):
clear_staging_env()
@@ -585,29 +619,25 @@ down_revision = '%s'
def test_stamp_creates_table(self):
command.stamp(self.cfg, "head")
eq_(
- self.bind.scalar("select version_num from alembic_version"),
- self.b
+ self.bind.scalar("select version_num from alembic_version"), self.b
)
def test_stamp_existing_upgrade(self):
command.stamp(self.cfg, self.a)
command.stamp(self.cfg, self.b)
eq_(
- self.bind.scalar("select version_num from alembic_version"),
- self.b
+ self.bind.scalar("select version_num from alembic_version"), self.b
)
def test_stamp_existing_downgrade(self):
command.stamp(self.cfg, self.b)
command.stamp(self.cfg, self.a)
eq_(
- self.bind.scalar("select version_num from alembic_version"),
- self.a
+ self.bind.scalar("select version_num from alembic_version"), self.a
)
class EditTest(TestBase):
-
@classmethod
def setup_class(cls):
cls.env = staging_env()
@@ -622,56 +652,61 @@ class EditTest(TestBase):
command.stamp(self.cfg, "base")
def test_edit_head(self):
- expected_call_arg = '%s/scripts/versions/%s_revision_c.py' % (
- EditTest.cfg.config_args['here'],
- EditTest.c
+ expected_call_arg = "%s/scripts/versions/%s_revision_c.py" % (
+ EditTest.cfg.config_args["here"],
+ EditTest.c,
)
- with mock.patch('alembic.util.edit') as edit:
+ with mock.patch("alembic.util.edit") as edit:
command.edit(self.cfg, "head")
edit.assert_called_with(expected_call_arg)
def test_edit_b(self):
- expected_call_arg = '%s/scripts/versions/%s_revision_b.py' % (
- EditTest.cfg.config_args['here'],
- EditTest.b
+ expected_call_arg = "%s/scripts/versions/%s_revision_b.py" % (
+ EditTest.cfg.config_args["here"],
+ EditTest.b,
)
- with mock.patch('alembic.util.edit') as edit:
+ with mock.patch("alembic.util.edit") as edit:
command.edit(self.cfg, self.b[0:3])
edit.assert_called_with(expected_call_arg)
def test_edit_with_missing_editor(self):
- with mock.patch('editor.edit') as edit_mock:
+ with mock.patch("editor.edit") as edit_mock:
edit_mock.side_effect = OSError("file not found")
assert_raises_message(
util.CommandError,
- 'file not found',
+ "file not found",
util.edit,
- "/not/a/file.txt")
+ "/not/a/file.txt",
+ )
def test_edit_no_revs(self):
assert_raises_message(
util.CommandError,
"No revision files indicated by symbol 'base'",
command.edit,
- self.cfg, "base")
+ self.cfg,
+ "base",
+ )
def test_edit_no_current(self):
assert_raises_message(
util.CommandError,
"No current revisions",
command.edit,
- self.cfg, "current")
+ self.cfg,
+ "current",
+ )
def test_edit_current(self):
- expected_call_arg = '%s/scripts/versions/%s_revision_b.py' % (
- EditTest.cfg.config_args['here'],
- EditTest.b
+ expected_call_arg = "%s/scripts/versions/%s_revision_b.py" % (
+ EditTest.cfg.config_args["here"],
+ EditTest.b,
)
command.stamp(self.cfg, self.b)
- with mock.patch('alembic.util.edit') as edit:
+ with mock.patch("alembic.util.edit") as edit:
command.edit(self.cfg, "current")
edit.assert_called_with(expected_call_arg)
@@ -691,16 +726,21 @@ class CommandLineTest(TestBase):
# the command function has "process_revision_directives"
# however the ArgumentParser does not. ensure things work
def revision(
- config, message=None, autogenerate=False, sql=False,
- head="head", splice=False, branch_label=None,
- version_path=None, rev_id=None, depends_on=None,
- process_revision_directives=None
+ config,
+ message=None,
+ autogenerate=False,
+ sql=False,
+ head="head",
+ splice=False,
+ branch_label=None,
+ version_path=None,
+ rev_id=None,
+ depends_on=None,
+ process_revision_directives=None,
):
- canary(
- config, message=message
- )
+ canary(config, message=message)
- revision.__module__ = 'alembic.command'
+ revision.__module__ = "alembic.command"
# CommandLine() pulls the function into the ArgumentParser
# and needs the full signature, so we can't patch the "revision"
@@ -712,7 +752,4 @@ class CommandLineTest(TestBase):
commandline.run_cmd(self.cfg, options)
finally:
config.command.revision = orig_revision
- eq_(
- canary.mock_calls,
- [mock.call(self.cfg, message="foo")]
- )
+ eq_(canary.mock_calls, [mock.call(self.cfg, message="foo")])
diff --git a/tests/test_config.py b/tests/test_config.py
index 50e1b05..b1d1ca1 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -11,30 +11,35 @@ from alembic.testing.mock import Mock, call
from alembic.testing import eq_, assert_raises_message
from alembic.testing.fixtures import capture_db
-from alembic.testing.env import _no_sql_testing_config, clear_staging_env,\
- staging_env, _write_config_file
+from alembic.testing.env import (
+ _no_sql_testing_config,
+ clear_staging_env,
+ staging_env,
+ _write_config_file,
+)
class FileConfigTest(TestBase):
-
def test_config_args(self):
- cfg = _write_config_file("""
+ cfg = _write_config_file(
+ """
[alembic]
migrations = %(base_path)s/db/migrations
-""")
+"""
+ )
test_cfg = config.Config(
cfg.config_file_name, config_args=dict(base_path="/home/alembic")
)
eq_(
test_cfg.get_section_option("alembic", "migrations"),
- "/home/alembic/db/migrations")
+ "/home/alembic/db/migrations",
+ )
def tearDown(self):
clear_staging_env()
class ConfigTest(TestBase):
-
def test_config_no_file_main_option(self):
cfg = config.Config()
cfg.set_main_option("url", "postgresql://foo/bar")
@@ -66,12 +71,12 @@ class ConfigTest(TestBase):
cfg = config.Config()
cfg.set_section_option("some_section", "foob", "foob_value")
- cfg.set_section_option(
- "some_section", "bar", "bar with %(foob)s")
+ cfg.set_section_option("some_section", "bar", "bar with %(foob)s")
eq_(
cfg.get_section_option("some_section", "bar"),
- "bar with foob_value")
+ "bar with foob_value",
+ )
def test_standalone_op(self):
eng, buf = capture_db()
@@ -80,71 +85,58 @@ class ConfigTest(TestBase):
op = Operations(env)
op.alter_column("t", "c", nullable=True)
- eq_(buf, ['ALTER TABLE t ALTER COLUMN c DROP NOT NULL'])
+ eq_(buf, ["ALTER TABLE t ALTER COLUMN c DROP NOT NULL"])
def test_no_script_error(self):
cfg = config.Config()
assert_raises_message(
util.CommandError,
"No 'script_location' key found in configuration.",
- ScriptDirectory.from_config, cfg
+ ScriptDirectory.from_config,
+ cfg,
)
def test_attributes_attr(self):
m1 = Mock()
cfg = config.Config()
- cfg.attributes['connection'] = m1
- eq_(
- cfg.attributes['connection'], m1
- )
+ cfg.attributes["connection"] = m1
+ eq_(cfg.attributes["connection"], m1)
def test_attributes_construtor(self):
m1 = Mock()
m2 = Mock()
- cfg = config.Config(attributes={'m1': m1})
- cfg.attributes['connection'] = m2
- eq_(
- cfg.attributes, {'m1': m1, 'connection': m2}
- )
+ cfg = config.Config(attributes={"m1": m1})
+ cfg.attributes["connection"] = m2
+ eq_(cfg.attributes, {"m1": m1, "connection": m2})
class StdoutOutputEncodingTest(TestBase):
-
def test_plain(self):
- stdout = Mock(encoding='latin-1')
+ stdout = Mock(encoding="latin-1")
cfg = config.Config(stdout=stdout)
cfg.print_stdout("test %s %s", "x", "y")
- eq_(
- stdout.mock_calls,
- [call.write('test x y'), call.write('\n')]
- )
+ eq_(stdout.mock_calls, [call.write("test x y"), call.write("\n")])
def test_utf8_unicode(self):
- stdout = Mock(encoding='latin-1')
+ stdout = Mock(encoding="latin-1")
cfg = config.Config(stdout=stdout)
cfg.print_stdout(compat.u("méil %s %s"), "x", "y")
eq_(
stdout.mock_calls,
- [call.write(compat.u('méil x y')), call.write('\n')]
+ [call.write(compat.u("méil x y")), call.write("\n")],
)
def test_ascii_unicode(self):
stdout = Mock(encoding=None)
cfg = config.Config(stdout=stdout)
cfg.print_stdout(compat.u("méil %s %s"), "x", "y")
- eq_(
- stdout.mock_calls,
- [call.write('m?il x y'), call.write('\n')]
- )
+ eq_(stdout.mock_calls, [call.write("m?il x y"), call.write("\n")])
def test_only_formats_output_with_args(self):
stdout = Mock(encoding=None)
cfg = config.Config(stdout=stdout)
cfg.print_stdout(compat.u("test 3%"))
- eq_(
- stdout.mock_calls,
- [call.write('test 3%'), call.write('\n')]
- )
+ eq_(stdout.mock_calls, [call.write("test 3%"), call.write("\n")])
class TemplateOutputEncodingTest(TestBase):
@@ -157,9 +149,9 @@ class TemplateOutputEncodingTest(TestBase):
def test_default(self):
script = ScriptDirectory.from_config(self.cfg)
- eq_(script.output_encoding, 'utf-8')
+ eq_(script.output_encoding, "utf-8")
def test_setting(self):
- self.cfg.set_main_option('output_encoding', 'latin-1')
+ self.cfg.set_main_option("output_encoding", "latin-1")
script = ScriptDirectory.from_config(self.cfg)
- eq_(script.output_encoding, 'latin-1')
+ eq_(script.output_encoding, "latin-1")
diff --git a/tests/test_environment.py b/tests/test_environment.py
index 42ff328..cfa72f6 100644
--- a/tests/test_environment.py
+++ b/tests/test_environment.py
@@ -5,15 +5,19 @@ from alembic.environment import EnvironmentContext
from alembic.migration import MigrationContext
from alembic.testing.fixtures import TestBase
from alembic.testing.mock import Mock, call, MagicMock
-from alembic.testing.env import _no_sql_testing_config, \
- staging_env, clear_staging_env, write_script, _sqlite_file_db
+from alembic.testing.env import (
+ _no_sql_testing_config,
+ staging_env,
+ clear_staging_env,
+ write_script,
+ _sqlite_file_db,
+)
from alembic.testing.assertions import expect_warnings
from alembic.testing import eq_, is_
class EnvironmentTest(TestBase):
-
def setUp(self):
staging_env()
self.cfg = _no_sql_testing_config()
@@ -23,49 +27,30 @@ class EnvironmentTest(TestBase):
def _fixture(self, **kw):
script = ScriptDirectory.from_config(self.cfg)
- env = EnvironmentContext(
- self.cfg,
- script,
- **kw
- )
+ env = EnvironmentContext(self.cfg, script, **kw)
return env
def test_x_arg(self):
env = self._fixture()
self.cfg.cmd_opts = Mock(x="y=5")
- eq_(
- env.get_x_argument(),
- "y=5"
- )
+ eq_(env.get_x_argument(), "y=5")
def test_x_arg_asdict(self):
env = self._fixture()
self.cfg.cmd_opts = Mock(x=["y=5"])
- eq_(
- env.get_x_argument(as_dictionary=True),
- {"y": "5"}
- )
+ eq_(env.get_x_argument(as_dictionary=True), {"y": "5"})
def test_x_arg_no_opts(self):
env = self._fixture()
- eq_(
- env.get_x_argument(),
- []
- )
+ eq_(env.get_x_argument(), [])
def test_x_arg_no_opts_asdict(self):
env = self._fixture()
- eq_(
- env.get_x_argument(as_dictionary=True),
- {}
- )
+ eq_(env.get_x_argument(as_dictionary=True), {})
def test_tag_arg(self):
env = self._fixture(tag="x")
- eq_(
- env.get_tag_argument(),
- "x"
- )
+ eq_(env.get_tag_argument(), "x")
def test_migration_context_has_config(self):
env = self._fixture()
@@ -81,9 +66,12 @@ class EnvironmentTest(TestBase):
engine = _sqlite_file_db()
- a_rev = 'arev'
+ a_rev = "arev"
env.script.generate_revision(a_rev, "revision a", refresh=True)
- write_script(env.script, a_rev, """\
+ write_script(
+ env.script,
+ a_rev,
+ """\
"Rev A"
revision = '%s'
down_revision = None
@@ -98,7 +86,9 @@ def upgrade():
def downgrade():
pass
-""" % a_rev)
+"""
+ % a_rev,
+ )
migration_fn = MagicMock()
def upgrade(rev, context):
@@ -106,15 +96,13 @@ def downgrade():
return env.script._upgrade_revs(a_rev, rev)
with expect_warnings(
- r"'connection' argument to configure\(\) is "
- r"expected to be a sqlalchemy.engine.Connection "):
+ r"'connection' argument to configure\(\) is "
+ r"expected to be a sqlalchemy.engine.Connection "
+ ):
env.configure(
- connection=engine, fn=upgrade,
- transactional_ddl=False)
+ connection=engine, fn=upgrade, transactional_ddl=False
+ )
env.run_migrations()
- eq_(
- migration_fn.mock_calls,
- [call((), env._migration_context)]
- )
+ eq_(migration_fn.mock_calls, [call((), env._migration_context)])
diff --git a/tests/test_external_dialect.py b/tests/test_external_dialect.py
index dc01b75..1c3222d 100644
--- a/tests/test_external_dialect.py
+++ b/tests/test_external_dialect.py
@@ -15,6 +15,7 @@ from sqlalchemy.engine import default
class CustomDialect(default.DefaultDialect):
name = "custom_dialect"
+
try:
from sqlalchemy.dialects import registry
except ImportError:
@@ -24,20 +25,22 @@ else:
class CustomDialectImpl(impl.DefaultImpl):
- __dialect__ = 'custom_dialect'
+ __dialect__ = "custom_dialect"
transactional_ddl = False
def render_type(self, type_, autogen_context):
if type_.__module__ == __name__:
autogen_context.imports.add(
- "from %s import custom_dialect_types" % (__name__, ))
+ "from %s import custom_dialect_types" % (__name__,)
+ )
is_external = True
else:
is_external = False
- if is_external and \
- hasattr(self, '_render_%s_type' % type_.__visit_name__):
- meth = getattr(self, '_render_%s_type' % type_.__visit_name__)
+ if is_external and hasattr(
+ self, "_render_%s_type" % type_.__visit_name__
+ ):
+ meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
return meth(type_, autogen_context)
if is_external:
@@ -47,13 +50,16 @@ class CustomDialectImpl(impl.DefaultImpl):
def _render_EXT_ARRAY_type(self, type_, autogen_context):
return render._render_type_w_subtype(
- type_, autogen_context, 'item_type', r'(.+?\()',
- prefix="custom_dialect_types."
+ type_,
+ autogen_context,
+ "item_type",
+ r"(.+?\()",
+ prefix="custom_dialect_types.",
)
class EXT_ARRAY(sqla_types.TypeEngine):
- __visit_name__ = 'EXT_ARRAY'
+ __visit_name__ = "EXT_ARRAY"
def __init__(self, item_type):
if isinstance(item_type, type):
@@ -63,75 +69,78 @@ class EXT_ARRAY(sqla_types.TypeEngine):
class FOOBARTYPE(sqla_types.TypeEngine):
- __visit_name__ = 'FOOBARTYPE'
+ __visit_name__ = "FOOBARTYPE"
class ExternalDialectRenderTest(TestBase):
-
def setUp(self):
ctx_opts = {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'target_metadata': MetaData(),
- 'user_module_prefix': None
+ "sqlalchemy_module_prefix": "sa.",
+ "alembic_module_prefix": "op.",
+ "target_metadata": MetaData(),
+ "user_module_prefix": None,
}
context = MigrationContext.configure(
- dialect_name="custom_dialect",
- opts=ctx_opts
+ dialect_name="custom_dialect", opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context)
def test_render_type(self):
eq_ignore_whitespace(
- autogenerate.render._repr_type(
- FOOBARTYPE(), self.autogen_context),
- "custom_dialect_types.FOOBARTYPE()"
+ autogenerate.render._repr_type(FOOBARTYPE(), self.autogen_context),
+ "custom_dialect_types.FOOBARTYPE()",
)
eq_(
self.autogen_context.imports,
- set([
- 'from tests.test_external_dialect import custom_dialect_types'
- ])
+ set(
+ [
+ "from tests.test_external_dialect import custom_dialect_types"
+ ]
+ ),
)
def test_external_nested_render_sqla_type(self):
eq_ignore_whitespace(
autogenerate.render._repr_type(
- EXT_ARRAY(sqla_types.Integer), self.autogen_context),
- "custom_dialect_types.EXT_ARRAY(sa.Integer())"
+ EXT_ARRAY(sqla_types.Integer), self.autogen_context
+ ),
+ "custom_dialect_types.EXT_ARRAY(sa.Integer())",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- EXT_ARRAY(
- sqla_types.DateTime(timezone=True)
- ),
- self.autogen_context),
- "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))"
+ EXT_ARRAY(sqla_types.DateTime(timezone=True)),
+ self.autogen_context,
+ ),
+ "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))",
)
eq_(
self.autogen_context.imports,
- set([
- 'from tests.test_external_dialect import custom_dialect_types'
- ])
+ set(
+ [
+ "from tests.test_external_dialect import custom_dialect_types"
+ ]
+ ),
)
def test_external_nested_render_external_type(self):
eq_ignore_whitespace(
autogenerate.render._repr_type(
- EXT_ARRAY(FOOBARTYPE),
- self.autogen_context),
- "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())"
+ EXT_ARRAY(FOOBARTYPE), self.autogen_context
+ ),
+ "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())",
)
eq_(
self.autogen_context.imports,
- set([
- 'from tests.test_external_dialect import custom_dialect_types'
- ])
+ set(
+ [
+ "from tests.test_external_dialect import custom_dialect_types"
+ ]
+ ),
)
diff --git a/tests/test_mssql.py b/tests/test_mssql.py
index b092dcf..1657ccc 100644
--- a/tests/test_mssql.py
+++ b/tests/test_mssql.py
@@ -8,13 +8,16 @@ from alembic import op, command, util
from alembic.testing import eq_, assert_raises_message
from alembic.testing.fixtures import capture_context_buffer, op_fixture
-from alembic.testing.env import staging_env, _no_sql_testing_config, \
- three_rev_fixture, clear_staging_env
+from alembic.testing.env import (
+ staging_env,
+ _no_sql_testing_config,
+ three_rev_fixture,
+ clear_staging_env,
+)
from alembic.testing import config
class FullEnvironmentTests(TestBase):
-
@classmethod
def setup_class(cls):
staging_env()
@@ -24,8 +27,7 @@ class FullEnvironmentTests(TestBase):
directives = ""
cls.cfg = cfg = _no_sql_testing_config("mssql", directives)
- cls.a, cls.b, cls.c = \
- three_rev_fixture(cfg)
+ cls.a, cls.b, cls.c = three_rev_fixture(cfg)
@classmethod
def teardown_class(cls):
@@ -39,7 +41,7 @@ class FullEnvironmentTests(TestBase):
# ensure ends in COMMIT; GO
eq_(
[x for x in buf.getvalue().splitlines() if x][-2:],
- ['COMMIT;', 'GO']
+ ["COMMIT;", "GO"],
)
def test_batch_separator_default(self):
@@ -54,242 +56,248 @@ class FullEnvironmentTests(TestBase):
class OpTest(TestBase):
-
def test_add_column(self):
- context = op_fixture('mssql')
- op.add_column('t1', Column('c1', Integer, nullable=False))
+ context = op_fixture("mssql")
+ op.add_column("t1", Column("c1", Integer, nullable=False))
context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
def test_add_column_with_default(self):
context = op_fixture("mssql")
op.add_column(
- 't1', Column('c1', Integer, nullable=False, server_default="12"))
+ "t1", Column("c1", Integer, nullable=False, server_default="12")
+ )
context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'")
def test_alter_column_rename_mssql(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", new_column_name="x")
- context.assert_(
- "EXEC sp_rename 't.c', x, 'COLUMN'"
- )
+ context.assert_("EXEC sp_rename 't.c', x, 'COLUMN'")
def test_alter_column_rename_quoted_mssql(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", new_column_name="SomeFancyName")
- context.assert_(
- "EXEC sp_rename 't.c', [SomeFancyName], 'COLUMN'"
- )
+ context.assert_("EXEC sp_rename 't.c', [SomeFancyName], 'COLUMN'")
def test_alter_column_new_type(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", type_=Integer)
- context.assert_(
- 'ALTER TABLE t ALTER COLUMN c INTEGER'
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER")
def test_alter_column_dont_touch_constraints(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
from sqlalchemy import Boolean
- op.alter_column('tests', 'col',
- existing_type=Boolean(),
- nullable=False)
- context.assert_('ALTER TABLE tests ALTER COLUMN col BIT NOT NULL')
+
+ op.alter_column(
+ "tests", "col", existing_type=Boolean(), nullable=False
+ )
+ context.assert_("ALTER TABLE tests ALTER COLUMN col BIT NOT NULL")
def test_drop_index(self):
- context = op_fixture('mssql')
- op.drop_index('my_idx', 'my_table')
+ context = op_fixture("mssql")
+ op.drop_index("my_idx", "my_table")
context.assert_contains("DROP INDEX my_idx ON my_table")
def test_drop_column_w_default(self):
- context = op_fixture('mssql')
- op.drop_column('t1', 'c1', mssql_drop_default=True)
- op.drop_column('t1', 'c2', mssql_drop_default=True)
+ context = op_fixture("mssql")
+ op.drop_column("t1", "c1", mssql_drop_default=True)
+ op.drop_column("t1", "c2", mssql_drop_default=True)
context.assert_contains(
- "exec('alter table t1 drop constraint ' + @const_name)")
+ "exec('alter table t1 drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
def test_drop_column_w_default_in_batch(self):
- context = op_fixture('mssql')
- with op.batch_alter_table('t1', schema=None) as batch_op:
- batch_op.drop_column('c1', mssql_drop_default=True)
- batch_op.drop_column('c2', mssql_drop_default=True)
+ context = op_fixture("mssql")
+ with op.batch_alter_table("t1", schema=None) as batch_op:
+ batch_op.drop_column("c1", mssql_drop_default=True)
+ batch_op.drop_column("c2", mssql_drop_default=True)
context.assert_contains(
- "exec('alter table t1 drop constraint ' + @const_name)")
+ "exec('alter table t1 drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
def test_alter_column_drop_default(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", server_default=None)
context.assert_contains(
- "exec('alter table t drop constraint ' + @const_name)")
+ "exec('alter table t drop constraint ' + @const_name)"
+ )
def test_alter_column_dont_drop_default(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", server_default=False)
context.assert_()
def test_drop_column_w_schema(self):
- context = op_fixture('mssql')
- op.drop_column('t1', 'c1', schema='xyz')
+ context = op_fixture("mssql")
+ op.drop_column("t1", "c1", schema="xyz")
context.assert_contains("ALTER TABLE xyz.t1 DROP COLUMN c1")
def test_drop_column_w_check(self):
- context = op_fixture('mssql')
- op.drop_column('t1', 'c1', mssql_drop_check=True)
- op.drop_column('t1', 'c2', mssql_drop_check=True)
+ context = op_fixture("mssql")
+ op.drop_column("t1", "c1", mssql_drop_check=True)
+ op.drop_column("t1", "c2", mssql_drop_check=True)
context.assert_contains(
- "exec('alter table t1 drop constraint ' + @const_name)")
+ "exec('alter table t1 drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
def test_drop_column_w_check_in_batch(self):
- context = op_fixture('mssql')
- with op.batch_alter_table('t1', schema=None) as batch_op:
- batch_op.drop_column('c1', mssql_drop_check=True)
- batch_op.drop_column('c2', mssql_drop_check=True)
+ context = op_fixture("mssql")
+ with op.batch_alter_table("t1", schema=None) as batch_op:
+ batch_op.drop_column("c1", mssql_drop_check=True)
+ batch_op.drop_column("c2", mssql_drop_check=True)
context.assert_contains(
- "exec('alter table t1 drop constraint ' + @const_name)")
+ "exec('alter table t1 drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
def test_drop_column_w_check_quoting(self):
- context = op_fixture('mssql')
- op.drop_column('table', 'column', mssql_drop_check=True)
+ context = op_fixture("mssql")
+ op.drop_column("table", "column", mssql_drop_check=True)
context.assert_contains(
- "exec('alter table [table] drop constraint ' + @const_name)")
+ "exec('alter table [table] drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE [table] DROP COLUMN [column]")
def test_alter_column_nullable_w_existing_type(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", nullable=True, existing_type=Integer)
- context.assert_(
- "ALTER TABLE t ALTER COLUMN c INTEGER NULL"
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NULL")
def test_drop_column_w_fk(self):
- context = op_fixture('mssql')
- op.drop_column('t1', 'c1', mssql_drop_foreign_key=True)
+ context = op_fixture("mssql")
+ op.drop_column("t1", "c1", mssql_drop_foreign_key=True)
context.assert_contains(
- "exec('alter table t1 drop constraint ' + @const_name)")
+ "exec('alter table t1 drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
def test_drop_column_w_fk_in_batch(self):
- context = op_fixture('mssql')
- with op.batch_alter_table('t1', schema=None) as batch_op:
- batch_op.drop_column('c1', mssql_drop_foreign_key=True)
+ context = op_fixture("mssql")
+ with op.batch_alter_table("t1", schema=None) as batch_op:
+ batch_op.drop_column("c1", mssql_drop_foreign_key=True)
context.assert_contains(
- "exec('alter table t1 drop constraint ' + @const_name)")
+ "exec('alter table t1 drop constraint ' + @const_name)"
+ )
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
def test_alter_column_not_nullable_w_existing_type(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", nullable=False, existing_type=Integer)
- context.assert_(
- "ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL"
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL")
def test_alter_column_nullable_w_new_type(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", nullable=True, type_=Integer)
- context.assert_(
- "ALTER TABLE t ALTER COLUMN c INTEGER NULL"
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NULL")
def test_alter_column_not_nullable_w_new_type(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", nullable=False, type_=Integer)
- context.assert_(
- "ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL"
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL")
def test_alter_column_nullable_type_required(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
assert_raises_message(
util.CommandError,
"MS-SQL ALTER COLUMN operations with NULL or "
"NOT NULL require the existing_type or a new "
"type_ be passed.",
- op.alter_column, "t", "c", nullable=False
+ op.alter_column,
+ "t",
+ "c",
+ nullable=False,
)
def test_alter_add_server_default(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", server_default="5")
- context.assert_(
- "ALTER TABLE t ADD DEFAULT '5' FOR c"
- )
+ context.assert_("ALTER TABLE t ADD DEFAULT '5' FOR c")
def test_alter_replace_server_default(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column(
- "t", "c", server_default="5", existing_server_default="6")
- context.assert_contains(
- "exec('alter table t drop constraint ' + @const_name)")
+ "t", "c", server_default="5", existing_server_default="6"
+ )
context.assert_contains(
- "ALTER TABLE t ADD DEFAULT '5' FOR c"
+ "exec('alter table t drop constraint ' + @const_name)"
)
+ context.assert_contains("ALTER TABLE t ADD DEFAULT '5' FOR c")
def test_alter_remove_server_default(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", server_default=None)
context.assert_contains(
- "exec('alter table t drop constraint ' + @const_name)")
+ "exec('alter table t drop constraint ' + @const_name)"
+ )
def test_alter_do_everything(self):
- context = op_fixture('mssql')
- op.alter_column("t", "c", new_column_name="c2", nullable=True,
- type_=Integer, server_default="5")
+ context = op_fixture("mssql")
+ op.alter_column(
+ "t",
+ "c",
+ new_column_name="c2",
+ nullable=True,
+ type_=Integer,
+ server_default="5",
+ )
context.assert_(
- 'ALTER TABLE t ALTER COLUMN c INTEGER NULL',
+ "ALTER TABLE t ALTER COLUMN c INTEGER NULL",
"ALTER TABLE t ADD DEFAULT '5' FOR c",
- "EXEC sp_rename 't.c', c2, 'COLUMN'"
+ "EXEC sp_rename 't.c', c2, 'COLUMN'",
)
def test_rename_table(self):
- context = op_fixture('mssql')
- op.rename_table('t1', 't2')
+ context = op_fixture("mssql")
+ op.rename_table("t1", "t2")
context.assert_contains("EXEC sp_rename 't1', t2")
def test_rename_table_schema(self):
- context = op_fixture('mssql')
- op.rename_table('t1', 't2', schema="foobar")
+ context = op_fixture("mssql")
+ op.rename_table("t1", "t2", schema="foobar")
context.assert_contains("EXEC sp_rename 'foobar.t1', t2")
def test_rename_table_casesens(self):
- context = op_fixture('mssql')
- op.rename_table('TeeOne', 'TeeTwo')
+ context = op_fixture("mssql")
+ op.rename_table("TeeOne", "TeeTwo")
# yup, ran this in SQL Server 2014, the two levels of quoting
# seems to be understood. Can't do the two levels on the
# target name though !
context.assert_contains("EXEC sp_rename '[TeeOne]', [TeeTwo]")
def test_rename_table_schema_casesens(self):
- context = op_fixture('mssql')
- op.rename_table('TeeOne', 'TeeTwo', schema="FooBar")
+ context = op_fixture("mssql")
+ op.rename_table("TeeOne", "TeeTwo", schema="FooBar")
# yup, ran this in SQL Server 2014, the two levels of quoting
# seems to be understood. Can't do the two levels on the
# target name though !
context.assert_contains("EXEC sp_rename '[FooBar].[TeeOne]', [TeeTwo]")
def test_alter_column_rename_mssql_schema(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.alter_column("t", "c", name="x", schema="y")
- context.assert_(
- "EXEC sp_rename 'y.t.c', x, 'COLUMN'"
- )
+ context.assert_("EXEC sp_rename 'y.t.c', x, 'COLUMN'")
def test_create_index_mssql_include(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.create_index(
- op.f('ix_mytable_a_b'), 'mytable', ['col_a', 'col_b'],
- unique=False, mssql_include=['col_c'])
+ op.f("ix_mytable_a_b"),
+ "mytable",
+ ["col_a", "col_b"],
+ unique=False,
+ mssql_include=["col_c"],
+ )
context.assert_contains(
"CREATE INDEX ix_mytable_a_b ON mytable "
- "(col_a, col_b) INCLUDE (col_c)")
+ "(col_a, col_b) INCLUDE (col_c)"
+ )
def test_create_index_mssql_include_is_none(self):
- context = op_fixture('mssql')
+ context = op_fixture("mssql")
op.create_index(
- op.f('ix_mytable_a_b'), 'mytable', ['col_a', 'col_b'],
- unique=False)
+ op.f("ix_mytable_a_b"), "mytable", ["col_a", "col_b"], unique=False
+ )
context.assert_contains(
- "CREATE INDEX ix_mytable_a_b ON mytable "
- "(col_a, col_b)")
+ "CREATE INDEX ix_mytable_a_b ON mytable " "(col_a, col_b)"
+ )
diff --git a/tests/test_mysql.py b/tests/test_mysql.py
index dd872f7..68746ba 100644
--- a/tests/test_mysql.py
+++ b/tests/test_mysql.py
@@ -7,53 +7,68 @@ from alembic import op, util
from alembic.testing import eq_, assert_raises_message
from alembic.testing.fixtures import capture_context_buffer, op_fixture
-from alembic.testing.env import staging_env, _no_sql_testing_config, \
- three_rev_fixture, clear_staging_env
+from alembic.testing.env import (
+ staging_env,
+ _no_sql_testing_config,
+ three_rev_fixture,
+ clear_staging_env,
+)
from alembic.migration import MigrationContext
class MySQLOpTest(TestBase):
-
def test_rename_column(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- 't1', 'c1', new_column_name="c2", existing_type=Integer)
- context.assert_(
- 'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL'
+ "t1", "c1", new_column_name="c2", existing_type=Integer
)
+ context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL")
def test_rename_column_quotes_needed_one(self):
- context = op_fixture('mysql')
- op.alter_column('MyTable', 'ColumnOne', new_column_name="ColumnTwo",
- existing_type=Integer)
+ context = op_fixture("mysql")
+ op.alter_column(
+ "MyTable",
+ "ColumnOne",
+ new_column_name="ColumnTwo",
+ existing_type=Integer,
+ )
context.assert_(
- 'ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL'
+ "ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL"
)
def test_rename_column_quotes_needed_two(self):
- context = op_fixture('mysql')
- op.alter_column('my table', 'column one', new_column_name="column two",
- existing_type=Integer)
+ context = op_fixture("mysql")
+ op.alter_column(
+ "my table",
+ "column one",
+ new_column_name="column two",
+ existing_type=Integer,
+ )
context.assert_(
- 'ALTER TABLE `my table` CHANGE `column one` '
- '`column two` INTEGER NULL'
+ "ALTER TABLE `my table` CHANGE `column one` "
+ "`column two` INTEGER NULL"
)
def test_rename_column_serv_default(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- 't1', 'c1', new_column_name="c2", existing_type=Integer,
- existing_server_default="q")
- context.assert_(
- "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'"
+ "t1",
+ "c1",
+ new_column_name="c2",
+ existing_type=Integer,
+ existing_server_default="q",
)
+ context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'")
def test_rename_column_serv_compiled_default(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- 't1', 'c1', existing_type=Integer,
- server_default=func.utc_thing(func.current_timestamp()))
+ "t1",
+ "c1",
+ existing_type=Integer,
+ server_default=func.utc_thing(func.current_timestamp()),
+ )
# this is not a valid MySQL default but the point is to just
# test SQL expression rendering
context.assert_(
@@ -62,184 +77,183 @@ class MySQLOpTest(TestBase):
)
def test_rename_column_autoincrement(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- 't1', 'c1', new_column_name="c2", existing_type=Integer,
- existing_autoincrement=True)
+ "t1",
+ "c1",
+ new_column_name="c2",
+ existing_type=Integer,
+ existing_autoincrement=True,
+ )
context.assert_(
- 'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT'
+ "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT"
)
def test_col_add_autoincrement(self):
- context = op_fixture('mysql')
- op.alter_column('t1', 'c1', existing_type=Integer,
- autoincrement=True)
- context.assert_(
- 'ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT'
- )
+ context = op_fixture("mysql")
+ op.alter_column("t1", "c1", existing_type=Integer, autoincrement=True)
+ context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT")
def test_col_remove_autoincrement(self):
- context = op_fixture('mysql')
- op.alter_column('t1', 'c1', existing_type=Integer,
- existing_autoincrement=True,
- autoincrement=False)
- context.assert_(
- 'ALTER TABLE t1 MODIFY c1 INTEGER NULL'
+ context = op_fixture("mysql")
+ op.alter_column(
+ "t1",
+ "c1",
+ existing_type=Integer,
+ existing_autoincrement=True,
+ autoincrement=False,
)
+ context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL")
def test_col_dont_remove_server_default(self):
- context = op_fixture('mysql')
- op.alter_column('t1', 'c1', existing_type=Integer,
- existing_server_default='1',
- server_default=False)
+ context = op_fixture("mysql")
+ op.alter_column(
+ "t1",
+ "c1",
+ existing_type=Integer,
+ existing_server_default="1",
+ server_default=False,
+ )
context.assert_()
def test_alter_column_drop_default(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column("t", "c", existing_type=Integer, server_default=None)
- context.assert_(
- 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT'
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT")
def test_alter_column_remove_schematype(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- "t", "c",
+ "t",
+ "c",
type_=Integer,
existing_type=Boolean(create_constraint=True, name="ck1"),
- server_default=None)
- context.assert_(
- 'ALTER TABLE t MODIFY c INTEGER NULL'
+ server_default=None,
)
+ context.assert_("ALTER TABLE t MODIFY c INTEGER NULL")
def test_alter_column_modify_default(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
# notice we dont need the existing type on this one...
- op.alter_column("t", "c", server_default='1')
- context.assert_(
- "ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'"
- )
+ op.alter_column("t", "c", server_default="1")
+ context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'")
def test_col_not_nullable(self):
- context = op_fixture('mysql')
- op.alter_column('t1', 'c1', nullable=False, existing_type=Integer)
- context.assert_(
- 'ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL'
- )
+ context = op_fixture("mysql")
+ op.alter_column("t1", "c1", nullable=False, existing_type=Integer)
+ context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL")
def test_col_not_nullable_existing_serv_default(self):
- context = op_fixture('mysql')
- op.alter_column('t1', 'c1', nullable=False, existing_type=Integer,
- existing_server_default='5')
+ context = op_fixture("mysql")
+ op.alter_column(
+ "t1",
+ "c1",
+ nullable=False,
+ existing_type=Integer,
+ existing_server_default="5",
+ )
context.assert_(
"ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT '5'"
)
def test_col_nullable(self):
- context = op_fixture('mysql')
- op.alter_column('t1', 'c1', nullable=True, existing_type=Integer)
- context.assert_(
- 'ALTER TABLE t1 MODIFY c1 INTEGER NULL'
- )
+ context = op_fixture("mysql")
+ op.alter_column("t1", "c1", nullable=True, existing_type=Integer)
+ context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL")
def test_col_multi_alter(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- 't1', 'c1', nullable=False, server_default="q", type_=Integer)
+ "t1", "c1", nullable=False, server_default="q", type_=Integer
+ )
context.assert_(
"ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT 'q'"
)
def test_alter_column_multi_alter_w_drop_default(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.alter_column(
- 't1', 'c1', nullable=False, server_default=None, type_=Integer)
- context.assert_(
- "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL"
+ "t1", "c1", nullable=False, server_default=None, type_=Integer
)
+ context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL")
def test_col_alter_type_required(self):
- op_fixture('mysql')
+ op_fixture("mysql")
assert_raises_message(
util.CommandError,
"MySQL CHANGE/MODIFY COLUMN operations require the existing type.",
- op.alter_column, 't1', 'c1', nullable=False, server_default="q"
+ op.alter_column,
+ "t1",
+ "c1",
+ nullable=False,
+ server_default="q",
)
def test_drop_fk(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("f1", "t1", "foreignkey")
- context.assert_(
- "ALTER TABLE t1 DROP FOREIGN KEY f1"
- )
+ context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
def test_drop_fk_quoted(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("MyFk", "MyTable", "foreignkey")
- context.assert_(
- "ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`"
- )
+ context.assert_("ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`")
def test_drop_constraint_primary(self):
- context = op_fixture('mysql')
- op.drop_constraint('primary', 't1', type_='primary')
- context.assert_(
- "ALTER TABLE t1 DROP PRIMARY KEY"
- )
+ context = op_fixture("mysql")
+ op.drop_constraint("primary", "t1", type_="primary")
+ context.assert_("ALTER TABLE t1 DROP PRIMARY KEY")
def test_drop_unique(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("f1", "t1", "unique")
- context.assert_(
- "ALTER TABLE t1 DROP INDEX f1"
- )
+ context.assert_("ALTER TABLE t1 DROP INDEX f1")
def test_drop_unique_quoted(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("MyUnique", "MyTable", "unique")
- context.assert_(
- "ALTER TABLE `MyTable` DROP INDEX `MyUnique`"
- )
+ context.assert_("ALTER TABLE `MyTable` DROP INDEX `MyUnique`")
def test_drop_check(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("f1", "t1", "check")
- context.assert_(
- "ALTER TABLE t1 DROP CONSTRAINT f1"
- )
+ context.assert_("ALTER TABLE t1 DROP CONSTRAINT f1")
def test_drop_check_quoted(self):
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("MyCheck", "MyTable", "check")
- context.assert_(
- "ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`"
- )
+ context.assert_("ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`")
def test_drop_unknown(self):
- op_fixture('mysql')
+ op_fixture("mysql")
assert_raises_message(
TypeError,
"'type' can be one of 'check', 'foreignkey', "
"'primary', 'unique', None",
- op.drop_constraint, "f1", "t1", "typo"
+ op.drop_constraint,
+ "f1",
+ "t1",
+ "typo",
)
def test_drop_generic_constraint(self):
- op_fixture('mysql')
+ op_fixture("mysql")
assert_raises_message(
NotImplementedError,
"No generic 'DROP CONSTRAINT' in MySQL - please "
"specify constraint type",
- op.drop_constraint, "f1", "t1"
+ op.drop_constraint,
+ "f1",
+ "t1",
)
class MySQLDefaultCompareTest(TestBase):
- __only_on__ = 'mysql'
+ __only_on__ = "mysql"
__backend__ = True
- __requires__ = 'mysql_timestamp_reflection',
+ __requires__ = ("mysql_timestamp_reflection",)
@classmethod
def setup_class(cls):
@@ -247,17 +261,14 @@ class MySQLDefaultCompareTest(TestBase):
staging_env()
context = MigrationContext.configure(
connection=cls.bind.connect(),
- opts={
- 'compare_type': True,
- 'compare_server_default': True
- }
+ opts={"compare_type": True, "compare_server_default": True},
)
connection = context.bind
cls.autogen_context = {
- 'imports': set(),
- 'connection': connection,
- 'dialect': connection.dialect,
- 'context': context
+ "imports": set(),
+ "connection": connection,
+ "dialect": connection.dialect,
+ "context": context,
}
@classmethod
@@ -277,64 +288,46 @@ class MySQLDefaultCompareTest(TestBase):
alternate = txt
expected = False
t = Table(
- "test", self.metadata,
+ "test",
+ self.metadata,
Column(
- "somecol", type_,
- server_default=text(txt) if txt else None
- )
+ "somecol", type_, server_default=text(txt) if txt else None
+ ),
+ )
+ t2 = Table(
+ "test",
+ MetaData(),
+ Column("somecol", type_, server_default=text(alternate)),
)
- t2 = Table("test", MetaData(),
- Column("somecol", type_, server_default=text(alternate))
- )
- assert self._compare_default(
- t, t2, t2.c.somecol, alternate
- ) is expected
-
- def _compare_default(
- self,
- t1, t2, col,
- rendered
- ):
+ assert (
+ self._compare_default(t, t2, t2.c.somecol, alternate) is expected
+ )
+
+ def _compare_default(self, t1, t2, col, rendered):
t1.create(self.bind)
insp = Inspector.from_engine(self.bind)
cols = insp.get_columns(t1.name)
refl = Table(t1.name, MetaData())
insp.reflecttable(refl, None)
- ctx = self.autogen_context['context']
+ ctx = self.autogen_context["context"]
return ctx.impl.compare_server_default(
- refl.c[cols[0]['name']],
- col,
- rendered,
- cols[0]['default'])
+ refl.c[cols[0]["name"]], col, rendered, cols[0]["default"]
+ )
def test_compare_timestamp_current_timestamp(self):
- self._compare_default_roundtrip(
- TIMESTAMP(),
- "CURRENT_TIMESTAMP",
- )
+ self._compare_default_roundtrip(TIMESTAMP(), "CURRENT_TIMESTAMP")
def test_compare_timestamp_current_timestamp_diff(self):
- self._compare_default_roundtrip(
- TIMESTAMP(),
- None, "CURRENT_TIMESTAMP",
- )
+ self._compare_default_roundtrip(TIMESTAMP(), None, "CURRENT_TIMESTAMP")
def test_compare_integer_same(self):
- self._compare_default_roundtrip(
- Integer(), "5"
- )
+ self._compare_default_roundtrip(Integer(), "5")
def test_compare_integer_diff(self):
- self._compare_default_roundtrip(
- Integer(), "5", "7"
- )
+ self._compare_default_roundtrip(Integer(), "5", "7")
def test_compare_boolean_same(self):
- self._compare_default_roundtrip(
- Boolean(), "1"
- )
+ self._compare_default_roundtrip(Boolean(), "1")
def test_compare_boolean_diff(self):
- self._compare_default_roundtrip(
- Boolean(), "1", "0"
- )
+ self._compare_default_roundtrip(Boolean(), "1", "0")
diff --git a/tests/test_offline_environment.py b/tests/test_offline_environment.py
index 3920690..fbbbec3 100644
--- a/tests/test_offline_environment.py
+++ b/tests/test_offline_environment.py
@@ -3,16 +3,20 @@ from alembic.testing.fixtures import TestBase, capture_context_buffer
from alembic import command, util
from alembic.testing import assert_raises_message
-from alembic.testing.env import staging_env, _no_sql_testing_config, \
- three_rev_fixture, clear_staging_env, env_file_fixture, \
- multi_heads_fixture
+from alembic.testing.env import (
+ staging_env,
+ _no_sql_testing_config,
+ three_rev_fixture,
+ clear_staging_env,
+ env_file_fixture,
+ multi_heads_fixture,
+)
import re
a = b = c = None
class OfflineEnvironmentTest(TestBase):
-
def setUp(self):
staging_env()
self.cfg = _no_sql_testing_config()
@@ -24,92 +28,122 @@ class OfflineEnvironmentTest(TestBase):
clear_staging_env()
def test_not_requires_connection(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert not context.requires_connection()
-""")
+"""
+ )
command.upgrade(self.cfg, a, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
def test_requires_connection(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.requires_connection()
-""")
+"""
+ )
command.upgrade(self.cfg, a)
command.downgrade(self.cfg, a)
def test_starting_rev_post_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite', starting_rev='x')
assert context.get_starting_revision_argument() == 'x'
-""")
+"""
+ )
command.upgrade(self.cfg, a, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
command.current(self.cfg)
command.stamp(self.cfg, a)
def test_starting_rev_pre_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_starting_revision_argument() == 'x'
-""")
+"""
+ )
command.upgrade(self.cfg, "x:y", sql=True)
command.downgrade(self.cfg, "x:y", sql=True)
def test_starting_rev_pre_context_cmd_w_no_startrev(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_starting_revision_argument() == 'x'
-""")
+"""
+ )
assert_raises_message(
util.CommandError,
"No starting revision argument is available.",
- command.current, self.cfg)
+ command.current,
+ self.cfg,
+ )
def test_starting_rev_current_pre_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_starting_revision_argument() is None
-""")
+"""
+ )
assert_raises_message(
util.CommandError,
"No starting revision argument is available.",
- command.current, self.cfg
+ command.current,
+ self.cfg,
)
def test_destination_rev_pre_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_revision_argument() == '%s'
-""" % b)
+"""
+ % b
+ )
command.upgrade(self.cfg, b, sql=True)
command.stamp(self.cfg, b, sql=True)
command.downgrade(self.cfg, "%s:%s" % (c, b), sql=True)
def test_destination_rev_pre_context_multihead(self):
d, e, f = multi_heads_fixture(self.cfg, a, b, c)
- env_file_fixture("""
+ env_file_fixture(
+ """
assert set(context.get_revision_argument()) == set(('%s', '%s', '%s', ))
-""" % (f, e, c))
- command.upgrade(self.cfg, 'heads', sql=True)
+"""
+ % (f, e, c)
+ )
+ command.upgrade(self.cfg, "heads", sql=True)
def test_destination_rev_post_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
assert context.get_revision_argument() == '%s'
-""" % b)
+"""
+ % b
+ )
command.upgrade(self.cfg, b, sql=True)
command.downgrade(self.cfg, "%s:%s" % (c, b), sql=True)
command.stamp(self.cfg, b, sql=True)
def test_destination_rev_post_context_multihead(self):
d, e, f = multi_heads_fixture(self.cfg, a, b, c)
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
assert set(context.get_revision_argument()) == set(('%s', '%s', '%s', ))
-""" % (f, e, c))
- command.upgrade(self.cfg, 'heads', sql=True)
+"""
+ % (f, e, c)
+ )
+ command.upgrade(self.cfg, "heads", sql=True)
def test_head_rev_pre_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_head_revision() == '%s'
assert context.get_head_revisions() == ('%s', )
-""" % (c, c))
+"""
+ % (c, c)
+ )
command.upgrade(self.cfg, b, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
command.stamp(self.cfg, b, sql=True)
@@ -117,20 +151,26 @@ assert context.get_head_revisions() == ('%s', )
def test_head_rev_pre_context_multihead(self):
d, e, f = multi_heads_fixture(self.cfg, a, b, c)
- env_file_fixture("""
+ env_file_fixture(
+ """
assert set(context.get_head_revisions()) == set(('%s', '%s', '%s', ))
-""" % (e, f, c))
+"""
+ % (e, f, c)
+ )
command.upgrade(self.cfg, e, sql=True)
command.downgrade(self.cfg, "%s:%s" % (e, b), sql=True)
command.stamp(self.cfg, c, sql=True)
command.current(self.cfg)
def test_head_rev_post_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
assert context.get_head_revision() == '%s'
assert context.get_head_revisions() == ('%s', )
-""" % (c, c))
+"""
+ % (c, c)
+ )
command.upgrade(self.cfg, b, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
command.stamp(self.cfg, b, sql=True)
@@ -138,70 +178,89 @@ assert context.get_head_revisions() == ('%s', )
def test_head_rev_post_context_multihead(self):
d, e, f = multi_heads_fixture(self.cfg, a, b, c)
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
assert set(context.get_head_revisions()) == set(('%s', '%s', '%s', ))
-""" % (e, f, c))
+"""
+ % (e, f, c)
+ )
command.upgrade(self.cfg, e, sql=True)
command.downgrade(self.cfg, "%s:%s" % (e, b), sql=True)
command.stamp(self.cfg, c, sql=True)
command.current(self.cfg)
def test_tag_pre_context(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_tag_argument() == 'hi'
-""")
- command.upgrade(self.cfg, b, sql=True, tag='hi')
- command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi')
+"""
+ )
+ command.upgrade(self.cfg, b, sql=True, tag="hi")
+ command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi")
def test_tag_pre_context_None(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_tag_argument() is None
-""")
+"""
+ )
command.upgrade(self.cfg, b, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
def test_tag_cmd_arg(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
assert context.get_tag_argument() == 'hi'
-""")
- command.upgrade(self.cfg, b, sql=True, tag='hi')
- command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi')
+"""
+ )
+ command.upgrade(self.cfg, b, sql=True, tag="hi")
+ command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi")
def test_tag_cfg_arg(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite', tag='there')
assert context.get_tag_argument() == 'there'
-""")
- command.upgrade(self.cfg, b, sql=True, tag='hi')
- command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi')
+"""
+ )
+ command.upgrade(self.cfg, b, sql=True, tag="hi")
+ command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi")
def test_tag_None(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
assert context.get_tag_argument() is None
-""")
+"""
+ )
command.upgrade(self.cfg, b, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
def test_downgrade_wo_colon(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
-""")
+"""
+ )
assert_raises_message(
util.CommandError,
"downgrade with --sql requires <fromrev>:<torev>",
command.downgrade,
- self.cfg, b, sql=True
+ self.cfg,
+ b,
+ sql=True,
)
def test_upgrade_with_output_encoding(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
url = config.get_main_option('sqlalchemy.url')
context.configure(url=url, output_encoding='utf-8')
assert not context.requires_connection()
-""")
+"""
+ )
command.upgrade(self.cfg, a, sql=True)
command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
@@ -213,37 +272,49 @@ assert not context.requires_connection()
with capture_context_buffer(transactional_ddl=True) as buf:
command.upgrade(self.cfg, "%s:%s" % (a, d.revision), sql=True)
- assert not re.match(r".*-- .*and multiline", buf.getvalue(), re.S | re.M)
+ assert not re.match(
+ r".*-- .*and multiline", buf.getvalue(), re.S | re.M
+ )
def test_starting_rev_pre_context_abbreviated(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_starting_revision_argument() == '%s'
-""" % b[0:4])
+"""
+ % b[0:4]
+ )
command.upgrade(self.cfg, "%s:%s" % (b[0:4], c), sql=True)
command.stamp(self.cfg, "%s:%s" % (b[0:4], c), sql=True)
command.downgrade(self.cfg, "%s:%s" % (b[0:4], a), sql=True)
def test_destination_rev_pre_context_abbreviated(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
assert context.get_revision_argument() == '%s'
-""" % b[0:4])
+"""
+ % b[0:4]
+ )
command.upgrade(self.cfg, "%s:%s" % (a, b[0:4]), sql=True)
command.stamp(self.cfg, b[0:4], sql=True)
command.downgrade(self.cfg, "%s:%s" % (c, b[0:4]), sql=True)
def test_starting_rev_context_runs_abbreviated(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
context.run_migrations()
-""")
+"""
+ )
command.upgrade(self.cfg, "%s:%s" % (b[0:4], c), sql=True)
command.downgrade(self.cfg, "%s:%s" % (b[0:4], a), sql=True)
def test_destination_rev_context_runs_abbreviated(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite')
context.run_migrations()
-""")
+"""
+ )
command.upgrade(self.cfg, "%s:%s" % (a, b[0:4]), sql=True)
command.stamp(self.cfg, b[0:4], sql=True)
command.downgrade(self.cfg, "%s:%s" % (c, b[0:4]), sql=True)
diff --git a/tests/test_op.py b/tests/test_op.py
index f9a6c51..fb2db5f 100644
--- a/tests/test_op.py
+++ b/tests/test_op.py
@@ -1,7 +1,6 @@
"""Test against the builders in the op.* module."""
-from sqlalchemy import Integer, Column, ForeignKey, \
- Table, String, Boolean
+from sqlalchemy import Integer, Column, ForeignKey, Table, String, Boolean
from sqlalchemy.sql import column, func, text
from sqlalchemy import event
@@ -17,19 +16,18 @@ from alembic.operations import schemaobj, ops
@event.listens_for(Table, "after_parent_attach")
def _add_cols(table, metadata):
if table.name == "tbl_with_auto_appended_column":
- table.append_column(Column('bat', Integer))
+ table.append_column(Column("bat", Integer))
class OpTest(TestBase):
-
def test_rename_table(self):
context = op_fixture()
- op.rename_table('t1', 't2')
+ op.rename_table("t1", "t2")
context.assert_("ALTER TABLE t1 RENAME TO t2")
def test_rename_table_schema(self):
context = op_fixture()
- op.rename_table('t1', 't2', schema="foo")
+ op.rename_table("t1", "t2", schema="foo")
context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2")
def test_create_index_no_expr_allowed(self):
@@ -37,15 +35,21 @@ class OpTest(TestBase):
assert_raises_message(
ValueError,
r"String or text\(\) construct expected",
- op.create_index, 'name', 'tname', [func.foo(column('x'))]
+ op.create_index,
+ "name",
+ "tname",
+ [func.foo(column("x"))],
)
def test_add_column_schema_hard_quoting(self):
from sqlalchemy.sql.schema import quoted_name
+
context = op_fixture("postgresql")
op.add_column(
- "somename", Column("colname", String),
- schema=quoted_name("some.schema", quote=True))
+ "somename",
+ Column("colname", String),
+ schema=quoted_name("some.schema", quote=True),
+ )
context.assert_(
'ALTER TABLE "some.schema".somename ADD COLUMN colname VARCHAR'
@@ -53,68 +57,67 @@ class OpTest(TestBase):
def test_rename_table_schema_hard_quoting(self):
from sqlalchemy.sql.schema import quoted_name
+
context = op_fixture("postgresql")
op.rename_table(
- 't1', 't2',
- schema=quoted_name("some.schema", quote=True))
-
- context.assert_(
- 'ALTER TABLE "some.schema".t1 RENAME TO t2'
+ "t1", "t2", schema=quoted_name("some.schema", quote=True)
)
+ context.assert_('ALTER TABLE "some.schema".t1 RENAME TO t2')
+
def test_add_constraint_schema_hard_quoting(self):
from sqlalchemy.sql.schema import quoted_name
+
context = op_fixture("postgresql")
op.create_check_constraint(
"ck_user_name_len",
"user_table",
- func.len(column('name')) > 5,
- schema=quoted_name("some.schema", quote=True)
+ func.len(column("name")) > 5,
+ schema=quoted_name("some.schema", quote=True),
)
context.assert_(
'ALTER TABLE "some.schema".user_table ADD '
- 'CONSTRAINT ck_user_name_len CHECK (len(name) > 5)'
+ "CONSTRAINT ck_user_name_len CHECK (len(name) > 5)"
)
def test_create_index_quoting(self):
context = op_fixture("postgresql")
- op.create_index(
- 'geocoded',
- 'locations',
- ["IShouldBeQuoted"])
+ op.create_index("geocoded", "locations", ["IShouldBeQuoted"])
context.assert_(
- 'CREATE INDEX geocoded ON locations ("IShouldBeQuoted")')
+ 'CREATE INDEX geocoded ON locations ("IShouldBeQuoted")'
+ )
def test_create_index_expressions(self):
context = op_fixture()
- op.create_index(
- 'geocoded',
- 'locations',
- [text('lower(coordinates)')])
+ op.create_index("geocoded", "locations", [text("lower(coordinates)")])
context.assert_(
- "CREATE INDEX geocoded ON locations (lower(coordinates))")
+ "CREATE INDEX geocoded ON locations (lower(coordinates))"
+ )
def test_add_column(self):
context = op_fixture()
- op.add_column('t1', Column('c1', Integer, nullable=False))
+ op.add_column("t1", Column("c1", Integer, nullable=False))
context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL")
def test_add_column_schema(self):
context = op_fixture()
- op.add_column('t1', Column('c1', Integer, nullable=False), schema="foo")
+ op.add_column(
+ "t1", Column("c1", Integer, nullable=False), schema="foo"
+ )
context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL")
def test_add_column_with_default(self):
context = op_fixture()
op.add_column(
- 't1', Column('c1', Integer, nullable=False, server_default="12"))
+ "t1", Column("c1", Integer, nullable=False, server_default="12")
+ )
context.assert_(
- "ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL")
+ "ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL"
+ )
def test_add_column_with_index(self):
context = op_fixture()
- op.add_column(
- 't1', Column('c1', Integer, nullable=False, index=True))
+ op.add_column("t1", Column("c1", Integer, nullable=False, index=True))
context.assert_(
"ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
"CREATE INDEX ix_t1_c1 ON t1 (c1)",
@@ -122,107 +125,117 @@ class OpTest(TestBase):
def test_add_column_schema_with_default(self):
context = op_fixture()
- op.add_column('t1',
- Column('c1', Integer, nullable=False, server_default="12"),
- schema='foo')
+ op.add_column(
+ "t1",
+ Column("c1", Integer, nullable=False, server_default="12"),
+ schema="foo",
+ )
context.assert_(
- "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL")
+ "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL"
+ )
def test_add_column_fk(self):
context = op_fixture()
op.add_column(
- 't1', Column('c1', Integer, ForeignKey('c2.id'), nullable=False))
+ "t1", Column("c1", Integer, ForeignKey("c2.id"), nullable=False)
+ )
context.assert_(
"ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
- "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)"
+ "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)",
)
def test_add_column_schema_fk(self):
context = op_fixture()
- op.add_column('t1',
- Column('c1', Integer, ForeignKey('c2.id'), nullable=False),
- schema='foo')
+ op.add_column(
+ "t1",
+ Column("c1", Integer, ForeignKey("c2.id"), nullable=False),
+ schema="foo",
+ )
context.assert_(
"ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
- "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)"
+ "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)",
)
def test_add_column_schema_type(self):
"""Test that a schema type generates its constraints...."""
context = op_fixture()
- op.add_column('t1', Column('c1', Boolean, nullable=False))
+ op.add_column("t1", Column("c1", Boolean, nullable=False))
context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
- 'ALTER TABLE t1 ADD CHECK (c1 IN (0, 1))'
+ "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+ "ALTER TABLE t1 ADD CHECK (c1 IN (0, 1))",
)
def test_add_column_schema_schema_type(self):
"""Test that a schema type generates its constraints...."""
context = op_fixture()
- op.add_column('t1', Column('c1', Boolean, nullable=False), schema='foo')
+ op.add_column(
+ "t1", Column("c1", Boolean, nullable=False), schema="foo"
+ )
context.assert_(
- 'ALTER TABLE foo.t1 ADD COLUMN c1 BOOLEAN NOT NULL',
- 'ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))'
+ "ALTER TABLE foo.t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+ "ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))",
)
def test_add_column_schema_type_checks_rule(self):
"""Test that a schema type doesn't generate a
constraint based on check rule."""
- context = op_fixture('postgresql')
- op.add_column('t1', Column('c1', Boolean, nullable=False))
- context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
- )
+ context = op_fixture("postgresql")
+ op.add_column("t1", Column("c1", Boolean, nullable=False))
+ context.assert_("ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL")
def test_add_column_fk_self_referential(self):
context = op_fixture()
op.add_column(
- 't1', Column('c1', Integer, ForeignKey('t1.c2'), nullable=False))
+ "t1", Column("c1", Integer, ForeignKey("t1.c2"), nullable=False)
+ )
context.assert_(
"ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
- "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)"
+ "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)",
)
def test_add_column_schema_fk_self_referential(self):
context = op_fixture()
op.add_column(
- 't1',
- Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False),
- schema='foo')
+ "t1",
+ Column("c1", Integer, ForeignKey("foo.t1.c2"), nullable=False),
+ schema="foo",
+ )
context.assert_(
"ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
- "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)"
+ "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)",
)
def test_add_column_fk_schema(self):
context = op_fixture()
op.add_column(
- 't1',
- Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False))
+ "t1",
+ Column("c1", Integer, ForeignKey("remote.t2.c2"), nullable=False),
+ )
context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL',
- 'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
+ "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
+ "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)",
)
def test_add_column_schema_fk_schema(self):
context = op_fixture()
op.add_column(
- 't1',
- Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False),
- schema='foo')
+ "t1",
+ Column("c1", Integer, ForeignKey("remote.t2.c2"), nullable=False),
+ schema="foo",
+ )
context.assert_(
- 'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL',
- 'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
+ "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
+ "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)",
)
def test_drop_column(self):
context = op_fixture()
- op.drop_column('t1', 'c1')
+ op.drop_column("t1", "c1")
context.assert_("ALTER TABLE t1 DROP COLUMN c1")
def test_drop_column_schema(self):
context = op_fixture()
- op.drop_column('t1', 'c1', schema='foo')
+ op.drop_column("t1", "c1", schema="foo")
context.assert_("ALTER TABLE foo.t1 DROP COLUMN c1")
def test_alter_column_nullable(self):
@@ -236,7 +249,7 @@ class OpTest(TestBase):
def test_alter_column_schema_nullable(self):
context = op_fixture()
- op.alter_column("t", "c", nullable=True, schema='foo')
+ op.alter_column("t", "c", nullable=True, schema="foo")
context.assert_(
# TODO: not sure if this is PG only or standard
# SQL
@@ -254,7 +267,7 @@ class OpTest(TestBase):
def test_alter_column_schema_not_nullable(self):
context = op_fixture()
- op.alter_column("t", "c", nullable=False, schema='foo')
+ op.alter_column("t", "c", nullable=False, schema="foo")
context.assert_(
# TODO: not sure if this is PG only or standard
# SQL
@@ -264,58 +277,50 @@ class OpTest(TestBase):
def test_alter_column_rename(self):
context = op_fixture()
op.alter_column("t", "c", new_column_name="x")
- context.assert_(
- "ALTER TABLE t RENAME c TO x"
- )
+ context.assert_("ALTER TABLE t RENAME c TO x")
def test_alter_column_schema_rename(self):
context = op_fixture()
- op.alter_column("t", "c", new_column_name="x", schema='foo')
- context.assert_(
- "ALTER TABLE foo.t RENAME c TO x"
- )
+ op.alter_column("t", "c", new_column_name="x", schema="foo")
+ context.assert_("ALTER TABLE foo.t RENAME c TO x")
def test_alter_column_type(self):
context = op_fixture()
op.alter_column("t", "c", type_=String(50))
- context.assert_(
- 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)'
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)")
def test_alter_column_schema_type(self):
context = op_fixture()
- op.alter_column("t", "c", type_=String(50), schema='foo')
- context.assert_(
- 'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)'
- )
+ op.alter_column("t", "c", type_=String(50), schema="foo")
+ context.assert_("ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)")
def test_alter_column_set_default(self):
context = op_fixture()
op.alter_column("t", "c", server_default="q")
- context.assert_(
- "ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'"
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'")
def test_alter_column_schema_set_default(self):
context = op_fixture()
- op.alter_column("t", "c", server_default="q", schema='foo')
- context.assert_(
- "ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'"
- )
+ op.alter_column("t", "c", server_default="q", schema="foo")
+ context.assert_("ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'")
def test_alter_column_set_compiled_default(self):
context = op_fixture()
- op.alter_column("t", "c",
- server_default=func.utc_thing(func.current_timestamp()))
+ op.alter_column(
+ "t", "c", server_default=func.utc_thing(func.current_timestamp())
+ )
context.assert_(
"ALTER TABLE t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
)
def test_alter_column_schema_set_compiled_default(self):
context = op_fixture()
- op.alter_column("t", "c",
- server_default=func.utc_thing(func.current_timestamp()),
- schema='foo')
+ op.alter_column(
+ "t",
+ "c",
+ server_default=func.utc_thing(func.current_timestamp()),
+ schema="foo",
+ )
context.assert_(
"ALTER TABLE foo.t ALTER COLUMN c "
"SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
@@ -324,101 +329,98 @@ class OpTest(TestBase):
def test_alter_column_drop_default(self):
context = op_fixture()
op.alter_column("t", "c", server_default=None)
- context.assert_(
- 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT'
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT")
def test_alter_column_schema_drop_default(self):
context = op_fixture()
- op.alter_column("t", "c", server_default=None, schema='foo')
- context.assert_(
- 'ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT'
- )
+ op.alter_column("t", "c", server_default=None, schema="foo")
+ context.assert_("ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT")
def test_alter_column_schema_type_unnamed(self):
- context = op_fixture('mssql', native_boolean=False)
+ context = op_fixture("mssql", native_boolean=False)
op.alter_column("t", "c", type_=Boolean())
context.assert_(
- 'ALTER TABLE t ALTER COLUMN c BIT',
- 'ALTER TABLE t ADD CHECK (c IN (0, 1))'
+ "ALTER TABLE t ALTER COLUMN c BIT",
+ "ALTER TABLE t ADD CHECK (c IN (0, 1))",
)
def test_alter_column_schema_schema_type_unnamed(self):
- context = op_fixture('mssql', native_boolean=False)
- op.alter_column("t", "c", type_=Boolean(), schema='foo')
+ context = op_fixture("mssql", native_boolean=False)
+ op.alter_column("t", "c", type_=Boolean(), schema="foo")
context.assert_(
- 'ALTER TABLE foo.t ALTER COLUMN c BIT',
- 'ALTER TABLE foo.t ADD CHECK (c IN (0, 1))'
+ "ALTER TABLE foo.t ALTER COLUMN c BIT",
+ "ALTER TABLE foo.t ADD CHECK (c IN (0, 1))",
)
def test_alter_column_schema_type_named(self):
- context = op_fixture('mssql', native_boolean=False)
+ context = op_fixture("mssql", native_boolean=False)
op.alter_column("t", "c", type_=Boolean(name="xyz"))
context.assert_(
- 'ALTER TABLE t ALTER COLUMN c BIT',
- 'ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))'
+ "ALTER TABLE t ALTER COLUMN c BIT",
+ "ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))",
)
def test_alter_column_schema_schema_type_named(self):
- context = op_fixture('mssql', native_boolean=False)
- op.alter_column("t", "c", type_=Boolean(name="xyz"), schema='foo')
+ context = op_fixture("mssql", native_boolean=False)
+ op.alter_column("t", "c", type_=Boolean(name="xyz"), schema="foo")
context.assert_(
- 'ALTER TABLE foo.t ALTER COLUMN c BIT',
- 'ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))'
+ "ALTER TABLE foo.t ALTER COLUMN c BIT",
+ "ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))",
)
def test_alter_column_schema_type_existing_type(self):
- context = op_fixture('mssql', native_boolean=False)
+ context = op_fixture("mssql", native_boolean=False)
op.alter_column(
- "t", "c", type_=String(10), existing_type=Boolean(name="xyz"))
+ "t", "c", type_=String(10), existing_type=Boolean(name="xyz")
+ )
context.assert_(
- 'ALTER TABLE t DROP CONSTRAINT xyz',
- 'ALTER TABLE t ALTER COLUMN c VARCHAR(10)'
+ "ALTER TABLE t DROP CONSTRAINT xyz",
+ "ALTER TABLE t ALTER COLUMN c VARCHAR(10)",
)
def test_alter_column_schema_schema_type_existing_type(self):
- context = op_fixture('mssql', native_boolean=False)
- op.alter_column("t", "c", type_=String(10),
- existing_type=Boolean(name="xyz"), schema='foo')
+ context = op_fixture("mssql", native_boolean=False)
+ op.alter_column(
+ "t",
+ "c",
+ type_=String(10),
+ existing_type=Boolean(name="xyz"),
+ schema="foo",
+ )
context.assert_(
- 'ALTER TABLE foo.t DROP CONSTRAINT xyz',
- 'ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)'
+ "ALTER TABLE foo.t DROP CONSTRAINT xyz",
+ "ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)",
)
def test_alter_column_schema_type_existing_type_no_const(self):
- context = op_fixture('postgresql')
+ context = op_fixture("postgresql")
op.alter_column("t", "c", type_=String(10), existing_type=Boolean())
- context.assert_(
- 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)'
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)")
def test_alter_column_schema_schema_type_existing_type_no_const(self):
- context = op_fixture('postgresql')
- op.alter_column("t", "c", type_=String(10), existing_type=Boolean(),
- schema='foo')
- context.assert_(
- 'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)'
+ context = op_fixture("postgresql")
+ op.alter_column(
+ "t", "c", type_=String(10), existing_type=Boolean(), schema="foo"
)
+ context.assert_("ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)")
def test_alter_column_schema_type_existing_type_no_new_type(self):
- context = op_fixture('postgresql')
+ context = op_fixture("postgresql")
op.alter_column("t", "c", nullable=False, existing_type=Boolean())
- context.assert_(
- 'ALTER TABLE t ALTER COLUMN c SET NOT NULL'
- )
+ context.assert_("ALTER TABLE t ALTER COLUMN c SET NOT NULL")
def test_alter_column_schema_schema_type_existing_type_no_new_type(self):
- context = op_fixture('postgresql')
- op.alter_column("t", "c", nullable=False, existing_type=Boolean(),
- schema='foo')
- context.assert_(
- 'ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL'
+ context = op_fixture("postgresql")
+ op.alter_column(
+ "t", "c", nullable=False, existing_type=Boolean(), schema="foo"
)
+ context.assert_("ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL")
def test_add_foreign_key(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'])
+ op.create_foreign_key(
+ "fk_test", "t1", "t2", ["foo", "bar"], ["bat", "hoho"]
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES t2 (bat, hoho)"
@@ -426,9 +428,15 @@ class OpTest(TestBase):
def test_add_foreign_key_schema(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- source_schema='foo2', referent_schema='bar2')
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ source_schema="foo2",
+ referent_schema="bar2",
+ )
context.assert_(
"ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES bar2.t2 (bat, hoho)"
@@ -436,9 +444,15 @@ class OpTest(TestBase):
def test_add_foreign_key_schema_same_tablename(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't1',
- ['foo', 'bar'], ['bat', 'hoho'],
- source_schema='foo2', referent_schema='bar2')
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t1",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ source_schema="foo2",
+ referent_schema="bar2",
+ )
context.assert_(
"ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES bar2.t1 (bat, hoho)"
@@ -446,9 +460,14 @@ class OpTest(TestBase):
def test_add_foreign_key_onupdate(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- onupdate='CASCADE')
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ onupdate="CASCADE",
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES t2 (bat, hoho) ON UPDATE CASCADE"
@@ -456,9 +475,14 @@ class OpTest(TestBase):
def test_add_foreign_key_ondelete(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- ondelete='CASCADE')
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ ondelete="CASCADE",
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES t2 (bat, hoho) ON DELETE CASCADE"
@@ -466,9 +490,14 @@ class OpTest(TestBase):
def test_add_foreign_key_deferrable(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- deferrable=True)
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ deferrable=True,
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES t2 (bat, hoho) DEFERRABLE"
@@ -476,9 +505,14 @@ class OpTest(TestBase):
def test_add_foreign_key_initially(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- initially='INITIAL')
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ initially="INITIAL",
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES t2 (bat, hoho) INITIALLY INITIAL"
@@ -487,9 +521,14 @@ class OpTest(TestBase):
@config.requirements.foreign_key_match
def test_add_foreign_key_match(self):
context = op_fixture()
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- match='SIMPLE')
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ match="SIMPLE",
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
"REFERENCES t2 (bat, hoho) MATCH SIMPLE"
@@ -497,24 +536,44 @@ class OpTest(TestBase):
def test_add_foreign_key_dialect_kw(self):
op_fixture()
- with mock.patch(
- "sqlalchemy.schema.ForeignKeyConstraint"
- ) as fkc:
- op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- foobar_arg='xyz')
+ with mock.patch("sqlalchemy.schema.ForeignKeyConstraint") as fkc:
+ op.create_foreign_key(
+ "fk_test",
+ "t1",
+ "t2",
+ ["foo", "bar"],
+ ["bat", "hoho"],
+ foobar_arg="xyz",
+ )
if config.requirements.foreign_key_match.enabled:
- eq_(fkc.mock_calls[0],
- mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'],
- onupdate=None, ondelete=None, name='fk_test',
- foobar_arg='xyz',
- deferrable=None, initially=None, match=None))
+ eq_(
+ fkc.mock_calls[0],
+ mock.call(
+ ["foo", "bar"],
+ ["t2.bat", "t2.hoho"],
+ onupdate=None,
+ ondelete=None,
+ name="fk_test",
+ foobar_arg="xyz",
+ deferrable=None,
+ initially=None,
+ match=None,
+ ),
+ )
else:
- eq_(fkc.mock_calls[0],
- mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'],
- onupdate=None, ondelete=None, name='fk_test',
- foobar_arg='xyz',
- deferrable=None, initially=None))
+ eq_(
+ fkc.mock_calls[0],
+ mock.call(
+ ["foo", "bar"],
+ ["t2.bat", "t2.hoho"],
+ onupdate=None,
+ ondelete=None,
+ name="fk_test",
+ foobar_arg="xyz",
+ deferrable=None,
+ initially=None,
+ ),
+ )
def test_add_foreign_key_self_referential(self):
context = op_fixture()
@@ -541,9 +600,7 @@ class OpTest(TestBase):
def test_add_check_constraint(self):
context = op_fixture()
op.create_check_constraint(
- "ck_user_name_len",
- "user_table",
- func.len(column('name')) > 5
+ "ck_user_name_len", "user_table", func.len(column("name")) > 5
)
context.assert_(
"ALTER TABLE user_table ADD CONSTRAINT ck_user_name_len "
@@ -555,8 +612,8 @@ class OpTest(TestBase):
op.create_check_constraint(
"ck_user_name_len",
"user_table",
- func.len(column('name')) > 5,
- schema='foo'
+ func.len(column("name")) > 5,
+ schema="foo",
)
context.assert_(
"ALTER TABLE foo.user_table ADD CONSTRAINT ck_user_name_len "
@@ -565,7 +622,7 @@ class OpTest(TestBase):
def test_add_unique_constraint(self):
context = op_fixture()
- op.create_unique_constraint('uk_test', 't1', ['foo', 'bar'])
+ op.create_unique_constraint("uk_test", "t1", ["foo", "bar"])
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
)
@@ -574,12 +631,12 @@ class OpTest(TestBase):
context = op_fixture()
op.create_foreign_key(
- name='some_fk',
- source='some_table',
- referent='referred_table',
- local_cols=['a', 'b'],
- remote_cols=['c', 'd'],
- ondelete='CASCADE'
+ name="some_fk",
+ source="some_table",
+ referent="referred_table",
+ local_cols=["a", "b"],
+ remote_cols=["c", "d"],
+ ondelete="CASCADE",
)
context.assert_(
"ALTER TABLE some_table ADD CONSTRAINT some_fk "
@@ -590,27 +647,26 @@ class OpTest(TestBase):
def test_add_unique_constraint_legacy_kwarg(self):
context = op_fixture()
op.create_unique_constraint(
- name='uk_test',
- source='t1',
- local_cols=['foo', 'bar'])
+ name="uk_test", source="t1", local_cols=["foo", "bar"]
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
)
def test_drop_constraint_legacy_kwarg(self):
context = op_fixture()
- op.drop_constraint(name='pk_name',
- table_name='sometable',
- type_='primary')
- context.assert_(
- "ALTER TABLE sometable DROP CONSTRAINT pk_name"
+ op.drop_constraint(
+ name="pk_name", table_name="sometable", type_="primary"
)
+ context.assert_("ALTER TABLE sometable DROP CONSTRAINT pk_name")
def test_create_pk_legacy_kwarg(self):
context = op_fixture()
- op.create_primary_key(name=None,
- table_name='sometable',
- cols=['router_id', 'l3_agent_id'])
+ op.create_primary_key(
+ name=None,
+ table_name="sometable",
+ cols=["router_id", "l3_agent_id"],
+ )
context.assert_(
"ALTER TABLE sometable ADD PRIMARY KEY (router_id, l3_agent_id)"
)
@@ -623,57 +679,50 @@ class OpTest(TestBase):
"missing required positional argument: columns",
op.create_primary_key,
name=None,
- table_name='sometable',
- wrong_cols=['router_id', 'l3_agent_id']
+ table_name="sometable",
+ wrong_cols=["router_id", "l3_agent_id"],
)
def test_add_unique_constraint_schema(self):
context = op_fixture()
op.create_unique_constraint(
- 'uk_test', 't1', ['foo', 'bar'], schema='foo')
+ "uk_test", "t1", ["foo", "bar"], schema="foo"
+ )
context.assert_(
"ALTER TABLE foo.t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
)
def test_drop_constraint(self):
context = op_fixture()
- op.drop_constraint('foo_bar_bat', 't1')
- context.assert_(
- "ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat"
- )
+ op.drop_constraint("foo_bar_bat", "t1")
+ context.assert_("ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat")
def test_drop_constraint_schema(self):
context = op_fixture()
- op.drop_constraint('foo_bar_bat', 't1', schema='foo')
- context.assert_(
- "ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat"
- )
+ op.drop_constraint("foo_bar_bat", "t1", schema="foo")
+ context.assert_("ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat")
def test_create_index(self):
context = op_fixture()
- op.create_index('ik_test', 't1', ['foo', 'bar'])
- context.assert_(
- "CREATE INDEX ik_test ON t1 (foo, bar)"
- )
+ op.create_index("ik_test", "t1", ["foo", "bar"])
+ context.assert_("CREATE INDEX ik_test ON t1 (foo, bar)")
def test_create_unique_index(self):
context = op_fixture()
- op.create_index('ik_test', 't1', ['foo', 'bar'], unique=True)
- context.assert_(
- "CREATE UNIQUE INDEX ik_test ON t1 (foo, bar)"
- )
+ op.create_index("ik_test", "t1", ["foo", "bar"], unique=True)
+ context.assert_("CREATE UNIQUE INDEX ik_test ON t1 (foo, bar)")
def test_create_index_quote_flag(self):
context = op_fixture()
- op.create_index('ik_test', 't1', ['foo', 'bar'], quote=True)
- context.assert_(
- 'CREATE INDEX "ik_test" ON t1 (foo, bar)'
- )
+ op.create_index("ik_test", "t1", ["foo", "bar"], quote=True)
+ context.assert_('CREATE INDEX "ik_test" ON t1 (foo, bar)')
def test_create_index_table_col_event(self):
context = op_fixture()
- op.create_index('ik_test', 'tbl_with_auto_appended_column', ['foo', 'bar'])
+ op.create_index(
+ "ik_test", "tbl_with_auto_appended_column", ["foo", "bar"]
+ )
context.assert_(
"CREATE INDEX ik_test ON tbl_with_auto_appended_column (foo, bar)"
)
@@ -681,8 +730,8 @@ class OpTest(TestBase):
def test_add_unique_constraint_col_event(self):
context = op_fixture()
op.create_unique_constraint(
- 'ik_test',
- 'tbl_with_auto_appended_column', ['foo', 'bar'])
+ "ik_test", "tbl_with_auto_appended_column", ["foo", "bar"]
+ )
context.assert_(
"ALTER TABLE tbl_with_auto_appended_column "
"ADD CONSTRAINT ik_test UNIQUE (foo, bar)"
@@ -690,45 +739,35 @@ class OpTest(TestBase):
def test_create_index_schema(self):
context = op_fixture()
- op.create_index('ik_test', 't1', ['foo', 'bar'], schema='foo')
- context.assert_(
- "CREATE INDEX ik_test ON foo.t1 (foo, bar)"
- )
+ op.create_index("ik_test", "t1", ["foo", "bar"], schema="foo")
+ context.assert_("CREATE INDEX ik_test ON foo.t1 (foo, bar)")
def test_drop_index(self):
context = op_fixture()
- op.drop_index('ik_test')
- context.assert_(
- "DROP INDEX ik_test"
- )
+ op.drop_index("ik_test")
+ context.assert_("DROP INDEX ik_test")
def test_drop_index_schema(self):
context = op_fixture()
- op.drop_index('ik_test', schema='foo')
- context.assert_(
- "DROP INDEX foo.ik_test"
- )
+ op.drop_index("ik_test", schema="foo")
+ context.assert_("DROP INDEX foo.ik_test")
def test_drop_table(self):
context = op_fixture()
- op.drop_table('tb_test')
- context.assert_(
- "DROP TABLE tb_test"
- )
+ op.drop_table("tb_test")
+ context.assert_("DROP TABLE tb_test")
def test_drop_table_schema(self):
context = op_fixture()
- op.drop_table('tb_test', schema='foo')
- context.assert_(
- "DROP TABLE foo.tb_test"
- )
+ op.drop_table("tb_test", schema="foo")
+ context.assert_("DROP TABLE foo.tb_test")
def test_create_table_selfref(self):
context = op_fixture()
op.create_table(
"some_table",
- Column('id', Integer, primary_key=True),
- Column('st_id', Integer, ForeignKey('some_table.id'))
+ Column("id", Integer, primary_key=True),
+ Column("st_id", Integer, ForeignKey("some_table.id")),
)
context.assert_(
"CREATE TABLE some_table ("
@@ -742,9 +781,9 @@ class OpTest(TestBase):
context = op_fixture()
t1 = op.create_table(
"some_table",
- Column('id', Integer, primary_key=True),
- Column('foo_id', Integer, ForeignKey('foo.id')),
- schema='schema'
+ Column("id", Integer, primary_key=True),
+ Column("foo_id", Integer, ForeignKey("foo.id")),
+ schema="schema",
)
context.assert_(
"CREATE TABLE schema.some_table ("
@@ -760,9 +799,9 @@ class OpTest(TestBase):
context = op_fixture()
t1 = op.create_table(
"some_table",
- Column('x', Integer),
- Column('y', Integer),
- Column('z', Integer),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", Integer),
)
context.assert_(
"CREATE TABLE some_table (x INTEGER, y INTEGER, z INTEGER)"
@@ -773,9 +812,9 @@ class OpTest(TestBase):
context = op_fixture()
op.create_table(
"some_table",
- Column('id', Integer, primary_key=True),
- Column('foo_id', Integer, ForeignKey('foo.id')),
- Column('foo_bar', Integer, ForeignKey('foo.bar')),
+ Column("id", Integer, primary_key=True),
+ Column("foo_id", Integer, ForeignKey("foo.id")),
+ Column("foo_bar", Integer, ForeignKey("foo.bar")),
)
context.assert_(
"CREATE TABLE some_table ("
@@ -792,27 +831,26 @@ class OpTest(TestBase):
from sqlalchemy.sql import table, column
from sqlalchemy import String, Integer
- account = table('account',
- column('name', String),
- column('id', Integer)
- )
+ account = table(
+ "account", column("name", String), column("id", Integer)
+ )
op.execute(
- account.update().
- where(account.c.name == op.inline_literal('account 1')).
- values({'name': op.inline_literal('account 2')})
+ account.update()
+ .where(account.c.name == op.inline_literal("account 1"))
+ .values({"name": op.inline_literal("account 2")})
)
op.execute(
- account.update().
- where(account.c.id == op.inline_literal(1)).
- values({'id': op.inline_literal(2)})
+ account.update()
+ .where(account.c.id == op.inline_literal(1))
+ .values({"id": op.inline_literal(2)})
)
context.assert_(
"UPDATE account SET name='account 2' WHERE account.name = 'account 1'",
- "UPDATE account SET id=2 WHERE account.id = 1"
+ "UPDATE account SET id=2 WHERE account.id = 1",
)
def test_cant_op(self):
- if hasattr(op, '_proxy'):
+ if hasattr(op, "_proxy"):
del op._proxy
assert_raises_message(
NameError,
@@ -820,7 +858,8 @@ class OpTest(TestBase):
"proxy object has not yet been established "
"for the Alembic 'Operations' class. "
"Try placing this code inside a callable.",
- op.inline_literal, "asdf"
+ op.inline_literal,
+ "asdf",
)
def test_naming_changes(self):
@@ -832,17 +871,17 @@ class OpTest(TestBase):
op.alter_column("t", "c", new_column_name="x")
context.assert_("ALTER TABLE t RENAME c TO x")
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("f1", "t1", type="foreignkey")
context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
- context = op_fixture('mysql')
+ context = op_fixture("mysql")
op.drop_constraint("f1", "t1", type_="foreignkey")
context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
def test_naming_changes_drop_idx(self):
- context = op_fixture('mssql')
- op.drop_index('ik_test', tablename='t1')
+ context = op_fixture("mssql")
+ op.drop_index("ik_test", tablename="t1")
context.assert_("DROP INDEX ik_test ON t1")
@@ -852,20 +891,19 @@ class SQLModeOpTest(TestBase):
from sqlalchemy.sql import table, column
from sqlalchemy import String, Integer
- account = table('account',
- column('name', String),
- column('id', Integer)
- )
+ account = table(
+ "account", column("name", String), column("id", Integer)
+ )
op.execute(
- account.update().
- where(account.c.name == op.inline_literal('account 1')).
- values({'name': op.inline_literal('account 2')})
+ account.update()
+ .where(account.c.name == op.inline_literal("account 1"))
+ .values({"name": op.inline_literal("account 2")})
)
- op.execute(text("update table set foo=:bar").bindparams(bar='bat'))
+ op.execute(text("update table set foo=:bar").bindparams(bar="bat"))
context.assert_(
"UPDATE account SET name='account 2' "
"WHERE account.name = 'account 1'",
- "update table set foo='bat'"
+ "update table set foo='bat'",
)
def test_create_table_literal_binds(self):
@@ -873,8 +911,8 @@ class SQLModeOpTest(TestBase):
op.create_table(
"some_table",
- Column('id', Integer, primary_key=True),
- Column('st_id', Integer, ForeignKey('some_table.id'))
+ Column("id", Integer, primary_key=True),
+ Column("st_id", Integer, ForeignKey("some_table.id")),
)
context.assert_(
@@ -907,7 +945,7 @@ class CustomOpTest(TestBase):
operations.execute("CREATE SEQUENCE %s" % operation.sequence_name)
context = op_fixture()
- op.create_sequence('foob')
+ op.create_sequence("foob")
context.assert_("CREATE SEQUENCE foob")
@@ -923,48 +961,36 @@ class EnsureOrigObjectFromToTest(TestBase):
def test_drop_index(self):
schema_obj = schemaobj.SchemaObjects()
- idx = schema_obj.index('x', 'y', ['z'])
+ idx = schema_obj.index("x", "y", ["z"])
op = ops.DropIndexOp.from_index(idx)
- is_(
- op.to_index(), idx
- )
+ is_(op.to_index(), idx)
def test_create_index(self):
schema_obj = schemaobj.SchemaObjects()
- idx = schema_obj.index('x', 'y', ['z'])
+ idx = schema_obj.index("x", "y", ["z"])
op = ops.CreateIndexOp.from_index(idx)
- is_(
- op.to_index(), idx
- )
+ is_(op.to_index(), idx)
def test_drop_table(self):
schema_obj = schemaobj.SchemaObjects()
- table = schema_obj.table('x', Column('q', Integer))
+ table = schema_obj.table("x", Column("q", Integer))
op = ops.DropTableOp.from_table(table)
- is_(
- op.to_table(), table
- )
+ is_(op.to_table(), table)
def test_create_table(self):
schema_obj = schemaobj.SchemaObjects()
- table = schema_obj.table('x', Column('q', Integer))
+ table = schema_obj.table("x", Column("q", Integer))
op = ops.CreateTableOp.from_table(table)
- is_(
- op.to_table(), table
- )
+ is_(op.to_table(), table)
def test_drop_unique_constraint(self):
schema_obj = schemaobj.SchemaObjects()
- const = schema_obj.unique_constraint('x', 'foobar', ['a'])
+ const = schema_obj.unique_constraint("x", "foobar", ["a"])
op = ops.DropConstraintOp.from_constraint(const)
- is_(
- op.to_constraint(), const
- )
+ is_(op.to_constraint(), const)
def test_drop_constraint_not_available(self):
- op = ops.DropConstraintOp('x', 'y', type_='unique')
+ op = ops.DropConstraintOp("x", "y", type_="unique")
assert_raises_message(
- ValueError,
- "constraint cannot be produced",
- op.to_constraint
+ ValueError, "constraint cannot be produced", op.to_constraint
)
diff --git a/tests/test_op_naming_convention.py b/tests/test_op_naming_convention.py
index fd70faa..fbcd181 100644
--- a/tests/test_op_naming_convention.py
+++ b/tests/test_op_naming_convention.py
@@ -1,5 +1,11 @@
-from sqlalchemy import Integer, Column, \
- Table, Boolean, MetaData, CheckConstraint
+from sqlalchemy import (
+ Integer,
+ Column,
+ Table,
+ Boolean,
+ MetaData,
+ CheckConstraint,
+)
from sqlalchemy.sql import column, func
from alembic import op
@@ -9,16 +15,14 @@ from alembic.testing.fixtures import TestBase
class AutoNamingConventionTest(TestBase):
- __requires__ = ('sqlalchemy_094', )
+ __requires__ = ("sqlalchemy_094",)
def test_add_check_constraint(self):
- context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"
- })
+ context = op_fixture(
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
op.create_check_constraint(
- "foo",
- "user_table",
- func.len(column('name')) > 5
+ "foo", "user_table", func.len(column("name")) > 5
)
context.assert_(
"ALTER TABLE user_table ADD CONSTRAINT ck_user_table_foo "
@@ -26,13 +30,9 @@ class AutoNamingConventionTest(TestBase):
)
def test_add_check_constraint_name_is_none(self):
- context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_foo"
- })
+ context = op_fixture(naming_convention={"ck": "ck_%(table_name)s_foo"})
op.create_check_constraint(
- None,
- "user_table",
- func.len(column('name')) > 5
+ None, "user_table", func.len(column("name")) > 5
)
context.assert_(
"ALTER TABLE user_table ADD CONSTRAINT ck_user_table_foo "
@@ -40,44 +40,29 @@ class AutoNamingConventionTest(TestBase):
)
def test_add_unique_constraint_name_is_none(self):
- context = op_fixture(naming_convention={
- "uq": "uq_%(table_name)s_foo"
- })
- op.create_unique_constraint(
- None,
- "user_table",
- 'x'
- )
+ context = op_fixture(naming_convention={"uq": "uq_%(table_name)s_foo"})
+ op.create_unique_constraint(None, "user_table", "x")
context.assert_(
"ALTER TABLE user_table ADD CONSTRAINT uq_user_table_foo UNIQUE (x)"
)
def test_add_index_name_is_none(self):
- context = op_fixture(naming_convention={
- "ix": "ix_%(table_name)s_foo"
- })
- op.create_index(
- None,
- "user_table",
- 'x'
- )
- context.assert_(
- "CREATE INDEX ix_user_table_foo ON user_table (x)"
- )
+ context = op_fixture(naming_convention={"ix": "ix_%(table_name)s_foo"})
+ op.create_index(None, "user_table", "x")
+ context.assert_("CREATE INDEX ix_user_table_foo ON user_table (x)")
def test_add_check_constraint_already_named_from_schema(self):
m1 = MetaData(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
ck = CheckConstraint("im a constraint", name="cc1")
- Table('t', m1, Column('x'), ck)
+ Table("t", m1, Column("x"), ck)
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
-
- op.create_table(
- "some_table",
- Column('x', Integer, ck),
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
)
+
+ op.create_table("some_table", Column("x", Integer, ck))
context.assert_(
"CREATE TABLE some_table "
"(x INTEGER CONSTRAINT ck_t_cc1 CHECK (im a constraint))"
@@ -85,11 +70,12 @@ class AutoNamingConventionTest(TestBase):
def test_add_check_constraint_inline_on_table(self):
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
op.create_table(
"some_table",
- Column('x', Integer),
- CheckConstraint("im a constraint", name="cc1")
+ Column("x", Integer),
+ CheckConstraint("im a constraint", name="cc1"),
)
context.assert_(
"CREATE TABLE some_table "
@@ -98,11 +84,12 @@ class AutoNamingConventionTest(TestBase):
def test_add_check_constraint_inline_on_table_w_f(self):
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
op.create_table(
"some_table",
- Column('x', Integer),
- CheckConstraint("im a constraint", name=op.f("ck_some_table_cc1"))
+ Column("x", Integer),
+ CheckConstraint("im a constraint", name=op.f("ck_some_table_cc1")),
)
context.assert_(
"CREATE TABLE some_table "
@@ -111,10 +98,13 @@ class AutoNamingConventionTest(TestBase):
def test_add_check_constraint_inline_on_column(self):
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
op.create_table(
"some_table",
- Column('x', Integer, CheckConstraint("im a constraint", name="cc1"))
+ Column(
+ "x", Integer, CheckConstraint("im a constraint", name="cc1")
+ ),
)
context.assert_(
"CREATE TABLE some_table "
@@ -123,12 +113,15 @@ class AutoNamingConventionTest(TestBase):
def test_add_check_constraint_inline_on_column_w_f(self):
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
op.create_table(
"some_table",
Column(
- 'x', Integer,
- CheckConstraint("im a constraint", name=op.f("ck_q_cc1")))
+ "x",
+ Integer,
+ CheckConstraint("im a constraint", name=op.f("ck_q_cc1")),
+ ),
)
context.assert_(
"CREATE TABLE some_table "
@@ -136,22 +129,23 @@ class AutoNamingConventionTest(TestBase):
)
def test_add_column_schema_type(self):
- context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"
- })
- op.add_column('t1', Column('c1', Boolean(name='foo'), nullable=False))
+ context = op_fixture(
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
+ op.add_column("t1", Column("c1", Boolean(name="foo"), nullable=False))
context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
- 'ALTER TABLE t1 ADD CONSTRAINT ck_t1_foo CHECK (c1 IN (0, 1))'
+ "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+ "ALTER TABLE t1 ADD CONSTRAINT ck_t1_foo CHECK (c1 IN (0, 1))",
)
def test_add_column_schema_type_w_f(self):
- context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"
- })
+ context = op_fixture(
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+ )
op.add_column(
- 't1', Column('c1', Boolean(name=op.f('foo')), nullable=False))
+ "t1", Column("c1", Boolean(name=op.f("foo")), nullable=False)
+ )
context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
- 'ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))'
+ "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+ "ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))",
)
diff --git a/tests/test_oracle.py b/tests/test_oracle.py
index 8b9c9e5..86e0ece 100644
--- a/tests/test_oracle.py
+++ b/tests/test_oracle.py
@@ -1,23 +1,24 @@
-
from sqlalchemy import Integer, Column
from alembic import op, command
from alembic.testing.fixtures import TestBase
from alembic.testing.fixtures import op_fixture, capture_context_buffer
-from alembic.testing.env import _no_sql_testing_config, staging_env, \
- three_rev_fixture, clear_staging_env
+from alembic.testing.env import (
+ _no_sql_testing_config,
+ staging_env,
+ three_rev_fixture,
+ clear_staging_env,
+)
class FullEnvironmentTests(TestBase):
-
@classmethod
def setup_class(cls):
staging_env()
cls.cfg = cfg = _no_sql_testing_config("oracle")
- cls.a, cls.b, cls.c = \
- three_rev_fixture(cfg)
+ cls.a, cls.b, cls.c = three_rev_fixture(cfg)
@classmethod
def teardown_class(cls):
@@ -42,113 +43,99 @@ class FullEnvironmentTests(TestBase):
class OpTest(TestBase):
-
def test_add_column(self):
- context = op_fixture('oracle')
- op.add_column('t1', Column('c1', Integer, nullable=False))
+ context = op_fixture("oracle")
+ op.add_column("t1", Column("c1", Integer, nullable=False))
context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
def test_add_column_with_default(self):
context = op_fixture("oracle")
op.add_column(
- 't1', Column('c1', Integer, nullable=False, server_default="12"))
+ "t1", Column("c1", Integer, nullable=False, server_default="12")
+ )
context.assert_("ALTER TABLE t1 ADD c1 INTEGER DEFAULT '12' NOT NULL")
def test_alter_column_rename_oracle(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", name="x")
- context.assert_(
- "ALTER TABLE t RENAME COLUMN c TO x"
- )
+ context.assert_("ALTER TABLE t RENAME COLUMN c TO x")
def test_alter_column_new_type(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", type_=Integer)
- context.assert_(
- 'ALTER TABLE t MODIFY c INTEGER'
- )
+ context.assert_("ALTER TABLE t MODIFY c INTEGER")
def test_drop_index(self):
- context = op_fixture('oracle')
- op.drop_index('my_idx', 'my_table')
+ context = op_fixture("oracle")
+ op.drop_index("my_idx", "my_table")
context.assert_contains("DROP INDEX my_idx")
def test_drop_column_w_default(self):
- context = op_fixture('oracle')
- op.drop_column('t1', 'c1')
- context.assert_(
- "ALTER TABLE t1 DROP COLUMN c1"
- )
+ context = op_fixture("oracle")
+ op.drop_column("t1", "c1")
+ context.assert_("ALTER TABLE t1 DROP COLUMN c1")
def test_drop_column_w_check(self):
- context = op_fixture('oracle')
- op.drop_column('t1', 'c1')
- context.assert_(
- "ALTER TABLE t1 DROP COLUMN c1"
- )
+ context = op_fixture("oracle")
+ op.drop_column("t1", "c1")
+ context.assert_("ALTER TABLE t1 DROP COLUMN c1")
def test_alter_column_nullable_w_existing_type(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", nullable=True, existing_type=Integer)
- context.assert_(
- "ALTER TABLE t MODIFY c NULL"
- )
+ context.assert_("ALTER TABLE t MODIFY c NULL")
def test_alter_column_not_nullable_w_existing_type(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", nullable=False, existing_type=Integer)
- context.assert_(
- "ALTER TABLE t MODIFY c NOT NULL"
- )
+ context.assert_("ALTER TABLE t MODIFY c NOT NULL")
def test_alter_column_nullable_w_new_type(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", nullable=True, type_=Integer)
context.assert_(
- "ALTER TABLE t MODIFY c NULL",
- 'ALTER TABLE t MODIFY c INTEGER'
+ "ALTER TABLE t MODIFY c NULL", "ALTER TABLE t MODIFY c INTEGER"
)
def test_alter_column_not_nullable_w_new_type(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", nullable=False, type_=Integer)
context.assert_(
- "ALTER TABLE t MODIFY c NOT NULL",
- "ALTER TABLE t MODIFY c INTEGER"
+ "ALTER TABLE t MODIFY c NOT NULL", "ALTER TABLE t MODIFY c INTEGER"
)
def test_alter_add_server_default(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", server_default="5")
- context.assert_(
- "ALTER TABLE t MODIFY c DEFAULT '5'"
- )
+ context.assert_("ALTER TABLE t MODIFY c DEFAULT '5'")
def test_alter_replace_server_default(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column(
- "t", "c", server_default="5", existing_server_default="6")
- context.assert_(
- "ALTER TABLE t MODIFY c DEFAULT '5'"
+ "t", "c", server_default="5", existing_server_default="6"
)
+ context.assert_("ALTER TABLE t MODIFY c DEFAULT '5'")
def test_alter_remove_server_default(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column("t", "c", server_default=None)
- context.assert_(
- "ALTER TABLE t MODIFY c DEFAULT NULL"
- )
+ context.assert_("ALTER TABLE t MODIFY c DEFAULT NULL")
def test_alter_do_everything(self):
- context = op_fixture('oracle')
+ context = op_fixture("oracle")
op.alter_column(
- "t", "c", name="c2", nullable=True,
- type_=Integer, server_default="5")
+ "t",
+ "c",
+ name="c2",
+ nullable=True,
+ type_=Integer,
+ server_default="5",
+ )
context.assert_(
- 'ALTER TABLE t MODIFY c NULL',
+ "ALTER TABLE t MODIFY c NULL",
"ALTER TABLE t MODIFY c DEFAULT '5'",
- 'ALTER TABLE t MODIFY c INTEGER',
- 'ALTER TABLE t RENAME COLUMN c TO c2'
+ "ALTER TABLE t MODIFY c INTEGER",
+ "ALTER TABLE t RENAME COLUMN c TO c2",
)
# TODO: when we add schema support
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index 23ec49c..61ba2d1 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -1,13 +1,28 @@
-
-from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
- String, Interval, Sequence, Numeric, BigInteger, Float, Numeric
+from sqlalchemy import (
+ DateTime,
+ MetaData,
+ Table,
+ Column,
+ text,
+ Integer,
+ String,
+ Interval,
+ Sequence,
+ Numeric,
+ BigInteger,
+ Float,
+ Numeric,
+)
from sqlalchemy.dialects.postgresql import ARRAY, UUID, BYTEA
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import types
from alembic.operations import Operations
from sqlalchemy.sql import table, column
-from alembic.autogenerate.compare import \
- _compare_server_default, _compare_tables, _render_server_default_for_compare
+from alembic.autogenerate.compare import (
+ _compare_server_default,
+ _compare_tables,
+ _render_server_default_for_compare,
+)
from alembic.operations import ops
from alembic import command, util
@@ -16,8 +31,12 @@ from alembic.script import ScriptDirectory
from alembic.autogenerate import api
from alembic.testing import eq_, provide_metadata
-from alembic.testing.env import staging_env, clear_staging_env, \
- _no_sql_testing_config, write_script
+from alembic.testing.env import (
+ staging_env,
+ clear_staging_env,
+ _no_sql_testing_config,
+ write_script,
+)
from alembic.testing.fixtures import capture_context_buffer
from alembic.testing.fixtures import TestBase
from alembic.testing.fixtures import op_fixture
@@ -36,79 +55,79 @@ if util.sqla_09:
class PostgresqlOpTest(TestBase):
-
def test_rename_table_postgresql(self):
context = op_fixture("postgresql")
- op.rename_table('t1', 't2')
+ op.rename_table("t1", "t2")
context.assert_("ALTER TABLE t1 RENAME TO t2")
def test_rename_table_schema_postgresql(self):
context = op_fixture("postgresql")
- op.rename_table('t1', 't2', schema="foo")
+ op.rename_table("t1", "t2", schema="foo")
context.assert_("ALTER TABLE foo.t1 RENAME TO t2")
def test_create_index_postgresql_expressions(self):
context = op_fixture("postgresql")
op.create_index(
- 'geocoded',
- 'locations',
- [text('lower(coordinates)')],
- postgresql_where=text("locations.coordinates != Null"))
+ "geocoded",
+ "locations",
+ [text("lower(coordinates)")],
+ postgresql_where=text("locations.coordinates != Null"),
+ )
context.assert_(
"CREATE INDEX geocoded ON locations (lower(coordinates)) "
- "WHERE locations.coordinates != Null")
+ "WHERE locations.coordinates != Null"
+ )
def test_create_index_postgresql_where(self):
context = op_fixture("postgresql")
op.create_index(
- 'geocoded',
- 'locations',
- ['coordinates'],
- postgresql_where=text("locations.coordinates != Null"))
+ "geocoded",
+ "locations",
+ ["coordinates"],
+ postgresql_where=text("locations.coordinates != Null"),
+ )
context.assert_(
"CREATE INDEX geocoded ON locations (coordinates) "
- "WHERE locations.coordinates != Null")
+ "WHERE locations.coordinates != Null"
+ )
@config.requirements.fail_before_sqla_099
def test_create_index_postgresql_concurrently(self):
context = op_fixture("postgresql")
op.create_index(
- 'geocoded',
- 'locations',
- ['coordinates'],
- postgresql_concurrently=True)
+ "geocoded",
+ "locations",
+ ["coordinates"],
+ postgresql_concurrently=True,
+ )
context.assert_(
- "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)")
+ "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)"
+ )
@config.requirements.fail_before_sqla_110
def test_drop_index_postgresql_concurrently(self):
context = op_fixture("postgresql")
- op.drop_index(
- 'geocoded',
- 'locations',
- postgresql_concurrently=True)
- context.assert_(
- "DROP INDEX CONCURRENTLY geocoded")
+ op.drop_index("geocoded", "locations", postgresql_concurrently=True)
+ context.assert_("DROP INDEX CONCURRENTLY geocoded")
def test_alter_column_type_using(self):
- context = op_fixture('postgresql')
- op.alter_column("t", "c", type_=Integer, postgresql_using='c::integer')
+ context = op_fixture("postgresql")
+ op.alter_column("t", "c", type_=Integer, postgresql_using="c::integer")
context.assert_(
- 'ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer'
+ "ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer"
)
def test_col_w_pk_is_serial(self):
context = op_fixture("postgresql")
- op.add_column("some_table", Column('q', Integer, primary_key=True))
- context.assert_(
- 'ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL'
- )
+ op.add_column("some_table", Column("q", Integer, primary_key=True))
+ context.assert_("ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL")
@config.requirements.fail_before_sqla_100
def test_create_exclude_constraint(self):
context = op_fixture("postgresql")
op.create_exclude_constraint(
- "ex1", "t1", ('x', '>'), where='x > 5', using="gist")
+ "ex1", "t1", ("x", ">"), where="x > 5", using="gist"
+ )
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT ex1 EXCLUDE USING gist (x WITH >) "
"WHERE (x > 5)"
@@ -118,8 +137,12 @@ class PostgresqlOpTest(TestBase):
def test_create_exclude_constraint_quoted_literal(self):
context = op_fixture("postgresql")
op.create_exclude_constraint(
- "ex1", "SomeTable", ('"SomeColumn"', '>'),
- where='"SomeColumn" > 5', using="gist")
+ "ex1",
+ "SomeTable",
+ ('"SomeColumn"', ">"),
+ where='"SomeColumn" > 5',
+ using="gist",
+ )
context.assert_(
'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE USING gist '
'("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)'
@@ -129,8 +152,12 @@ class PostgresqlOpTest(TestBase):
def test_create_exclude_constraint_quoted_column(self):
context = op_fixture("postgresql")
op.create_exclude_constraint(
- "ex1", "SomeTable", (column("SomeColumn"), '>'),
- where=column("SomeColumn") > 5, using="gist")
+ "ex1",
+ "SomeTable",
+ (column("SomeColumn"), ">"),
+ where=column("SomeColumn") > 5,
+ using="gist",
+ )
context.assert_(
'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE '
'USING gist ("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)'
@@ -138,7 +165,6 @@ class PostgresqlOpTest(TestBase):
class PGOfflineEnumTest(TestBase):
-
def setUp(self):
staging_env()
self.cfg = cfg = _no_sql_testing_config()
@@ -152,7 +178,10 @@ class PGOfflineEnumTest(TestBase):
clear_staging_env()
def _inline_enum_script(self):
- write_script(self.script, self.rid, """
+ write_script(
+ self.script,
+ self.rid,
+ """
revision = '%s'
down_revision = None
@@ -169,10 +198,15 @@ def upgrade():
def downgrade():
op.drop_table("sometable")
-""" % self.rid)
+"""
+ % self.rid,
+ )
def _distinct_enum_script(self):
- write_script(self.script, self.rid, """
+ write_script(
+ self.script,
+ self.rid,
+ """
revision = '%s'
down_revision = None
@@ -193,14 +227,18 @@ def downgrade():
op.drop_table("sometable")
ENUM(name="pgenum").drop(op.get_bind(), checkfirst=False)
-""" % self.rid)
+"""
+ % self.rid,
+ )
def test_offline_inline_enum_create(self):
self._inline_enum_script()
with capture_context_buffer() as buf:
command.upgrade(self.cfg, self.rid, sql=True)
- assert "CREATE TYPE pgenum AS "\
+ assert (
+ "CREATE TYPE pgenum AS "
"ENUM ('one', 'two', 'three')" in buf.getvalue()
+ )
assert "CREATE TABLE sometable (\n data pgenum\n)" in buf.getvalue()
def test_offline_inline_enum_drop(self):
@@ -215,8 +253,10 @@ def downgrade():
self._distinct_enum_script()
with capture_context_buffer() as buf:
command.upgrade(self.cfg, self.rid, sql=True)
- assert "CREATE TYPE pgenum AS ENUM "\
+ assert (
+ "CREATE TYPE pgenum AS ENUM "
"('one', 'two', 'three')" in buf.getvalue()
+ )
assert "CREATE TABLE sometable (\n data pgenum\n)" in buf.getvalue()
def test_offline_distinct_enum_drop(self):
@@ -228,23 +268,27 @@ def downgrade():
class PostgresqlInlineLiteralTest(TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
@classmethod
def setup_class(cls):
cls.bind = config.db
- cls.bind.execute("""
+ cls.bind.execute(
+ """
create table tab (
col varchar(50)
)
- """)
- cls.bind.execute("""
+ """
+ )
+ cls.bind.execute(
+ """
insert into tab (col) values
('old data 1'),
('old data 2.1'),
('old data 3')
- """)
+ """
+ )
@classmethod
def teardown_class(cls):
@@ -260,35 +304,32 @@ class PostgresqlInlineLiteralTest(TestBase):
def test_inline_percent(self):
# TODO: here's the issue, you need to escape this.
- tab = table('tab', column('col'))
+ tab = table("tab", column("col"))
self.op.execute(
- tab.update().where(
- tab.c.col.like(self.op.inline_literal('%.%'))
- ).values(col=self.op.inline_literal('new data')),
- execution_options={'no_parameters': True}
+ tab.update()
+ .where(tab.c.col.like(self.op.inline_literal("%.%")))
+ .values(col=self.op.inline_literal("new data")),
+ execution_options={"no_parameters": True},
)
eq_(
self.conn.execute(
- "select count(*) from tab where col='new data'").scalar(),
+ "select count(*) from tab where col='new data'"
+ ).scalar(),
1,
)
class PostgresqlDefaultCompareTest(TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
-
@classmethod
def setup_class(cls):
cls.bind = config.db
staging_env()
cls.migration_context = MigrationContext.configure(
connection=cls.bind.connect(),
- opts={
- 'compare_type': True,
- 'compare_server_default': True
- }
+ opts={"compare_type": True, "compare_server_default": True},
)
def setUp(self):
@@ -303,216 +344,166 @@ class PostgresqlDefaultCompareTest(TestBase):
self.metadata.drop_all()
def _compare_default_roundtrip(
- self, type_, orig_default, alternate=None, diff_expected=None):
- diff_expected = diff_expected \
- if diff_expected is not None \
+ self, type_, orig_default, alternate=None, diff_expected=None
+ ):
+ diff_expected = (
+ diff_expected
+ if diff_expected is not None
else alternate is not None
+ )
if alternate is None:
alternate = orig_default
- t1 = Table("test", self.metadata,
- Column("somecol", type_, server_default=orig_default))
- t2 = Table("test", MetaData(),
- Column("somecol", type_, server_default=alternate))
+ t1 = Table(
+ "test",
+ self.metadata,
+ Column("somecol", type_, server_default=orig_default),
+ )
+ t2 = Table(
+ "test",
+ MetaData(),
+ Column("somecol", type_, server_default=alternate),
+ )
t1.create(self.bind)
insp = Inspector.from_engine(self.bind)
cols = insp.get_columns(t1.name)
- insp_col = Column("somecol", cols[0]['type'],
- server_default=text(cols[0]['default']))
+ insp_col = Column(
+ "somecol", cols[0]["type"], server_default=text(cols[0]["default"])
+ )
op = ops.AlterColumnOp("test", "somecol")
_compare_server_default(
- self.autogen_context, op,
- None, "test", "somecol", insp_col, t2.c.somecol)
+ self.autogen_context,
+ op,
+ None,
+ "test",
+ "somecol",
+ insp_col,
+ t2.c.somecol,
+ )
diffs = op.to_diff_tuple()
eq_(bool(diffs), diff_expected)
- def _compare_default(
- self,
- t1, t2, col,
- rendered
- ):
+ def _compare_default(self, t1, t2, col, rendered):
t1.create(self.bind, checkfirst=True)
insp = Inspector.from_engine(self.bind)
cols = insp.get_columns(t1.name)
ctx = self.autogen_context.migration_context
return ctx.impl.compare_server_default(
- None,
- col,
- rendered,
- cols[0]['default'])
+ None, col, rendered, cols[0]["default"]
+ )
def test_compare_interval_str(self):
# this form shouldn't be used but testing here
# for compatibility
- self._compare_default_roundtrip(
- Interval,
- "14 days"
- )
+ self._compare_default_roundtrip(Interval, "14 days")
@config.requirements.postgresql_uuid_ossp
def test_compare_uuid_text(self):
- self._compare_default_roundtrip(
- UUID,
- text("uuid_generate_v4()")
- )
+ self._compare_default_roundtrip(UUID, text("uuid_generate_v4()"))
def test_compare_interval_text(self):
- self._compare_default_roundtrip(
- Interval,
- text("'14 days'")
- )
+ self._compare_default_roundtrip(Interval, text("'14 days'"))
def test_compare_array_of_integer_text(self):
self._compare_default_roundtrip(
- ARRAY(Integer),
- text("(ARRAY[]::integer[])")
+ ARRAY(Integer), text("(ARRAY[]::integer[])")
)
def test_compare_current_timestamp_text(self):
self._compare_default_roundtrip(
- DateTime(),
- text("TIMEZONE('utc', CURRENT_TIMESTAMP)"),
+ DateTime(), text("TIMEZONE('utc', CURRENT_TIMESTAMP)")
)
def test_compare_integer_str(self):
- self._compare_default_roundtrip(
- Integer(),
- "5",
- )
+ self._compare_default_roundtrip(Integer(), "5")
def test_compare_integer_text(self):
- self._compare_default_roundtrip(
- Integer(),
- text("5"),
- )
+ self._compare_default_roundtrip(Integer(), text("5"))
def test_compare_integer_text_diff(self):
- self._compare_default_roundtrip(
- Integer(),
- text("5"), "7"
- )
+ self._compare_default_roundtrip(Integer(), text("5"), "7")
def test_compare_float_str(self):
- self._compare_default_roundtrip(
- Float(),
- "5.2",
- )
+ self._compare_default_roundtrip(Float(), "5.2")
def test_compare_float_text(self):
- self._compare_default_roundtrip(
- Float(),
- text("5.2"),
- )
+ self._compare_default_roundtrip(Float(), text("5.2"))
def test_compare_float_no_diff1(self):
self._compare_default_roundtrip(
- Float(),
- text("5.2"), "5.2",
- diff_expected=False
+ Float(), text("5.2"), "5.2", diff_expected=False
)
def test_compare_float_no_diff2(self):
self._compare_default_roundtrip(
- Float(),
- "5.2", text("5.2"),
- diff_expected=False
+ Float(), "5.2", text("5.2"), diff_expected=False
)
def test_compare_float_no_diff3(self):
self._compare_default_roundtrip(
- Float(),
- text("5"), text("5.0"),
- diff_expected=False
+ Float(), text("5"), text("5.0"), diff_expected=False
)
def test_compare_float_no_diff4(self):
self._compare_default_roundtrip(
- Float(),
- "5", "5.0",
- diff_expected=False
+ Float(), "5", "5.0", diff_expected=False
)
def test_compare_float_no_diff5(self):
self._compare_default_roundtrip(
- Float(),
- text("5"), "5.0",
- diff_expected=False
+ Float(), text("5"), "5.0", diff_expected=False
)
def test_compare_float_no_diff6(self):
self._compare_default_roundtrip(
- Float(),
- "5", text("5.0"),
- diff_expected=False
+ Float(), "5", text("5.0"), diff_expected=False
)
def test_compare_numeric_no_diff(self):
self._compare_default_roundtrip(
- Numeric(),
- text("5"), "5.0",
- diff_expected=False
+ Numeric(), text("5"), "5.0", diff_expected=False
)
def test_compare_unicode_literal(self):
- self._compare_default_roundtrip(
- String(),
- u'im a default'
- )
+ self._compare_default_roundtrip(String(), u"im a default")
# TOOD: will need to actually eval() the repr() and
# spend more effort figuring out exactly the kind of expression
# to use
def _TODO_test_compare_character_str_w_singlequote(self):
- self._compare_default_roundtrip(
- String(),
- "hel''lo",
- )
+ self._compare_default_roundtrip(String(), "hel''lo")
def test_compare_character_str(self):
- self._compare_default_roundtrip(
- String(),
- "hello",
- )
+ self._compare_default_roundtrip(String(), "hello")
def test_compare_character_text(self):
- self._compare_default_roundtrip(
- String(),
- text("'hello'"),
- )
+ self._compare_default_roundtrip(String(), text("'hello'"))
def test_compare_character_str_diff(self):
- self._compare_default_roundtrip(
- String(),
- "hello",
- "there"
- )
+ self._compare_default_roundtrip(String(), "hello", "there")
def test_compare_character_text_diff(self):
self._compare_default_roundtrip(
- String(),
- text("'hello'"),
- text("'there'")
+ String(), text("'hello'"), text("'there'")
)
def test_primary_key_skip(self):
"""Test that SERIAL cols are just skipped"""
- t1 = Table("sometable", self.metadata,
- Column("id", Integer, primary_key=True)
- )
- t2 = Table("sometable", MetaData(),
- Column("id", Integer, primary_key=True)
- )
- assert not self._compare_default(
- t1, t2, t2.c.id, ""
+ t1 = Table(
+ "sometable", self.metadata, Column("id", Integer, primary_key=True)
+ )
+ t2 = Table(
+ "sometable", MetaData(), Column("id", Integer, primary_key=True)
)
+ assert not self._compare_default(t1, t2, t2.c.id, "")
class PostgresqlDetectSerialTest(TestBase):
- __only_on__ = 'postgresql'
+ __only_on__ = "postgresql"
__backend__ = True
@classmethod
@@ -522,10 +513,7 @@ class PostgresqlDetectSerialTest(TestBase):
staging_env()
cls.migration_context = MigrationContext.configure(
connection=cls.conn,
- opts={
- 'compare_type': True,
- 'compare_server_default': True
- }
+ opts={"compare_type": True, "compare_server_default": True},
)
def setUp(self):
@@ -538,7 +526,7 @@ class PostgresqlDetectSerialTest(TestBase):
@provide_metadata
def _expect_default(self, c_expected, col, seq=None):
- Table('t', self.metadata, col)
+ Table("t", self.metadata, col)
self.autogen_context.metadata = self.metadata
@@ -550,43 +538,50 @@ class PostgresqlDetectSerialTest(TestBase):
uo = ops.UpgradeOps(ops=[])
_compare_tables(
- set([(None, 't')]), set([]),
- insp, uo, self.autogen_context)
+ set([(None, "t")]), set([]), insp, uo, self.autogen_context
+ )
diffs = uo.as_diffs()
tab = diffs[0][1]
- eq_(_render_server_default_for_compare(
- tab.c.x.server_default, tab.c.x, self.autogen_context),
- c_expected)
+ eq_(
+ _render_server_default_for_compare(
+ tab.c.x.server_default, tab.c.x, self.autogen_context
+ ),
+ c_expected,
+ )
insp = Inspector.from_engine(config.db)
uo = ops.UpgradeOps(ops=[])
m2 = MetaData()
- Table('t', m2, Column('x', BigInteger()))
+ Table("t", m2, Column("x", BigInteger()))
self.autogen_context.metadata = m2
_compare_tables(
- set([(None, 't')]), set([(None, 't')]),
- insp, uo, self.autogen_context)
+ set([(None, "t")]),
+ set([(None, "t")]),
+ insp,
+ uo,
+ self.autogen_context,
+ )
diffs = uo.as_diffs()
- server_default = diffs[0][0][4]['existing_server_default']
- eq_(_render_server_default_for_compare(
- server_default, tab.c.x, self.autogen_context),
- c_expected)
+ server_default = diffs[0][0][4]["existing_server_default"]
+ eq_(
+ _render_server_default_for_compare(
+ server_default, tab.c.x, self.autogen_context
+ ),
+ c_expected,
+ )
def test_serial(self):
- self._expect_default(
- None,
- Column('x', Integer, primary_key=True)
- )
+ self._expect_default(None, Column("x", Integer, primary_key=True))
def test_separate_seq(self):
seq = Sequence("x_id_seq")
self._expect_default(
"nextval('x_id_seq'::regclass)",
Column(
- 'x', Integer,
- server_default=seq.next_value(), primary_key=True),
- seq
+ "x", Integer, server_default=seq.next_value(), primary_key=True
+ ),
+ seq,
)
def test_numeric(self):
@@ -594,29 +589,29 @@ class PostgresqlDetectSerialTest(TestBase):
self._expect_default(
"nextval('x_id_seq'::regclass)",
Column(
- 'x', Numeric(8, 2), server_default=seq.next_value(),
- primary_key=True),
- seq
+ "x",
+ Numeric(8, 2),
+ server_default=seq.next_value(),
+ primary_key=True,
+ ),
+ seq,
)
def test_no_default(self):
self._expect_default(
- None,
- Column('x', Integer, autoincrement=False, primary_key=True)
+ None, Column("x", Integer, autoincrement=False, primary_key=True)
)
class PostgresqlAutogenRenderTest(TestBase):
-
def setUp(self):
ctx_opts = {
- 'sqlalchemy_module_prefix': 'sa.',
- 'alembic_module_prefix': 'op.',
- 'target_metadata': MetaData()
+ "sqlalchemy_module_prefix": "sa.",
+ "alembic_module_prefix": "op.",
+ "target_metadata": MetaData(),
}
context = MigrationContext.configure(
- dialect_name="postgresql",
- opts=ctx_opts
+ dialect_name="postgresql", opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context)
@@ -625,13 +620,11 @@ class PostgresqlAutogenRenderTest(TestBase):
autogen_context = self.autogen_context
m = MetaData()
- t = Table('t', m,
- Column('x', String),
- Column('y', String)
- )
+ t = Table("t", m, Column("x", String), Column("y", String))
- idx = Index('foo_idx', t.c.x, t.c.y,
- postgresql_where=(t.c.y == 'something'))
+ idx = Index(
+ "foo_idx", t.c.x, t.c.y, postgresql_where=(t.c.y == "something")
+ )
op_obj = ops.CreateIndexOp.from_index(idx)
@@ -639,114 +632,124 @@ class PostgresqlAutogenRenderTest(TestBase):
autogenerate.render_op_text(autogen_context, op_obj),
"""op.create_index('foo_idx', 't', \
['x', 'y'], unique=False, """
- """postgresql_where=sa.text(!U"y = 'something'"))"""
+ """postgresql_where=sa.text(!U"y = 'something'"))""",
)
def test_render_server_default_native_boolean(self):
c = Column(
- 'updated_at', Boolean(),
- server_default=false(),
- nullable=False)
- result = autogenerate.render._render_column(
- c, self.autogen_context,
+ "updated_at", Boolean(), server_default=false(), nullable=False
)
+ result = autogenerate.render._render_column(c, self.autogen_context)
eq_ignore_whitespace(
result,
- 'sa.Column(\'updated_at\', sa.Boolean(), '
- 'server_default=sa.text(!U\'false\'), '
- 'nullable=False)'
+ "sa.Column('updated_at', sa.Boolean(), "
+ "server_default=sa.text(!U'false'), "
+ "nullable=False)",
)
def test_postgresql_array_type(self):
eq_ignore_whitespace(
autogenerate.render._repr_type(
- ARRAY(Integer), self.autogen_context),
- "postgresql.ARRAY(sa.Integer())"
+ ARRAY(Integer), self.autogen_context
+ ),
+ "postgresql.ARRAY(sa.Integer())",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- ARRAY(DateTime(timezone=True)), self.autogen_context),
- "postgresql.ARRAY(sa.DateTime(timezone=True))"
+ ARRAY(DateTime(timezone=True)), self.autogen_context
+ ),
+ "postgresql.ARRAY(sa.DateTime(timezone=True))",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- ARRAY(BYTEA, as_tuple=True, dimensions=2),
- self.autogen_context),
- "postgresql.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)"
+ ARRAY(BYTEA, as_tuple=True, dimensions=2), self.autogen_context
+ ),
+ "postgresql.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)",
)
- assert 'from sqlalchemy.dialects import postgresql' in \
- self.autogen_context.imports
+ assert (
+ "from sqlalchemy.dialects import postgresql"
+ in self.autogen_context.imports
+ )
@config.requirements.sqlalchemy_110
def test_postgresql_hstore_subtypes(self):
eq_ignore_whitespace(
- autogenerate.render._repr_type(
- HSTORE(), self.autogen_context),
- "postgresql.HSTORE(text_type=sa.Text())"
+ autogenerate.render._repr_type(HSTORE(), self.autogen_context),
+ "postgresql.HSTORE(text_type=sa.Text())",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- HSTORE(text_type=String()), self.autogen_context),
- "postgresql.HSTORE(text_type=sa.String())"
+ HSTORE(text_type=String()), self.autogen_context
+ ),
+ "postgresql.HSTORE(text_type=sa.String())",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- HSTORE(text_type=BYTEA()), self.autogen_context),
- "postgresql.HSTORE(text_type=postgresql.BYTEA())"
+ HSTORE(text_type=BYTEA()), self.autogen_context
+ ),
+ "postgresql.HSTORE(text_type=postgresql.BYTEA())",
)
- assert 'from sqlalchemy.dialects import postgresql' in \
- self.autogen_context.imports
+ assert (
+ "from sqlalchemy.dialects import postgresql"
+ in self.autogen_context.imports
+ )
@config.requirements.sqlalchemy_110
def test_generic_array_type(self):
eq_ignore_whitespace(
autogenerate.render._repr_type(
- types.ARRAY(Integer), self.autogen_context),
- "sa.ARRAY(sa.Integer())"
+ types.ARRAY(Integer), self.autogen_context
+ ),
+ "sa.ARRAY(sa.Integer())",
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- types.ARRAY(DateTime(timezone=True)), self.autogen_context),
- "sa.ARRAY(sa.DateTime(timezone=True))"
+ types.ARRAY(DateTime(timezone=True)), self.autogen_context
+ ),
+ "sa.ARRAY(sa.DateTime(timezone=True))",
)
- assert 'from sqlalchemy.dialects import postgresql' not in \
- self.autogen_context.imports
+ assert (
+ "from sqlalchemy.dialects import postgresql"
+ not in self.autogen_context.imports
+ )
eq_ignore_whitespace(
autogenerate.render._repr_type(
types.ARRAY(BYTEA, as_tuple=True, dimensions=2),
- self.autogen_context),
- "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)"
+ self.autogen_context,
+ ),
+ "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)",
)
- assert 'from sqlalchemy.dialects import postgresql' in \
- self.autogen_context.imports
+ assert (
+ "from sqlalchemy.dialects import postgresql"
+ in self.autogen_context.imports
+ )
def test_array_type_user_defined_inner(self):
def repr_type(typestring, object_, autogen_context):
- if typestring == 'type' and isinstance(object_, String):
+ if typestring == "type" and isinstance(object_, String):
return "foobar.MYVARCHAR"
else:
return False
- self.autogen_context.opts.update(
- render_item=repr_type
- )
+ self.autogen_context.opts.update(render_item=repr_type)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- ARRAY(String), self.autogen_context),
- "postgresql.ARRAY(foobar.MYVARCHAR)"
+ ARRAY(String), self.autogen_context
+ ),
+ "postgresql.ARRAY(foobar.MYVARCHAR)",
)
@config.requirements.fail_before_sqla_100
@@ -756,22 +759,18 @@ class PostgresqlAutogenRenderTest(TestBase):
autogen_context = self.autogen_context
m = MetaData()
- t = Table('t', m,
- Column('x', String),
- Column('y', String)
- )
-
- op_obj = ops.AddConstraintOp.from_constraint(ExcludeConstraint(
- (t.c.x, ">"),
- where=t.c.x != 2,
- using="gist",
- name="t_excl_x"
- ))
+ t = Table("t", m, Column("x", String), Column("y", String))
+
+ op_obj = ops.AddConstraintOp.from_constraint(
+ ExcludeConstraint(
+ (t.c.x, ">"), where=t.c.x != 2, using="gist", name="t_excl_x"
+ )
+ )
eq_ignore_whitespace(
autogenerate.render_op_text(autogen_context, op_obj),
"op.create_exclude_constraint('t_excl_x', 't', (sa.column('x'), '>'), "
- "where=sa.text(!U'x != 2'), using='gist')"
+ "where=sa.text(!U'x != 2'), using='gist')",
)
@config.requirements.fail_before_sqla_100
@@ -781,25 +780,25 @@ class PostgresqlAutogenRenderTest(TestBase):
autogen_context = self.autogen_context
m = MetaData()
- t = Table('TTAble', m,
- Column('XColumn', String),
- Column('YColumn', String)
- )
+ t = Table(
+ "TTAble", m, Column("XColumn", String), Column("YColumn", String)
+ )
- op_obj = ops.AddConstraintOp.from_constraint(ExcludeConstraint(
- (t.c.XColumn, ">"),
- where=t.c.XColumn != 2,
- using="gist",
- name="t_excl_x"
- ))
+ op_obj = ops.AddConstraintOp.from_constraint(
+ ExcludeConstraint(
+ (t.c.XColumn, ">"),
+ where=t.c.XColumn != 2,
+ using="gist",
+ name="t_excl_x",
+ )
+ )
eq_ignore_whitespace(
autogenerate.render_op_text(autogen_context, op_obj),
"op.create_exclude_constraint('t_excl_x', 'TTAble', (sa.column('XColumn'), '>'), "
- "where=sa.text(!U'\"XColumn\" != 2'), using='gist')"
+ "where=sa.text(!U'\"XColumn\" != 2'), using='gist')",
)
-
@config.requirements.fail_before_sqla_100
def test_inline_exclude_constraint(self):
from sqlalchemy.dialects.postgresql import ExcludeConstraint
@@ -808,15 +807,13 @@ class PostgresqlAutogenRenderTest(TestBase):
m = MetaData()
t = Table(
- 't', m,
- Column('x', String),
- Column('y', String),
+ "t",
+ m,
+ Column("x", String),
+ Column("y", String),
ExcludeConstraint(
- ('x', ">"),
- using="gist",
- where='x != 2',
- name="t_excl_x"
- )
+ ("x", ">"), using="gist", where="x != 2", name="t_excl_x"
+ ),
)
op_obj = ops.CreateTableOp.from_table(t)
@@ -827,7 +824,7 @@ class PostgresqlAutogenRenderTest(TestBase):
"sa.Column('y', sa.String(), nullable=True),"
"postgresql.ExcludeConstraint((!U'x', '>'), "
"where=sa.text(!U'x != 2'), using='gist', name='t_excl_x')"
- ")"
+ ")",
)
@config.requirements.fail_before_sqla_100
@@ -838,15 +835,13 @@ class PostgresqlAutogenRenderTest(TestBase):
m = MetaData()
t = Table(
- 'TTable', m,
- Column('XColumn', String),
- Column('YColumn', String),
+ "TTable", m, Column("XColumn", String), Column("YColumn", String)
)
ExcludeConstraint(
(t.c.XColumn, ">"),
using="gist",
where='"XColumn" != 2',
- name="TExclX"
+ name="TExclX",
)
op_obj = ops.CreateTableOp.from_table(t)
@@ -858,33 +853,29 @@ class PostgresqlAutogenRenderTest(TestBase):
"sa.Column('YColumn', sa.String(), nullable=True),"
"postgresql.ExcludeConstraint((sa.column('XColumn'), '>'), "
"where=sa.text(!U'\"XColumn\" != 2'), using='gist', "
- "name='TExclX'))"
+ "name='TExclX'))",
)
def test_json_type(self):
if config.requirements.sqlalchemy_110.enabled:
eq_ignore_whitespace(
- autogenerate.render._repr_type(
- JSON(), self.autogen_context),
- "postgresql.JSON(astext_type=sa.Text())"
+ autogenerate.render._repr_type(JSON(), self.autogen_context),
+ "postgresql.JSON(astext_type=sa.Text())",
)
else:
eq_ignore_whitespace(
- autogenerate.render._repr_type(
- JSON(), self.autogen_context),
- "postgresql.JSON()"
+ autogenerate.render._repr_type(JSON(), self.autogen_context),
+ "postgresql.JSON()",
)
def test_jsonb_type(self):
if config.requirements.sqlalchemy_110.enabled:
eq_ignore_whitespace(
- autogenerate.render._repr_type(
- JSONB(), self.autogen_context),
- "postgresql.JSONB(astext_type=sa.Text())"
+ autogenerate.render._repr_type(JSONB(), self.autogen_context),
+ "postgresql.JSONB(astext_type=sa.Text())",
)
else:
eq_ignore_whitespace(
- autogenerate.render._repr_type(
- JSONB(), self.autogen_context),
- "postgresql.JSONB()"
+ autogenerate.render._repr_type(JSONB(), self.autogen_context),
+ "postgresql.JSONB()",
)
diff --git a/tests/test_revision.py b/tests/test_revision.py
index 1f1c342..41a713e 100644
--- a/tests/test_revision.py
+++ b/tests/test_revision.py
@@ -1,7 +1,11 @@
from alembic.testing.fixtures import TestBase
from alembic.testing import eq_, assert_raises_message
-from alembic.script.revision import RevisionMap, Revision, MultipleHeads, \
- RevisionError
+from alembic.script.revision import (
+ RevisionMap,
+ Revision,
+ MultipleHeads,
+ RevisionError,
+)
from . import _large_map
@@ -9,136 +13,144 @@ class APITest(TestBase):
def test_add_revision_one_head(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", ("b",)),
]
)
- eq_(map_.heads, ('c', ))
+ eq_(map_.heads, ("c",))
- map_.add_revision(Revision('d', ('c', )))
- eq_(map_.heads, ('d', ))
+ map_.add_revision(Revision("d", ("c",)))
+ eq_(map_.heads, ("d",))
def test_add_revision_two_head(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c1', ('b',)),
- Revision('c2', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c1", ("b",)),
+ Revision("c2", ("b",)),
]
)
- eq_(map_.heads, ('c1', 'c2'))
+ eq_(map_.heads, ("c1", "c2"))
- map_.add_revision(Revision('d1', ('c1', )))
- eq_(map_.heads, ('c2', 'd1'))
+ map_.add_revision(Revision("d1", ("c1",)))
+ eq_(map_.heads, ("c2", "d1"))
def test_get_revision_head_single(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", ("b",)),
]
)
- eq_(map_.get_revision('head'), map_._revision_map['c'])
+ eq_(map_.get_revision("head"), map_._revision_map["c"])
def test_get_revision_base_single(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", ("b",)),
]
)
- eq_(map_.get_revision('base'), None)
+ eq_(map_.get_revision("base"), None)
def test_get_revision_head_multiple(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c1', ('b',)),
- Revision('c2', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c1", ("b",)),
+ Revision("c2", ("b",)),
]
)
assert_raises_message(
MultipleHeads,
"Multiple heads are present",
- map_.get_revision, 'head'
+ map_.get_revision,
+ "head",
)
def test_get_revision_heads_multiple(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c1', ('b',)),
- Revision('c2', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c1", ("b",)),
+ Revision("c2", ("b",)),
]
)
assert_raises_message(
MultipleHeads,
"Multiple heads are present",
- map_.get_revision, "heads"
+ map_.get_revision,
+ "heads",
)
def test_get_revision_base_multiple(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', ()),
- Revision('d', ('c',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", ()),
+ Revision("d", ("c",)),
]
)
- eq_(map_.get_revision('base'), None)
+ eq_(map_.get_revision("base"), None)
def test_iterate_tolerates_dupe_targets(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', ('b',)),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", ("b",)),
]
)
eq_(
- [
- r.revision for r in
- map_._iterate_revisions(('c', 'c'), 'a')
- ],
- ['c', 'b', 'a']
+ [r.revision for r in map_._iterate_revisions(("c", "c"), "a")],
+ ["c", "b", "a"],
)
def test_repr_revs(self):
map_ = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', (), dependencies=('a', 'b')),
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", (), dependencies=("a", "b")),
]
)
- c = map_._revision_map['c']
+ c = map_._revision_map["c"]
eq_(repr(c), "Revision('c', None, dependencies=('a', 'b'))")
class DownIterateTest(TestBase):
def _assert_iteration(
- self, upper, lower, assertion, inclusive=True, map_=None,
- implicit_base=False, select_for_downgrade=False):
+ self,
+ upper,
+ lower,
+ assertion,
+ inclusive=True,
+ map_=None,
+ implicit_base=False,
+ select_for_downgrade=False,
+ ):
if map_ is None:
map_ = self.map
eq_(
[
- rev.revision for rev in
- map_.iterate_revisions(
- upper, lower,
- inclusive=inclusive, implicit_base=implicit_base,
- select_for_downgrade=select_for_downgrade
+ rev.revision
+ for rev in map_.iterate_revisions(
+ upper,
+ lower,
+ inclusive=inclusive,
+ implicit_base=implicit_base,
+ select_for_downgrade=select_for_downgrade,
)
],
- assertion
+ assertion,
)
@@ -146,173 +158,141 @@ class DiamondTest(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b1', ('a',)),
- Revision('b2', ('a',)),
- Revision('c', ('b1', 'b2')),
- Revision('d', ('c',)),
+ Revision("a", ()),
+ Revision("b1", ("a",)),
+ Revision("b2", ("a",)),
+ Revision("c", ("b1", "b2")),
+ Revision("d", ("c",)),
]
)
def test_iterate_simple_diamond(self):
- self._assert_iteration(
- "d", "a",
- ["d", "c", "b1", "b2", "a"]
- )
+ self._assert_iteration("d", "a", ["d", "c", "b1", "b2", "a"])
class EmptyMapTest(DownIterateTest):
# see issue #258
def setUp(self):
- self.map = RevisionMap(
- lambda: []
- )
+ self.map = RevisionMap(lambda: [])
def test_iterate(self):
- self._assert_iteration(
- "head", "base",
- []
- )
+ self._assert_iteration("head", "base", [])
class LabeledBranchTest(DownIterateTest):
def test_dupe_branch_collection(self):
fn = lambda: [
- Revision('a', ()),
- Revision('b', ('a',)),
- Revision('c', ('b',), branch_labels=['xy1']),
- Revision('d', ()),
- Revision('e', ('d',), branch_labels=['xy1']),
- Revision('f', ('e',))
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c", ("b",), branch_labels=["xy1"]),
+ Revision("d", ()),
+ Revision("e", ("d",), branch_labels=["xy1"]),
+ Revision("f", ("e",)),
]
assert_raises_message(
RevisionError,
r"Branch name 'xy1' in revision (?:e|c) already "
"used by revision (?:e|c)",
- getattr, RevisionMap(fn), "_revision_map"
+ getattr,
+ RevisionMap(fn),
+ "_revision_map",
)
def test_filter_for_lineage_labeled_head_across_merge(self):
fn = lambda: [
- Revision('a', ()),
- Revision('b', ('a', )),
- Revision('c1', ('b', ), branch_labels='c1branch'),
- Revision('c2', ('b', )),
- Revision('d', ('c1', 'c2')),
-
+ Revision("a", ()),
+ Revision("b", ("a",)),
+ Revision("c1", ("b",), branch_labels="c1branch"),
+ Revision("c2", ("b",)),
+ Revision("d", ("c1", "c2")),
]
map_ = RevisionMap(fn)
- c1 = map_.get_revision('c1')
- c2 = map_.get_revision('c2')
- d = map_.get_revision('d')
- eq_(
- map_.filter_for_lineage([c1, c2, d], "c1branch@head"),
- [c1, c2, d]
- )
+ c1 = map_.get_revision("c1")
+ c2 = map_.get_revision("c2")
+ d = map_.get_revision("d")
+ eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), [c1, c2, d])
def test_filter_for_lineage_heads(self):
eq_(
- self.map.filter_for_lineage(
- [self.map.get_revision("f")],
- "heads"
- ),
- [self.map.get_revision("f")]
+ self.map.filter_for_lineage([self.map.get_revision("f")], "heads"),
+ [self.map.get_revision("f")],
)
def setUp(self):
- self.map = RevisionMap(lambda: [
- Revision('a', (), branch_labels='abranch'),
- Revision('b', ('a',)),
- Revision('somelongername', ('b',)),
- Revision('c', ('somelongername',)),
- Revision('d', ()),
- Revision('e', ('d',), branch_labels=['ebranch']),
- Revision('someothername', ('e',)),
- Revision('f', ('someothername',)),
- ])
+ self.map = RevisionMap(
+ lambda: [
+ Revision("a", (), branch_labels="abranch"),
+ Revision("b", ("a",)),
+ Revision("somelongername", ("b",)),
+ Revision("c", ("somelongername",)),
+ Revision("d", ()),
+ Revision("e", ("d",), branch_labels=["ebranch"]),
+ Revision("someothername", ("e",)),
+ Revision("f", ("someothername",)),
+ ]
+ )
def test_get_base_revisions_labeled(self):
- eq_(
- self.map._get_base_revisions("somelongername@base"),
- ['a']
- )
+ eq_(self.map._get_base_revisions("somelongername@base"), ["a"])
def test_get_current_named_rev(self):
- eq_(
- self.map.get_revision("ebranch@head"),
- self.map.get_revision("f")
- )
+ eq_(self.map.get_revision("ebranch@head"), self.map.get_revision("f"))
def test_get_base_revisions(self):
- eq_(
- self.map._get_base_revisions("base"),
- ['a', 'd']
- )
+ eq_(self.map._get_base_revisions("base"), ["a", "d"])
def test_iterate_head_to_named_base(self):
self._assert_iteration(
- "heads", "ebranch@base",
- ['f', 'someothername', 'e', 'd']
+ "heads", "ebranch@base", ["f", "someothername", "e", "d"]
)
self._assert_iteration(
- "heads", "abranch@base",
- ['c', 'somelongername', 'b', 'a']
+ "heads", "abranch@base", ["c", "somelongername", "b", "a"]
)
def test_iterate_named_head_to_base(self):
self._assert_iteration(
- "ebranch@head", "base",
- ['f', 'someothername', 'e', 'd']
+ "ebranch@head", "base", ["f", "someothername", "e", "d"]
)
self._assert_iteration(
- "abranch@head", "base",
- ['c', 'somelongername', 'b', 'a']
+ "abranch@head", "base", ["c", "somelongername", "b", "a"]
)
def test_iterate_named_head_to_heads(self):
- self._assert_iteration(
- "heads", "ebranch@head",
- ['f'],
- inclusive=True
- )
+ self._assert_iteration("heads", "ebranch@head", ["f"], inclusive=True)
def test_iterate_named_rev_to_heads(self):
self._assert_iteration(
- "heads", "ebranch@d",
- ['f', 'someothername', 'e', 'd'],
- inclusive=True
+ "heads",
+ "ebranch@d",
+ ["f", "someothername", "e", "d"],
+ inclusive=True,
)
def test_iterate_head_to_version_specific_base(self):
self._assert_iteration(
- "heads", "e@base",
- ['f', 'someothername', 'e', 'd']
+ "heads", "e@base", ["f", "someothername", "e", "d"]
)
self._assert_iteration(
- "heads", "c@base",
- ['c', 'somelongername', 'b', 'a']
+ "heads", "c@base", ["c", "somelongername", "b", "a"]
)
def test_iterate_to_branch_at_rev(self):
self._assert_iteration(
- "heads", "ebranch@d",
- ['f', 'someothername', 'e', 'd']
+ "heads", "ebranch@d", ["f", "someothername", "e", "d"]
)
def test_branch_w_down_relative(self):
self._assert_iteration(
- "heads", "ebranch@-2",
- ['f', 'someothername', 'e']
+ "heads", "ebranch@-2", ["f", "someothername", "e"]
)
def test_branch_w_up_relative(self):
self._assert_iteration(
- "ebranch@+2", "base",
- ['someothername', 'e', 'd']
+ "ebranch@+2", "base", ["someothername", "e", "d"]
)
def test_partial_id_resolve(self):
@@ -320,43 +300,43 @@ class LabeledBranchTest(DownIterateTest):
eq_(self.map.get_revision("abranch@some").revision, "somelongername")
def test_branch_at_heads(self):
- eq_(
- self.map.get_revision("abranch@heads").revision,
- "c"
- )
+ eq_(self.map.get_revision("abranch@heads").revision, "c")
def test_branch_at_syntax(self):
- eq_(self.map.get_revision("abranch@head").revision, 'c')
+ eq_(self.map.get_revision("abranch@head").revision, "c")
eq_(self.map.get_revision("abranch@base"), None)
- eq_(self.map.get_revision("ebranch@head").revision, 'f')
+ eq_(self.map.get_revision("ebranch@head").revision, "f")
eq_(self.map.get_revision("abranch@base"), None)
- eq_(self.map.get_revision("ebranch@d").revision, 'd')
+ eq_(self.map.get_revision("ebranch@d").revision, "d")
def test_branch_at_self(self):
- eq_(self.map.get_revision("ebranch@ebranch").revision, 'e')
+ eq_(self.map.get_revision("ebranch@ebranch").revision, "e")
def test_retrieve_branch_revision(self):
- eq_(self.map.get_revision("abranch").revision, 'a')
- eq_(self.map.get_revision("ebranch").revision, 'e')
+ eq_(self.map.get_revision("abranch").revision, "a")
+ eq_(self.map.get_revision("ebranch").revision, "e")
def test_rev_not_in_branch(self):
assert_raises_message(
RevisionError,
"Revision b is not a member of branch 'ebranch'",
- self.map.get_revision, "ebranch@b"
+ self.map.get_revision,
+ "ebranch@b",
)
assert_raises_message(
RevisionError,
"Revision d is not a member of branch 'abranch'",
- self.map.get_revision, "abranch@d"
+ self.map.get_revision,
+ "abranch@d",
)
def test_no_revision_exists(self):
assert_raises_message(
RevisionError,
"No such revision or branch 'q'",
- self.map.get_revision, "abranch@q"
+ self.map.get_revision,
+ "abranch@q",
)
def test_not_actually_a_branch(self):
@@ -367,9 +347,7 @@ class LabeledBranchTest(DownIterateTest):
def test_no_such_branch(self):
assert_raises_message(
- RevisionError,
- "No such branch: 'x'",
- self.map.get_revision, "x@d"
+ RevisionError, "No such branch: 'x'", self.map.get_revision, "x@d"
)
@@ -377,19 +355,18 @@ class LongShortBranchTest(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b1', ('a',)),
- Revision('b2', ('a',)),
- Revision('c1', ('b1',)),
- Revision('d11', ('c1',)),
- Revision('d12', ('c1',)),
+ Revision("a", ()),
+ Revision("b1", ("a",)),
+ Revision("b2", ("a",)),
+ Revision("c1", ("b1",)),
+ Revision("d11", ("c1",)),
+ Revision("d12", ("c1",)),
]
)
def test_iterate_full(self):
self._assert_iteration(
- "heads", "base",
- ['b2', 'd11', 'd12', 'c1', 'b1', 'a']
+ "heads", "base", ["b2", "d11", "d12", "c1", "b1", "a"]
)
@@ -397,56 +374,46 @@ class MultipleBranchTest(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('a', ()),
- Revision('b1', ('a',)),
- Revision('b2', ('a',)),
- Revision('cb1', ('b1',)),
- Revision('cb2', ('b2',)),
- Revision('d1cb1', ('cb1',)), # head
- Revision('d2cb1', ('cb1',)), # head
- Revision('d1cb2', ('cb2',)),
- Revision('d2cb2', ('cb2',)),
- Revision('d3cb2', ('cb2',)), # head
- Revision('d1d2cb2', ('d1cb2', 'd2cb2')) # head + merge point
+ Revision("a", ()),
+ Revision("b1", ("a",)),
+ Revision("b2", ("a",)),
+ Revision("cb1", ("b1",)),
+ Revision("cb2", ("b2",)),
+ Revision("d1cb1", ("cb1",)), # head
+ Revision("d2cb1", ("cb1",)), # head
+ Revision("d1cb2", ("cb2",)),
+ Revision("d2cb2", ("cb2",)),
+ Revision("d3cb2", ("cb2",)), # head
+ Revision("d1d2cb2", ("d1cb2", "d2cb2")), # head + merge point
]
)
def test_iterate_from_merge_point(self):
self._assert_iteration(
- "d1d2cb2", "a",
- ['d1d2cb2', 'd1cb2', 'd2cb2', 'cb2', 'b2', 'a']
+ "d1d2cb2", "a", ["d1d2cb2", "d1cb2", "d2cb2", "cb2", "b2", "a"]
)
def test_iterate_multiple_heads(self):
self._assert_iteration(
- ["d2cb2", "d3cb2"], "a",
- ['d2cb2', 'd3cb2', 'cb2', 'b2', 'a']
+ ["d2cb2", "d3cb2"], "a", ["d2cb2", "d3cb2", "cb2", "b2", "a"]
)
def test_iterate_single_branch(self):
- self._assert_iteration(
- "d3cb2", "a",
- ['d3cb2', 'cb2', 'b2', 'a']
- )
+ self._assert_iteration("d3cb2", "a", ["d3cb2", "cb2", "b2", "a"])
def test_iterate_single_branch_to_base(self):
- self._assert_iteration(
- "d3cb2", "base",
- ['d3cb2', 'cb2', 'b2', 'a']
- )
+ self._assert_iteration("d3cb2", "base", ["d3cb2", "cb2", "b2", "a"])
def test_iterate_multiple_branch_to_base(self):
self._assert_iteration(
- ["d3cb2", "cb1"], "base",
- ['d3cb2', 'cb2', 'b2', 'cb1', 'b1', 'a']
+ ["d3cb2", "cb1"], "base", ["d3cb2", "cb2", "b2", "cb1", "b1", "a"]
)
def test_iterate_multiple_heads_single_base(self):
# head d1cb1 is omitted as it is not
# a descendant of b2
self._assert_iteration(
- ["d1cb1", "d2cb2", "d3cb2"], "b2",
- ["d2cb2", 'd3cb2', 'cb2', 'b2']
+ ["d1cb1", "d2cb2", "d3cb2"], "b2", ["d2cb2", "d3cb2", "cb2", "b2"]
)
def test_same_branch_wrong_direction(self):
@@ -456,7 +423,7 @@ class MultipleBranchTest(DownIterateTest):
RevisionError,
r"Revision d1cb1 is not an ancestor of revision b1",
list,
- self.map._iterate_revisions('b1', 'd1cb1')
+ self.map._iterate_revisions("b1", "d1cb1"),
)
def test_distinct_branches(self):
@@ -465,7 +432,7 @@ class MultipleBranchTest(DownIterateTest):
RevisionError,
r"Revision b1 is not an ancestor of revision d2cb2",
list,
- self.map._iterate_revisions('d2cb2', 'b1')
+ self.map._iterate_revisions("d2cb2", "b1"),
)
def test_wrong_direction_to_base_as_none(self):
@@ -475,7 +442,7 @@ class MultipleBranchTest(DownIterateTest):
RevisionError,
r"Revision d1cb1 is not an ancestor of revision base",
list,
- self.map._iterate_revisions(None, 'd1cb1')
+ self.map._iterate_revisions(None, "d1cb1"),
)
def test_wrong_direction_to_base_as_empty(self):
@@ -485,7 +452,7 @@ class MultipleBranchTest(DownIterateTest):
RevisionError,
r"Revision d1cb1 is not an ancestor of revision base",
list,
- self.map._iterate_revisions((), 'd1cb1')
+ self.map._iterate_revisions((), "d1cb1"),
)
@@ -501,22 +468,20 @@ class BranchTravellingTest(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('a1', ()),
- Revision('a2', ('a1',)),
- Revision('a3', ('a2',)),
- Revision('b1', ('a3',)),
- Revision('b2', ('a3',)),
- Revision('cb1', ('b1',)),
- Revision('cb2', ('b2',)),
- Revision('db1', ('cb1',)),
- Revision('db2', ('cb2',)),
-
- Revision('e1b1', ('db1',)),
- Revision('fe1b1', ('e1b1',)),
-
- Revision('e2b1', ('db1',)),
- Revision('e2b2', ('db2',)),
- Revision("merge", ('e2b1', 'e2b2'))
+ Revision("a1", ()),
+ Revision("a2", ("a1",)),
+ Revision("a3", ("a2",)),
+ Revision("b1", ("a3",)),
+ Revision("b2", ("a3",)),
+ Revision("cb1", ("b1",)),
+ Revision("cb2", ("b2",)),
+ Revision("db1", ("cb1",)),
+ Revision("db2", ("cb2",)),
+ Revision("e1b1", ("db1",)),
+ Revision("fe1b1", ("e1b1",)),
+ Revision("e2b1", ("db1",)),
+ Revision("e2b2", ("db2",)),
+ Revision("merge", ("e2b1", "e2b2")),
]
)
@@ -524,19 +489,31 @@ class BranchTravellingTest(DownIterateTest):
# test that when we hit a merge point, implicit base will
# ensure all branches that supply the merge point are filled in
self._assert_iteration(
- "merge", "db1",
- ['merge',
- 'e2b1', 'db1',
- 'e2b2', 'db2', 'cb2', 'b2'],
- implicit_base=True
+ "merge",
+ "db1",
+ ["merge", "e2b1", "db1", "e2b2", "db2", "cb2", "b2"],
+ implicit_base=True,
)
def test_three_branches_end_in_single_branch(self):
self._assert_iteration(
- ["merge", "fe1b1"], "a3",
- ['merge', 'e2b1', 'e2b2', 'db2', 'cb2', 'b2',
- 'fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3']
+ ["merge", "fe1b1"],
+ "a3",
+ [
+ "merge",
+ "e2b1",
+ "e2b2",
+ "db2",
+ "cb2",
+ "b2",
+ "fe1b1",
+ "e1b1",
+ "db1",
+ "cb1",
+ "b1",
+ "a3",
+ ],
)
def test_two_branches_to_root(self):
@@ -544,80 +521,103 @@ class BranchTravellingTest(DownIterateTest):
# here we want 'a3' as a "stop" branch point, but *not*
# 'db1', as we don't have multiple traversals on db1
self._assert_iteration(
- "merge", "a1",
- ['merge',
- 'e2b1', 'db1', 'cb1', 'b1', # e2b1 branch
- 'e2b2', 'db2', 'cb2', 'b2', # e2b2 branch
- 'a3', # both terminate at a3
- 'a2', 'a1' # finish out
- ] # noqa
+ "merge",
+ "a1",
+ [
+ "merge",
+ "e2b1",
+ "db1",
+ "cb1",
+ "b1", # e2b1 branch
+ "e2b2",
+ "db2",
+ "cb2",
+ "b2", # e2b2 branch
+ "a3", # both terminate at a3
+ "a2",
+ "a1", # finish out
+ ], # noqa
)
def test_two_branches_end_in_branch(self):
self._assert_iteration(
- "merge", "b1",
+ "merge",
+ "b1",
# 'b1' is local to 'e2b1'
# branch so that is all we get
- ['merge', 'e2b1', 'db1', 'cb1', 'b1',
-
- ] # noqa
+ ["merge", "e2b1", "db1", "cb1", "b1"], # noqa
)
def test_two_branches_end_behind_branch(self):
self._assert_iteration(
- "merge", "a2",
- ['merge',
- 'e2b1', 'db1', 'cb1', 'b1', # e2b1 branch
- 'e2b2', 'db2', 'cb2', 'b2', # e2b2 branch
- 'a3', # both terminate at a3
- 'a2'
- ] # noqa
+ "merge",
+ "a2",
+ [
+ "merge",
+ "e2b1",
+ "db1",
+ "cb1",
+ "b1", # e2b1 branch
+ "e2b2",
+ "db2",
+ "cb2",
+ "b2", # e2b2 branch
+ "a3", # both terminate at a3
+ "a2",
+ ], # noqa
)
def test_three_branches_to_root(self):
# in this case, both "a3" and "db1" are stop points
self._assert_iteration(
- ["merge", "fe1b1"], "a1",
- ['merge',
- 'e2b1', # e2b1 branch
- 'e2b2', 'db2', 'cb2', 'b2', # e2b2 branch
- 'fe1b1', 'e1b1', # fe1b1 branch
- 'db1', # fe1b1 and e2b1 branches terminate at db1
- 'cb1', 'b1', # e2b1 branch continued....might be nicer
- # if this was before the e2b2 branch...
- 'a3', # e2b1 and e2b2 branches terminate at a3
- 'a2', 'a1' # finish out
- ] # noqa
+ ["merge", "fe1b1"],
+ "a1",
+ [
+ "merge",
+ "e2b1", # e2b1 branch
+ "e2b2",
+ "db2",
+ "cb2",
+ "b2", # e2b2 branch
+ "fe1b1",
+ "e1b1", # fe1b1 branch
+ "db1", # fe1b1 and e2b1 branches terminate at db1
+ "cb1",
+ "b1", # e2b1 branch continued....might be nicer
+ # if this was before the e2b2 branch...
+ "a3", # e2b1 and e2b2 branches terminate at a3
+ "a2",
+ "a1", # finish out
+ ], # noqa
)
def test_three_branches_end_multiple_bases(self):
# in this case, both "a3" and "db1" are stop points
self._assert_iteration(
- ["merge", "fe1b1"], ["cb1", "cb2"],
+ ["merge", "fe1b1"],
+ ["cb1", "cb2"],
[
- 'merge',
- 'e2b1',
- 'e2b2', 'db2', 'cb2',
- 'fe1b1', 'e1b1',
- 'db1',
- 'cb1'
- ]
+ "merge",
+ "e2b1",
+ "e2b2",
+ "db2",
+ "cb2",
+ "fe1b1",
+ "e1b1",
+ "db1",
+ "cb1",
+ ],
)
def test_three_branches_end_multiple_bases_exclusive(self):
self._assert_iteration(
- ["merge", "fe1b1"], ["cb1", "cb2"],
- [
- 'merge',
- 'e2b1',
- 'e2b2', 'db2',
- 'fe1b1', 'e1b1',
- 'db1',
- ],
- inclusive=False
+ ["merge", "fe1b1"],
+ ["cb1", "cb2"],
+ ["merge", "e2b1", "e2b2", "db2", "fe1b1", "e1b1", "db1"],
+ inclusive=False,
)
def test_detect_invalid_head_selection(self):
@@ -627,26 +627,34 @@ class BranchTravellingTest(DownIterateTest):
"Requested revision fe1b1 overlaps "
"with other requested revisions",
list,
- self.map._iterate_revisions(["db1", "b2", "fe1b1"], ())
+ self.map._iterate_revisions(["db1", "b2", "fe1b1"], ()),
)
def test_three_branches_end_multiple_bases_exclusive_blank(self):
self._assert_iteration(
- ["e2b1", "b2", "fe1b1"], (),
+ ["e2b1", "b2", "fe1b1"],
+ (),
[
- 'e2b1',
- 'b2',
- 'fe1b1', 'e1b1',
- 'db1', 'cb1', 'b1', 'a3', 'a2', 'a1'
+ "e2b1",
+ "b2",
+ "fe1b1",
+ "e1b1",
+ "db1",
+ "cb1",
+ "b1",
+ "a3",
+ "a2",
+ "a1",
],
- inclusive=False
+ inclusive=False,
)
def test_iterate_to_symbolic_base(self):
self._assert_iteration(
- ["fe1b1"], "base",
- ['fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3', 'a2', 'a1'],
- inclusive=False
+ ["fe1b1"],
+ "base",
+ ["fe1b1", "e1b1", "db1", "cb1", "b1", "a3", "a2", "a1"],
+ inclusive=False,
)
def test_ancestor_nodes(self):
@@ -656,8 +664,22 @@ class BranchTravellingTest(DownIterateTest):
rev.revision
for rev in self.map._get_ancestor_nodes([merge], check=True)
),
- set(['a1', 'e2b2', 'e2b1', 'cb2', 'merge',
- 'a3', 'a2', 'b1', 'b2', 'db1', 'db2', 'cb1'])
+ set(
+ [
+ "a1",
+ "e2b2",
+ "e2b1",
+ "cb2",
+ "merge",
+ "a3",
+ "a2",
+ "b1",
+ "b2",
+ "db1",
+ "db2",
+ "cb1",
+ ]
+ ),
)
@@ -665,125 +687,153 @@ class MultipleBaseTest(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('base1', ()),
- Revision('base2', ()),
- Revision('base3', ()),
-
- Revision('a1a', ('base1',)),
- Revision('a1b', ('base1',)),
- Revision('a2', ('base2',)),
- Revision('a3', ('base3',)),
-
- Revision('b1a', ('a1a',)),
- Revision('b1b', ('a1b',)),
- Revision('b2', ('a2',)),
- Revision('b3', ('a3',)),
-
- Revision('c2', ('b2',)),
- Revision('d2', ('c2',)),
-
- Revision('mergeb3d2', ('b3', 'd2'))
+ Revision("base1", ()),
+ Revision("base2", ()),
+ Revision("base3", ()),
+ Revision("a1a", ("base1",)),
+ Revision("a1b", ("base1",)),
+ Revision("a2", ("base2",)),
+ Revision("a3", ("base3",)),
+ Revision("b1a", ("a1a",)),
+ Revision("b1b", ("a1b",)),
+ Revision("b2", ("a2",)),
+ Revision("b3", ("a3",)),
+ Revision("c2", ("b2",)),
+ Revision("d2", ("c2",)),
+ Revision("mergeb3d2", ("b3", "d2")),
]
)
def test_heads_to_base(self):
self._assert_iteration(
- "heads", "base",
+ "heads",
+ "base",
[
- 'b1a', 'a1a',
- 'b1b', 'a1b',
- 'mergeb3d2',
- 'b3', 'a3', 'base3',
- 'd2', 'c2', 'b2', 'a2', 'base2',
- 'base1'
- ]
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "mergeb3d2",
+ "b3",
+ "a3",
+ "base3",
+ "d2",
+ "c2",
+ "b2",
+ "a2",
+ "base2",
+ "base1",
+ ],
)
def test_heads_to_base_exclusive(self):
self._assert_iteration(
- "heads", "base",
+ "heads",
+ "base",
[
- 'b1a', 'a1a',
- 'b1b', 'a1b',
- 'mergeb3d2',
- 'b3', 'a3', 'base3',
- 'd2', 'c2', 'b2', 'a2', 'base2',
- 'base1',
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "mergeb3d2",
+ "b3",
+ "a3",
+ "base3",
+ "d2",
+ "c2",
+ "b2",
+ "a2",
+ "base2",
+ "base1",
],
- inclusive=False
+ inclusive=False,
)
def test_heads_to_blank(self):
self._assert_iteration(
- "heads", None,
+ "heads",
+ None,
[
- 'b1a', 'a1a',
- 'b1b', 'a1b',
- 'mergeb3d2',
- 'b3', 'a3', 'base3',
- 'd2', 'c2', 'b2', 'a2', 'base2',
- 'base1'
- ]
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "mergeb3d2",
+ "b3",
+ "a3",
+ "base3",
+ "d2",
+ "c2",
+ "b2",
+ "a2",
+ "base2",
+ "base1",
+ ],
)
def test_detect_invalid_base_selection(self):
assert_raises_message(
RevisionError,
- "Requested revision a2 overlaps with "
- "other requested revisions",
+ "Requested revision a2 overlaps with " "other requested revisions",
list,
- self.map._iterate_revisions(["c2"], ["a2", "b2"])
+ self.map._iterate_revisions(["c2"], ["a2", "b2"]),
)
def test_heads_to_revs_plus_implicit_base_exclusive(self):
self._assert_iteration(
- "heads", ["c2"],
+ "heads",
+ ["c2"],
[
- 'b1a', 'a1a',
- 'b1b', 'a1b',
- 'mergeb3d2',
- 'b3', 'a3', 'base3',
- 'd2',
- 'base1'
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "mergeb3d2",
+ "b3",
+ "a3",
+ "base3",
+ "d2",
+ "base1",
],
inclusive=False,
- implicit_base=True
+ implicit_base=True,
)
def test_heads_to_revs_base_exclusive(self):
self._assert_iteration(
- "heads", ["c2"],
- [
- 'mergeb3d2', 'd2'
- ],
- inclusive=False
+ "heads", ["c2"], ["mergeb3d2", "d2"], inclusive=False
)
def test_heads_to_revs_plus_implicit_base_inclusive(self):
self._assert_iteration(
- "heads", ["c2"],
+ "heads",
+ ["c2"],
[
- 'b1a', 'a1a',
- 'b1b', 'a1b',
- 'mergeb3d2',
- 'b3', 'a3', 'base3',
- 'd2', 'c2',
- 'base1'
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "mergeb3d2",
+ "b3",
+ "a3",
+ "base3",
+ "d2",
+ "c2",
+ "base1",
],
- implicit_base=True
+ implicit_base=True,
)
def test_specific_path_one(self):
- self._assert_iteration(
- "b3", "base3",
- ['b3', 'a3', 'base3']
- )
+ self._assert_iteration("b3", "base3", ["b3", "a3", "base3"])
def test_specific_path_two_implicit_base(self):
self._assert_iteration(
- ["b3", "b2"], "base3",
- ['b3', 'a3', 'b2', 'a2', 'base2'],
- inclusive=False, implicit_base=True
+ ["b3", "b2"],
+ "base3",
+ ["b3", "a3", "b2", "a2", "base2"],
+ inclusive=False,
+ implicit_base=True,
)
@@ -808,21 +858,19 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
"""
self.map = RevisionMap(
lambda: [
- Revision('base1', (), branch_labels='b_1'),
- Revision('a1a', ('base1',)),
- Revision('a1b', ('base1',)),
- Revision('b1a', ('a1a',)),
- Revision('b1b', ('a1b', ), dependencies='a3'),
-
- Revision('base2', (), branch_labels='b_2'),
- Revision('a2', ('base2',)),
- Revision('b2', ('a2',)),
- Revision('c2', ('b2', ), dependencies='a3'),
- Revision('d2', ('c2',)),
-
- Revision('base3', (), branch_labels='b_3'),
- Revision('a3', ('base3',)),
- Revision('b3', ('a3',)),
+ Revision("base1", (), branch_labels="b_1"),
+ Revision("a1a", ("base1",)),
+ Revision("a1b", ("base1",)),
+ Revision("b1a", ("a1a",)),
+ Revision("b1b", ("a1b",), dependencies="a3"),
+ Revision("base2", (), branch_labels="b_2"),
+ Revision("a2", ("base2",)),
+ Revision("b2", ("a2",)),
+ Revision("c2", ("b2",), dependencies="a3"),
+ Revision("d2", ("c2",)),
+ Revision("base3", (), branch_labels="b_3"),
+ Revision("a3", ("base3",)),
+ Revision("b3", ("a3",)),
]
)
@@ -831,25 +879,45 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
def test_heads_to_base(self):
self._assert_iteration(
- "heads", "base",
+ "heads",
+ "base",
[
-
- 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
- 'b3', 'a3', 'base3',
- 'base1'
- ]
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "d2",
+ "c2",
+ "b2",
+ "a2",
+ "base2",
+ "b3",
+ "a3",
+ "base3",
+ "base1",
+ ],
)
def test_heads_to_base_downgrade(self):
self._assert_iteration(
- "heads", "base",
+ "heads",
+ "base",
[
-
- 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
- 'b3', 'a3', 'base3',
- 'base1'
+ "b1a",
+ "a1a",
+ "b1b",
+ "a1b",
+ "d2",
+ "c2",
+ "b2",
+ "a2",
+ "base2",
+ "b3",
+ "a3",
+ "base3",
+ "base1",
],
- select_for_downgrade=True
+ select_for_downgrade=True,
)
def test_same_branch_wrong_direction(self):
@@ -857,83 +925,79 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
RevisionError,
r"Revision d2 is not an ancestor of revision b2",
list,
- self.map._iterate_revisions('b2', 'd2')
+ self.map._iterate_revisions("b2", "d2"),
)
def test_different_branch_not_wrong_direction(self):
- self._assert_iteration(
- "b3", "d2",
- []
- )
+ self._assert_iteration("b3", "d2", [])
def test_we_need_head2_upgrade(self):
# the 2 branch relies on the 3 branch
self._assert_iteration(
- "b_2@head", "base",
- ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3']
+ "b_2@head",
+ "base",
+ ["d2", "c2", "b2", "a2", "base2", "a3", "base3"],
)
def test_we_need_head2_downgrade(self):
# the 2 branch relies on the 3 branch, but
# on the downgrade side, don't need to touch the 3 branch
self._assert_iteration(
- "b_2@head", "b_2@base",
- ['d2', 'c2', 'b2', 'a2', 'base2'],
- select_for_downgrade=True
+ "b_2@head",
+ "b_2@base",
+ ["d2", "c2", "b2", "a2", "base2"],
+ select_for_downgrade=True,
)
def test_we_need_head3_upgrade(self):
# the 3 branch can be upgraded alone.
- self._assert_iteration(
- "b_3@head", "base",
- ['b3', 'a3', 'base3']
- )
+ self._assert_iteration("b_3@head", "base", ["b3", "a3", "base3"])
def test_we_need_head3_downgrade(self):
# the 3 branch can be upgraded alone.
self._assert_iteration(
- "b_3@head", "base",
- ['b3', 'a3', 'base3'],
- select_for_downgrade=True
+ "b_3@head",
+ "base",
+ ["b3", "a3", "base3"],
+ select_for_downgrade=True,
)
def test_we_need_head1_upgrade(self):
# the 1 branch relies on the 3 branch
self._assert_iteration(
- "b1b@head", "base",
- ['b1b', 'a1b', 'base1', 'a3', 'base3']
+ "b1b@head", "base", ["b1b", "a1b", "base1", "a3", "base3"]
)
def test_we_need_head1_downgrade(self):
# going down we don't need a3-> base3, as long
# as we are limiting the base target
self._assert_iteration(
- "b1b@head", "b1b@base",
- ['b1b', 'a1b', 'base1'],
- select_for_downgrade=True
+ "b1b@head",
+ "b1b@base",
+ ["b1b", "a1b", "base1"],
+ select_for_downgrade=True,
)
def test_we_need_base2_upgrade(self):
# consider a downgrade to b_2@base - we
# want to run through all the "2"s alone, and we're done.
self._assert_iteration(
- "heads", "b_2@base",
- ['d2', 'c2', 'b2', 'a2', 'base2']
+ "heads", "b_2@base", ["d2", "c2", "b2", "a2", "base2"]
)
def test_we_need_base2_downgrade(self):
# consider a downgrade to b_2@base - we
# want to run through all the "2"s alone, and we're done.
self._assert_iteration(
- "heads", "b_2@base",
- ['d2', 'c2', 'b2', 'a2', 'base2'],
- select_for_downgrade=True
+ "heads",
+ "b_2@base",
+ ["d2", "c2", "b2", "a2", "base2"],
+ select_for_downgrade=True,
)
def test_we_need_base3_upgrade(self):
self._assert_iteration(
- "heads", "b_3@base",
- ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3']
+ "heads", "b_3@base", ["b1b", "d2", "c2", "b3", "a3", "base3"]
)
def test_we_need_base3_downgrade(self):
@@ -942,9 +1006,10 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
# as well, which means b1b and c2. Then we can downgrade
# the 3s.
self._assert_iteration(
- "heads", "b_3@base",
- ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3'],
- select_for_downgrade=True
+ "heads",
+ "b_3@base",
+ ["b1b", "d2", "c2", "b3", "a3", "base3"],
+ select_for_downgrade=True,
)
@@ -952,22 +1017,20 @@ class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('base1', (), branch_labels='b_1'),
- Revision('a1', 'base1'),
- Revision('b1', 'a1'),
- Revision('c1', 'b1'),
-
- Revision('base2', (), dependencies='b_1', branch_labels='b_2'),
- Revision('a2', 'base2'),
- Revision('b2', 'a2'),
- Revision('c2', 'b2'),
- Revision('d2', 'c2'),
-
- Revision('base3', (), branch_labels='b_3'),
- Revision('a3', 'base3'),
- Revision('b3', 'a3'),
- Revision('c3', 'b3', dependencies='b2'),
- Revision('d3', 'c3'),
+ Revision("base1", (), branch_labels="b_1"),
+ Revision("a1", "base1"),
+ Revision("b1", "a1"),
+ Revision("c1", "b1"),
+ Revision("base2", (), dependencies="b_1", branch_labels="b_2"),
+ Revision("a2", "base2"),
+ Revision("b2", "a2"),
+ Revision("c2", "b2"),
+ Revision("d2", "c2"),
+ Revision("base3", (), branch_labels="b_3"),
+ Revision("a3", "base3"),
+ Revision("b3", "a3"),
+ Revision("c3", "b3", dependencies="b2"),
+ Revision("d3", "c3"),
]
)
@@ -976,55 +1039,68 @@ class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
def test_heads_to_base(self):
self._assert_iteration(
- "heads", "base",
+ "heads",
+ "base",
[
- 'c1', 'b1', 'a1',
- 'd2', 'c2',
- 'd3', 'c3', 'b3', 'a3', 'base3',
- 'b2', 'a2', 'base2',
- 'base1'
- ]
+ "c1",
+ "b1",
+ "a1",
+ "d2",
+ "c2",
+ "d3",
+ "c3",
+ "b3",
+ "a3",
+ "base3",
+ "b2",
+ "a2",
+ "base2",
+ "base1",
+ ],
)
def test_we_need_head2(self):
self._assert_iteration(
- "b_2@head", "base",
- ['d2', 'c2', 'b2', 'a2', 'base2', 'base1']
+ "b_2@head", "base", ["d2", "c2", "b2", "a2", "base2", "base1"]
)
def test_we_need_head3(self):
self._assert_iteration(
- "b_3@head", "base",
- ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1']
+ "b_3@head",
+ "base",
+ ["d3", "c3", "b3", "a3", "base3", "b2", "a2", "base2", "base1"],
)
def test_we_need_head1(self):
- self._assert_iteration(
- "b_1@head", "base",
- ['c1', 'b1', 'a1', 'base1']
- )
+ self._assert_iteration("b_1@head", "base", ["c1", "b1", "a1", "base1"])
def test_we_need_base1(self):
self._assert_iteration(
- "heads", "b_1@base",
+ "heads",
+ "b_1@base",
[
- 'c1', 'b1', 'a1',
- 'd2', 'c2',
- 'd3', 'c3', 'b2', 'a2', 'base2',
- 'base1'
- ]
+ "c1",
+ "b1",
+ "a1",
+ "d2",
+ "c2",
+ "d3",
+ "c3",
+ "b2",
+ "a2",
+ "base2",
+ "base1",
+ ],
)
def test_we_need_base2(self):
self._assert_iteration(
- "heads", "b_2@base",
- ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2']
+ "heads", "b_2@base", ["d2", "c2", "d3", "c3", "b2", "a2", "base2"]
)
def test_we_need_base3(self):
self._assert_iteration(
- "heads", "b_3@base",
- ['d3', 'c3', 'b3', 'a3', 'base3']
+ "heads", "b_3@base", ["d3", "c3", "b3", "a3", "base3"]
)
@@ -1035,24 +1111,21 @@ class LargeMapTest(DownIterateTest):
def test_all(self):
raw = [r for r in self.map._revision_map.values() if r is not None]
- revs = [
- rev for rev in
- self.map.iterate_revisions(
- "heads", "base"
- )
- ]
+ revs = [rev for rev in self.map.iterate_revisions("heads", "base")]
eq_(set(raw), set(revs))
for idx, rev in enumerate(revs):
- ancestors = set(
- self.map._get_ancestor_nodes([rev])).difference([rev])
+ ancestors = set(self.map._get_ancestor_nodes([rev])).difference(
+ [rev]
+ )
descendants = set(
- self.map._get_descendant_nodes([rev])).difference([rev])
+ self.map._get_descendant_nodes([rev])
+ ).difference([rev])
assert not ancestors.intersection(descendants)
- remaining = set(revs[idx + 1:])
+ remaining = set(revs[idx + 1 :])
if remaining:
assert remaining.intersection(ancestors)
@@ -1061,22 +1134,20 @@ class DepResolutionFailedTest(DownIterateTest):
def setUp(self):
self.map = RevisionMap(
lambda: [
- Revision('base1', ()),
- Revision('a1', 'base1'),
- Revision('a2', 'base1'),
- Revision('b1', 'a1'),
- Revision('c1', 'b1'),
+ Revision("base1", ()),
+ Revision("a1", "base1"),
+ Revision("a2", "base1"),
+ Revision("b1", "a1"),
+ Revision("c1", "b1"),
]
)
# intentionally make a broken map
- self.map._revision_map['fake'] = self.map._revision_map['a2']
- self.map._revision_map['b1'].dependencies = 'fake'
- self.map._revision_map['b1']._resolved_dependencies = ('fake', )
+ self.map._revision_map["fake"] = self.map._revision_map["a2"]
+ self.map._revision_map["b1"].dependencies = "fake"
+ self.map._revision_map["b1"]._resolved_dependencies = ("fake",)
def test_failure_message(self):
iter_ = self.map.iterate_revisions("c1", "base1")
assert_raises_message(
- RevisionError,
- "Dependency resolution failed;",
- list, iter_
+ RevisionError, "Dependency resolution failed;", list, iter_
)
diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py
index b394784..749b173 100644
--- a/tests/test_script_consumption.py
+++ b/tests/test_script_consumption.py
@@ -7,9 +7,16 @@ import textwrap
from alembic import command, util
from alembic.util import compat
from alembic.script import ScriptDirectory, Script
-from alembic.testing.env import clear_staging_env, staging_env, \
- _sqlite_testing_config, write_script, _sqlite_file_db, \
- three_rev_fixture, _no_sql_testing_config, env_file_fixture
+from alembic.testing.env import (
+ clear_staging_env,
+ staging_env,
+ _sqlite_testing_config,
+ write_script,
+ _sqlite_file_db,
+ three_rev_fixture,
+ _no_sql_testing_config,
+ env_file_fixture,
+)
from alembic.testing import eq_, assert_raises_message
from alembic.testing.fixtures import TestBase, capture_context_buffer
from alembic.environment import EnvironmentContext
@@ -18,7 +25,7 @@ from alembic.testing import mock
class ApplyVersionsFunctionalTest(TestBase):
- __only_on__ = 'sqlite'
+ __only_on__ = "sqlite"
sourceless = False
@@ -46,7 +53,10 @@ class ApplyVersionsFunctionalTest(TestBase):
script = ScriptDirectory.from_config(self.cfg)
script.generate_revision(a, None, refresh=True)
- write_script(script, a, """
+ write_script(
+ script,
+ a,
+ """
revision = '%s'
down_revision = None
@@ -60,10 +70,16 @@ class ApplyVersionsFunctionalTest(TestBase):
def downgrade():
op.execute("DROP TABLE foo")
- """ % a, sourceless=self.sourceless)
+ """
+ % a,
+ sourceless=self.sourceless,
+ )
script.generate_revision(b, None, refresh=True)
- write_script(script, b, """
+ write_script(
+ script,
+ b,
+ """
revision = '%s'
down_revision = '%s'
@@ -77,10 +93,16 @@ class ApplyVersionsFunctionalTest(TestBase):
def downgrade():
op.execute("DROP TABLE bar")
- """ % (b, a), sourceless=self.sourceless)
+ """
+ % (b, a),
+ sourceless=self.sourceless,
+ )
script.generate_revision(c, None, refresh=True)
- write_script(script, c, """
+ write_script(
+ script,
+ c,
+ """
revision = '%s'
down_revision = '%s'
@@ -94,49 +116,52 @@ class ApplyVersionsFunctionalTest(TestBase):
def downgrade():
op.execute("DROP TABLE bat")
- """ % (c, b), sourceless=self.sourceless)
+ """
+ % (c, b),
+ sourceless=self.sourceless,
+ )
def _test_002_upgrade(self):
command.upgrade(self.cfg, self.c)
db = self.bind
- assert db.dialect.has_table(db.connect(), 'foo')
- assert db.dialect.has_table(db.connect(), 'bar')
- assert db.dialect.has_table(db.connect(), 'bat')
+ assert db.dialect.has_table(db.connect(), "foo")
+ assert db.dialect.has_table(db.connect(), "bar")
+ assert db.dialect.has_table(db.connect(), "bat")
def _test_003_downgrade(self):
command.downgrade(self.cfg, self.a)
db = self.bind
- assert db.dialect.has_table(db.connect(), 'foo')
- assert not db.dialect.has_table(db.connect(), 'bar')
- assert not db.dialect.has_table(db.connect(), 'bat')
+ assert db.dialect.has_table(db.connect(), "foo")
+ assert not db.dialect.has_table(db.connect(), "bar")
+ assert not db.dialect.has_table(db.connect(), "bat")
def _test_004_downgrade(self):
- command.downgrade(self.cfg, 'base')
+ command.downgrade(self.cfg, "base")
db = self.bind
- assert not db.dialect.has_table(db.connect(), 'foo')
- assert not db.dialect.has_table(db.connect(), 'bar')
- assert not db.dialect.has_table(db.connect(), 'bat')
+ assert not db.dialect.has_table(db.connect(), "foo")
+ assert not db.dialect.has_table(db.connect(), "bar")
+ assert not db.dialect.has_table(db.connect(), "bat")
def _test_005_upgrade(self):
command.upgrade(self.cfg, self.b)
db = self.bind
- assert db.dialect.has_table(db.connect(), 'foo')
- assert db.dialect.has_table(db.connect(), 'bar')
- assert not db.dialect.has_table(db.connect(), 'bat')
+ assert db.dialect.has_table(db.connect(), "foo")
+ assert db.dialect.has_table(db.connect(), "bar")
+ assert not db.dialect.has_table(db.connect(), "bat")
def _test_006_upgrade_again(self):
command.upgrade(self.cfg, self.b)
db = self.bind
- assert db.dialect.has_table(db.connect(), 'foo')
- assert db.dialect.has_table(db.connect(), 'bar')
- assert not db.dialect.has_table(db.connect(), 'bat')
+ assert db.dialect.has_table(db.connect(), "foo")
+ assert db.dialect.has_table(db.connect(), "bar")
+ assert not db.dialect.has_table(db.connect(), "bat")
def _test_007_stamp_upgrade(self):
command.stamp(self.cfg, self.c)
db = self.bind
- assert db.dialect.has_table(db.connect(), 'foo')
- assert db.dialect.has_table(db.connect(), 'bar')
- assert not db.dialect.has_table(db.connect(), 'bat')
+ assert db.dialect.has_table(db.connect(), "foo")
+ assert db.dialect.has_table(db.connect(), "bar")
+ assert not db.dialect.has_table(db.connect(), "bat")
class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
@@ -144,25 +169,29 @@ class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
class NewFangledSourcelessEnvOnlyApplyVersionsTest(
- ApplyVersionsFunctionalTest):
+ ApplyVersionsFunctionalTest
+):
sourceless = "pep3147_envonly"
- __requires__ = "pep3147",
+ __requires__ = ("pep3147",)
class NewFangledSourcelessEverythingApplyVersionsTest(
- ApplyVersionsFunctionalTest):
+ ApplyVersionsFunctionalTest
+):
sourceless = "pep3147_everything"
- __requires__ = "pep3147",
+ __requires__ = ("pep3147",)
class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
- exp_kwargs = frozenset(('ctx', 'heads', 'run_args', 'step'))
+ exp_kwargs = frozenset(("ctx", "heads", "run_args", "step"))
@staticmethod
def _env_file_fixture():
- env_file_fixture(textwrap.dedent("""\
+ env_file_fixture(
+ textwrap.dedent(
+ """\
import alembic
from alembic import context
from sqlalchemy import engine_from_config, pool
@@ -199,13 +228,16 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
run_migrations_offline()
else:
run_migrations_online()
- """))
+ """
+ )
+ )
def test_steps(self):
import alembic
+
alembic.mock_event_listener = None
self._env_file_fixture()
- with mock.patch('alembic.mock_event_listener', mock.Mock()) as mymock:
+ with mock.patch("alembic.mock_event_listener", mock.Mock()) as mymock:
super(CallbackEnvironmentTest, self).test_steps()
calls = mymock.call_args_list
assert calls
@@ -213,27 +245,27 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
args, kw = call
assert not args
assert set(kw.keys()) >= self.exp_kwargs
- assert kw['run_args'] == {}
- assert hasattr(kw['ctx'], 'get_current_revision')
+ assert kw["run_args"] == {}
+ assert hasattr(kw["ctx"], "get_current_revision")
- step = kw['step']
+ step = kw["step"]
assert isinstance(step.is_upgrade, bool)
assert isinstance(step.is_stamp, bool)
assert isinstance(step.is_migration, bool)
assert isinstance(step.up_revision_id, compat.string_types)
assert isinstance(step.up_revision, Script)
- for revtype in 'up', 'down', 'source', 'destination':
- revs = getattr(step, '%s_revisions' % revtype)
+ for revtype in "up", "down", "source", "destination":
+ revs = getattr(step, "%s_revisions" % revtype)
assert isinstance(revs, tuple)
for rev in revs:
assert isinstance(rev, Script)
- revids = getattr(step, '%s_revision_ids' % revtype)
+ revids = getattr(step, "%s_revision_ids" % revtype)
for revid in revids:
assert isinstance(revid, compat.string_types)
- heads = kw['heads']
- assert hasattr(heads, '__iter__')
+ heads = kw["heads"]
+ assert hasattr(heads, "__iter__")
for h in heads:
assert h is None or isinstance(h, compat.string_types)
@@ -242,8 +274,8 @@ class OfflineTransactionalDDLTest(TestBase):
def setUp(self):
self.env = staging_env()
self.cfg = cfg = _no_sql_testing_config()
- cfg.set_main_option('dialect_name', 'sqlite')
- cfg.remove_main_option('url')
+ cfg.set_main_option("dialect_name", "sqlite")
+ cfg.remove_main_option("url")
self.a, self.b, self.c = three_rev_fixture(cfg)
@@ -254,11 +286,12 @@ class OfflineTransactionalDDLTest(TestBase):
with capture_context_buffer(transactional_ddl=True) as buf:
command.upgrade(self.cfg, self.c, sql=True)
assert re.match(
- (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % self.a) +
- (r".*%s" % self.b) +
- (r".*%s.*?COMMIT;.*$" % self.c),
-
- buf.getvalue(), re.S)
+ (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % self.a)
+ + (r".*%s" % self.b)
+ + (r".*%s.*?COMMIT;.*$" % self.c),
+ buf.getvalue(),
+ re.S,
+ )
def test_begin_commit_nontransactional_ddl(self):
with capture_context_buffer(transactional_ddl=False) as buf:
@@ -270,11 +303,12 @@ class OfflineTransactionalDDLTest(TestBase):
with capture_context_buffer(transaction_per_migration=True) as buf:
command.upgrade(self.cfg, self.c, sql=True)
assert re.match(
- (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % self.a) +
- (r"BEGIN;.*?%s.*?COMMIT;.*" % self.b) +
- (r"BEGIN;.*?%s.*?COMMIT;.*$" % self.c),
-
- buf.getvalue(), re.S)
+ (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % self.a)
+ + (r"BEGIN;.*?%s.*?COMMIT;.*" % self.b)
+ + (r"BEGIN;.*?%s.*?COMMIT;.*$" % self.c),
+ buf.getvalue(),
+ re.S,
+ )
class OnlineTransactionalDDLTest(TestBase):
@@ -290,7 +324,10 @@ class OnlineTransactionalDDLTest(TestBase):
b = util.rev_id()
c = util.rev_id()
script.generate_revision(a, "revision a", refresh=True)
- write_script(script, a, """
+ write_script(
+ script,
+ a,
+ """
"rev a"
revision = '%s'
@@ -302,9 +339,14 @@ def upgrade():
def downgrade():
pass
-""" % (a, ))
+"""
+ % (a,),
+ )
script.generate_revision(b, "revision b", refresh=True)
- write_script(script, b, """
+ write_script(
+ script,
+ b,
+ """
"rev b"
revision = '%s'
down_revision = '%s'
@@ -320,9 +362,14 @@ def upgrade():
def downgrade():
pass
-""" % (b, a))
+"""
+ % (b, a),
+ )
script.generate_revision(c, "revision c", refresh=True)
- write_script(script, c, """
+ write_script(
+ script,
+ c,
+ """
"rev c"
revision = '%s'
down_revision = '%s'
@@ -337,7 +384,9 @@ def upgrade():
def downgrade():
pass
-""" % (c, b))
+"""
+ % (c, b),
+ )
return a, b, c
@contextmanager
@@ -347,7 +396,8 @@ def downgrade():
def configure(*arg, **opt):
opt.update(
transactional_ddl=transactional_ddl,
- transaction_per_migration=transaction_per_migration)
+ transaction_per_migration=transaction_per_migration,
+ )
return conf(*arg, **opt)
with mock.patch.object(EnvironmentContext, "configure", configure):
@@ -357,39 +407,47 @@ def downgrade():
a, b, c = self._opened_transaction_fixture()
with self._patch_environment(
- transactional_ddl=False, transaction_per_migration=False):
+ transactional_ddl=False, transaction_per_migration=False
+ ):
assert_raises_message(
util.CommandError,
r'Migration "upgrade .*, rev b" has left an uncommitted '
- r'transaction opened; transactional_ddl is False so Alembic '
- r'is not committing transactions',
- command.upgrade, self.cfg, c
+ r"transaction opened; transactional_ddl is False so Alembic "
+ r"is not committing transactions",
+ command.upgrade,
+ self.cfg,
+ c,
)
def test_raise_when_rev_leaves_open_transaction_tpm(self):
a, b, c = self._opened_transaction_fixture()
with self._patch_environment(
- transactional_ddl=False, transaction_per_migration=True):
+ transactional_ddl=False, transaction_per_migration=True
+ ):
assert_raises_message(
util.CommandError,
r'Migration "upgrade .*, rev b" has left an uncommitted '
- r'transaction opened; transactional_ddl is False so Alembic '
- r'is not committing transactions',
- command.upgrade, self.cfg, c
+ r"transaction opened; transactional_ddl is False so Alembic "
+ r"is not committing transactions",
+ command.upgrade,
+ self.cfg,
+ c,
)
def test_noerr_rev_leaves_open_transaction_transactional_ddl(self):
a, b, c = self._opened_transaction_fixture()
with self._patch_environment(
- transactional_ddl=True, transaction_per_migration=False):
+ transactional_ddl=True, transaction_per_migration=False
+ ):
command.upgrade(self.cfg, c)
def test_noerr_transaction_opened_externally(self):
a, b, c = self._opened_transaction_fixture()
- env_file_fixture("""
+ env_file_fixture(
+ """
from sqlalchemy import engine_from_config, pool
def run_migrations_online():
@@ -411,22 +469,27 @@ def run_migrations_online():
run_migrations_online()
-""")
+"""
+ )
command.stamp(self.cfg, c)
class EncodingTest(TestBase):
-
def setUp(self):
self.env = staging_env()
self.cfg = cfg = _no_sql_testing_config()
- cfg.set_main_option('dialect_name', 'sqlite')
- cfg.remove_main_option('url')
+ cfg.set_main_option("dialect_name", "sqlite")
+ cfg.remove_main_option("url")
self.a = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_revision(self.a, "revision a", refresh=True)
- write_script(script, self.a, (compat.u("""# coding: utf-8
+ write_script(
+ script,
+ self.a,
+ (
+ compat.u(
+ """# coding: utf-8
from __future__ import unicode_literals
revision = '%s'
down_revision = None
@@ -439,22 +502,25 @@ def upgrade():
def downgrade():
op.execute("drôle de petite voix m’a réveillé")
-""") % self.a), encoding='utf-8')
+"""
+ )
+ % self.a
+ ),
+ encoding="utf-8",
+ )
def tearDown(self):
clear_staging_env()
def test_encode(self):
with capture_context_buffer(
- bytes_io=True,
- output_encoding='utf-8'
+ bytes_io=True, output_encoding="utf-8"
) as buf:
command.upgrade(self.cfg, self.a, sql=True)
assert compat.u("« S’il vous plaît…").encode("utf-8") in buf.getvalue()
class VersionNameTemplateTest(TestBase):
-
def setUp(self):
self.env = staging_env()
self.cfg = _sqlite_testing_config()
@@ -467,7 +533,10 @@ class VersionNameTemplateTest(TestBase):
script = ScriptDirectory.from_config(self.cfg)
a = util.rev_id()
script.generate_revision(a, "some message", refresh=True)
- write_script(script, a, """
+ write_script(
+ script,
+ a,
+ """
revision = '%s'
down_revision = None
@@ -481,7 +550,9 @@ class VersionNameTemplateTest(TestBase):
def downgrade():
op.execute("DROP TABLE foo")
- """ % a)
+ """
+ % a,
+ )
script = ScriptDirectory.from_config(self.cfg)
rev = script.get_revision(a)
@@ -493,7 +564,10 @@ class VersionNameTemplateTest(TestBase):
script = ScriptDirectory.from_config(self.cfg)
a = util.rev_id()
script.generate_revision(a, None, refresh=True)
- write_script(script, a, """
+ write_script(
+ script,
+ a,
+ """
down_revision = None
from alembic import op
@@ -506,7 +580,8 @@ class VersionNameTemplateTest(TestBase):
def downgrade():
op.execute("DROP TABLE foo")
- """)
+ """,
+ )
script = ScriptDirectory.from_config(self.cfg)
rev = script.get_revision(a)
@@ -520,8 +595,9 @@ class VersionNameTemplateTest(TestBase):
script.generate_revision(a, "foobar", refresh=True)
path = script.get_revision(a).path
- with open(path, 'w') as fp:
- fp.write("""
+ with open(path, "w") as fp:
+ fp.write(
+ """
down_revision = None
from alembic import op
@@ -534,7 +610,8 @@ def upgrade():
def downgrade():
op.execute("DROP TABLE foo")
-""")
+"""
+ )
pyc_path = util.pyc_file_from_path(path)
if pyc_path is not None and os.access(pyc_path, os.F_OK):
os.unlink(pyc_path)
@@ -544,7 +621,10 @@ def downgrade():
"Could not determine revision id from filename foobar_%s.py. "
"Be sure the 'revision' variable is declared "
"inside the script." % a,
- Script._from_path, script, path)
+ Script._from_path,
+ script,
+ path,
+ )
class IgnoreFilesTest(TestBase):
@@ -563,13 +643,11 @@ class IgnoreFilesTest(TestBase):
command.revision(self.cfg, message="some rev")
script = ScriptDirectory.from_config(self.cfg)
path = os.path.join(script.versions, fname)
- with open(path, 'w') as f:
- f.write(
- "crap, crap -> crap"
- )
+ with open(path, "w") as f:
+ f.write("crap, crap -> crap")
command.revision(self.cfg, message="another rev")
- script.get_revision('head')
+ script.get_revision("head")
def _test_ignore_init_py(self, ext):
"""test that __init__.py is ignored."""
@@ -613,17 +691,16 @@ class SimpleSourcelessIgnoreFilesTest(IgnoreFilesTest):
class NewFangledEnvOnlySourcelessIgnoreFilesTest(IgnoreFilesTest):
sourceless = "pep3147_envonly"
- __requires__ = "pep3147",
+ __requires__ = ("pep3147",)
class NewFangledEverythingSourcelessIgnoreFilesTest(IgnoreFilesTest):
sourceless = "pep3147_everything"
- __requires__ = "pep3147",
+ __requires__ = ("pep3147",)
class SourcelessNeedsFlagTest(TestBase):
-
def setUp(self):
self.env = staging_env(sourceless=False)
self.cfg = _sqlite_testing_config()
@@ -636,7 +713,10 @@ class SourcelessNeedsFlagTest(TestBase):
script = ScriptDirectory.from_config(self.cfg)
script.generate_revision(a, None, refresh=True)
- write_script(script, a, """
+ write_script(
+ script,
+ a,
+ """
revision = '%s'
down_revision = None
@@ -650,7 +730,10 @@ class SourcelessNeedsFlagTest(TestBase):
def downgrade():
op.execute("DROP TABLE foo")
- """ % a, sourceless=True)
+ """
+ % a,
+ sourceless=True,
+ )
script = ScriptDirectory.from_config(self.cfg)
eq_(script.get_heads(), [])
diff --git a/tests/test_script_production.py b/tests/test_script_production.py
index af01a38..f7837d9 100644
--- a/tests/test_script_production.py
+++ b/tests/test_script_production.py
@@ -1,10 +1,20 @@
from alembic.testing.fixtures import TestBase
from alembic.testing import eq_, ne_, assert_raises_message, is_, assertions
-from alembic.testing.env import clear_staging_env, staging_env, \
- _get_staging_directory, _no_sql_testing_config, env_file_fixture, \
- script_file_fixture, _testing_config, _sqlite_testing_config, \
- three_rev_fixture, _multi_dir_testing_config, write_script,\
- _sqlite_file_db, _multidb_testing_config
+from alembic.testing.env import (
+ clear_staging_env,
+ staging_env,
+ _get_staging_directory,
+ _no_sql_testing_config,
+ env_file_fixture,
+ script_file_fixture,
+ _testing_config,
+ _sqlite_testing_config,
+ three_rev_fixture,
+ _multi_dir_testing_config,
+ write_script,
+ _sqlite_file_db,
+ _multidb_testing_config,
+)
from alembic import command
from alembic.script import ScriptDirectory
from alembic.environment import EnvironmentContext
@@ -24,7 +34,6 @@ env, abc, def_ = None, None, None
class GeneralOrderedTests(TestBase):
-
def setUp(self):
global env
env = staging_env()
@@ -43,11 +52,8 @@ class GeneralOrderedTests(TestBase):
self._test_008_long_name_configurable()
def _test_001_environment(self):
- assert_set = set(['env.py', 'script.py.mako', 'README'])
- eq_(
- assert_set.intersection(os.listdir(env.dir)),
- assert_set
- )
+ assert_set = set(["env.py", "script.py.mako", "README"])
+ eq_(assert_set.intersection(os.listdir(env.dir)), assert_set)
def _test_002_rev_ids(self):
global abc, def_
@@ -66,19 +72,23 @@ class GeneralOrderedTests(TestBase):
eq_(script.revision, abc)
eq_(script.down_revision, None)
assert os.access(
- os.path.join(env.dir, 'versions',
- '%s_this_is_a_message.py' % abc), os.F_OK)
+ os.path.join(env.dir, "versions", "%s_this_is_a_message.py" % abc),
+ os.F_OK,
+ )
assert callable(script.module.upgrade)
eq_(env.get_heads(), [abc])
eq_(env.get_base(), abc)
def _test_005_nextrev(self):
script = env.generate_revision(
- def_, "this is the next rev", refresh=True)
+ def_, "this is the next rev", refresh=True
+ )
assert os.access(
os.path.join(
- env.dir, 'versions',
- '%s_this_is_the_next_rev.py' % def_), os.F_OK)
+ env.dir, "versions", "%s_this_is_the_next_rev.py" % def_
+ ),
+ os.F_OK,
+ )
eq_(script.revision, def_)
eq_(script.down_revision, abc)
eq_(env.get_revision(abc).nextrev, set([def_]))
@@ -103,32 +113,42 @@ class GeneralOrderedTests(TestBase):
def _test_007_long_name(self):
rid = util.rev_id()
- env.generate_revision(rid,
- "this is a really long name with "
- "lots of characters and also "
- "I'd like it to\nhave\nnewlines")
+ env.generate_revision(
+ rid,
+ "this is a really long name with "
+ "lots of characters and also "
+ "I'd like it to\nhave\nnewlines",
+ )
assert os.access(
os.path.join(
- env.dir, 'versions',
- '%s_this_is_a_really_long_name_with_lots_of_.py' % rid),
- os.F_OK)
+ env.dir,
+ "versions",
+ "%s_this_is_a_really_long_name_with_lots_of_.py" % rid,
+ ),
+ os.F_OK,
+ )
def _test_008_long_name_configurable(self):
env.truncate_slug_length = 60
rid = util.rev_id()
- env.generate_revision(rid,
- "this is a really long name with "
- "lots of characters and also "
- "I'd like it to\nhave\nnewlines")
+ env.generate_revision(
+ rid,
+ "this is a really long name with "
+ "lots of characters and also "
+ "I'd like it to\nhave\nnewlines",
+ )
assert os.access(
- os.path.join(env.dir, 'versions',
- '%s_this_is_a_really_long_name_with_lots_'
- 'of_characters_and_also_.py' % rid),
- os.F_OK)
+ os.path.join(
+ env.dir,
+ "versions",
+ "%s_this_is_a_really_long_name_with_lots_"
+ "of_characters_and_also_.py" % rid,
+ ),
+ os.F_OK,
+ )
class ScriptNamingTest(TestBase):
-
@classmethod
def setup_class(cls):
_testing_config()
@@ -143,15 +163,17 @@ class ScriptNamingTest(TestBase):
file_template="%(rev)s_%(slug)s_"
"%(year)s_%(month)s_"
"%(day)s_%(hour)s_"
- "%(minute)s_%(second)s"
+ "%(minute)s_%(second)s",
)
create_date = datetime.datetime(2012, 7, 25, 15, 8, 5)
eq_(
script._rev_path(
- script.versions, "12345", "this is a message", create_date),
+ script.versions, "12345", "this is a message", create_date
+ ),
os.path.abspath(
"%s/versions/12345_this_is_a_"
- "message_2012_7_25_15_8_5.py" % _get_staging_directory())
+ "message_2012_7_25_15_8_5.py" % _get_staging_directory()
+ ),
)
def _test_tz(self, timezone_arg, given, expected):
@@ -161,61 +183,57 @@ class ScriptNamingTest(TestBase):
"%(year)s_%(month)s_"
"%(day)s_%(hour)s_"
"%(minute)s_%(second)s",
- timezone=timezone_arg
+ timezone=timezone_arg,
)
with mock.patch(
- "alembic.script.base.datetime",
- mock.Mock(
- datetime=mock.Mock(
- utcnow=lambda: given,
- now=lambda: given
- )
- )
+ "alembic.script.base.datetime",
+ mock.Mock(
+ datetime=mock.Mock(utcnow=lambda: given, now=lambda: given)
+ ),
):
create_date = script._generate_create_date()
- eq_(
- create_date,
- expected
- )
+ eq_(create_date, expected)
def test_custom_tz(self):
self._test_tz(
- 'EST5EDT',
+ "EST5EDT",
datetime.datetime(2012, 7, 25, 15, 8, 5),
datetime.datetime(
- 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz('EST5EDT'))
+ 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT")
+ ),
)
def test_custom_tz_lowercase(self):
self._test_tz(
- 'est5edt',
+ "est5edt",
datetime.datetime(2012, 7, 25, 15, 8, 5),
datetime.datetime(
- 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz('EST5EDT'))
+ 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT")
+ ),
)
def test_custom_tz_utc(self):
self._test_tz(
- 'utc',
+ "utc",
datetime.datetime(2012, 7, 25, 15, 8, 5),
- datetime.datetime(
- 2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz('UTC'))
+ datetime.datetime(2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz("UTC")),
)
def test_custom_tzdata_tz(self):
self._test_tz(
- 'Europe/Berlin',
+ "Europe/Berlin",
datetime.datetime(2012, 7, 25, 15, 8, 5),
datetime.datetime(
- 2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz('Europe/Berlin'))
+ 2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz("Europe/Berlin")
+ ),
)
def test_default_tz(self):
self._test_tz(
None,
datetime.datetime(2012, 7, 25, 15, 8, 5),
- datetime.datetime(2012, 7, 25, 15, 8, 5)
+ datetime.datetime(2012, 7, 25, 15, 8, 5),
)
def test_tz_cant_locate(self):
@@ -225,7 +243,7 @@ class ScriptNamingTest(TestBase):
self._test_tz,
"fake",
datetime.datetime(2012, 7, 25, 15, 8, 5),
- datetime.datetime(2012, 7, 25, 15, 8, 5)
+ datetime.datetime(2012, 7, 25, 15, 8, 5),
)
@@ -247,7 +265,8 @@ class RevisionCommandTest(TestBase):
def test_create_script_splice(self):
rev = command.revision(
- self.cfg, message="some message", head=self.b, splice=True)
+ self.cfg, message="some message", head=self.b, splice=True
+ )
script = ScriptDirectory.from_config(self.cfg)
rev = script.get_revision(rev.revision)
eq_(rev.down_revision, self.b)
@@ -260,7 +279,9 @@ class RevisionCommandTest(TestBase):
"Revision %s is not a head revision; please specify --splice "
"to create a new branch from this revision" % self.b,
command.revision,
- self.cfg, message="some message", head=self.b
+ self.cfg,
+ message="some message",
+ head=self.b,
)
def test_illegal_revision_chars(self):
@@ -269,19 +290,23 @@ class RevisionCommandTest(TestBase):
r"Character\(s\) '-' not allowed in "
"revision identifier 'no-dashes'",
command.revision,
- self.cfg, message="some message", rev_id="no-dashes"
+ self.cfg,
+ message="some message",
+ rev_id="no-dashes",
)
assert not os.path.exists(
- os.path.join(
- self.env.dir, "versions", "no-dashes_some_message.py"))
+ os.path.join(self.env.dir, "versions", "no-dashes_some_message.py")
+ )
assert_raises_message(
util.CommandError,
r"Character\(s\) '@' not allowed in "
"revision identifier 'no@atsigns'",
command.revision,
- self.cfg, message="some message", rev_id="no@atsigns"
+ self.cfg,
+ message="some message",
+ rev_id="no@atsigns",
)
assert_raises_message(
@@ -289,7 +314,9 @@ class RevisionCommandTest(TestBase):
r"Character\(s\) '-, @' not allowed in revision "
"identifier 'no@atsigns-ordashes'",
command.revision,
- self.cfg, message="some message", rev_id="no@atsigns-ordashes"
+ self.cfg,
+ message="some message",
+ rev_id="no@atsigns-ordashes",
)
assert_raises_message(
@@ -297,12 +324,15 @@ class RevisionCommandTest(TestBase):
r"Character\(s\) '\+' not allowed in revision "
r"identifier 'no\+plussignseither'",
command.revision,
- self.cfg, message="some message", rev_id="no+plussignseither"
+ self.cfg,
+ message="some message",
+ rev_id="no+plussignseither",
)
def test_create_script_branches(self):
rev = command.revision(
- self.cfg, message="some message", branch_label="foobar")
+ self.cfg, message="some message", branch_label="foobar"
+ )
script = ScriptDirectory.from_config(self.cfg)
rev = script.get_revision(rev.revision)
eq_(script.get_revision("foobar"), rev)
@@ -330,7 +360,9 @@ class RevisionCommandTest(TestBase):
"upgraded your script.py.mako to include the 'branch_labels' "
r"section\?",
command.revision,
- self.cfg, message="some message", branch_label="foobar"
+ self.cfg,
+ message="some message",
+ branch_label="foobar",
)
@@ -350,11 +382,17 @@ class CustomizeRevisionTest(TestBase):
(self.model3, "model3"),
]:
script.generate_revision(
- model, name, refresh=True,
+ model,
+ name,
+ refresh=True,
version_path=os.path.join(_get_staging_directory(), name),
- head="base")
+ head="base",
+ )
- write_script(script, model, """\
+ write_script(
+ script,
+ model,
+ """\
"%s"
revision = '%s'
down_revision = None
@@ -370,7 +408,9 @@ def upgrade():
def downgrade():
pass
-""" % (name, model, name))
+"""
+ % (name, model, name),
+ )
def tearDown(self):
clear_staging_env()
@@ -385,13 +425,13 @@ def downgrade():
context.configure(
connection=connection,
target_metadata=target_metadata,
- process_revision_directives=fn)
+ process_revision_directives=fn,
+ )
with context.begin_transaction():
context.run_migrations()
return mock.patch(
- "alembic.script.base.ScriptDirectory.run_env",
- run_env
+ "alembic.script.base.ScriptDirectory.run_env", run_env
)
def test_new_locations_no_autogen(self):
@@ -404,24 +444,27 @@ def downgrade():
ops.UpgradeOps(),
ops.DowngradeOps(),
version_path=os.path.join(
- _get_staging_directory(), "model1"),
- head="model1@head"
+ _get_staging_directory(), "model1"
+ ),
+ head="model1@head",
),
ops.MigrationScript(
util.rev_id(),
ops.UpgradeOps(),
ops.DowngradeOps(),
version_path=os.path.join(
- _get_staging_directory(), "model2"),
- head="model2@head"
+ _get_staging_directory(), "model2"
+ ),
+ head="model2@head",
),
ops.MigrationScript(
util.rev_id(),
ops.UpgradeOps(),
ops.DowngradeOps(),
version_path=os.path.join(
- _get_staging_directory(), "model3"),
- head="model3@head"
+ _get_staging_directory(), "model3"
+ ),
+ head="model3@head",
),
]
@@ -438,10 +481,13 @@ def downgrade():
rev_script = script.get_revision(rev.revision)
eq_(
rev_script.path,
- os.path.abspath(os.path.join(
- _get_staging_directory(), model,
- "%s_.py" % (rev_script.revision, )
- ))
+ os.path.abspath(
+ os.path.join(
+ _get_staging_directory(),
+ model,
+ "%s_.py" % (rev_script.revision,),
+ )
+ ),
)
assert os.path.exists(rev_script.path)
@@ -455,19 +501,23 @@ def downgrade():
with self._env_fixture(process_revision_directives, m):
rev = command.revision(
- self.cfg, message="some message", head="model1@head", sql=True)
+ self.cfg, message="some message", head="model1@head", sql=True
+ )
with mock.patch.object(rev.module, "op") as op_mock:
rev.module.upgrade()
eq_(
op_mock.mock_calls,
- [mock.call.create_index(
- 'some_index', 'some_table', ['a', 'b'], unique=False)]
+ [
+ mock.call.create_index(
+ "some_index", "some_table", ["a", "b"], unique=False
+ )
+ ],
)
def test_autogen(self):
m = sa.MetaData()
- sa.Table('t', m, sa.Column('x', sa.Integer))
+ sa.Table("t", m, sa.Column("x", sa.Integer))
def process_revision_directives(context, rev, generate_revisions):
existing_upgrades = generate_revisions[0].upgrade_ops
@@ -483,17 +533,19 @@ def downgrade():
existing_upgrades,
ops.DowngradeOps(),
version_path=os.path.join(
- _get_staging_directory(), "model1"),
- head="model1@head"
+ _get_staging_directory(), "model1"
+ ),
+ head="model1@head",
),
ops.MigrationScript(
util.rev_id(),
ops.UpgradeOps(ops=existing_downgrades.ops),
ops.DowngradeOps(),
version_path=os.path.join(
- _get_staging_directory(), "model2"),
- head="model2@head"
- )
+ _get_staging_directory(), "model2"
+ ),
+ head="model2@head",
+ ),
]
with self._env_fixture(process_revision_directives, m):
@@ -501,57 +553,57 @@ def downgrade():
eq_(
Inspector.from_engine(self.engine).get_table_names(),
- ["alembic_version"]
+ ["alembic_version"],
)
command.revision(
- self.cfg, message="some message",
- autogenerate=True)
+ self.cfg, message="some message", autogenerate=True
+ )
command.upgrade(self.cfg, "model1@head")
eq_(
Inspector.from_engine(self.engine).get_table_names(),
- ["alembic_version", "t"]
+ ["alembic_version", "t"],
)
command.upgrade(self.cfg, "model2@head")
eq_(
Inspector.from_engine(self.engine).get_table_names(),
- ["alembic_version"]
+ ["alembic_version"],
)
def test_programmatic_command_option(self):
-
def process_revision_directives(context, rev, generate_revisions):
generate_revisions[0].message = "test programatic"
generate_revisions[0].upgrade_ops = ops.UpgradeOps(
ops=[
ops.CreateTableOp(
- 'test_table',
+ "test_table",
[
- sa.Column('id', sa.Integer(), primary_key=True),
- sa.Column('name', sa.String(50), nullable=False)
- ]
- ),
+ sa.Column("id", sa.Integer(), primary_key=True),
+ sa.Column("name", sa.String(50), nullable=False),
+ ],
+ )
]
)
generate_revisions[0].downgrade_ops = ops.DowngradeOps(
- ops=[
- ops.DropTableOp('test_table')
- ]
+ ops=[ops.DropTableOp("test_table")]
)
with self._env_fixture(None, None):
rev = command.revision(
self.cfg,
head="model1@head",
- process_revision_directives=process_revision_directives)
+ process_revision_directives=process_revision_directives,
+ )
with open(rev.path) as handle:
result = handle.read()
- assert ("""
+ assert (
+ (
+ """
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('test_table',
@@ -560,22 +612,19 @@ def upgrade():
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
-""") in result
+"""
+ )
+ in result
+ )
class ScriptAccessorTest(TestBase):
def test_upgrade_downgrade_ops_list_accessors(self):
u1 = ops.UpgradeOps(ops=[])
d1 = ops.DowngradeOps(ops=[])
- m1 = ops.MigrationScript(
- "somerev", u1, d1
- )
- is_(
- m1.upgrade_ops, u1
- )
- is_(
- m1.downgrade_ops, d1
- )
+ m1 = ops.MigrationScript("somerev", u1, d1)
+ is_(m1.upgrade_ops, u1)
+ is_(m1.downgrade_ops, d1)
u2 = ops.UpgradeOps(ops=[])
d2 = ops.DowngradeOps(ops=[])
m1._upgrade_ops.append(u2)
@@ -585,13 +634,17 @@ class ScriptAccessorTest(TestBase):
ValueError,
"This MigrationScript instance has a multiple-entry list for "
"UpgradeOps; please use the upgrade_ops_list attribute.",
- getattr, m1, "upgrade_ops"
+ getattr,
+ m1,
+ "upgrade_ops",
)
assert_raises_message(
ValueError,
"This MigrationScript instance has a multiple-entry list for "
"DowngradeOps; please use the downgrade_ops_list attribute.",
- getattr, m1, "downgrade_ops"
+ getattr,
+ m1,
+ "downgrade_ops",
)
eq_(m1.upgrade_ops_list, [u1, u2])
eq_(m1.downgrade_ops_list, [d1, d2])
@@ -615,39 +668,36 @@ class ImportsTest(TestBase):
context.configure(
connection=connection,
target_metadata=target_metadata,
- **kw)
+ **kw
+ )
with context.begin_transaction():
context.run_migrations()
return mock.patch(
- "alembic.script.base.ScriptDirectory.run_env",
- run_env
+ "alembic.script.base.ScriptDirectory.run_env", run_env
)
def test_imports_in_script(self):
from sqlalchemy import MetaData, Table, Column
from sqlalchemy.dialects.mysql import VARCHAR
- type_ = VARCHAR(20, charset='utf8', national=True)
+ type_ = VARCHAR(20, charset="utf8", national=True)
m = MetaData()
- Table(
- 't', m,
- Column('x', type_)
- )
+ Table("t", m, Column("x", type_))
def process_revision_directives(context, rev, generate_revisions):
generate_revisions[0].imports.add(
- "from sqlalchemy.dialects.mysql import TINYINT")
+ "from sqlalchemy.dialects.mysql import TINYINT"
+ )
with self._env_fixture(
- m,
- process_revision_directives=process_revision_directives
+ m, process_revision_directives=process_revision_directives
):
rev = command.revision(
- self.cfg, message="some message",
- autogenerate=True)
+ self.cfg, message="some message", autogenerate=True
+ )
with open(rev.path) as file_:
contents = file_.read()
@@ -659,24 +709,24 @@ class MultiContextTest(TestBase):
"""test the multidb template for autogenerate front-to-back"""
def setUp(self):
- self.engine1 = _sqlite_file_db(tempname='eng1.db')
- self.engine2 = _sqlite_file_db(tempname='eng2.db')
- self.engine3 = _sqlite_file_db(tempname='eng3.db')
+ self.engine1 = _sqlite_file_db(tempname="eng1.db")
+ self.engine2 = _sqlite_file_db(tempname="eng2.db")
+ self.engine3 = _sqlite_file_db(tempname="eng3.db")
self.env = staging_env(template="multidb")
- self.cfg = _multidb_testing_config({
- "engine1": self.engine1,
- "engine2": self.engine2,
- "engine3": self.engine3
- })
+ self.cfg = _multidb_testing_config(
+ {
+ "engine1": self.engine1,
+ "engine2": self.engine2,
+ "engine3": self.engine3,
+ }
+ )
def _write_metadata(self, meta):
- path = os.path.join(_get_staging_directory(), 'scripts', 'env.py')
+ path = os.path.join(_get_staging_directory(), "scripts", "env.py")
with open(path) as env_:
existing_env = env_.read()
- existing_env = existing_env.replace(
- "target_metadata = {}",
- meta)
+ existing_env = existing_env.replace("target_metadata = {}", meta)
with open(path, "w") as env_:
env_.write(existing_env)
@@ -701,40 +751,30 @@ sa.Table('e3t1', m3, sa.Column('z', sa.Integer))
)
rev = command.revision(
- self.cfg, message="some message",
- autogenerate=True
+ self.cfg, message="some message", autogenerate=True
)
with mock.patch.object(rev.module, "op") as op_mock:
rev.module.upgrade_engine1()
eq_(
op_mock.mock_calls[-1],
- mock.call.create_table('e1t1', mock.ANY)
+ mock.call.create_table("e1t1", mock.ANY),
)
rev.module.upgrade_engine2()
eq_(
op_mock.mock_calls[-1],
- mock.call.create_table('e2t1', mock.ANY)
+ mock.call.create_table("e2t1", mock.ANY),
)
rev.module.upgrade_engine3()
eq_(
op_mock.mock_calls[-1],
- mock.call.create_table('e3t1', mock.ANY)
+ mock.call.create_table("e3t1", mock.ANY),
)
rev.module.downgrade_engine1()
- eq_(
- op_mock.mock_calls[-1],
- mock.call.drop_table('e1t1')
- )
+ eq_(op_mock.mock_calls[-1], mock.call.drop_table("e1t1"))
rev.module.downgrade_engine2()
- eq_(
- op_mock.mock_calls[-1],
- mock.call.drop_table('e2t1')
- )
+ eq_(op_mock.mock_calls[-1], mock.call.drop_table("e2t1"))
rev.module.downgrade_engine3()
- eq_(
- op_mock.mock_calls[-1],
- mock.call.drop_table('e3t1')
- )
+ eq_(op_mock.mock_calls[-1], mock.call.drop_table("e3t1"))
class RewriterTest(TestBase):
@@ -744,20 +784,13 @@ class RewriterTest(TestBase):
mocker = mock.Mock(side_effect=lambda context, revision, op: op)
writer.rewrites(ops.MigrateOperation)(mocker)
- addcolop = ops.AddColumnOp(
- 't1', sa.Column('x', sa.Integer())
- )
+ addcolop = ops.AddColumnOp("t1", sa.Column("x", sa.Integer()))
directives = [
ops.MigrationScript(
util.rev_id(),
- ops.UpgradeOps(ops=[
- ops.ModifyTableOps('t1', ops=[
- addcolop
- ])
- ]),
- ops.DowngradeOps(ops=[
- ]),
+ ops.UpgradeOps(ops=[ops.ModifyTableOps("t1", ops=[addcolop])]),
+ ops.DowngradeOps(ops=[]),
)
]
@@ -771,7 +804,7 @@ class RewriterTest(TestBase):
mock.call(ctx, rev, directives[0].upgrade_ops.ops[0]),
mock.call(ctx, rev, addcolop),
mock.call(ctx, rev, directives[0].downgrade_ops),
- ]
+ ],
)
def test_double_migrate_table(self):
@@ -783,28 +816,33 @@ class RewriterTest(TestBase):
def second_table(context, revision, op):
return [
op,
- ops.ModifyTableOps('t2', ops=[
- ops.AddColumnOp('t2', sa.Column('x', sa.Integer()))
- ])
+ ops.ModifyTableOps(
+ "t2",
+ ops=[ops.AddColumnOp("t2", sa.Column("x", sa.Integer()))],
+ ),
]
@writer.rewrites(ops.AddColumnOp)
def add_column(context, revision, op):
- idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name])
+ idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
idx_ops.append(idx_op)
- return [
- op,
- idx_op
- ]
+ return [op, idx_op]
directives = [
ops.MigrationScript(
util.rev_id(),
- ops.UpgradeOps(ops=[
- ops.ModifyTableOps('t1', ops=[
- ops.AddColumnOp('t1', sa.Column('x', sa.Integer()))
- ])
- ]),
+ ops.UpgradeOps(
+ ops=[
+ ops.ModifyTableOps(
+ "t1",
+ ops=[
+ ops.AddColumnOp(
+ "t1", sa.Column("x", sa.Integer())
+ )
+ ],
+ )
+ ]
+ ),
ops.DowngradeOps(ops=[]),
)
]
@@ -812,17 +850,10 @@ class RewriterTest(TestBase):
ctx, rev = mock.Mock(), mock.Mock()
writer(ctx, rev, directives)
eq_(
- [d.table_name for d in directives[0].upgrade_ops.ops],
- ['t1', 't2']
- )
- is_(
- directives[0].upgrade_ops.ops[0].ops[1],
- idx_ops[0]
- )
- is_(
- directives[0].upgrade_ops.ops[1].ops[1],
- idx_ops[1]
+ [d.table_name for d in directives[0].upgrade_ops.ops], ["t1", "t2"]
)
+ is_(directives[0].upgrade_ops.ops[0].ops[1], idx_ops[0])
+ is_(directives[0].upgrade_ops.ops[1].ops[1], idx_ops[1])
def test_chained_ops(self):
writer1 = autogenerate.Rewriter()
@@ -841,26 +872,32 @@ class RewriterTest(TestBase):
op.column.name,
modify_nullable=False,
existing_type=op.column.type,
- )
+ ),
]
@writer2.rewrites(ops.AddColumnOp)
def add_column_idx(context, revision, op):
- idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name])
- return [
- op,
- idx_op
- ]
+ idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
+ return [op, idx_op]
directives = [
ops.MigrationScript(
util.rev_id(),
- ops.UpgradeOps(ops=[
- ops.ModifyTableOps('t1', ops=[
- ops.AddColumnOp(
- 't1', sa.Column('x', sa.Integer(), nullable=False))
- ])
- ]),
+ ops.UpgradeOps(
+ ops=[
+ ops.ModifyTableOps(
+ "t1",
+ ops=[
+ ops.AddColumnOp(
+ "t1",
+ sa.Column(
+ "x", sa.Integer(), nullable=False
+ ),
+ )
+ ],
+ )
+ ]
+ ),
ops.DowngradeOps(ops=[]),
)
]
@@ -877,7 +914,7 @@ class RewriterTest(TestBase):
" op.alter_column('t1', 'x',\n"
" existing_type=sa.Integer(),\n"
" nullable=False)\n"
- " # ### end Alembic commands ###"
+ " # ### end Alembic commands ###",
)
@@ -894,7 +931,9 @@ class MultiDirRevisionCommandTest(TestBase):
util.CommandError,
"Multiple version locations present, please specify "
"--version-path",
- command.revision, self.cfg, message="some message"
+ command.revision,
+ self.cfg,
+ message="some message",
)
def test_multiple_dir_no_bases_invalid_version_path(self):
@@ -902,40 +941,46 @@ class MultiDirRevisionCommandTest(TestBase):
util.CommandError,
"Path foo/bar/ is not represented in current version locations",
command.revision,
- self.cfg, message="x",
- version_path=os.path.join("foo/bar/")
+ self.cfg,
+ message="x",
+ version_path=os.path.join("foo/bar/"),
)
def test_multiple_dir_no_bases_version_path(self):
script = command.revision(
- self.cfg, message="x",
- version_path=os.path.join(_get_staging_directory(), "model1"))
+ self.cfg,
+ message="x",
+ version_path=os.path.join(_get_staging_directory(), "model1"),
+ )
assert os.access(script.path, os.F_OK)
def test_multiple_dir_chooses_base(self):
command.revision(
- self.cfg, message="x",
+ self.cfg,
+ message="x",
head="base",
- version_path=os.path.join(_get_staging_directory(), "model1"))
+ version_path=os.path.join(_get_staging_directory(), "model1"),
+ )
script2 = command.revision(
- self.cfg, message="y",
+ self.cfg,
+ message="y",
head="base",
- version_path=os.path.join(_get_staging_directory(), "model2"))
+ version_path=os.path.join(_get_staging_directory(), "model2"),
+ )
script3 = command.revision(
- self.cfg, message="y2",
- head=script2.revision)
+ self.cfg, message="y2", head=script2.revision
+ )
eq_(
os.path.dirname(script3.path),
- os.path.abspath(os.path.join(_get_staging_directory(), "model2"))
+ os.path.abspath(os.path.join(_get_staging_directory(), "model2")),
)
assert os.access(script3.path, os.F_OK)
class TemplateArgsTest(TestBase):
-
def setUp(self):
staging_env()
self.cfg = _no_sql_testing_config(
@@ -949,43 +994,45 @@ class TemplateArgsTest(TestBase):
config = _no_sql_testing_config()
script = ScriptDirectory.from_config(config)
template_args = {"x": "x1", "y": "y1", "z": "z1"}
- env = EnvironmentContext(
- config,
- script,
- template_args=template_args
- )
- env.configure(dialect_name="sqlite",
- template_args={"y": "y2", "q": "q1"})
- eq_(
- template_args,
- {"x": "x1", "y": "y2", "z": "z1", "q": "q1"}
+ env = EnvironmentContext(config, script, template_args=template_args)
+ env.configure(
+ dialect_name="sqlite", template_args={"y": "y2", "q": "q1"}
)
+ eq_(template_args, {"x": "x1", "y": "y2", "z": "z1", "q": "q1"})
def test_tmpl_args_revision(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
-""")
- script_file_fixture("""
+"""
+ )
+ script_file_fixture(
+ """
# somearg: ${somearg}
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
-""")
+"""
+ )
command.revision(self.cfg, message="some rev")
script = ScriptDirectory.from_config(self.cfg)
- rev = script.get_revision('head')
+ rev = script.get_revision("head")
with open(rev.path) as f:
text = f.read()
assert "somearg: somevalue" in text
def test_bad_render(self):
- env_file_fixture("""
+ env_file_fixture(
+ """
context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
-""")
- script_file_fixture("""
+"""
+ )
+ script_file_fixture(
+ """
<% z = x + y %>
-""")
+"""
+ )
try:
command.revision(self.cfg, message="some rev")
@@ -993,7 +1040,7 @@ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
m = re.match(
r"^Template rendering failed; see (.+?) "
"for a template-oriented",
- str(ce)
+ str(ce),
)
assert m, "Command error did not produce a file"
with open(m.group(1)) as handle:
@@ -1003,13 +1050,12 @@ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
class DuplicateVersionLocationsTest(TestBase):
-
def setUp(self):
self.env = staging_env()
self.cfg = _multi_dir_testing_config(
# this is a duplicate of one of the paths
# already present in this fixture
- extra_version_location='%(here)s/model1'
+ extra_version_location="%(here)s/model1"
)
script = ScriptDirectory.from_config(self.cfg)
@@ -1022,10 +1068,16 @@ class DuplicateVersionLocationsTest(TestBase):
(self.model3, "model3"),
]:
script.generate_revision(
- model, name, refresh=True,
+ model,
+ name,
+ refresh=True,
version_path=os.path.join(_get_staging_directory(), name),
- head="base")
- write_script(script, model, """\
+ head="base",
+ )
+ write_script(
+ script,
+ model,
+ """\
"%s"
revision = '%s'
down_revision = None
@@ -1041,7 +1093,9 @@ def upgrade():
def downgrade():
pass
-""" % (name, model, name))
+"""
+ % (name, model, name),
+ )
def tearDown(self):
clear_staging_env()
@@ -1049,16 +1103,20 @@ def downgrade():
def test_env_emits_warning(self):
with assertions.expect_warnings(
"File %s loaded twice! ignoring. "
- "Please ensure version_locations is unique" % (
- os.path.realpath(os.path.join(
- _get_staging_directory(),
- "model1",
- "%s_model1.py" % self.model1
- )))
+ "Please ensure version_locations is unique"
+ % (
+ os.path.realpath(
+ os.path.join(
+ _get_staging_directory(),
+ "model1",
+ "%s_model1.py" % self.model1,
+ )
+ )
+ )
):
script = ScriptDirectory.from_config(self.cfg)
script.revision_map.heads
eq_(
[rev.revision for rev in script.walk_revisions()],
- [self.model1, self.model2, self.model3]
+ [self.model1, self.model2, self.model3],
)
diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py
index 75972d4..6718be9 100644
--- a/tests/test_sqlite.py
+++ b/tests/test_sqlite.py
@@ -7,34 +7,29 @@ from alembic.testing.fixtures import TestBase
class SQLiteTest(TestBase):
-
def test_add_column(self):
- context = op_fixture('sqlite')
- op.add_column('t1', Column('c1', Integer))
- context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 INTEGER'
- )
+ context = op_fixture("sqlite")
+ op.add_column("t1", Column("c1", Integer))
+ context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER")
def test_add_column_implicit_constraint(self):
- context = op_fixture('sqlite')
- op.add_column('t1', Column('c1', Boolean))
- context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN'
- )
+ context = op_fixture("sqlite")
+ op.add_column("t1", Column("c1", Boolean))
+ context.assert_("ALTER TABLE t1 ADD COLUMN c1 BOOLEAN")
def test_add_explicit_constraint(self):
- op_fixture('sqlite')
+ op_fixture("sqlite")
assert_raises_message(
NotImplementedError,
"No support for ALTER of constraints in SQLite dialect",
op.create_check_constraint,
"foo",
"sometable",
- column('name') > 5
+ column("name") > 5,
)
def test_drop_explicit_constraint(self):
- op_fixture('sqlite')
+ op_fixture("sqlite")
assert_raises_message(
NotImplementedError,
"No support for ALTER of constraints in SQLite dialect",
diff --git a/tests/test_version_table.py b/tests/test_version_table.py
index 29530c0..0a545cf 100644
--- a/tests/test_version_table.py
+++ b/tests/test_version_table.py
@@ -8,24 +8,22 @@ from alembic import migration
from alembic.util import CommandError
-version_table = Table('version_table', MetaData(),
- Column('version_num', String(32), nullable=False))
+version_table = Table(
+ "version_table",
+ MetaData(),
+ Column("version_num", String(32), nullable=False),
+)
def _up(from_, to_, branch_presence_changed=False):
- return migration.StampStep(
- from_, to_, True, branch_presence_changed
- )
+ return migration.StampStep(from_, to_, True, branch_presence_changed)
def _down(from_, to_, branch_presence_changed=False):
- return migration.StampStep(
- from_, to_, False, branch_presence_changed
- )
+ return migration.StampStep(from_, to_, False, branch_presence_changed)
class TestMigrationContext(TestBase):
-
@classmethod
def setup_class(cls):
cls.bind = config.db
@@ -48,112 +46,126 @@ class TestMigrationContext(TestBase):
if len(rows) == 0:
return None
eq_(len(rows), 1)
- return rows[0]['version_num']
+ return rows[0]["version_num"]
def test_config_default_version_table_name(self):
- context = self.make_one(dialect_name='sqlite')
- eq_(context._version.name, 'alembic_version')
+ context = self.make_one(dialect_name="sqlite")
+ eq_(context._version.name, "alembic_version")
def test_config_explicit_version_table_name(self):
- context = self.make_one(dialect_name='sqlite',
- opts={'version_table': 'explicit'})
- eq_(context._version.name, 'explicit')
- eq_(context._version.primary_key.name, 'explicit_pkc')
+ context = self.make_one(
+ dialect_name="sqlite", opts={"version_table": "explicit"}
+ )
+ eq_(context._version.name, "explicit")
+ eq_(context._version.primary_key.name, "explicit_pkc")
def test_config_explicit_version_table_schema(self):
- context = self.make_one(dialect_name='sqlite',
- opts={'version_table_schema': 'explicit'})
- eq_(context._version.schema, 'explicit')
+ context = self.make_one(
+ dialect_name="sqlite", opts={"version_table_schema": "explicit"}
+ )
+ eq_(context._version.schema, "explicit")
def test_config_explicit_no_pk(self):
- context = self.make_one(dialect_name='sqlite',
- opts={'version_table_pk': False})
+ context = self.make_one(
+ dialect_name="sqlite", opts={"version_table_pk": False}
+ )
eq_(len(context._version.primary_key), 0)
def test_config_explicit_w_pk(self):
- context = self.make_one(dialect_name='sqlite',
- opts={'version_table_pk': True})
+ context = self.make_one(
+ dialect_name="sqlite", opts={"version_table_pk": True}
+ )
eq_(len(context._version.primary_key), 1)
eq_(context._version.primary_key.name, "alembic_version_pkc")
def test_get_current_revision_doesnt_create_version_table(self):
- context = self.make_one(connection=self.connection,
- opts={'version_table': 'version_table'})
+ context = self.make_one(
+ connection=self.connection, opts={"version_table": "version_table"}
+ )
eq_(context.get_current_revision(), None)
insp = Inspector(self.connection)
- assert ('version_table' not in insp.get_table_names())
+ assert "version_table" not in insp.get_table_names()
def test_get_current_revision(self):
- context = self.make_one(connection=self.connection,
- opts={'version_table': 'version_table'})
+ context = self.make_one(
+ connection=self.connection, opts={"version_table": "version_table"}
+ )
version_table.create(self.connection)
eq_(context.get_current_revision(), None)
self.connection.execute(
- version_table.insert().values(version_num='revid'))
- eq_(context.get_current_revision(), 'revid')
+ version_table.insert().values(version_num="revid")
+ )
+ eq_(context.get_current_revision(), "revid")
def test_get_current_revision_error_if_starting_rev_given_online(self):
- context = self.make_one(connection=self.connection,
- opts={'starting_rev': 'boo'})
- assert_raises(
- CommandError,
- context.get_current_revision
+ context = self.make_one(
+ connection=self.connection, opts={"starting_rev": "boo"}
)
+ assert_raises(CommandError, context.get_current_revision)
def test_get_current_revision_offline(self):
- context = self.make_one(dialect_name='sqlite',
- opts={'starting_rev': 'startrev',
- 'as_sql': True})
- eq_(context.get_current_revision(), 'startrev')
+ context = self.make_one(
+ dialect_name="sqlite",
+ opts={"starting_rev": "startrev", "as_sql": True},
+ )
+ eq_(context.get_current_revision(), "startrev")
def test_get_current_revision_multiple_heads(self):
version_table.create(self.connection)
- context = self.make_one(connection=self.connection,
- opts={'version_table': 'version_table'})
+ context = self.make_one(
+ connection=self.connection, opts={"version_table": "version_table"}
+ )
updater = migration.HeadMaintainer(context, ())
- updater.update_to_step(_up(None, 'a', True))
- updater.update_to_step(_up(None, 'b', True))
+ updater.update_to_step(_up(None, "a", True))
+ updater.update_to_step(_up(None, "b", True))
assert_raises_message(
CommandError,
"Version table 'version_table' has more than one head present; "
"please use get_current_heads()",
- context.get_current_revision
+ context.get_current_revision,
)
def test_get_heads(self):
version_table.create(self.connection)
- context = self.make_one(connection=self.connection,
- opts={'version_table': 'version_table'})
+ context = self.make_one(
+ connection=self.connection, opts={"version_table": "version_table"}
+ )
updater = migration.HeadMaintainer(context, ())
- updater.update_to_step(_up(None, 'a', True))
- updater.update_to_step(_up(None, 'b', True))
- eq_(context.get_current_heads(), ('a', 'b'))
+ updater.update_to_step(_up(None, "a", True))
+ updater.update_to_step(_up(None, "b", True))
+ eq_(context.get_current_heads(), ("a", "b"))
def test_get_heads_offline(self):
version_table.create(self.connection)
- context = self.make_one(connection=self.connection,
- opts={
- 'starting_rev': 'q',
- 'version_table': 'version_table',
- 'as_sql': True})
- eq_(context.get_current_heads(), ('q', ))
+ context = self.make_one(
+ connection=self.connection,
+ opts={
+ "starting_rev": "q",
+ "version_table": "version_table",
+ "as_sql": True,
+ },
+ )
+ eq_(context.get_current_heads(), ("q",))
def test_stamp_api_creates_table(self):
context = self.make_one(connection=self.connection)
assert (
- 'alembic_version'
- not in Inspector(self.connection).get_table_names())
+ "alembic_version"
+ not in Inspector(self.connection).get_table_names()
+ )
- script = mock.Mock(_stamp_revs=lambda revision, heads: [
- _up(None, 'a', True),
- _up(None, 'b', True)
- ])
+ script = mock.Mock(
+ _stamp_revs=lambda revision, heads: [
+ _up(None, "a", True),
+ _up(None, "b", True),
+ ]
+ )
- context.stamp(script, 'b')
- eq_(context.get_current_heads(), ('a', 'b'))
+ context.stamp(script, "b")
+ eq_(context.get_current_heads(), ("a", "b"))
assert (
- 'alembic_version'
- in Inspector(self.connection).get_table_names())
+ "alembic_version" in Inspector(self.connection).get_table_names()
+ )
class UpdateRevTest(TestBase):
@@ -166,8 +178,8 @@ class UpdateRevTest(TestBase):
def setUp(self):
self.connection = self.bind.connect()
self.context = migration.MigrationContext.configure(
- connection=self.connection,
- opts={"version_table": "version_table"})
+ connection=self.connection, opts={"version_table": "version_table"}
+ )
version_table.create(self.connection)
self.updater = migration.HeadMaintainer(self.context, ())
@@ -180,105 +192,108 @@ class UpdateRevTest(TestBase):
eq_(self.updater.heads, set(heads))
def test_update_none_to_single(self):
- self.updater.update_to_step(_up(None, 'a', True))
- self._assert_heads(('a',))
+ self.updater.update_to_step(_up(None, "a", True))
+ self._assert_heads(("a",))
def test_update_single_to_single(self):
- self.updater.update_to_step(_up(None, 'a', True))
- self.updater.update_to_step(_up('a', 'b'))
- self._assert_heads(('b',))
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_up("a", "b"))
+ self._assert_heads(("b",))
def test_update_single_to_none(self):
- self.updater.update_to_step(_up(None, 'a', True))
- self.updater.update_to_step(_down('a', None, True))
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_down("a", None, True))
self._assert_heads(())
def test_add_branches(self):
- self.updater.update_to_step(_up(None, 'a', True))
- self.updater.update_to_step(_up('a', 'b'))
- self.updater.update_to_step(_up(None, 'c', True))
- self._assert_heads(('b', 'c'))
- self.updater.update_to_step(_up('c', 'd'))
- self.updater.update_to_step(_up('d', 'e1'))
- self.updater.update_to_step(_up('d', 'e2', True))
- self._assert_heads(('b', 'e1', 'e2'))
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_up("a", "b"))
+ self.updater.update_to_step(_up(None, "c", True))
+ self._assert_heads(("b", "c"))
+ self.updater.update_to_step(_up("c", "d"))
+ self.updater.update_to_step(_up("d", "e1"))
+ self.updater.update_to_step(_up("d", "e2", True))
+ self._assert_heads(("b", "e1", "e2"))
def test_teardown_branches(self):
- self.updater.update_to_step(_up(None, 'd1', True))
- self.updater.update_to_step(_up(None, 'd2', True))
- self._assert_heads(('d1', 'd2'))
+ self.updater.update_to_step(_up(None, "d1", True))
+ self.updater.update_to_step(_up(None, "d2", True))
+ self._assert_heads(("d1", "d2"))
- self.updater.update_to_step(_down('d1', 'c'))
- self._assert_heads(('c', 'd2'))
+ self.updater.update_to_step(_down("d1", "c"))
+ self._assert_heads(("c", "d2"))
- self.updater.update_to_step(_down('d2', 'c', True))
+ self.updater.update_to_step(_down("d2", "c", True))
- self._assert_heads(('c',))
- self.updater.update_to_step(_down('c', 'b'))
- self._assert_heads(('b',))
+ self._assert_heads(("c",))
+ self.updater.update_to_step(_down("c", "b"))
+ self._assert_heads(("b",))
def test_resolve_merges(self):
- self.updater.update_to_step(_up(None, 'a', True))
- self.updater.update_to_step(_up('a', 'b'))
- self.updater.update_to_step(_up('b', 'c1'))
- self.updater.update_to_step(_up('b', 'c2', True))
- self.updater.update_to_step(_up('c1', 'd1'))
- self.updater.update_to_step(_up('c2', 'd2'))
- self._assert_heads(('d1', 'd2'))
- self.updater.update_to_step(_up(('d1', 'd2'), 'e'))
- self._assert_heads(('e',))
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_up("a", "b"))
+ self.updater.update_to_step(_up("b", "c1"))
+ self.updater.update_to_step(_up("b", "c2", True))
+ self.updater.update_to_step(_up("c1", "d1"))
+ self.updater.update_to_step(_up("c2", "d2"))
+ self._assert_heads(("d1", "d2"))
+ self.updater.update_to_step(_up(("d1", "d2"), "e"))
+ self._assert_heads(("e",))
def test_unresolve_merges(self):
- self.updater.update_to_step(_up(None, 'e', True))
+ self.updater.update_to_step(_up(None, "e", True))
- self.updater.update_to_step(_down('e', ('d1', 'd2')))
- self._assert_heads(('d2', 'd1'))
+ self.updater.update_to_step(_down("e", ("d1", "d2")))
+ self._assert_heads(("d2", "d1"))
- self.updater.update_to_step(_down('d2', 'c2'))
- self._assert_heads(('c2', 'd1'))
+ self.updater.update_to_step(_down("d2", "c2"))
+ self._assert_heads(("c2", "d1"))
def test_update_no_match(self):
- self.updater.update_to_step(_up(None, 'a', True))
- self.updater.heads.add('x')
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.heads.add("x")
assert_raises_message(
CommandError,
"Online migration expected to match one row when updating "
"'x' to 'b' in 'version_table'; 0 found",
- self.updater.update_to_step, _up('x', 'b')
+ self.updater.update_to_step,
+ _up("x", "b"),
)
def test_update_multi_match(self):
- self.connection.execute(version_table.insert(), version_num='a')
- self.connection.execute(version_table.insert(), version_num='a')
+ self.connection.execute(version_table.insert(), version_num="a")
+ self.connection.execute(version_table.insert(), version_num="a")
- self.updater.heads.add('a')
+ self.updater.heads.add("a")
assert_raises_message(
CommandError,
"Online migration expected to match one row when updating "
"'a' to 'b' in 'version_table'; 2 found",
- self.updater.update_to_step, _up('a', 'b')
+ self.updater.update_to_step,
+ _up("a", "b"),
)
def test_delete_no_match(self):
- self.updater.update_to_step(_up(None, 'a', True))
+ self.updater.update_to_step(_up(None, "a", True))
- self.updater.heads.add('x')
+ self.updater.heads.add("x")
assert_raises_message(
CommandError,
"Online migration expected to match one row when "
"deleting 'x' in 'version_table'; 0 found",
- self.updater.update_to_step, _down('x', None, True)
+ self.updater.update_to_step,
+ _down("x", None, True),
)
def test_delete_multi_match(self):
- self.connection.execute(version_table.insert(), version_num='a')
- self.connection.execute(version_table.insert(), version_num='a')
+ self.connection.execute(version_table.insert(), version_num="a")
+ self.connection.execute(version_table.insert(), version_num="a")
- self.updater.heads.add('a')
+ self.updater.heads.add("a")
assert_raises_message(
CommandError,
"Online migration expected to match one row when "
"deleting 'a' in 'version_table'; 2 found",
- self.updater.update_to_step, _down('a', None, True)
+ self.updater.update_to_step,
+ _down("a", None, True),
)
-
diff --git a/tests/test_version_traversal.py b/tests/test_version_traversal.py
index f69a9bd..c32c0c8 100644
--- a/tests/test_version_traversal.py
+++ b/tests/test_version_traversal.py
@@ -7,20 +7,15 @@ from alembic.migration import MigrationStep, HeadMaintainer
class MigrationTest(TestBase):
-
def up_(self, rev):
- return MigrationStep.upgrade_from_script(
- self.env.revision_map, rev)
+ return MigrationStep.upgrade_from_script(self.env.revision_map, rev)
def down_(self, rev):
- return MigrationStep.downgrade_from_script(
- self.env.revision_map, rev)
+ return MigrationStep.downgrade_from_script(self.env.revision_map, rev)
def _assert_downgrade(self, destination, source, expected, expected_heads):
revs = self.env._downgrade_revs(destination, source)
- eq_(
- revs, expected
- )
+ eq_(revs, expected)
heads = set(util.to_tuple(source, default=()))
head = HeadMaintainer(mock.Mock(), heads)
for rev in revs:
@@ -29,9 +24,7 @@ class MigrationTest(TestBase):
def _assert_upgrade(self, destination, source, expected, expected_heads):
revs = self.env._upgrade_revs(destination, source)
- eq_(
- revs, expected
- )
+ eq_(revs, expected)
heads = set(util.to_tuple(source, default=()))
head = HeadMaintainer(mock.Mock(), heads)
for rev in revs:
@@ -40,15 +33,14 @@ class MigrationTest(TestBase):
class RevisionPathTest(MigrationTest):
-
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
- cls.a = env.generate_revision(util.rev_id(), '->a')
- cls.b = env.generate_revision(util.rev_id(), 'a->b')
- cls.c = env.generate_revision(util.rev_id(), 'b->c')
- cls.d = env.generate_revision(util.rev_id(), 'c->d')
- cls.e = env.generate_revision(util.rev_id(), 'd->e')
+ cls.a = env.generate_revision(util.rev_id(), "->a")
+ cls.b = env.generate_revision(util.rev_id(), "a->b")
+ cls.c = env.generate_revision(util.rev_id(), "b->c")
+ cls.d = env.generate_revision(util.rev_id(), "c->d")
+ cls.e = env.generate_revision(util.rev_id(), "d->e")
@classmethod
def teardown_class(cls):
@@ -56,58 +48,50 @@ class RevisionPathTest(MigrationTest):
def test_upgrade_path(self):
self._assert_upgrade(
- self.e.revision, self.c.revision,
- [
- self.up_(self.d),
- self.up_(self.e)
- ],
- set([self.e.revision])
+ self.e.revision,
+ self.c.revision,
+ [self.up_(self.d), self.up_(self.e)],
+ set([self.e.revision]),
)
self._assert_upgrade(
- self.c.revision, None,
- [
- self.up_(self.a),
- self.up_(self.b),
- self.up_(self.c),
- ],
- set([self.c.revision])
+ self.c.revision,
+ None,
+ [self.up_(self.a), self.up_(self.b), self.up_(self.c)],
+ set([self.c.revision]),
)
def test_relative_upgrade_path(self):
self._assert_upgrade(
- "+2", self.a.revision,
- [
- self.up_(self.b),
- self.up_(self.c),
- ],
- set([self.c.revision])
+ "+2",
+ self.a.revision,
+ [self.up_(self.b), self.up_(self.c)],
+ set([self.c.revision]),
)
self._assert_upgrade(
- "+1", self.a.revision,
- [
- self.up_(self.b)
- ],
- set([self.b.revision])
+ "+1", self.a.revision, [self.up_(self.b)], set([self.b.revision])
)
self._assert_upgrade(
- "+3", self.b.revision,
+ "+3",
+ self.b.revision,
[self.up_(self.c), self.up_(self.d), self.up_(self.e)],
- set([self.e.revision])
+ set([self.e.revision]),
)
self._assert_upgrade(
- "%s+2" % self.b.revision, self.a.revision,
+ "%s+2" % self.b.revision,
+ self.a.revision,
[self.up_(self.b), self.up_(self.c), self.up_(self.d)],
- set([self.d.revision])
+ set([self.d.revision]),
)
self._assert_upgrade(
- "%s-2" % self.d.revision, self.a.revision,
+ "%s-2" % self.d.revision,
+ self.a.revision,
[self.up_(self.b)],
- set([self.b.revision])
+ set([self.b.revision]),
)
def test_invalid_relative_upgrade_path(self):
@@ -115,53 +99,60 @@ class RevisionPathTest(MigrationTest):
assert_raises_message(
util.CommandError,
"Relative revision -2 didn't produce 2 migrations",
- self.env._upgrade_revs, "-2", self.b.revision
+ self.env._upgrade_revs,
+ "-2",
+ self.b.revision,
)
assert_raises_message(
util.CommandError,
r"Relative revision \+5 didn't produce 5 migrations",
- self.env._upgrade_revs, "+5", self.b.revision
+ self.env._upgrade_revs,
+ "+5",
+ self.b.revision,
)
def test_downgrade_path(self):
self._assert_downgrade(
- self.c.revision, self.e.revision,
+ self.c.revision,
+ self.e.revision,
[self.down_(self.e), self.down_(self.d)],
- set([self.c.revision])
+ set([self.c.revision]),
)
self._assert_downgrade(
- None, self.c.revision,
+ None,
+ self.c.revision,
[self.down_(self.c), self.down_(self.b), self.down_(self.a)],
- set()
+ set(),
)
def test_relative_downgrade_path(self):
self._assert_downgrade(
- "-1", self.c.revision,
- [self.down_(self.c)],
- set([self.b.revision])
+ "-1", self.c.revision, [self.down_(self.c)], set([self.b.revision])
)
self._assert_downgrade(
- "-3", self.e.revision,
+ "-3",
+ self.e.revision,
[self.down_(self.e), self.down_(self.d), self.down_(self.c)],
- set([self.b.revision])
+ set([self.b.revision]),
)
self._assert_downgrade(
- "%s+2" % self.a.revision, self.d.revision,
+ "%s+2" % self.a.revision,
+ self.d.revision,
[self.down_(self.d)],
- set([self.c.revision])
+ set([self.c.revision]),
)
self._assert_downgrade(
- "%s-2" % self.c.revision, self.d.revision,
+ "%s-2" % self.c.revision,
+ self.d.revision,
[self.down_(self.d), self.down_(self.c), self.down_(self.b)],
- set([self.a.revision])
+ set([self.a.revision]),
)
def test_invalid_relative_downgrade_path(self):
@@ -169,13 +160,17 @@ class RevisionPathTest(MigrationTest):
assert_raises_message(
util.CommandError,
"Relative revision -5 didn't produce 5 migrations",
- self.env._downgrade_revs, "-5", self.b.revision
+ self.env._downgrade_revs,
+ "-5",
+ self.b.revision,
)
assert_raises_message(
util.CommandError,
r"Relative revision \+2 didn't produce 2 migrations",
- self.env._downgrade_revs, "+2", self.b.revision
+ self.env._downgrade_revs,
+ "+2",
+ self.b.revision,
)
def test_invalid_move_rev_to_none(self):
@@ -184,7 +179,9 @@ class RevisionPathTest(MigrationTest):
util.CommandError,
r"Destination %s is not a valid downgrade "
r"target from current head\(s\)" % self.b.revision[0:3],
- self.env._downgrade_revs, self.b.revision[0:3], None
+ self.env._downgrade_revs,
+ self.b.revision[0:3],
+ None,
)
def test_invalid_move_higher_to_lower(self):
@@ -193,7 +190,9 @@ class RevisionPathTest(MigrationTest):
util.CommandError,
r"Destination %s is not a valid downgrade "
r"target from current head\(s\)" % self.c.revision[0:4],
- self.env._downgrade_revs, self.c.revision[0:4], self.b.revision
+ self.env._downgrade_revs,
+ self.c.revision[0:4],
+ self.b.revision,
)
def test_stamp_to_base(self):
@@ -204,26 +203,27 @@ class RevisionPathTest(MigrationTest):
class BranchedPathTest(MigrationTest):
-
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
- cls.a = env.generate_revision(util.rev_id(), '->a')
- cls.b = env.generate_revision(util.rev_id(), 'a->b')
+ cls.a = env.generate_revision(util.rev_id(), "->a")
+ cls.b = env.generate_revision(util.rev_id(), "a->b")
cls.c1 = env.generate_revision(
- util.rev_id(), 'b->c1',
- branch_labels='c1branch',
- refresh=True)
- cls.d1 = env.generate_revision(util.rev_id(), 'c1->d1')
+ util.rev_id(), "b->c1", branch_labels="c1branch", refresh=True
+ )
+ cls.d1 = env.generate_revision(util.rev_id(), "c1->d1")
cls.c2 = env.generate_revision(
- util.rev_id(), 'b->c2',
- branch_labels='c2branch',
- head=cls.b.revision, splice=True)
+ util.rev_id(),
+ "b->c2",
+ branch_labels="c2branch",
+ head=cls.b.revision,
+ splice=True,
+ )
cls.d2 = env.generate_revision(
- util.rev_id(), 'c2->d2',
- head=cls.c2.revision)
+ util.rev_id(), "c2->d2", head=cls.c2.revision
+ )
@classmethod
def teardown_class(cls):
@@ -231,73 +231,87 @@ class BranchedPathTest(MigrationTest):
def test_stamp_down_across_multiple_branch_to_branchpoint(self):
heads = [self.d1.revision, self.c2.revision]
- revs = self.env._stamp_revs(
- self.b.revision, heads)
+ revs = self.env._stamp_revs(self.b.revision, heads)
eq_(len(revs), 1)
eq_(
revs[0].merge_branch_idents(heads),
# DELETE d1 revision, UPDATE c2 to b
- ([self.d1.revision], self.c2.revision, self.b.revision)
+ ([self.d1.revision], self.c2.revision, self.b.revision),
)
def test_stamp_to_labeled_base_multiple_heads(self):
revs = self.env._stamp_revs(
- "c1branch@base", [self.d1.revision, self.c2.revision])
+ "c1branch@base", [self.d1.revision, self.c2.revision]
+ )
eq_(len(revs), 1)
assert revs[0].should_delete_branch
eq_(revs[0].delete_version_num, self.d1.revision)
def test_stamp_to_labeled_head_multiple_heads(self):
heads = [self.d1.revision, self.c2.revision]
- revs = self.env._stamp_revs(
- "c2branch@head", heads)
+ revs = self.env._stamp_revs("c2branch@head", heads)
eq_(len(revs), 1)
eq_(
revs[0].merge_branch_idents(heads),
# the c1branch remains unchanged
- ([], self.c2.revision, self.d2.revision)
+ ([], self.c2.revision, self.d2.revision),
)
def test_upgrade_single_branch(self):
self._assert_upgrade(
- self.d1.revision, self.b.revision,
+ self.d1.revision,
+ self.b.revision,
[self.up_(self.c1), self.up_(self.d1)],
- set([self.d1.revision])
+ set([self.d1.revision]),
)
def test_upgrade_multiple_branch(self):
# move from a single head to multiple heads
self._assert_upgrade(
- (self.d1.revision, self.d2.revision), self.a.revision,
- [self.up_(self.b), self.up_(self.c2), self.up_(self.d2),
- self.up_(self.c1), self.up_(self.d1)],
- set([self.d1.revision, self.d2.revision])
+ (self.d1.revision, self.d2.revision),
+ self.a.revision,
+ [
+ self.up_(self.b),
+ self.up_(self.c2),
+ self.up_(self.d2),
+ self.up_(self.c1),
+ self.up_(self.d1),
+ ],
+ set([self.d1.revision, self.d2.revision]),
)
def test_downgrade_multiple_branch(self):
self._assert_downgrade(
- self.a.revision, (self.d1.revision, self.d2.revision),
- [self.down_(self.d1), self.down_(self.c1), self.down_(self.d2),
- self.down_(self.c2), self.down_(self.b)],
- set([self.a.revision])
+ self.a.revision,
+ (self.d1.revision, self.d2.revision),
+ [
+ self.down_(self.d1),
+ self.down_(self.c1),
+ self.down_(self.d2),
+ self.down_(self.c2),
+ self.down_(self.b),
+ ],
+ set([self.a.revision]),
)
def test_relative_upgrade(self):
self._assert_upgrade(
- "c2branch@head-1", self.b.revision,
+ "c2branch@head-1",
+ self.b.revision,
[self.up_(self.c2)],
- set([self.c2.revision])
+ set([self.c2.revision]),
)
def test_relative_downgrade(self):
self._assert_downgrade(
- "c2branch@base+2", [self.d2.revision, self.d1.revision],
+ "c2branch@base+2",
+ [self.d2.revision, self.d1.revision],
[self.down_(self.d2), self.down_(self.c2), self.down_(self.d1)],
- set([self.c1.revision])
+ set([self.c1.revision]),
)
@@ -311,43 +325,54 @@ class BranchFromMergepointTest(MigrationTest):
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
- cls.a1 = env.generate_revision(util.rev_id(), '->a1')
- cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
- cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1')
+ cls.a1 = env.generate_revision(util.rev_id(), "->a1")
+ cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
+ cls.c1 = env.generate_revision(util.rev_id(), "b1->c1")
cls.a2 = env.generate_revision(
- util.rev_id(), '->a2', head=(),
- refresh=True)
+ util.rev_id(), "->a2", head=(), refresh=True
+ )
cls.b2 = env.generate_revision(
- util.rev_id(), 'a2->b2', head=cls.a2.revision)
+ util.rev_id(), "a2->b2", head=cls.a2.revision
+ )
cls.c2 = env.generate_revision(
- util.rev_id(), 'b2->c2', head=cls.b2.revision)
+ util.rev_id(), "b2->c2", head=cls.b2.revision
+ )
# mergepoint between c1, c2
# d1 dependent on c2
cls.d1 = env.generate_revision(
- util.rev_id(), 'd1', head=(cls.c1.revision, cls.c2.revision),
- refresh=True)
+ util.rev_id(),
+ "d1",
+ head=(cls.c1.revision, cls.c2.revision),
+ refresh=True,
+ )
# but then c2 keeps going into d2
cls.d2 = env.generate_revision(
- util.rev_id(), 'd2', head=cls.c2.revision,
- refresh=True, splice=True)
+ util.rev_id(),
+ "d2",
+ head=cls.c2.revision,
+ refresh=True,
+ splice=True,
+ )
def test_mergepoint_to_only_one_side_upgrade(self):
self._assert_upgrade(
- self.d1.revision, (self.d2.revision, self.b1.revision),
+ self.d1.revision,
+ (self.d2.revision, self.b1.revision),
[self.up_(self.c1), self.up_(self.d1)],
- set([self.d2.revision, self.d1.revision])
+ set([self.d2.revision, self.d1.revision]),
)
def test_mergepoint_to_only_one_side_downgrade(self):
self._assert_downgrade(
- self.b1.revision, (self.d2.revision, self.d1.revision),
+ self.b1.revision,
+ (self.d2.revision, self.d1.revision),
[self.down_(self.d1), self.down_(self.c1)],
- set([self.d2.revision, self.b1.revision])
+ set([self.d2.revision, self.b1.revision]),
)
@@ -361,42 +386,56 @@ class BranchFrom3WayMergepointTest(MigrationTest):
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
- cls.a1 = env.generate_revision(util.rev_id(), '->a1')
- cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
- cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1')
+ cls.a1 = env.generate_revision(util.rev_id(), "->a1")
+ cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
+ cls.c1 = env.generate_revision(util.rev_id(), "b1->c1")
cls.a2 = env.generate_revision(
- util.rev_id(), '->a2', head=(),
- refresh=True)
+ util.rev_id(), "->a2", head=(), refresh=True
+ )
cls.b2 = env.generate_revision(
- util.rev_id(), 'a2->b2', head=cls.a2.revision)
+ util.rev_id(), "a2->b2", head=cls.a2.revision
+ )
cls.c2 = env.generate_revision(
- util.rev_id(), 'b2->c2', head=cls.b2.revision)
+ util.rev_id(), "b2->c2", head=cls.b2.revision
+ )
cls.a3 = env.generate_revision(
- util.rev_id(), '->a3', head=(),
- refresh=True)
+ util.rev_id(), "->a3", head=(), refresh=True
+ )
cls.b3 = env.generate_revision(
- util.rev_id(), 'a3->b3', head=cls.a3.revision)
+ util.rev_id(), "a3->b3", head=cls.a3.revision
+ )
cls.c3 = env.generate_revision(
- util.rev_id(), 'b3->c3', head=cls.b3.revision)
+ util.rev_id(), "b3->c3", head=cls.b3.revision
+ )
# mergepoint between c1, c2, c3
# d1 dependent on c2, c3
cls.d1 = env.generate_revision(
- util.rev_id(), 'd1', head=(
- cls.c1.revision, cls.c2.revision, cls.c3.revision),
- refresh=True)
+ util.rev_id(),
+ "d1",
+ head=(cls.c1.revision, cls.c2.revision, cls.c3.revision),
+ refresh=True,
+ )
# but then c2 keeps going into d2
cls.d2 = env.generate_revision(
- util.rev_id(), 'd2', head=cls.c2.revision,
- refresh=True, splice=True)
+ util.rev_id(),
+ "d2",
+ head=cls.c2.revision,
+ refresh=True,
+ splice=True,
+ )
# c3 keeps going into d3
cls.d3 = env.generate_revision(
- util.rev_id(), 'd3', head=cls.c3.revision,
- refresh=True, splice=True)
+ util.rev_id(),
+ "d3",
+ head=cls.c3.revision,
+ refresh=True,
+ splice=True,
+ )
def test_mergepoint_to_only_one_side_upgrade(self):
@@ -404,7 +443,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
self.d1.revision,
(self.d3.revision, self.d2.revision, self.b1.revision),
[self.up_(self.c1), self.up_(self.d1)],
- set([self.d3.revision, self.d2.revision, self.d1.revision])
+ set([self.d3.revision, self.d2.revision, self.d1.revision]),
)
def test_mergepoint_to_only_one_side_downgrade(self):
@@ -412,7 +451,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
self.b1.revision,
(self.d3.revision, self.d2.revision, self.d1.revision),
[self.down_(self.d1), self.down_(self.c1)],
- set([self.d3.revision, self.d2.revision, self.b1.revision])
+ set([self.d3.revision, self.d2.revision, self.b1.revision]),
)
def test_mergepoint_to_two_sides_upgrade(self):
@@ -422,14 +461,15 @@ class BranchFrom3WayMergepointTest(MigrationTest):
(self.d3.revision, self.b2.revision, self.b1.revision),
[self.up_(self.c2), self.up_(self.c1), self.up_(self.d1)],
# this will merge b2 and b1 into d1
- set([self.d3.revision, self.d1.revision])
+ set([self.d3.revision, self.d1.revision]),
)
# but then! b2 will break out again if we keep going with it
self._assert_upgrade(
- self.d2.revision, (self.d3.revision, self.d1.revision),
+ self.d2.revision,
+ (self.d3.revision, self.d1.revision),
[self.up_(self.d2)],
- set([self.d3.revision, self.d2.revision, self.d1.revision])
+ set([self.d3.revision, self.d2.revision, self.d1.revision]),
)
@@ -438,6 +478,7 @@ class TwinMergeTest(MigrationTest):
originating branches.
"""
+
@classmethod
def setup_class(cls):
"""
@@ -463,44 +504,43 @@ class TwinMergeTest(MigrationTest):
"""
cls.env = env = staging_env()
- cls.a = env.generate_revision(
- 'a', 'a'
+ cls.a = env.generate_revision("a", "a")
+ cls.b1 = env.generate_revision("b1", "b1", head=cls.a.revision)
+ cls.b2 = env.generate_revision(
+ "b2", "b2", splice=True, head=cls.a.revision
+ )
+ cls.b3 = env.generate_revision(
+ "b3", "b3", splice=True, head=cls.a.revision
)
- cls.b1 = env.generate_revision('b1', 'b1',
- head=cls.a.revision)
- cls.b2 = env.generate_revision('b2', 'b2',
- splice=True,
- head=cls.a.revision)
- cls.b3 = env.generate_revision('b3', 'b3',
- splice=True,
- head=cls.a.revision)
cls.c1 = env.generate_revision(
- 'c1', 'c1',
- head=(cls.b1.revision, cls.b2.revision, cls.b3.revision))
+ "c1",
+ "c1",
+ head=(cls.b1.revision, cls.b2.revision, cls.b3.revision),
+ )
cls.c2 = env.generate_revision(
- 'c2', 'c2',
+ "c2",
+ "c2",
splice=True,
- head=(cls.b1.revision, cls.b2.revision, cls.b3.revision))
+ head=(cls.b1.revision, cls.b2.revision, cls.b3.revision),
+ )
- cls.d1 = env.generate_revision(
- 'd1', 'd1', head=cls.c1.revision)
+ cls.d1 = env.generate_revision("d1", "d1", head=cls.c1.revision)
- cls.d2 = env.generate_revision(
- 'd2', 'd2', head=cls.c2.revision)
+ cls.d2 = env.generate_revision("d2", "d2", head=cls.c2.revision)
def test_upgrade(self):
head = HeadMaintainer(mock.Mock(), [self.a.revision])
steps = [
- (self.up_(self.b3), ('b3',)),
- (self.up_(self.b1), ('b1', 'b3',)),
- (self.up_(self.b2), ('b1', 'b2', 'b3',)),
- (self.up_(self.c2), ('c2',)),
- (self.up_(self.d2), ('d2',)),
- (self.up_(self.c1), ('c1', 'd2')),
- (self.up_(self.d1), ('d1', 'd2')),
+ (self.up_(self.b3), ("b3",)),
+ (self.up_(self.b1), ("b1", "b3")),
+ (self.up_(self.b2), ("b1", "b2", "b3")),
+ (self.up_(self.c2), ("c2",)),
+ (self.up_(self.d2), ("d2",)),
+ (self.up_(self.c1), ("c1", "d2")),
+ (self.up_(self.d1), ("d1", "d2")),
]
for step, assert_ in steps:
head.update_to_step(step)
@@ -511,6 +551,7 @@ class NotQuiteTwinMergeTest(MigrationTest):
"""Test a variant of #297.
"""
+
@classmethod
def setup_class(cls):
"""
@@ -527,32 +568,26 @@ class NotQuiteTwinMergeTest(MigrationTest):
"""
cls.env = env = staging_env()
- cls.a = env.generate_revision(
- 'a', 'a'
+ cls.a = env.generate_revision("a", "a")
+ cls.b1 = env.generate_revision("b1", "b1", head=cls.a.revision)
+ cls.b2 = env.generate_revision(
+ "b2", "b2", splice=True, head=cls.a.revision
+ )
+ cls.b3 = env.generate_revision(
+ "b3", "b3", splice=True, head=cls.a.revision
)
- cls.b1 = env.generate_revision('b1', 'b1',
- head=cls.a.revision)
- cls.b2 = env.generate_revision('b2', 'b2',
- splice=True,
- head=cls.a.revision)
- cls.b3 = env.generate_revision('b3', 'b3',
- splice=True,
- head=cls.a.revision)
cls.c1 = env.generate_revision(
- 'c1', 'c1',
- head=(cls.b1.revision, cls.b2.revision))
+ "c1", "c1", head=(cls.b1.revision, cls.b2.revision)
+ )
cls.c2 = env.generate_revision(
- 'c2', 'c2',
- splice=True,
- head=(cls.b2.revision, cls.b3.revision))
+ "c2", "c2", splice=True, head=(cls.b2.revision, cls.b3.revision)
+ )
- cls.d1 = env.generate_revision(
- 'd1', 'd1', head=cls.c1.revision)
+ cls.d1 = env.generate_revision("d1", "d1", head=cls.c1.revision)
- cls.d2 = env.generate_revision(
- 'd2', 'd2', head=cls.c2.revision)
+ cls.d2 = env.generate_revision("d2", "d2", head=cls.c2.revision)
def test_upgrade(self):
head = HeadMaintainer(mock.Mock(), [self.a.revision])
@@ -568,14 +603,13 @@ class NotQuiteTwinMergeTest(MigrationTest):
"""
steps = [
- (self.up_(self.b2), ('b2',)),
- (self.up_(self.b3), ('b2', 'b3',)),
- (self.up_(self.c2), ('c2',)),
- (self.up_(self.d2), ('d2',)),
-
- (self.up_(self.b1), ('b1', 'd2',)),
- (self.up_(self.c1), ('c1', 'd2')),
- (self.up_(self.d1), ('d1', 'd2')),
+ (self.up_(self.b2), ("b2",)),
+ (self.up_(self.b3), ("b2", "b3")),
+ (self.up_(self.c2), ("c2",)),
+ (self.up_(self.d2), ("d2",)),
+ (self.up_(self.b1), ("b1", "d2")),
+ (self.up_(self.c1), ("c1", "d2")),
+ (self.up_(self.d1), ("d1", "d2")),
]
for step, assert_ in steps:
head.update_to_step(step)
@@ -583,32 +617,35 @@ class NotQuiteTwinMergeTest(MigrationTest):
class DependsOnBranchTestOne(MigrationTest):
-
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
cls.a1 = env.generate_revision(
- util.rev_id(), '->a1',
- branch_labels=['lib1'])
- cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
- cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1')
+ util.rev_id(), "->a1", branch_labels=["lib1"]
+ )
+ cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
+ cls.c1 = env.generate_revision(util.rev_id(), "b1->c1")
- cls.a2 = env.generate_revision(util.rev_id(), '->a2', head=())
+ cls.a2 = env.generate_revision(util.rev_id(), "->a2", head=())
cls.b2 = env.generate_revision(
- util.rev_id(), 'a2->b2', head=cls.a2.revision)
+ util.rev_id(), "a2->b2", head=cls.a2.revision
+ )
cls.c2 = env.generate_revision(
- util.rev_id(), 'b2->c2', head=cls.b2.revision,
- depends_on=cls.c1.revision)
+ util.rev_id(),
+ "b2->c2",
+ head=cls.b2.revision,
+ depends_on=cls.c1.revision,
+ )
cls.d1 = env.generate_revision(
- util.rev_id(), 'c1->d1',
- head=cls.c1.revision)
+ util.rev_id(), "c1->d1", head=cls.c1.revision
+ )
cls.e1 = env.generate_revision(
- util.rev_id(), 'd1->e1',
- head=cls.d1.revision)
+ util.rev_id(), "d1->e1", head=cls.d1.revision
+ )
cls.f1 = env.generate_revision(
- util.rev_id(), 'e1->f1',
- head=cls.e1.revision)
+ util.rev_id(), "e1->f1", head=cls.e1.revision
+ )
def test_downgrade_to_dependency(self):
heads = [self.c2.revision, self.d1.revision]
@@ -625,7 +662,6 @@ class DependsOnBranchTestOne(MigrationTest):
class DependsOnBranchTestTwo(MigrationTest):
-
@classmethod
def setup_class(cls):
"""
@@ -656,32 +692,36 @@ class DependsOnBranchTestTwo(MigrationTest):
"""
cls.env = env = staging_env()
- cls.a1 = env.generate_revision("a1", '->a1', head='base')
- cls.a2 = env.generate_revision("a2", '->a2', head='base')
- cls.a3 = env.generate_revision("a3", '->a3', head='base')
- cls.amerge = env.generate_revision("amerge", 'amerge', head=[
- cls.a1.revision, cls.a2.revision, cls.a3.revision
- ])
-
- cls.b1 = env.generate_revision("b1", '->b1', head='base')
- cls.b2 = env.generate_revision("b2", '->b2', head='base')
- cls.bmerge = env.generate_revision("bmerge", 'bmerge', head=[
- cls.b1.revision, cls.b2.revision
- ])
-
- cls.c1 = env.generate_revision("c1", '->c1', head='base')
- cls.c2 = env.generate_revision("c2", '->c2', head='base')
- cls.c3 = env.generate_revision("c3", '->c3', head='base')
- cls.cmerge = env.generate_revision("cmerge", 'cmerge', head=[
- cls.c1.revision, cls.c2.revision, cls.c3.revision
- ])
+ cls.a1 = env.generate_revision("a1", "->a1", head="base")
+ cls.a2 = env.generate_revision("a2", "->a2", head="base")
+ cls.a3 = env.generate_revision("a3", "->a3", head="base")
+ cls.amerge = env.generate_revision(
+ "amerge",
+ "amerge",
+ head=[cls.a1.revision, cls.a2.revision, cls.a3.revision],
+ )
+
+ cls.b1 = env.generate_revision("b1", "->b1", head="base")
+ cls.b2 = env.generate_revision("b2", "->b2", head="base")
+ cls.bmerge = env.generate_revision(
+ "bmerge", "bmerge", head=[cls.b1.revision, cls.b2.revision]
+ )
+
+ cls.c1 = env.generate_revision("c1", "->c1", head="base")
+ cls.c2 = env.generate_revision("c2", "->c2", head="base")
+ cls.c3 = env.generate_revision("c3", "->c3", head="base")
+ cls.cmerge = env.generate_revision(
+ "cmerge",
+ "cmerge",
+ head=[cls.c1.revision, cls.c2.revision, cls.c3.revision],
+ )
cls.d1 = env.generate_revision(
- "d1", 'o',
+ "d1",
+ "o",
head="base",
- depends_on=[
- cls.a3.revision, cls.b2.revision, cls.c1.revision
- ])
+ depends_on=[cls.a3.revision, cls.b2.revision, cls.c1.revision],
+ )
def test_kaboom(self):
# here's the upgrade path:
@@ -690,55 +730,77 @@ class DependsOnBranchTestTwo(MigrationTest):
heads = [
self.amerge.revision,
- self.bmerge.revision, self.cmerge.revision,
- self.d1.revision
+ self.bmerge.revision,
+ self.cmerge.revision,
+ self.d1.revision,
]
self._assert_downgrade(
- self.b2.revision, heads,
+ self.b2.revision,
+ heads,
[self.down_(self.bmerge)],
- set([
- self.amerge.revision,
- self.b1.revision, self.cmerge.revision, self.d1.revision])
+ set(
+ [
+ self.amerge.revision,
+ self.b1.revision,
+ self.cmerge.revision,
+ self.d1.revision,
+ ]
+ ),
)
# start with those heads..
heads = [
- self.amerge.revision, self.d1.revision,
- self.b1.revision, self.cmerge.revision]
+ self.amerge.revision,
+ self.d1.revision,
+ self.b1.revision,
+ self.cmerge.revision,
+ ]
# downgrade d1...
self._assert_downgrade(
- "d1@base", heads,
+ "d1@base",
+ heads,
[self.down_(self.d1)],
-
# b2 has to be INSERTed, because it was implied by d1
- set([
- self.amerge.revision, self.b1.revision,
- self.b2.revision, self.cmerge.revision])
+ set(
+ [
+ self.amerge.revision,
+ self.b1.revision,
+ self.b2.revision,
+ self.cmerge.revision,
+ ]
+ ),
)
# start with those heads ...
heads = [
- self.amerge.revision, self.b1.revision,
- self.b2.revision, self.cmerge.revision
+ self.amerge.revision,
+ self.b1.revision,
+ self.b2.revision,
+ self.cmerge.revision,
]
self._assert_downgrade(
- "base", heads,
+ "base",
+ heads,
[
- self.down_(self.amerge), self.down_(self.a1),
- self.down_(self.a2), self.down_(self.a3),
- self.down_(self.b1), self.down_(self.b2),
- self.down_(self.cmerge), self.down_(self.c1),
- self.down_(self.c2), self.down_(self.c3)
+ self.down_(self.amerge),
+ self.down_(self.a1),
+ self.down_(self.a2),
+ self.down_(self.a3),
+ self.down_(self.b1),
+ self.down_(self.b2),
+ self.down_(self.cmerge),
+ self.down_(self.c1),
+ self.down_(self.c2),
+ self.down_(self.c3),
],
- set([])
+ set([]),
)
class DependsOnBranchTestThree(MigrationTest):
-
@classmethod
def setup_class(cls):
"""
@@ -755,14 +817,18 @@ class DependsOnBranchTestThree(MigrationTest):
"""
cls.env = env = staging_env()
- cls.a1 = env.generate_revision("a1", '->a1', head='base')
- cls.a2 = env.generate_revision("a2", '->a2')
+ cls.a1 = env.generate_revision("a1", "->a1", head="base")
+ cls.a2 = env.generate_revision("a2", "->a2")
- cls.b1 = env.generate_revision("b1", '->b1', head='base')
- cls.b2 = env.generate_revision("b2", '->b2', depends_on='a2', head='b1')
- cls.b3 = env.generate_revision("b3", '->b3', head='b2')
+ cls.b1 = env.generate_revision("b1", "->b1", head="base")
+ cls.b2 = env.generate_revision(
+ "b2", "->b2", depends_on="a2", head="b1"
+ )
+ cls.b3 = env.generate_revision("b3", "->b3", head="b2")
- cls.a3 = env.generate_revision("a3", '->a3', head='a2', depends_on='b1')
+ cls.a3 = env.generate_revision(
+ "a3", "->a3", head="a2", depends_on="b1"
+ )
def test_downgrade_over_crisscross(self):
# this state was not possible prior to
@@ -772,9 +838,10 @@ class DependsOnBranchTestThree(MigrationTest):
# b2 because a2 is dependent on it, hence we add the ability
# to remove half of a merge point.
self._assert_downgrade(
- 'b1', ['a3', 'b2'],
+ "b1",
+ ["a3", "b2"],
[self.down_(self.b2)],
- set(['a3']) # we have b1 also, which is implied by a3
+ set(["a3"]), # we have b1 also, which is implied by a3
)
@@ -783,33 +850,35 @@ class DependsOnBranchLabelTest(MigrationTest):
def setup_class(cls):
cls.env = env = staging_env()
cls.a1 = env.generate_revision(
- util.rev_id(), '->a1',
- branch_labels=['lib1'])
- cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
+ util.rev_id(), "->a1", branch_labels=["lib1"]
+ )
+ cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
cls.c1 = env.generate_revision(
- util.rev_id(), 'b1->c1',
- branch_labels=['c1lib'])
+ util.rev_id(), "b1->c1", branch_labels=["c1lib"]
+ )
- cls.a2 = env.generate_revision(util.rev_id(), '->a2', head=())
+ cls.a2 = env.generate_revision(util.rev_id(), "->a2", head=())
cls.b2 = env.generate_revision(
- util.rev_id(), 'a2->b2', head=cls.a2.revision)
+ util.rev_id(), "a2->b2", head=cls.a2.revision
+ )
cls.c2 = env.generate_revision(
- util.rev_id(), 'b2->c2', head=cls.b2.revision,
- depends_on=['c1lib'])
+ util.rev_id(), "b2->c2", head=cls.b2.revision, depends_on=["c1lib"]
+ )
cls.d1 = env.generate_revision(
- util.rev_id(), 'c1->d1',
- head=cls.c1.revision)
+ util.rev_id(), "c1->d1", head=cls.c1.revision
+ )
cls.e1 = env.generate_revision(
- util.rev_id(), 'd1->e1',
- head=cls.d1.revision)
+ util.rev_id(), "d1->e1", head=cls.d1.revision
+ )
cls.f1 = env.generate_revision(
- util.rev_id(), 'e1->f1',
- head=cls.e1.revision)
+ util.rev_id(), "e1->f1", head=cls.e1.revision
+ )
def test_upgrade_path(self):
self._assert_upgrade(
- self.c2.revision, self.a2.revision,
+ self.c2.revision,
+ self.a2.revision,
[
self.up_(self.a1),
self.up_(self.b1),
@@ -817,23 +886,23 @@ class DependsOnBranchLabelTest(MigrationTest):
self.up_(self.b2),
self.up_(self.c2),
],
- set([self.c2.revision])
+ set([self.c2.revision]),
)
class ForestTest(MigrationTest):
-
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
- cls.a1 = env.generate_revision(util.rev_id(), '->a1')
- cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
+ cls.a1 = env.generate_revision(util.rev_id(), "->a1")
+ cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
cls.a2 = env.generate_revision(
- util.rev_id(), '->a2', head=(),
- refresh=True)
+ util.rev_id(), "->a2", head=(), refresh=True
+ )
cls.b2 = env.generate_revision(
- util.rev_id(), 'a2->b2', head=cls.a2.revision)
+ util.rev_id(), "a2->b2", head=cls.a2.revision
+ )
@classmethod
def teardown_class(cls):
@@ -842,8 +911,12 @@ class ForestTest(MigrationTest):
def test_base_to_heads(self):
eq_(
self.env._upgrade_revs("heads", "base"),
- [self.up_(self.a2), self.up_(self.b2),
- self.up_(self.a1), self.up_(self.b1)]
+ [
+ self.up_(self.a2),
+ self.up_(self.b2),
+ self.up_(self.a1),
+ self.up_(self.b1),
+ ],
)
def test_stamp_to_heads(self):
@@ -851,40 +924,44 @@ class ForestTest(MigrationTest):
eq_(len(revs), 2)
eq_(
set(r.to_revisions for r in revs),
- set([(self.b1.revision,), (self.b2.revision,)])
+ set([(self.b1.revision,), (self.b2.revision,)]),
)
def test_stamp_to_heads_no_moves_needed(self):
revs = self.env._stamp_revs(
- "heads", (self.b1.revision, self.b2.revision))
+ "heads", (self.b1.revision, self.b2.revision)
+ )
eq_(len(revs), 0)
class MergedPathTest(MigrationTest):
-
@classmethod
def setup_class(cls):
cls.env = env = staging_env()
- cls.a = env.generate_revision(util.rev_id(), '->a')
- cls.b = env.generate_revision(util.rev_id(), 'a->b')
+ cls.a = env.generate_revision(util.rev_id(), "->a")
+ cls.b = env.generate_revision(util.rev_id(), "a->b")
- cls.c1 = env.generate_revision(util.rev_id(), 'b->c1')
- cls.d1 = env.generate_revision(util.rev_id(), 'c1->d1')
+ cls.c1 = env.generate_revision(util.rev_id(), "b->c1")
+ cls.d1 = env.generate_revision(util.rev_id(), "c1->d1")
cls.c2 = env.generate_revision(
- util.rev_id(), 'b->c2',
- branch_labels='c2branch',
- head=cls.b.revision, splice=True)
+ util.rev_id(),
+ "b->c2",
+ branch_labels="c2branch",
+ head=cls.b.revision,
+ splice=True,
+ )
cls.d2 = env.generate_revision(
- util.rev_id(), 'c2->d2',
- head=cls.c2.revision)
+ util.rev_id(), "c2->d2", head=cls.c2.revision
+ )
cls.e = env.generate_revision(
- util.rev_id(), 'merge d1 and d2',
- head=(cls.d1.revision, cls.d2.revision)
+ util.rev_id(),
+ "merge d1 and d2",
+ head=(cls.d1.revision, cls.d2.revision),
)
- cls.f = env.generate_revision(util.rev_id(), 'e->f')
+ cls.f = env.generate_revision(util.rev_id(), "e->f")
@classmethod
def teardown_class(cls):
@@ -897,7 +974,7 @@ class MergedPathTest(MigrationTest):
eq_(
revs[0].merge_branch_idents(heads),
# no deletes, UPDATE e to c2
- ([], self.e.revision, self.c2.revision)
+ ([], self.e.revision, self.c2.revision),
)
def test_stamp_down_across_merge_prior_branching(self):
@@ -907,7 +984,7 @@ class MergedPathTest(MigrationTest):
eq_(
revs[0].merge_branch_idents(heads),
# no deletes, UPDATE e to c2
- ([], self.e.revision, self.a.revision)
+ ([], self.e.revision, self.a.revision),
)
def test_stamp_up_across_merge_from_single_branch(self):
@@ -916,7 +993,7 @@ class MergedPathTest(MigrationTest):
eq_(
revs[0].merge_branch_idents([self.c2.revision]),
# no deletes, UPDATE e to c2
- ([], self.c2.revision, self.e.revision)
+ ([], self.c2.revision, self.e.revision),
)
def test_stamp_labled_head_across_merge_from_multiple_branch(self):
@@ -924,23 +1001,23 @@ class MergedPathTest(MigrationTest):
# d1 both in terms of "c2branch" as well as that the "head"
# revision "f" is the head of both d1 and d2
revs = self.env._stamp_revs(
- "c2branch@head", [self.d1.revision, self.c2.revision])
+ "c2branch@head", [self.d1.revision, self.c2.revision]
+ )
eq_(len(revs), 1)
eq_(
revs[0].merge_branch_idents([self.d1.revision, self.c2.revision]),
# DELETE d1 revision, UPDATE c2 to e
- ([self.d1.revision], self.c2.revision, self.f.revision)
+ ([self.d1.revision], self.c2.revision, self.f.revision),
)
def test_stamp_up_across_merge_from_multiple_branch(self):
heads = [self.d1.revision, self.c2.revision]
- revs = self.env._stamp_revs(
- self.e.revision, heads)
+ revs = self.env._stamp_revs(self.e.revision, heads)
eq_(len(revs), 1)
eq_(
revs[0].merge_branch_idents(heads),
# DELETE d1 revision, UPDATE c2 to e
- ([self.d1.revision], self.c2.revision, self.e.revision)
+ ([self.d1.revision], self.c2.revision, self.e.revision),
)
def test_stamp_up_across_merge_prior_branching(self):
@@ -950,7 +1027,7 @@ class MergedPathTest(MigrationTest):
eq_(
revs[0].merge_branch_idents(heads),
# no deletes, UPDATE e to c2
- ([], self.b.revision, self.e.revision)
+ ([], self.b.revision, self.e.revision),
)
def test_upgrade_across_merge_point(self):
@@ -963,9 +1040,9 @@ class MergedPathTest(MigrationTest):
self.up_(self.c1), # b->c1, create new branch
self.up_(self.d1),
self.up_(self.e), # d1/d2 -> e, merge branches
- # (DELETE d2, UPDATE d1->e)
- self.up_(self.f)
- ]
+ # (DELETE d2, UPDATE d1->e)
+ self.up_(self.f),
+ ],
)
def test_downgrade_across_merge_point(self):
@@ -975,10 +1052,10 @@ class MergedPathTest(MigrationTest):
[
self.down_(self.f),
self.down_(self.e), # e -> d1 and d2, unmerge branches
- # (UPDATE e->d1, INSERT d2)
+ # (UPDATE e->d1, INSERT d2)
self.down_(self.d1),
self.down_(self.c1),
self.down_(self.d2),
self.down_(self.c2), # c2->b, delete branch
- ]
+ ],
)
diff --git a/tox.ini b/tox.ini
index 8f3640d..660761a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -55,16 +55,14 @@ commands=
{oracle}: python reap_oracle_dbs.py oracle_idents.txt
+# thanks to https://julien.danjou.info/the-best-flake8-extensions/
[testenv:pep8]
-deps=flake8
-commands = python -m flake8 {posargs}
-
-
-[flake8]
-
-show-source = True
-ignore = E711,E712,E721,D,N
-# F841,F811,F401
-exclude=.venv,.git,.tox,dist,doc,*egg,build
-
-
+deps=
+ flake8
+ flake8-import-order
+ flake8-builtins
+ flake8-docstrings
+ flake8-rst-docstrings
+ # used by flake8-rst-docstrings
+ pygments
+commands = flake8 ./alembic/ ./tests/ setup.py