summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/__init__.py2
-rw-r--r--alembic/autogenerate/api.py48
-rw-r--r--alembic/autogenerate/compare.py224
-rw-r--r--alembic/autogenerate/render.py170
-rw-r--r--alembic/command.py37
-rw-r--r--alembic/compat.py13
-rw-r--r--alembic/config.py63
-rw-r--r--alembic/ddl/base.py60
-rw-r--r--alembic/ddl/impl.py116
-rw-r--r--alembic/ddl/mssql.py94
-rw-r--r--alembic/ddl/mysql.py92
-rw-r--r--alembic/ddl/oracle.py16
-rw-r--r--alembic/ddl/postgresql.py11
-rw-r--r--alembic/ddl/sqlite.py30
-rw-r--r--alembic/environment.py65
-rw-r--r--alembic/migration.py91
-rw-r--r--alembic/operations.py150
-rw-r--r--alembic/script.py94
-rw-r--r--alembic/templates/generic/env.py15
-rw-r--r--alembic/templates/multidb/env.py22
-rw-r--r--alembic/templates/pylons/env.py8
-rw-r--r--alembic/util.py46
-rw-r--r--setup.py26
-rw-r--r--tests/__init__.py48
-rw-r--r--tests/test_autogen_indexes.py373
-rw-r--r--tests/test_autogen_render.py329
-rw-r--r--tests/test_autogenerate.py243
-rw-r--r--tests/test_bulk_insert.py79
-rw-r--r--tests/test_command.py1
-rw-r--r--tests/test_config.py3
-rw-r--r--tests/test_environment.py2
-rw-r--r--tests/test_mssql.py13
-rw-r--r--tests/test_mysql.py40
-rw-r--r--tests/test_offline_environment.py2
-rw-r--r--tests/test_op.py216
-rw-r--r--tests/test_op_naming_convention.py46
-rw-r--r--tests/test_oracle.py7
-rw-r--r--tests/test_postgresql.py28
-rw-r--r--tests/test_revision_create.py44
-rw-r--r--tests/test_revision_paths.py12
-rw-r--r--tests/test_sql_script.py24
-rw-r--r--tests/test_sqlite.py7
-rw-r--r--tests/test_version_table.py1
-rw-r--r--tests/test_versioning.py7
44 files changed, 1661 insertions, 1357 deletions
diff --git a/alembic/__init__.py b/alembic/__init__.py
index f2d3932..56a254d 100644
--- a/alembic/__init__.py
+++ b/alembic/__init__.py
@@ -7,5 +7,3 @@ package_dir = path.abspath(path.dirname(__file__))
from . import op
from . import context
-
-
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 148e352..b13a57b 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -8,13 +8,15 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.util import OrderedSet
from .compare import _compare_tables
from .render import _drop_table, _drop_column, _drop_index, _drop_constraint, \
- _add_table, _add_column, _add_index, _add_constraint, _modify_col
+ _add_table, _add_column, _add_index, _add_constraint, _modify_col
from .. import util
log = logging.getLogger(__name__)
###################################################
# public
+
+
def compare_metadata(context, metadata):
"""Compare a database schema to that given in a
:class:`~sqlalchemy.schema.MetaData` instance.
@@ -113,10 +115,11 @@ def compare_metadata(context, metadata):
###################################################
# top level
+
def _produce_migration_diffs(context, template_args,
- imports, include_symbol=None,
- include_object=None,
- include_schemas=False):
+ imports, include_symbol=None,
+ include_object=None,
+ include_schemas=False):
opts = context.opts
metadata = opts['target_metadata']
include_schemas = opts.get('include_schemas', include_schemas)
@@ -125,20 +128,20 @@ def _produce_migration_diffs(context, template_args,
if metadata is None:
raise util.CommandError(
- "Can't proceed with --autogenerate option; environment "
- "script %s does not provide "
- "a MetaData object to the context." % (
- context.script.env_py_location
- ))
+ "Can't proceed with --autogenerate option; environment "
+ "script %s does not provide "
+ "a MetaData object to the context." % (
+ context.script.env_py_location
+ ))
autogen_context, connection = _autogen_context(context, imports)
diffs = []
_produce_net_changes(connection, metadata, diffs,
- autogen_context, object_filters, include_schemas)
+ autogen_context, object_filters, include_schemas)
template_args[opts['upgrade_token']] = \
- _indent(_produce_upgrade_commands(diffs, autogen_context))
+ _indent(_produce_upgrade_commands(diffs, autogen_context))
template_args[opts['downgrade_token']] = \
- _indent(_produce_downgrade_commands(diffs, autogen_context))
+ _indent(_produce_downgrade_commands(diffs, autogen_context))
template_args['imports'] = "\n".join(sorted(imports))
@@ -171,9 +174,10 @@ def _autogen_context(context, imports):
'opts': opts
}, connection
+
def _indent(text):
text = "### commands auto generated by Alembic - "\
- "please adjust! ###\n" + text
+ "please adjust! ###\n" + text
text += "\n### end Alembic commands ###"
text = re.compile(r'^', re.M).sub(" ", text).strip()
return text
@@ -183,8 +187,8 @@ def _indent(text):
def _produce_net_changes(connection, metadata, diffs, autogen_context,
- object_filters=(),
- include_schemas=False):
+ object_filters=(),
+ include_schemas=False):
inspector = Inspector.from_engine(connection)
# TODO: not hardcode alembic_version here ?
conn_table_names = set()
@@ -202,11 +206,11 @@ def _produce_net_changes(connection, metadata, diffs, autogen_context,
for s in schemas:
tables = set(inspector.get_table_names(schema=s)).\
- difference(['alembic_version'])
+ difference(['alembic_version'])
conn_table_names.update(zip([s] * len(tables), tables))
metadata_table_names = OrderedSet([(table.schema, table.name)
- for table in metadata.sorted_tables])
+ for table in metadata.sorted_tables])
_compare_tables(conn_table_names, metadata_table_names,
object_filters,
@@ -232,6 +236,7 @@ def _produce_upgrade_commands(diffs, autogen_context):
buf = ["pass"]
return "\n".join(buf)
+
def _produce_downgrade_commands(diffs, autogen_context):
buf = []
for diff in reversed(diffs):
@@ -240,12 +245,14 @@ def _produce_downgrade_commands(diffs, autogen_context):
buf = ["pass"]
return "\n".join(buf)
+
def _invoke_command(updown, args, autogen_context):
if isinstance(args, tuple):
return _invoke_adddrop_command(updown, args, autogen_context)
else:
return _invoke_modify_command(updown, args, autogen_context)
+
def _invoke_adddrop_command(updown, args, autogen_context):
cmd_type = args[0]
adddrop, cmd_type = cmd_type.split("_")
@@ -270,6 +277,7 @@ def _invoke_adddrop_command(updown, args, autogen_context):
else:
return cmd_callables[0](*cmd_args)
+
def _invoke_modify_command(updown, args, autogen_context):
sname, tname, cname = args[0][1:4]
kw = {}
@@ -281,9 +289,9 @@ def _invoke_modify_command(updown, args, autogen_context):
}
for diff in args:
diff_kw = diff[4]
- for arg in ("existing_type", \
- "existing_nullable", \
- "existing_server_default"):
+ for arg in ("existing_type",
+ "existing_nullable",
+ "existing_server_default"):
if arg in diff_kw:
kw.setdefault(arg, diff_kw[arg])
old_kw, new_kw = _arg_struct[diff[0]]
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index a50bc6d..cc24173 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -8,6 +8,7 @@ from sqlalchemy.util import OrderedSet
log = logging.getLogger(__name__)
+
def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
for fn in object_filters:
if not fn(object_, name, type_, reflected, compare_to):
@@ -15,6 +16,7 @@ def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
else:
return True
+
def _compare_tables(conn_table_names, metadata_table_names,
object_filters,
inspector, metadata, diffs, autogen_context):
@@ -35,14 +37,14 @@ def _compare_tables(conn_table_names, metadata_table_names,
# as "schemaname.tablename" or just "tablename", create a new lookup
# which will match the "non-default-schema" keys to the Table object.
tname_to_table = dict(
- (
- no_dflt_schema,
- metadata.tables[sa_schema._get_table_key(tname, schema)]
- )
- for no_dflt_schema, (schema, tname) in zip(
- metadata_table_names_no_dflt_schema,
- metadata_table_names)
- )
+ (
+ no_dflt_schema,
+ metadata.tables[sa_schema._get_table_key(tname, schema)]
+ )
+ for no_dflt_schema, (schema, tname) in zip(
+ metadata_table_names_no_dflt_schema,
+ metadata_table_names)
+ )
metadata_table_names = metadata_table_names_no_dflt_schema
for s, tname in metadata_table_names.difference(conn_table_names):
@@ -52,9 +54,9 @@ def _compare_tables(conn_table_names, metadata_table_names,
diffs.append(("add_table", metadata_table))
log.info("Detected added table %r", name)
_compare_indexes_and_uniques(s, tname, object_filters,
- None,
- metadata_table,
- diffs, autogen_context, inspector)
+ None,
+ metadata_table,
+ diffs, autogen_context, inspector)
removal_metadata = sa_schema.MetaData()
for s, tname in conn_table_names.difference(metadata_table_names):
@@ -87,33 +89,36 @@ def _compare_tables(conn_table_names, metadata_table_names,
if _run_filters(metadata_table, tname, "table", False, conn_table, object_filters):
_compare_columns(s, tname, object_filters,
- conn_table,
- metadata_table,
- diffs, autogen_context, inspector)
+ conn_table,
+ metadata_table,
+ diffs, autogen_context, inspector)
_compare_indexes_and_uniques(s, tname, object_filters,
- conn_table,
- metadata_table,
- diffs, autogen_context, inspector)
+ conn_table,
+ metadata_table,
+ diffs, autogen_context, inspector)
# TODO:
# table constraints
# sequences
+
def _make_index(params, conn_table):
return sa_schema.Index(
- params['name'],
- *[conn_table.c[cname] for cname in params['column_names']],
- unique=params['unique']
+ params['name'],
+ *[conn_table.c[cname] for cname in params['column_names']],
+ unique=params['unique']
)
+
def _make_unique_constraint(params, conn_table):
return sa_schema.UniqueConstraint(
- *[conn_table.c[cname] for cname in params['column_names']],
- name=params['name']
+ *[conn_table.c[cname] for cname in params['column_names']],
+ name=params['name']
)
+
def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
- diffs, autogen_context, inspector):
+ diffs, autogen_context, inspector):
name = '%s.%s' % (schema, tname) if schema else tname
metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
conn_col_names = dict((c.name, c) for c in conn_table.c)
@@ -121,7 +126,7 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
for cname in metadata_col_names.difference(conn_col_names):
if _run_filters(metadata_cols_by_name[cname], cname,
- "column", False, None, object_filters):
+ "column", False, None, object_filters):
diffs.append(
("add_column", schema, tname, metadata_cols_by_name[cname])
)
@@ -129,7 +134,7 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
for cname in set(conn_col_names).difference(metadata_col_names):
if _run_filters(conn_table.c[cname], cname,
- "column", True, None, object_filters):
+ "column", True, None, object_filters):
diffs.append(
("remove_column", schema, tname, conn_table.c[cname])
)
@@ -139,28 +144,30 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
metadata_col = metadata_cols_by_name[colname]
conn_col = conn_table.c[colname]
if not _run_filters(
- metadata_col, colname, "column", False, conn_col, object_filters):
+ metadata_col, colname, "column", False, conn_col, object_filters):
continue
col_diff = []
_compare_type(schema, tname, colname,
- conn_col,
- metadata_col,
- col_diff, autogen_context
- )
+ conn_col,
+ metadata_col,
+ col_diff, autogen_context
+ )
_compare_nullable(schema, tname, colname,
- conn_col,
- metadata_col.nullable,
- col_diff, autogen_context
- )
+ conn_col,
+ metadata_col.nullable,
+ col_diff, autogen_context
+ )
_compare_server_default(schema, tname, colname,
- conn_col,
- metadata_col,
- col_diff, autogen_context
- )
+ conn_col,
+ metadata_col,
+ col_diff, autogen_context
+ )
if col_diff:
diffs.append(col_diff)
+
class _constraint_sig(object):
+
def __eq__(self, other):
return self.const == other.const
@@ -170,6 +177,7 @@ class _constraint_sig(object):
def __hash__(self):
return hash(self.const)
+
class _uq_constraint_sig(_constraint_sig):
is_index = False
is_unique = True
@@ -183,6 +191,7 @@ class _uq_constraint_sig(_constraint_sig):
def column_names(self):
return [col.name for col in self.const.columns]
+
class _ix_constraint_sig(_constraint_sig):
is_index = True
@@ -196,21 +205,23 @@ class _ix_constraint_sig(_constraint_sig):
def column_names(self):
return _get_index_column_names(self.const)
+
def _get_index_column_names(idx):
if compat.sqla_08:
return [getattr(exp, "name", None) for exp in idx.expressions]
else:
return [getattr(col, "name", None) for col in idx.columns]
+
def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
- metadata_table, diffs, autogen_context, inspector):
+ metadata_table, diffs, autogen_context, inspector):
is_create_table = conn_table is None
# 1a. get raw indexes and unique constraints from metadata ...
metadata_unique_constraints = set(uq for uq in metadata_table.constraints
- if isinstance(uq, sa_schema.UniqueConstraint)
- )
+ if isinstance(uq, sa_schema.UniqueConstraint)
+ )
metadata_indexes = set(metadata_table.indexes)
conn_uniques = conn_indexes = frozenset()
@@ -222,7 +233,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if hasattr(inspector, "get_unique_constraints"):
try:
conn_uniques = inspector.get_unique_constraints(
- tname, schema=schema)
+ tname, schema=schema)
supports_unique_constraints = True
except NotImplementedError:
pass
@@ -234,26 +245,26 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# 2. convert conn-level objects from raw inspector records
# into schema objects
conn_uniques = set(_make_unique_constraint(uq_def, conn_table)
- for uq_def in conn_uniques)
+ for uq_def in conn_uniques)
conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes)
# 3. give the dialect a chance to omit indexes and constraints that
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
autogen_context['context'].impl.\
- correct_for_autogen_constraints(
- conn_uniques, conn_indexes,
- metadata_unique_constraints,
- metadata_indexes
- )
+ correct_for_autogen_constraints(
+ conn_uniques, conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes
+ )
# 4. organize the constraints into "signature" collections, the
# _constraint_sig() objects provide a consistent facade over both
# Index and UniqueConstraint so we can easily work with them
# interchangeably
metadata_unique_constraints = set(_uq_constraint_sig(uq)
- for uq in metadata_unique_constraints
- )
+ for uq in metadata_unique_constraints
+ )
metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes)
@@ -263,16 +274,16 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# 5. index things by name, for those objects that have names
metadata_names = dict(
- (c.name, c) for c in
- metadata_unique_constraints.union(metadata_indexes)
- if c.name is not None)
+ (c.name, c) for c in
+ metadata_unique_constraints.union(metadata_indexes)
+ if c.name is not None)
conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
conn_names = dict((c.name, c) for c in
- conn_unique_constraints.union(conn_indexes)
- if c.name is not None)
+ conn_unique_constraints.union(conn_indexes)
+ if c.name is not None)
doubled_constraints = dict(
(name, (conn_uniques_by_name[name], conn_indexes_by_name[name]))
@@ -283,11 +294,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
# constraints.
conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
metadata_uniques_by_sig = dict(
- (uq.sig, uq) for uq in metadata_unique_constraints)
+ (uq.sig, uq) for uq in metadata_unique_constraints)
metadata_indexes_by_sig = dict(
- (ix.sig, ix) for ix in metadata_indexes)
+ (ix.sig, ix) for ix in metadata_indexes)
unnamed_metadata_uniques = dict((uq.sig, uq) for uq in
- metadata_unique_constraints if uq.name is None)
+ metadata_unique_constraints if uq.name is None)
# assumptions:
# 1. a unique constraint or an index from the connection *always*
@@ -301,10 +312,10 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if obj.is_index:
diffs.append(("add_index", obj.const))
log.info("Detected added index '%s' on %s",
- obj.name, ', '.join([
- "'%s'" % obj.column_names
- ])
- )
+ obj.name, ', '.join([
+ "'%s'" % obj.column_names
+ ])
+ )
else:
if not supports_unique_constraints:
# can't report unique indexes as added if we don't
@@ -315,10 +326,10 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
return
diffs.append(("add_constraint", obj.const))
log.info("Detected added unique constraint '%s' on %s",
- obj.name, ', '.join([
- "'%s'" % obj.column_names
- ])
- )
+ obj.name, ', '.join([
+ "'%s'" % obj.column_names
+ ])
+ )
def obj_removed(obj):
if obj.is_index:
@@ -333,20 +344,20 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
else:
diffs.append(("remove_constraint", obj.const))
log.info("Detected removed unique constraint '%s' on '%s'",
- obj.name, tname
- )
+ obj.name, tname
+ )
def obj_changed(old, new, msg):
if old.is_index:
log.info("Detected changed index '%s' on '%s':%s",
- old.name, tname, ', '.join(msg)
- )
+ old.name, tname, ', '.join(msg)
+ )
diffs.append(("remove_index", old.const))
diffs.append(("add_index", new.const))
else:
log.info("Detected changed unique constraint '%s' on '%s':%s",
- old.name, tname, ', '.join(msg)
- )
+ old.name, tname, ', '.join(msg)
+ )
diffs.append(("remove_constraint", old.const))
diffs.append(("add_constraint", new.const))
@@ -354,7 +365,6 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
obj = metadata_names[added_name]
obj_added(obj)
-
for existing_name in sorted(set(metadata_names).intersection(conn_names)):
metadata_obj = metadata_names[existing_name]
@@ -384,14 +394,13 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
if msg:
obj_changed(conn_obj, metadata_obj, msg)
-
for removed_name in sorted(set(conn_names).difference(metadata_names)):
conn_obj = conn_names[removed_name]
if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
continue
elif removed_name in doubled_constraints:
if conn_obj.sig not in metadata_indexes_by_sig and \
- conn_obj.sig not in metadata_uniques_by_sig:
+ conn_obj.sig not in metadata_uniques_by_sig:
conn_uq, conn_idx = doubled_constraints[removed_name]
obj_removed(conn_uq)
obj_removed(conn_idx)
@@ -404,8 +413,8 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table,
def _compare_nullable(schema, tname, cname, conn_col,
- metadata_col_nullable, diffs,
- autogen_context):
+ metadata_col_nullable, diffs,
+ autogen_context):
conn_col_nullable = conn_col.nullable
if conn_col_nullable is not metadata_col_nullable:
diffs.append(
@@ -418,24 +427,25 @@ def _compare_nullable(schema, tname, cname, conn_col,
metadata_col_nullable),
)
log.info("Detected %s on column '%s.%s'",
- "NULL" if metadata_col_nullable else "NOT NULL",
- tname,
- cname
- )
+ "NULL" if metadata_col_nullable else "NOT NULL",
+ tname,
+ cname
+ )
+
def _compare_type(schema, tname, cname, conn_col,
- metadata_col, diffs,
- autogen_context):
+ metadata_col, diffs,
+ autogen_context):
conn_type = conn_col.type
metadata_type = metadata_col.type
if conn_type._type_affinity is sqltypes.NullType:
log.info("Couldn't determine database type "
- "for column '%s.%s'", tname, cname)
+ "for column '%s.%s'", tname, cname)
return
if metadata_type._type_affinity is sqltypes.NullType:
log.info("Column '%s.%s' has no type within "
- "the model; can't compare", tname, cname)
+ "the model; can't compare", tname, cname)
return
isdiff = autogen_context['context']._compare_type(conn_col, metadata_col)
@@ -444,40 +454,42 @@ def _compare_type(schema, tname, cname, conn_col,
diffs.append(
("modify_type", schema, tname, cname,
- {
- "existing_nullable": conn_col.nullable,
- "existing_server_default": conn_col.server_default,
- },
- conn_type,
- metadata_type),
+ {
+ "existing_nullable": conn_col.nullable,
+ "existing_server_default": conn_col.server_default,
+ },
+ conn_type,
+ metadata_type),
)
log.info("Detected type change from %r to %r on '%s.%s'",
- conn_type, metadata_type, tname, cname
- )
+ conn_type, metadata_type, tname, cname
+ )
+
def _render_server_default_for_compare(metadata_default,
- metadata_col, autogen_context):
+ metadata_col, autogen_context):
return _render_server_default(
- metadata_default, autogen_context,
- repr_=metadata_col.type._type_affinity is sqltypes.String)
+ metadata_default, autogen_context,
+ repr_=metadata_col.type._type_affinity is sqltypes.String)
+
def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
- diffs, autogen_context):
+ diffs, autogen_context):
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
if conn_col_default is None and metadata_default is None:
return False
rendered_metadata_default = _render_server_default_for_compare(
- metadata_default, metadata_col, autogen_context)
+ metadata_default, metadata_col, autogen_context)
rendered_conn_default = conn_col.server_default.arg.text \
- if conn_col.server_default else None
+ if conn_col.server_default else None
isdiff = autogen_context['context']._compare_server_default(
- conn_col, metadata_col,
- rendered_metadata_default,
- rendered_conn_default
- )
+ conn_col, metadata_col,
+ rendered_metadata_default,
+ rendered_conn_default
+ )
if isdiff:
conn_col_default = rendered_conn_default
diffs.append(
@@ -490,6 +502,6 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
metadata_default),
)
log.info("Detected server default on column '%s.%s'",
- tname,
- cname
- )
+ tname,
+ cname
+ )
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index 81bd774..447870b 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -10,6 +10,7 @@ MAX_PYTHON_ARGS = 255
try:
from sqlalchemy.sql.naming import conv
+
def _render_gen_name(autogen_context, name):
if isinstance(name, conv):
return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
@@ -19,7 +20,9 @@ except ImportError:
def _render_gen_name(autogen_context, name):
return name
+
class _f_name(object):
+
def __init__(self, prefix, name):
self.prefix = prefix
self.name = name
@@ -27,6 +30,7 @@ class _f_name(object):
def __repr__(self):
return "%sf(%r)" % (self.prefix, self.name)
+
def _render_potential_expr(value, autogen_context):
if isinstance(value, sql.ClauseElement):
if compat.sqla_08:
@@ -37,23 +41,24 @@ def _render_potential_expr(value, autogen_context):
return "%(prefix)stext(%(sql)r)" % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"sql": str(
- value.compile(dialect=autogen_context['dialect'],
- **compile_kw)
- )
+ value.compile(dialect=autogen_context['dialect'],
+ **compile_kw)
+ )
}
else:
return repr(value)
+
def _add_table(table, autogen_context):
args = [col for col in
[_render_column(col, autogen_context) for col in table.c]
- if col] + \
+ if col] + \
sorted([rcons for rcons in
- [_render_constraint(cons, autogen_context) for cons in
- table.constraints]
- if rcons is not None
- ])
+ [_render_constraint(cons, autogen_context) for cons in
+ table.constraints]
+ if rcons is not None
+ ])
if len(args) > MAX_PYTHON_ARGS:
args = '*[' + ',\n'.join(args) + ']'
@@ -72,16 +77,18 @@ def _add_table(table, autogen_context):
text += "\n)"
return text
+
def _drop_table(table, autogen_context):
text = "%(prefix)sdrop_table(%(tname)r" % {
- "prefix": _alembic_autogenerate_prefix(autogen_context),
- "tname": table.name
- }
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "tname": table.name
+ }
if table.schema:
text += ", schema=%r" % table.schema
text += ")"
return text
+
def _add_index(index, autogen_context):
"""
Generate Alembic operations for the CREATE INDEX of an
@@ -90,27 +97,28 @@ def _add_index(index, autogen_context):
from .compare import _get_index_column_names
text = "%(prefix)screate_index(%(name)r, '%(table)s', %(columns)s, "\
- "unique=%(unique)r%(schema)s%(kwargs)s)" % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'name': _render_gen_name(autogen_context, index.name),
- 'table': index.table.name,
- 'columns': _get_index_column_names(index),
- 'unique': index.unique or False,
- 'schema': (", schema='%s'" % index.table.schema) if index.table.schema else '',
- 'kwargs': (', '+', '.join(
- ["%s=%s" % (key, _render_potential_expr(val, autogen_context))
- for key, val in index.kwargs.items()]))\
+ "unique=%(unique)r%(schema)s%(kwargs)s)" % {
+ 'prefix': _alembic_autogenerate_prefix(autogen_context),
+ 'name': _render_gen_name(autogen_context, index.name),
+ 'table': index.table.name,
+ 'columns': _get_index_column_names(index),
+ 'unique': index.unique or False,
+ 'schema': (", schema='%s'" % index.table.schema) if index.table.schema else '',
+ 'kwargs': (', ' + ', '.join(
+ ["%s=%s" % (key, _render_potential_expr(val, autogen_context))
+ for key, val in index.kwargs.items()]))
if len(index.kwargs) else ''
- }
+ }
return text
+
def _drop_index(index, autogen_context):
"""
Generate Alembic operations for the DROP INDEX of an
:class:`~sqlalchemy.schema.Index` instance.
"""
text = "%(prefix)sdrop_index(%(name)r, "\
- "table_name='%(table_name)s'%(schema)s)" % {
+ "table_name='%(table_name)s'%(schema)s)" % {
'prefix': _alembic_autogenerate_prefix(autogen_context),
'name': _render_gen_name(autogen_context, index.name),
'table_name': index.table.name,
@@ -135,6 +143,7 @@ def _add_unique_constraint(constraint, autogen_context):
"""
return _uq_constraint(constraint, autogen_context, True)
+
def _uq_constraint(constraint, autogen_context, alter):
opts = []
if constraint.deferrable:
@@ -148,13 +157,13 @@ def _uq_constraint(constraint, autogen_context, alter):
if alter:
args = [repr(_render_gen_name(autogen_context, constraint.name)),
- repr(constraint.table.name)]
+ repr(constraint.table.name)]
args.append(repr([col.name for col in constraint.columns]))
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)screate_unique_constraint(%(args)s)" % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'args': ", ".join(args)
- }
+ 'prefix': _alembic_autogenerate_prefix(autogen_context),
+ 'args': ", ".join(args)
+ }
else:
args = [repr(col.name) for col in constraint.columns]
args.extend(["%s=%r" % (k, v) for k, v in opts])
@@ -167,12 +176,15 @@ def _uq_constraint(constraint, autogen_context, alter):
def _add_fk_constraint(constraint, autogen_context):
raise NotImplementedError()
+
def _add_pk_constraint(constraint, autogen_context):
raise NotImplementedError()
+
def _add_check_constraint(constraint, autogen_context):
raise NotImplementedError()
+
def _add_constraint(constraint, autogen_context):
"""
Dispatcher for the different types of constraints.
@@ -186,42 +198,46 @@ def _add_constraint(constraint, autogen_context):
}
return funcs[constraint.__visit_name__](constraint, autogen_context)
+
def _drop_constraint(constraint, autogen_context):
"""
Generate Alembic operations for the ALTER TABLE ... DROP CONSTRAINT
of a :class:`~sqlalchemy.schema.UniqueConstraint` instance.
"""
text = "%(prefix)sdrop_constraint(%(name)r, '%(table_name)s'%(schema)s)" % {
- 'prefix': _alembic_autogenerate_prefix(autogen_context),
- 'name': _render_gen_name(autogen_context, constraint.name),
- 'table_name': constraint.table.name,
- 'schema': (", schema='%s'" % constraint.table.schema)
- if constraint.table.schema else '',
+ 'prefix': _alembic_autogenerate_prefix(autogen_context),
+ 'name': _render_gen_name(autogen_context, constraint.name),
+ 'table_name': constraint.table.name,
+ 'schema': (", schema='%s'" % constraint.table.schema)
+ if constraint.table.schema else '',
}
return text
+
def _add_column(schema, tname, column, autogen_context):
text = "%(prefix)sadd_column(%(tname)r, %(column)s" % {
- "prefix": _alembic_autogenerate_prefix(autogen_context),
- "tname": tname,
- "column": _render_column(column, autogen_context)
- }
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "tname": tname,
+ "column": _render_column(column, autogen_context)
+ }
if schema:
text += ", schema=%r" % schema
text += ")"
return text
+
def _drop_column(schema, tname, column, autogen_context):
text = "%(prefix)sdrop_column(%(tname)r, %(cname)r" % {
- "prefix": _alembic_autogenerate_prefix(autogen_context),
- "tname": tname,
- "cname": column.name
- }
+ "prefix": _alembic_autogenerate_prefix(autogen_context),
+ "tname": tname,
+ "cname": column.name
+ }
if schema:
text += ", schema=%r" % schema
text += ")"
return text
+
def _modify_col(tname, cname,
autogen_context,
server_default=False,
@@ -233,37 +249,38 @@ def _modify_col(tname, cname,
schema=None):
indent = " " * 11
text = "%(prefix)salter_column(%(tname)r, %(cname)r" % {
- 'prefix': _alembic_autogenerate_prefix(
- autogen_context),
- 'tname': tname,
- 'cname': cname}
+ 'prefix': _alembic_autogenerate_prefix(
+ autogen_context),
+ 'tname': tname,
+ 'cname': cname}
text += ",\n%sexisting_type=%s" % (indent,
- _repr_type(existing_type, autogen_context))
+ _repr_type(existing_type, autogen_context))
if server_default is not False:
rendered = _render_server_default(
- server_default, autogen_context)
+ server_default, autogen_context)
text += ",\n%sserver_default=%s" % (indent, rendered)
if type_ is not None:
text += ",\n%stype_=%s" % (indent,
- _repr_type(type_, autogen_context))
+ _repr_type(type_, autogen_context))
if nullable is not None:
text += ",\n%snullable=%r" % (
- indent, nullable,)
+ indent, nullable,)
if existing_nullable is not None:
text += ",\n%sexisting_nullable=%r" % (
- indent, existing_nullable)
+ indent, existing_nullable)
if existing_server_default:
rendered = _render_server_default(
- existing_server_default,
- autogen_context)
+ existing_server_default,
+ autogen_context)
text += ",\n%sexisting_server_default=%s" % (
- indent, rendered)
+ indent, rendered)
if schema:
text += ",\n%sschema=%r" % (indent, schema)
text += ")"
return text
+
def _user_autogenerate_prefix(autogen_context):
prefix = autogen_context['opts']['user_module_prefix']
if prefix is None:
@@ -271,12 +288,15 @@ def _user_autogenerate_prefix(autogen_context):
else:
return prefix
+
def _sqlalchemy_autogenerate_prefix(autogen_context):
return autogen_context['opts']['sqlalchemy_module_prefix'] or ''
+
def _alembic_autogenerate_prefix(autogen_context):
return autogen_context['opts']['alembic_module_prefix'] or ''
+
def _user_defined_render(type_, object_, autogen_context):
if 'opts' in autogen_context and \
'render_item' in autogen_context['opts']:
@@ -287,6 +307,7 @@ def _user_defined_render(type_, object_, autogen_context):
return rendered
return False
+
def _render_column(column, autogen_context):
rendered = _user_defined_render("column", column, autogen_context)
if rendered is not False:
@@ -295,8 +316,8 @@ def _render_column(column, autogen_context):
opts = []
if column.server_default:
rendered = _render_server_default(
- column.server_default, autogen_context
- )
+ column.server_default, autogen_context
+ )
if rendered:
opts.append(("server_default", rendered))
@@ -314,6 +335,7 @@ def _render_column(column, autogen_context):
'kw': ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
}
+
def _render_server_default(default, autogen_context, repr_=True):
rendered = _user_defined_render("server_default", default, autogen_context)
if rendered is not False:
@@ -324,7 +346,7 @@ def _render_server_default(default, autogen_context, repr_=True):
default = default.arg
else:
default = str(default.arg.compile(
- dialect=autogen_context['dialect']))
+ dialect=autogen_context['dialect']))
if isinstance(default, string_types):
if repr_:
default = re.sub(r"^'|'$", "", default)
@@ -334,6 +356,7 @@ def _render_server_default(default, autogen_context, repr_=True):
else:
return None
+
def _repr_type(type_, autogen_context):
rendered = _user_defined_render("type", type_, autogen_context)
if rendered is not False:
@@ -353,6 +376,7 @@ def _repr_type(type_, autogen_context):
prefix = _user_autogenerate_prefix(autogen_context)
return "%s%r" % (prefix, type_)
+
def _render_constraint(constraint, autogen_context):
renderer = _constraint_renderers.get(type(constraint), None)
if renderer:
@@ -360,6 +384,7 @@ def _render_constraint(constraint, autogen_context):
else:
return None
+
def _render_primary_key(constraint, autogen_context):
rendered = _user_defined_render("primary_key", constraint, autogen_context)
if rendered is not False:
@@ -379,6 +404,7 @@ def _render_primary_key(constraint, autogen_context):
),
}
+
def _fk_colspec(fk, metadata_schema):
"""Implement a 'safe' version of ForeignKey._get_colspec() that
never tries to resolve the remote table.
@@ -393,6 +419,7 @@ def _fk_colspec(fk, metadata_schema):
colspec = "%s.%s" % (metadata_schema, colspec)
return colspec
+
def _render_foreign_key(constraint, autogen_context):
rendered = _user_defined_render("foreign_key", constraint, autogen_context)
if rendered is not False:
@@ -414,15 +441,16 @@ def _render_foreign_key(constraint, autogen_context):
apply_metadata_schema = constraint.parent.metadata.schema
return "%(prefix)sForeignKeyConstraint([%(cols)s], "\
- "[%(refcols)s], %(args)s)" % {
- "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
- "cols": ", ".join("'%s'" % f.parent.key for f in constraint.elements),
- "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema))
- for f in constraint.elements),
- "args": ", ".join(
- ["%s=%s" % (kwname, val) for kwname, val in opts]
- ),
- }
+ "[%(refcols)s], %(args)s)" % {
+ "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
+ "cols": ", ".join("'%s'" % f.parent.key for f in constraint.elements),
+ "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema))
+ for f in constraint.elements),
+ "args": ", ".join(
+ ["%s=%s" % (kwname, val) for kwname, val in opts]
+ ),
+ }
+
def _render_check_constraint(constraint, autogen_context):
rendered = _user_defined_render("check", constraint, autogen_context)
@@ -436,21 +464,21 @@ def _render_check_constraint(constraint, autogen_context):
if constraint._create_rule and \
hasattr(constraint._create_rule, 'target') and \
isinstance(constraint._create_rule.target,
- sqltypes.TypeEngine):
+ sqltypes.TypeEngine):
return None
opts = []
if constraint.name:
opts.append(("name", repr(_render_gen_name(autogen_context, constraint.name))))
return "%(prefix)sCheckConstraint(%(sqltext)r%(opts)s)" % {
- "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
- "opts": ", " + (", ".join("%s=%s" % (k, v)
- for k, v in opts)) if opts else "",
- "sqltext": str(
+ "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
+ "opts": ", " + (", ".join("%s=%s" % (k, v)
+ for k, v in opts)) if opts else "",
+ "sqltext": str(
constraint.sqltext.compile(
dialect=autogen_context['dialect']
)
- )
- }
+ )
+ }
_constraint_renderers = {
sa_schema.PrimaryKeyConstraint: _render_primary_key,
diff --git a/alembic/command.py b/alembic/command.py
index f1c5962..a6d7995 100644
--- a/alembic/command.py
+++ b/alembic/command.py
@@ -4,21 +4,23 @@ from .script import ScriptDirectory
from .environment import EnvironmentContext
from . import util, autogenerate as autogen
+
def list_templates(config):
"""List available templates"""
config.print_stdout("Available templates:\n")
for tempname in os.listdir(config.get_template_directory()):
with open(os.path.join(
- config.get_template_directory(),
- tempname,
- 'README')) as readme:
+ config.get_template_directory(),
+ tempname,
+ 'README')) as readme:
synopsis = next(readme)
config.print_stdout("%s - %s", tempname, synopsis)
config.print_stdout("\nTemplates are used via the 'init' command, e.g.:")
config.print_stdout("\n alembic init --template pylons ./scripts")
+
def init(config, directory, template='generic'):
"""Initialize a new scripts directory."""
@@ -26,7 +28,7 @@ def init(config, directory, template='generic'):
raise util.CommandError("Directory %s already exists" % directory)
template_dir = os.path.join(config.get_template_directory(),
- template)
+ template)
if not os.access(template_dir, os.F_OK):
raise util.CommandError("No such template %r" % template)
@@ -58,8 +60,9 @@ def init(config, directory, template='generic'):
output_file
)
- util.msg("Please edit configuration/connection/logging "\
- "settings in %r before proceeding." % config_file)
+ util.msg("Please edit configuration/connection/logging "
+ "settings in %r before proceeding." % config_file)
+
def revision(config, message=None, autogenerate=False, sql=False):
"""Create a new revision file."""
@@ -77,6 +80,7 @@ def revision(config, message=None, autogenerate=False, sql=False):
if autogenerate:
environment = True
+
def retrieve_migrations(rev, context):
if script.get_revision(rev) is not script.get_revision("head"):
raise util.CommandError("Target database is not up to date.")
@@ -124,6 +128,7 @@ def upgrade(config, revision, sql=False, tag=None):
):
script.run_env()
+
def downgrade(config, revision, sql=False, tag=None):
"""Revert to a previous version."""
@@ -150,6 +155,7 @@ def downgrade(config, revision, sql=False, tag=None):
):
script.run_env()
+
def history(config, rev_range=None):
"""List changeset scripts in chronological order."""
@@ -157,16 +163,16 @@ def history(config, rev_range=None):
if rev_range is not None:
if ":" not in rev_range:
raise util.CommandError(
- "History range requires [start]:[end], "
- "[start]:, or :[end]")
+ "History range requires [start]:[end], "
+ "[start]:, or :[end]")
base, head = rev_range.strip().split(":")
else:
base = head = None
def _display_history(config, script, base, head):
for sc in script.walk_revisions(
- base=base or "base",
- head=head or "head"):
+ base=base or "base",
+ head=head or "head"):
if sc.is_head:
config.print_stdout("")
config.print_stdout(sc.log_entry)
@@ -202,14 +208,16 @@ def branches(config):
config.print_stdout(sc)
for rev in sc.nextrev:
config.print_stdout("%s -> %s",
- " " * len(str(sc.down_revision)),
- script.get_revision(rev)
- )
+ " " * len(str(sc.down_revision)),
+ script.get_revision(rev)
+ )
+
def current(config, head_only=False):
"""Display the current revision for each database."""
script = ScriptDirectory.from_config(config)
+
def display_version(rev, context):
rev = script.get_revision(rev)
@@ -232,11 +240,13 @@ def current(config, head_only=False):
):
script.run_env()
+
def stamp(config, revision, sql=False, tag=None):
"""'stamp' the revision table with the given revision; don't
run any migrations."""
script = ScriptDirectory.from_config(config)
+
def do_stamp(rev, context):
if sql:
current = False
@@ -257,6 +267,7 @@ def stamp(config, revision, sql=False, tag=None):
):
script.run_env()
+
def splice(config, parent, child):
"""'splice' two branches, creating a new revision file.
diff --git a/alembic/compat.py b/alembic/compat.py
index aac0560..cded54b 100644
--- a/alembic/compat.py
+++ b/alembic/compat.py
@@ -17,6 +17,7 @@ if py3k:
string_types = str,
binary_type = bytes
text_type = str
+
def callable(fn):
return hasattr(fn, '__call__')
@@ -45,6 +46,7 @@ if py2k:
if py33:
from importlib import machinery
+
def load_module_py(module_id, path):
return machinery.SourceFileLoader(module_id, path).load_module(module_id)
@@ -53,6 +55,7 @@ if py33:
else:
import imp
+
def load_module_py(module_id, path):
with open(path, 'rb') as fp:
mod = imp.load_source(module_id, path, fp)
@@ -78,6 +81,8 @@ except AttributeError:
################################################
# cross-compatible metaclass implementation
# Copyright (c) 2010-2012 Benjamin Peterson
+
+
def with_metaclass(meta, base=object):
"""Create a base class with a metaclass."""
return meta("%sBase" % meta.__name__, (base,), {})
@@ -88,6 +93,7 @@ def with_metaclass(meta, base=object):
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
+
def close(self):
pass
@@ -99,10 +105,12 @@ if py2k:
# adapter.
class ActLikePy3kIO(object):
+
"""Produce an object capable of wrapping either
sys.stdout (e.g. file) *or* StringIO.StringIO().
"""
+
def _false(self):
return False
@@ -123,8 +131,7 @@ if py2k:
return self.file_.flush()
class EncodedIO(EncodedIO):
+
def __init__(self, file_, encoding):
super(EncodedIO, self).__init__(
- ActLikePy3kIO(file_), encoding=encoding)
-
-
+ ActLikePy3kIO(file_), encoding=encoding)
diff --git a/alembic/config.py b/alembic/config.py
index 86ff1df..003949b 100644
--- a/alembic/config.py
+++ b/alembic/config.py
@@ -6,7 +6,9 @@ import sys
from . import command, util, package_dir, compat
+
class Config(object):
+
"""Represent an Alembic configuration.
Within an ``env.py`` script, this is available
@@ -50,8 +52,9 @@ class Config(object):
..versionadded:: 0.4
"""
+
def __init__(self, file_=None, ini_section='alembic', output_buffer=None,
- stdout=sys.stdout, cmd_opts=None):
+ stdout=sys.stdout, cmd_opts=None):
"""Construct a new :class:`.Config`
"""
@@ -90,9 +93,9 @@ class Config(object):
"""Render a message to standard out."""
util.write_outstream(
- self.stdout,
- (compat.text_type(text) % arg),
- "\n"
+ self.stdout,
+ (compat.text_type(text) % arg),
+ "\n"
)
@util.memoized_property
@@ -162,8 +165,8 @@ class Config(object):
"""
if not self.file_config.has_section(section):
raise util.CommandError("No config file %r found, or file has no "
- "'[%s]' section" %
- (self.config_file_name, section))
+ "'[%s]' section" %
+ (self.config_file_name, section))
if self.file_config.has_option(section, name):
return self.file_config.get(section, name)
else:
@@ -181,35 +184,35 @@ class Config(object):
class CommandLine(object):
+
def __init__(self, prog=None):
self._generate_args(prog)
-
def _generate_args(self, prog):
def add_options(parser, positional, kwargs):
if 'template' in kwargs:
parser.add_argument("-t", "--template",
- default='generic',
- type=str,
- help="Setup template for use with 'init'")
+ default='generic',
+ type=str,
+ help="Setup template for use with 'init'")
if 'message' in kwargs:
parser.add_argument("-m", "--message",
- type=str,
- help="Message string to use with 'revision'")
+ type=str,
+ help="Message string to use with 'revision'")
if 'sql' in kwargs:
parser.add_argument("--sql",
- action="store_true",
- help="Don't emit SQL to database - dump to "
- "standard output/file instead")
+ action="store_true",
+ help="Don't emit SQL to database - dump to "
+ "standard output/file instead")
if 'tag' in kwargs:
parser.add_argument("--tag",
- type=str,
- help="Arbitrary 'tag' name - can be used by "
- "custom env.py scripts.")
+ type=str,
+ help="Arbitrary 'tag' name - can be used by "
+ "custom env.py scripts.")
if 'autogenerate' in kwargs:
parser.add_argument("--autogenerate",
- action="store_true",
- help="Populate revision script with candidate "
+ action="store_true",
+ help="Populate revision script with candidate "
"migration operations, based on comparison "
"of database to model.")
# "current" command
@@ -225,7 +228,6 @@ class CommandLine(object):
help="Specify a revision range; "
"format is [start]:[end]")
-
positional_help = {
'directory': "location of scripts directory",
'revision': "revision identifier"
@@ -252,8 +254,8 @@ class CommandLine(object):
for fn in [getattr(command, n) for n in dir(command)]:
if inspect.isfunction(fn) and \
- fn.__name__[0] != '_' and \
- fn.__module__ == 'alembic.command':
+ fn.__name__[0] != '_' and \
+ fn.__module__ == 'alembic.command':
spec = inspect.getargspec(fn)
if spec[3]:
@@ -264,8 +266,8 @@ class CommandLine(object):
kwarg = []
subparser = subparsers.add_parser(
- fn.__name__,
- help=fn.__doc__)
+ fn.__name__,
+ help=fn.__doc__)
add_options(subparser, positional, kwarg)
subparser.set_defaults(cmd=(fn, positional, kwarg))
self.parser = parser
@@ -275,9 +277,9 @@ class CommandLine(object):
try:
fn(config,
- *[getattr(options, k) for k in positional],
- **dict((k, getattr(options, k)) for k in kwarg)
- )
+ *[getattr(options, k) for k in positional],
+ **dict((k, getattr(options, k)) for k in kwarg)
+ )
except util.CommandError as e:
util.err(str(e))
@@ -289,13 +291,14 @@ class CommandLine(object):
self.parser.error("too few arguments")
else:
cfg = Config(file_=options.config,
- ini_section=options.name, cmd_opts=options)
+ ini_section=options.name, cmd_opts=options)
self.run_cmd(cfg, options)
+
def main(argv=None, prog=None, **kwargs):
"""The console runner function for Alembic."""
CommandLine(prog=prog).main(argv=argv)
if __name__ == '__main__':
- main() \ No newline at end of file
+ main()
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py
index 5d703a5..3a60926 100644
--- a/alembic/ddl/base.py
+++ b/alembic/ddl/base.py
@@ -5,62 +5,81 @@ from sqlalchemy.schema import DDLElement, Column
from sqlalchemy import Integer
from sqlalchemy import types as sqltypes
+
class AlterTable(DDLElement):
+
"""Represent an ALTER TABLE statement.
Only the string name and optional schema name of the table
is required, not a full Table object.
"""
+
def __init__(self, table_name, schema=None):
self.table_name = table_name
self.schema = schema
+
class RenameTable(AlterTable):
+
def __init__(self, old_table_name, new_table_name, schema=None):
super(RenameTable, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
+
class AlterColumn(AlterTable):
+
def __init__(self, name, column_name, schema=None,
- existing_type=None,
- existing_nullable=None,
- existing_server_default=None):
+ existing_type=None,
+ existing_nullable=None,
+ existing_server_default=None):
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
- self.existing_type=sqltypes.to_instance(existing_type) \
- if existing_type is not None else None
- self.existing_nullable=existing_nullable
- self.existing_server_default=existing_server_default
+ self.existing_type = sqltypes.to_instance(existing_type) \
+ if existing_type is not None else None
+ self.existing_nullable = existing_nullable
+ self.existing_server_default = existing_server_default
+
class ColumnNullable(AlterColumn):
+
def __init__(self, name, column_name, nullable, **kw):
super(ColumnNullable, self).__init__(name, column_name,
- **kw)
+ **kw)
self.nullable = nullable
+
class ColumnType(AlterColumn):
+
def __init__(self, name, column_name, type_, **kw):
super(ColumnType, self).__init__(name, column_name,
- **kw)
+ **kw)
self.type_ = sqltypes.to_instance(type_)
+
class ColumnName(AlterColumn):
+
def __init__(self, name, column_name, newname, **kw):
super(ColumnName, self).__init__(name, column_name, **kw)
self.newname = newname
+
class ColumnDefault(AlterColumn):
+
def __init__(self, name, column_name, default, **kw):
super(ColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
+
class AddColumn(AlterTable):
+
def __init__(self, name, column, schema=None):
super(AddColumn, self).__init__(name, schema=schema)
self.column = column
+
class DropColumn(AlterTable):
+
def __init__(self, name, column, schema=None):
super(DropColumn, self).__init__(name, schema=schema)
self.column = column
@@ -73,6 +92,7 @@ def visit_rename_table(element, compiler, **kw):
format_table_name(compiler, element.new_table_name, element.schema)
)
+
@compiles(AddColumn)
def visit_add_column(element, compiler, **kw):
return "%s %s" % (
@@ -80,6 +100,7 @@ def visit_add_column(element, compiler, **kw):
add_column(compiler, element.column, **kw)
)
+
@compiles(DropColumn)
def visit_drop_column(element, compiler, **kw):
return "%s %s" % (
@@ -87,6 +108,7 @@ def visit_drop_column(element, compiler, **kw):
drop_column(compiler, element.column.name, **kw)
)
+
@compiles(ColumnNullable)
def visit_column_nullable(element, compiler, **kw):
return "%s %s %s" % (
@@ -95,6 +117,7 @@ def visit_column_nullable(element, compiler, **kw):
"DROP NOT NULL" if element.nullable else "SET NOT NULL"
)
+
@compiles(ColumnType)
def visit_column_type(element, compiler, **kw):
return "%s %s %s" % (
@@ -103,6 +126,7 @@ def visit_column_type(element, compiler, **kw):
"TYPE %s" % format_type(compiler, element.type_)
)
+
@compiles(ColumnName)
def visit_column_name(element, compiler, **kw):
return "%s RENAME %s TO %s" % (
@@ -111,23 +135,26 @@ def visit_column_name(element, compiler, **kw):
format_column_name(compiler, element.newname)
)
+
@compiles(ColumnDefault)
def visit_column_default(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"SET DEFAULT %s" %
- format_server_default(compiler, element.default)
+ format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT"
)
+
def quote_dotted(name, quote):
"""quote the elements of a dotted name"""
result = '.'.join([quote(x) for x in name.split('.')])
return result
+
def format_table_name(compiler, name, schema):
quote = functools.partial(compiler.preparer.quote, force=None)
if schema:
@@ -135,27 +162,32 @@ def format_table_name(compiler, name, schema):
else:
return quote(name)
+
def format_column_name(compiler, name):
return compiler.preparer.quote(name, None)
+
def format_server_default(compiler, default):
return compiler.get_column_default_string(
- Column("x", Integer, server_default=default)
- )
+ Column("x", Integer, server_default=default)
+ )
+
def format_type(compiler, type_):
return compiler.dialect.type_compiler.process(type_)
+
def alter_table(compiler, name, schema):
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
+
def drop_column(compiler, name):
return 'DROP COLUMN %s' % format_column_name(compiler, name)
+
def alter_column(compiler, name):
return 'ALTER COLUMN %s' % format_column_name(compiler, name)
+
def add_column(compiler, column, **kw):
return "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
-
-
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 664158f..a22a4fb 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -8,7 +8,9 @@ from ..compat import string_types, text_type, with_metaclass
from .. import util
from . import base
+
class ImplMeta(type):
+
def __init__(cls, classname, bases, dict_):
newtype = type.__init__(cls, classname, bases, dict_)
if '__dialect__' in dict_:
@@ -17,7 +19,9 @@ class ImplMeta(type):
_impls = {}
+
class DefaultImpl(with_metaclass(ImplMeta)):
+
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
@@ -35,8 +39,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
command_terminator = ";"
def __init__(self, dialect, connection, as_sql,
- transactional_ddl, output_buffer,
- context_opts):
+ transactional_ddl, output_buffer,
+ context_opts):
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
@@ -59,8 +63,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
return self.connection
def _exec(self, construct, execution_options=None,
- multiparams=(),
- params=util.immutabledict()):
+ multiparams=(),
+ params=util.immutabledict()):
if isinstance(construct, string_types):
construct = text(construct)
if self.as_sql:
@@ -68,8 +72,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
self.static_output(text_type(
- construct.compile(dialect=self.dialect)
- ).replace("\t", " ").strip() + self.command_terminator)
+ construct.compile(dialect=self.dialect)
+ ).replace("\t", " ").strip() + self.command_terminator)
else:
conn = self.connection
if execution_options:
@@ -80,49 +84,49 @@ class DefaultImpl(with_metaclass(ImplMeta)):
self._exec(sql, execution_options)
def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None
- ):
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ autoincrement=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ existing_autoincrement=None
+ ):
if autoincrement is not None or existing_autoincrement is not None:
util.warn("nautoincrement and existing_autoincrement only make sense for MySQL")
if nullable is not None:
self._exec(base.ColumnNullable(table_name, column_name,
- nullable, schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ nullable, schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ ))
if server_default is not False:
self._exec(base.ColumnDefault(
- table_name, column_name, server_default,
- schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ table_name, column_name, server_default,
+ schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ ))
if type_ is not None:
self._exec(base.ColumnType(
- table_name, column_name, type_, schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ table_name, column_name, type_, schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ ))
# do the new name last ;)
if name is not None:
self._exec(base.ColumnName(
- table_name, column_name, name, schema=schema,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- ))
+ table_name, column_name, name, schema=schema,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ ))
def add_column(self, table_name, column, schema=None):
self._exec(base.AddColumn(table_name, column, schema=schema))
@@ -132,7 +136,7 @@ class DefaultImpl(with_metaclass(ImplMeta)):
def add_constraint(self, const):
if const._create_rule is None or \
- const._create_rule(self):
+ const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const):
@@ -140,18 +144,18 @@ class DefaultImpl(with_metaclass(ImplMeta)):
def rename_table(self, old_table_name, new_table_name, schema=None):
self._exec(base.RenameTable(old_table_name,
- new_table_name, schema=schema))
+ new_table_name, schema=schema))
def create_table(self, table):
if util.sqla_07:
table.dispatch.before_create(table, self.connection,
- checkfirst=False,
- _ddl_runner=self)
+ checkfirst=False,
+ _ddl_runner=self)
self._exec(schema.CreateTable(table))
if util.sqla_07:
table.dispatch.after_create(table, self.connection,
checkfirst=False,
- _ddl_runner=self)
+ _ddl_runner=self)
for index in table.indexes:
self._exec(schema.CreateIndex(index))
@@ -200,8 +204,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
metadata_impl.__dict__.pop('_type_affinity', None)
if conn_type._compare_type_affinity(
- metadata_impl
- ):
+ metadata_impl
+ ):
comparator = _type_comparators.get(conn_type._type_affinity, None)
return comparator and comparator(metadata_type, conn_type)
@@ -209,9 +213,9 @@ class DefaultImpl(with_metaclass(ImplMeta)):
return True
def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default):
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(self, conn_uniques, conn_indexes,
@@ -247,9 +251,11 @@ class DefaultImpl(with_metaclass(ImplMeta)):
"""
self.static_output("COMMIT" + self.command_terminator)
+
class _literal_bindparam(_BindParamClause):
pass
+
@compiles(_literal_bindparam)
def _render_literal_bindparam(element, compiler, **kw):
return compiler.render_literal_bindparam(element, **kw)
@@ -268,6 +274,7 @@ def _textual_index_column(table, text_):
class _textual_index_element(sql.ColumnElement):
+
"""Wrap around a sqlalchemy text() construct in such a way that
we appear like a column-oriented SQL expression to an Index
construct.
@@ -305,21 +312,18 @@ def _string_compare(t1, t2):
t1.length is not None and \
t1.length != t2.length
+
def _numeric_compare(t1, t2):
return \
(
- t1.precision is not None and \
+ t1.precision is not None and
t1.precision != t2.precision
) or \
(
- t1.scale is not None and \
+ t1.scale is not None and
t1.scale != t2.scale
)
_type_comparators = {
- sqltypes.String:_string_compare,
- sqltypes.Numeric:_numeric_compare
+ sqltypes.String: _string_compare,
+ sqltypes.Numeric: _numeric_compare
}
-
-
-
-
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
index a3c67d6..d6c835c 100644
--- a/alembic/ddl/mssql.py
+++ b/alembic/ddl/mssql.py
@@ -4,9 +4,10 @@ from .. import util
from .impl import DefaultImpl
from .base import alter_table, AddColumn, ColumnName, RenameTable,\
format_table_name, format_column_name, ColumnNullable, alter_column,\
- format_server_default,ColumnDefault, format_type, ColumnType
+ format_server_default, ColumnDefault, format_type, ColumnType
from sqlalchemy.sql.expression import ClauseElement, Executable
+
class MSSQLImpl(DefaultImpl):
__dialect__ = 'mssql'
transactional_ddl = True
@@ -15,8 +16,8 @@ class MSSQLImpl(DefaultImpl):
def __init__(self, *arg, **kw):
super(MSSQLImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
- "mssql_batch_separator",
- self.batch_separator)
+ "mssql_batch_separator",
+ self.batch_separator)
def _exec(self, construct, *args, **kw):
super(MSSQLImpl, self)._exec(construct, *args, **kw)
@@ -32,17 +33,17 @@ class MSSQLImpl(DefaultImpl):
self.static_output(self.batch_separator)
def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None
- ):
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ autoincrement=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ existing_autoincrement=None
+ ):
if nullable is not None and existing_type is None:
if type_ is not None:
@@ -52,70 +53,69 @@ class MSSQLImpl(DefaultImpl):
type_ = None
else:
raise util.CommandError(
- "MS-SQL ALTER COLUMN operations "
- "with NULL or NOT NULL require the "
- "existing_type or a new type_ be passed.")
+ "MS-SQL ALTER COLUMN operations "
+ "with NULL or NOT NULL require the "
+ "existing_type or a new type_ be passed.")
super(MSSQLImpl, self).alter_column(
- table_name, column_name,
- nullable=nullable,
- type_=type_,
- schema=schema,
- autoincrement=autoincrement,
- existing_type=existing_type,
- existing_nullable=existing_nullable,
- existing_autoincrement=existing_autoincrement
+ table_name, column_name,
+ nullable=nullable,
+ type_=type_,
+ schema=schema,
+ autoincrement=autoincrement,
+ existing_type=existing_type,
+ existing_nullable=existing_nullable,
+ existing_autoincrement=existing_autoincrement
)
if server_default is not False:
if existing_server_default is not False or \
- server_default is None:
+ server_default is None:
self._exec(
_ExecDropConstraint(
- table_name, column_name,
- 'sys.default_constraints')
+ table_name, column_name,
+ 'sys.default_constraints')
)
if server_default is not None:
super(MSSQLImpl, self).alter_column(
- table_name, column_name,
- schema=schema,
- server_default=server_default)
+ table_name, column_name,
+ schema=schema,
+ server_default=server_default)
if name is not None:
super(MSSQLImpl, self).alter_column(
- table_name, column_name,
- schema=schema,
- name=name)
+ table_name, column_name,
+ schema=schema,
+ name=name)
def bulk_insert(self, table, rows, **kw):
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(table)
+ self.dialect.identifier_preparer.format_table(table)
)
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
self._exec(
"SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer.format_table(table)
+ self.dialect.identifier_preparer.format_table(table)
)
else:
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
-
def drop_column(self, table_name, column, **kw):
drop_default = kw.pop('mssql_drop_default', False)
if drop_default:
self._exec(
_ExecDropConstraint(
- table_name, column,
- 'sys.default_constraints')
+ table_name, column,
+ 'sys.default_constraints')
)
drop_check = kw.pop('mssql_drop_check', False)
if drop_check:
self._exec(
_ExecDropConstraint(
- table_name, column,
- 'sys.check_constraints')
+ table_name, column,
+ 'sys.check_constraints')
)
drop_fks = kw.pop('mssql_drop_foreign_key', False)
if drop_fks:
@@ -124,13 +124,17 @@ class MSSQLImpl(DefaultImpl):
)
super(MSSQLImpl, self).drop_column(table_name, column)
+
class _ExecDropConstraint(Executable, ClauseElement):
+
def __init__(self, tname, colname, type_):
self.tname = tname
self.colname = colname
self.type_ = type_
+
class _ExecDropFKConstraint(Executable, ClauseElement):
+
def __init__(self, tname, colname):
self.tname = tname
self.colname = colname
@@ -152,6 +156,7 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
'tname_quoted': format_table_name(compiler, tname, None),
}
+
@compiles(_ExecDropFKConstraint, 'mssql')
def _exec_drop_col_fk_constraint(element, compiler, **kw):
tname, colname = element.tname, element.colname
@@ -169,7 +174,6 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
}
-
@compiles(AddColumn, 'mssql')
def visit_add_column(element, compiler, **kw):
return "%s %s" % (
@@ -177,9 +181,11 @@ def visit_add_column(element, compiler, **kw):
mssql_add_column(compiler, element.column, **kw)
)
+
def mssql_add_column(compiler, column, **kw):
return "ADD %s" % compiler.get_column_specification(column, **kw)
+
@compiles(ColumnNullable, 'mssql')
def visit_column_nullable(element, compiler, **kw):
return "%s %s %s %s" % (
@@ -189,6 +195,7 @@ def visit_column_nullable(element, compiler, **kw):
"NULL" if element.nullable else "NOT NULL"
)
+
@compiles(ColumnDefault, 'mssql')
def visit_column_default(element, compiler, **kw):
# TODO: there can also be a named constraint
@@ -199,6 +206,7 @@ def visit_column_default(element, compiler, **kw):
format_column_name(compiler, element.column_name)
)
+
@compiles(ColumnName, 'mssql')
def visit_rename_column(element, compiler, **kw):
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
@@ -207,6 +215,7 @@ def visit_rename_column(element, compiler, **kw):
format_column_name(compiler, element.newname)
)
+
@compiles(ColumnType, 'mssql')
def visit_column_type(element, compiler, **kw):
return "%s %s %s" % (
@@ -215,6 +224,7 @@ def visit_column_type(element, compiler, **kw):
format_type(compiler, element.type_)
)
+
@compiles(RenameTable, 'mssql')
def visit_rename_table(element, compiler, **kw):
return "EXEC sp_rename '%s', %s" % (
diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py
index 58d5c70..7545df7 100644
--- a/alembic/ddl/mysql.py
+++ b/alembic/ddl/mysql.py
@@ -6,27 +6,28 @@ from ..compat import string_types
from .. import util
from .impl import DefaultImpl
from .base import ColumnNullable, ColumnName, ColumnDefault, \
- ColumnType, AlterColumn, format_column_name, \
- format_server_default
+ ColumnType, AlterColumn, format_column_name, \
+ format_server_default
from .base import alter_table
+
class MySQLImpl(DefaultImpl):
__dialect__ = 'mysql'
transactional_ddl = False
def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None
- ):
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ schema=None,
+ autoincrement=None,
+ existing_type=None,
+ existing_server_default=None,
+ existing_nullable=None,
+ existing_autoincrement=None
+ ):
if name is not None:
self._exec(
MySQLChangeColumn(
@@ -34,33 +35,33 @@ class MySQLImpl(DefaultImpl):
schema=schema,
newname=name,
nullable=nullable if nullable is not None else
- existing_nullable
- if existing_nullable is not None
- else True,
+ existing_nullable
+ if existing_nullable is not None
+ else True,
type_=type_ if type_ is not None else existing_type,
default=server_default if server_default is not False
- else existing_server_default,
+ else existing_server_default,
autoincrement=autoincrement if autoincrement is not None
- else existing_autoincrement
+ else existing_autoincrement
)
)
elif nullable is not None or \
- type_ is not None or \
- autoincrement is not None:
+ type_ is not None or \
+ autoincrement is not None:
self._exec(
MySQLModifyColumn(
table_name, column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable if nullable is not None else
- existing_nullable
- if existing_nullable is not None
- else True,
+ existing_nullable
+ if existing_nullable is not None
+ else True,
type_=type_ if type_ is not None else existing_type,
default=server_default if server_default is not False
- else existing_server_default,
+ else existing_server_default,
autoincrement=autoincrement if autoincrement is not None
- else existing_autoincrement
+ else existing_autoincrement
)
)
elif server_default is not False:
@@ -99,7 +100,9 @@ class MySQLImpl(DefaultImpl):
if idx.name in removed:
metadata_indexes.remove(idx)
+
class MySQLAlterDefault(AlterColumn):
+
def __init__(self, name, column_name, default, schema=None):
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
@@ -107,12 +110,13 @@ class MySQLAlterDefault(AlterColumn):
class MySQLChangeColumn(AlterColumn):
+
def __init__(self, name, column_name, schema=None,
- newname=None,
- type_=None,
- nullable=None,
- default=False,
- autoincrement=None):
+ newname=None,
+ type_=None,
+ nullable=None,
+ default=False,
+ autoincrement=None):
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
@@ -127,6 +131,7 @@ class MySQLChangeColumn(AlterColumn):
self.type_ = sqltypes.to_instance(type_)
+
class MySQLModifyColumn(MySQLChangeColumn):
pass
@@ -137,8 +142,8 @@ class MySQLModifyColumn(MySQLChangeColumn):
@compiles(ColumnType, 'mysql')
def _mysql_doesnt_support_individual(element, compiler, **kw):
raise NotImplementedError(
- "Individual alter column constructs not supported by MySQL"
- )
+ "Individual alter column constructs not supported by MySQL"
+ )
@compiles(MySQLAlterDefault, "mysql")
@@ -147,10 +152,11 @@ def _mysql_alter_default(element, compiler, **kw):
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
- if element.default is not None
- else "DROP DEFAULT"
+ if element.default is not None
+ else "DROP DEFAULT"
)
+
@compiles(MySQLModifyColumn, "mysql")
def _mysql_modify_column(element, compiler, **kw):
return "%s MODIFY %s %s" % (
@@ -181,14 +187,16 @@ def _mysql_change_column(element, compiler, **kw):
),
)
+
def _render_value(compiler, expr):
if isinstance(expr, string_types):
return "'%s'" % expr
else:
return compiler.sql_compiler.process(expr)
+
def _mysql_colspec(compiler, nullable, server_default, type_,
- autoincrement):
+ autoincrement):
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL"
@@ -200,6 +208,7 @@ def _mysql_colspec(compiler, nullable, server_default, type_,
return spec
+
@compiles(schema.DropConstraint, "mysql")
def _mysql_drop_constraint(element, compiler, **kw):
"""Redefine SQLAlchemy's drop constraint to
@@ -207,15 +216,14 @@ def _mysql_drop_constraint(element, compiler, **kw):
constraint = element.element
if isinstance(constraint, (schema.ForeignKeyConstraint,
- schema.PrimaryKeyConstraint,
- schema.UniqueConstraint)
- ):
+ schema.PrimaryKeyConstraint,
+ schema.UniqueConstraint)
+ ):
return compiler.visit_drop_constraint(element, **kw)
elif isinstance(constraint, schema.CheckConstraint):
raise NotImplementedError(
- "MySQL does not support CHECK constraints.")
+ "MySQL does not support CHECK constraints.")
else:
raise NotImplementedError(
- "No generic 'DROP CONSTRAINT' in MySQL - "
- "please specify constraint type")
-
+ "No generic 'DROP CONSTRAINT' in MySQL - "
+ "please specify constraint type")
diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py
index 28eb246..93e71e5 100644
--- a/alembic/ddl/oracle.py
+++ b/alembic/ddl/oracle.py
@@ -3,7 +3,8 @@ from sqlalchemy.ext.compiler import compiles
from .impl import DefaultImpl
from .base import alter_table, AddColumn, ColumnName, \
format_column_name, ColumnNullable, \
- format_server_default,ColumnDefault, format_type, ColumnType
+ format_server_default, ColumnDefault, format_type, ColumnType
+
class OracleImpl(DefaultImpl):
__dialect__ = 'oracle'
@@ -14,8 +15,8 @@ class OracleImpl(DefaultImpl):
def __init__(self, *arg, **kw):
super(OracleImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
- "oracle_batch_separator",
- self.batch_separator)
+ "oracle_batch_separator",
+ self.batch_separator)
def _exec(self, construct, *args, **kw):
super(OracleImpl, self)._exec(construct, *args, **kw)
@@ -28,6 +29,7 @@ class OracleImpl(DefaultImpl):
def emit_commit(self):
self._exec("COMMIT")
+
@compiles(AddColumn, 'oracle')
def visit_add_column(element, compiler, **kw):
return "%s %s" % (
@@ -35,6 +37,7 @@ def visit_add_column(element, compiler, **kw):
add_column(compiler, element.column, **kw),
)
+
@compiles(ColumnNullable, 'oracle')
def visit_column_nullable(element, compiler, **kw):
return "%s %s %s" % (
@@ -43,6 +46,7 @@ def visit_column_nullable(element, compiler, **kw):
"NULL" if element.nullable else "NOT NULL"
)
+
@compiles(ColumnType, 'oracle')
def visit_column_type(element, compiler, **kw):
return "%s %s %s" % (
@@ -51,6 +55,7 @@ def visit_column_type(element, compiler, **kw):
"%s" % format_type(compiler, element.type_)
)
+
@compiles(ColumnName, 'oracle')
def visit_column_name(element, compiler, **kw):
return "%s RENAME COLUMN %s TO %s" % (
@@ -59,19 +64,22 @@ def visit_column_name(element, compiler, **kw):
format_column_name(compiler, element.newname)
)
+
@compiles(ColumnDefault, 'oracle')
def visit_column_default(element, compiler, **kw):
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DEFAULT %s" %
- format_server_default(compiler, element.default)
+ format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL"
)
+
def alter_column(compiler, name):
return 'MODIFY %s' % format_column_name(compiler, name)
+
def add_column(compiler, column, **kw):
return "ADD %s" % compiler.get_column_specification(column, **kw)
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 27f31b0..eab1f4d 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -5,18 +5,19 @@ from .. import compat
from .base import compiles, alter_table, format_table_name, RenameTable
from .impl import DefaultImpl
+
class PostgresqlImpl(DefaultImpl):
__dialect__ = 'postgresql'
transactional_ddl = True
def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default):
# don't do defaults for SERIAL columns
if metadata_column.primary_key and \
- metadata_column is metadata_column.table._autoincrement_column:
+ metadata_column is metadata_column.table._autoincrement_column:
return False
conn_col_default = rendered_inspector_default
@@ -26,7 +27,7 @@ class PostgresqlImpl(DefaultImpl):
if metadata_column.server_default is not None and \
isinstance(metadata_column.server_default.arg,
- compat.string_types) and \
+ compat.string_types) and \
not re.match(r"^'.+'$", rendered_metadata_default):
rendered_metadata_default = "'%s'" % rendered_metadata_default
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
index 85c829e..1a00be1 100644
--- a/alembic/ddl/sqlite.py
+++ b/alembic/ddl/sqlite.py
@@ -6,6 +6,7 @@ import re
#from .base import AddColumn, alter_table
#from sqlalchemy.schema import AddConstraint
+
class SQLiteImpl(DefaultImpl):
__dialect__ = 'sqlite'
@@ -19,21 +20,20 @@ class SQLiteImpl(DefaultImpl):
# auto-gen constraint and an explicit one
if const._create_rule is None:
raise NotImplementedError(
- "No support for ALTER of constraints in SQLite dialect")
+ "No support for ALTER of constraints in SQLite dialect")
elif const._create_rule(self):
util.warn("Skipping unsupported ALTER for "
- "creation of implicit constraint")
-
+ "creation of implicit constraint")
def drop_constraint(self, const):
if const._create_rule is None:
raise NotImplementedError(
- "No support for ALTER of constraints in SQLite dialect")
+ "No support for ALTER of constraints in SQLite dialect")
def compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default):
+ metadata_column,
+ rendered_metadata_default,
+ rendered_inspector_default):
rendered_metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
return rendered_inspector_default != repr(rendered_metadata_default)
@@ -46,9 +46,9 @@ class SQLiteImpl(DefaultImpl):
return tuple(sorted(uq.columns.keys()))
conn_unique_sigs = set(
- uq_sig(uq)
- for uq in conn_unique_constraints
- )
+ uq_sig(uq)
+ for uq in conn_unique_constraints
+ )
for idx in list(metadata_unique_constraints):
# SQLite backend can't report on unnamed UNIQUE constraints,
@@ -65,18 +65,18 @@ class SQLiteImpl(DefaultImpl):
conn_uniques.remove(idx)
#@compiles(AddColumn, 'sqlite')
-#def visit_add_column(element, compiler, **kw):
+# def visit_add_column(element, compiler, **kw):
# return "%s %s" % (
# alter_table(compiler, element.table_name, element.schema),
# add_column(compiler, element.column, **kw)
# )
-#def add_column(compiler, column, **kw):
+# def add_column(compiler, column, **kw):
# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
-# # need to modify SQLAlchemy so that the CHECK associated with a Boolean
-# # or Enum gets placed as part of the column constraints, not the Table
-# # see ticket 98
+# need to modify SQLAlchemy so that the CHECK associated with a Boolean
+# or Enum gets placed as part of the column constraints, not the Table
+# see ticket 98
# for const in column.constraints:
# text += compiler.process(AddConstraint(const))
# return text
diff --git a/alembic/environment.py b/alembic/environment.py
index c3e7a38..405e2f2 100644
--- a/alembic/environment.py
+++ b/alembic/environment.py
@@ -2,7 +2,9 @@ from .operations import Operations
from .migration import MigrationContext
from . import util
+
class EnvironmentContext(object):
+
"""Represent the state made available to an ``env.py`` script.
:class:`.EnvironmentContext` is normally instantiated
@@ -156,13 +158,13 @@ class EnvironmentContext(object):
"""
if self._migration_context is not None:
return self.script._as_rev_number(
- self.get_context()._start_from_rev)
+ self.get_context()._start_from_rev)
elif 'starting_rev' in self.context_opts:
return self.script._as_rev_number(
- self.context_opts['starting_rev'])
+ self.context_opts['starting_rev'])
else:
raise util.CommandError(
- "No starting revision argument is available.")
+ "No starting revision argument is available.")
def get_revision_argument(self):
"""Get the 'destination' revision argument.
@@ -179,7 +181,7 @@ class EnvironmentContext(object):
"""
return self.script._as_rev_number(
- self.context_opts['destination_rev'])
+ self.context_opts['destination_rev'])
def get_tag_argument(self):
"""Return the value passed for the ``--tag`` argument, if any.
@@ -247,34 +249,34 @@ class EnvironmentContext(object):
value = []
if as_dictionary:
value = dict(
- arg.split('=', 1) for arg in value
- )
+ arg.split('=', 1) for arg in value
+ )
return value
def configure(self,
- connection=None,
- url=None,
- dialect_name=None,
- transactional_ddl=None,
- transaction_per_migration=False,
- output_buffer=None,
- starting_rev=None,
- tag=None,
- template_args=None,
- target_metadata=None,
- include_symbol=None,
- include_object=None,
- include_schemas=False,
- compare_type=False,
- compare_server_default=False,
- render_item=None,
- upgrade_token="upgrades",
- downgrade_token="downgrades",
- alembic_module_prefix="op.",
- sqlalchemy_module_prefix="sa.",
- user_module_prefix=None,
- **kw
- ):
+ connection=None,
+ url=None,
+ dialect_name=None,
+ transactional_ddl=None,
+ transaction_per_migration=False,
+ output_buffer=None,
+ starting_rev=None,
+ tag=None,
+ template_args=None,
+ target_metadata=None,
+ include_symbol=None,
+ include_object=None,
+ include_schemas=False,
+ compare_type=False,
+ compare_server_default=False,
+ render_item=None,
+ upgrade_token="upgrades",
+ downgrade_token="downgrades",
+ alembic_module_prefix="op.",
+ sqlalchemy_module_prefix="sa.",
+ user_module_prefix=None,
+ **kw
+ ):
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
connectivity and other configuration to a series of
@@ -701,7 +703,7 @@ class EnvironmentContext(object):
"""
self.get_context().execute(sql,
- execution_options=execution_options)
+ execution_options=execution_options)
def static_output(self, text):
"""Emit text directly to the "offline" SQL stream.
@@ -714,7 +716,6 @@ class EnvironmentContext(object):
"""
self.get_context().impl.static_output(text)
-
def begin_transaction(self):
"""Return a context manager that will
enclose an operation within a "transaction",
@@ -761,7 +762,6 @@ class EnvironmentContext(object):
return self.get_context().begin_transaction()
-
def get_context(self):
"""Return the current :class:`.MigrationContext` object.
@@ -789,4 +789,3 @@ class EnvironmentContext(object):
def get_impl(self):
return self.get_context().impl
-
diff --git a/alembic/migration.py b/alembic/migration.py
index dadf49a..0c91fd1 100644
--- a/alembic/migration.py
+++ b/alembic/migration.py
@@ -13,7 +13,9 @@ from . import ddl, util
log = logging.getLogger(__name__)
+
class MigrationContext(object):
+
"""Represent the database state made available to a migration
script.
@@ -58,6 +60,7 @@ class MigrationContext(object):
op.alter_column("mytable", "somecolumn", nullable=True)
"""
+
def __init__(self, dialect, connection, opts, environment_context=None):
self.environment_context = environment_context
self.opts = opts
@@ -68,7 +71,7 @@ class MigrationContext(object):
transactional_ddl = opts.get("transactional_ddl")
self._transaction_per_migration = opts.get(
- "transaction_per_migration", False)
+ "transaction_per_migration", False)
if as_sql:
self.connection = self._stdout_connection(connection)
@@ -88,8 +91,8 @@ class MigrationContext(object):
self._user_compare_type = opts.get('compare_type', False)
self._user_compare_server_default = opts.get(
- 'compare_server_default',
- False)
+ 'compare_server_default',
+ False)
version_table = opts.get('version_table', 'alembic_version')
version_table_schema = opts.get('version_table_schema', None)
self._version = Table(
@@ -99,26 +102,26 @@ class MigrationContext(object):
self._start_from_rev = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
- dialect, self.connection, self.as_sql,
- transactional_ddl,
- self.output_buffer,
- opts
- )
+ dialect, self.connection, self.as_sql,
+ transactional_ddl,
+ self.output_buffer,
+ opts
+ )
log.info("Context impl %s.", self.impl.__class__.__name__)
if self.as_sql:
log.info("Generating static SQL")
log.info("Will assume %s DDL.",
- "transactional" if self.impl.transactional_ddl
- else "non-transactional")
+ "transactional" if self.impl.transactional_ddl
+ else "non-transactional")
@classmethod
def configure(cls,
- connection=None,
- url=None,
- dialect_name=None,
- environment_context=None,
- opts=None,
- ):
+ connection=None,
+ url=None,
+ dialect_name=None,
+ environment_context=None,
+ opts=None,
+ ):
"""Create a new :class:`.MigrationContext`.
This is a factory method usually called
@@ -155,7 +158,6 @@ class MigrationContext(object):
return MigrationContext(dialect, connection, opts, environment_context)
-
def begin_transaction(self, _per_migration=False):
transaction_now = _per_migration == self._transaction_per_migration
@@ -209,12 +211,12 @@ class MigrationContext(object):
self.impl._exec(self._version.delete())
elif old is None:
self.impl._exec(self._version.insert().
- values(version_num=literal_column("'%s'" % new))
- )
+ values(version_num=literal_column("'%s'" % new))
+ )
else:
self.impl._exec(self._version.update().
- values(version_num=literal_column("'%s'" % new))
- )
+ values(version_num=literal_column("'%s'" % new))
+ )
def run_migrations(self, **kw):
"""Run the migration scripts established for this :class:`.MigrationContext`,
@@ -239,12 +241,12 @@ class MigrationContext(object):
"""
current_rev = rev = False
stamp_per_migration = not self.impl.transactional_ddl or \
- self._transaction_per_migration
+ self._transaction_per_migration
self.impl.start_migrations()
for change, prev_rev, rev, doc in self._migrations_fn(
- self.get_current_revision(),
- self):
+ self.get_current_revision(),
+ self):
with self.begin_transaction(_per_migration=True):
if current_rev is False:
current_rev = prev_rev
@@ -252,14 +254,14 @@ class MigrationContext(object):
self._version.create(self.connection)
if doc:
log.info("Running %s %s -> %s, %s", change.__name__, prev_rev,
- rev, doc)
+ rev, doc)
else:
log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
if self.as_sql:
self.impl.static_output(
- "-- Running %s %s -> %s" %
- (change.__name__, prev_rev, rev)
- )
+ "-- Running %s %s -> %s" %
+ (change.__name__, prev_rev, rev)
+ )
change(**kw)
if stamp_per_migration:
self._update_current_rev(prev_rev, rev)
@@ -288,7 +290,7 @@ class MigrationContext(object):
self.impl._exec(construct)
return create_engine("%s://" % self.dialect.name,
- strategy="mock", executor=dump)
+ strategy="mock", executor=dump)
@property
def bind(self):
@@ -338,32 +340,31 @@ class MigrationContext(object):
return user_value
return self.impl.compare_type(
- inspector_column,
- metadata_column)
+ inspector_column,
+ metadata_column)
def _compare_server_default(self, inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default):
+ metadata_column,
+ rendered_metadata_default,
+ rendered_column_default):
if self._user_compare_server_default is False:
return False
if callable(self._user_compare_server_default):
user_value = self._user_compare_server_default(
- self,
- inspector_column,
- metadata_column,
- rendered_column_default,
- metadata_column.server_default,
- rendered_metadata_default
+ self,
+ inspector_column,
+ metadata_column,
+ rendered_column_default,
+ metadata_column.server_default,
+ rendered_metadata_default
)
if user_value is not None:
return user_value
return self.impl.compare_server_default(
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default)
-
+ inspector_column,
+ metadata_column,
+ rendered_metadata_default,
+ rendered_column_default)
diff --git a/alembic/operations.py b/alembic/operations.py
index d028688..a1f3dee 100644
--- a/alembic/operations.py
+++ b/alembic/operations.py
@@ -14,7 +14,9 @@ try:
except:
conv = None
+
class Operations(object):
+
"""Define high level migration operations.
Each operation corresponds to some schema migration operation,
@@ -39,6 +41,7 @@ class Operations(object):
op.alter_column("t", "c", nullable=True)
"""
+
def __init__(self, migration_context):
"""Construct a new :class:`.Operations`
@@ -58,57 +61,56 @@ class Operations(object):
yield op
_remove_proxy()
-
def _primary_key_constraint(self, name, table_name, cols, schema=None):
m = self._metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
t1 = sa_schema.Table(table_name, m,
- *columns,
- schema=schema)
+ *columns,
+ schema=schema)
p = sa_schema.PrimaryKeyConstraint(*columns, name=name)
t1.append_constraint(p)
return p
def _foreign_key_constraint(self, name, source, referent,
- local_cols, remote_cols,
- onupdate=None, ondelete=None,
- deferrable=None, source_schema=None,
- referent_schema=None, initially=None,
- match=None, **dialect_kw):
+ local_cols, remote_cols,
+ onupdate=None, ondelete=None,
+ deferrable=None, source_schema=None,
+ referent_schema=None, initially=None,
+ match=None, **dialect_kw):
m = self._metadata()
if source == referent:
t1_cols = local_cols + remote_cols
else:
t1_cols = local_cols
sa_schema.Table(referent, m,
- *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
- schema=referent_schema)
+ *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
+ schema=referent_schema)
t1 = sa_schema.Table(source, m,
- *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
- schema=source_schema)
+ *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
+ schema=source_schema)
tname = "%s.%s" % (referent_schema, referent) if referent_schema \
else referent
f = sa_schema.ForeignKeyConstraint(local_cols,
- ["%s.%s" % (tname, n)
+ ["%s.%s" % (tname, n)
for n in remote_cols],
- name=name,
- onupdate=onupdate,
- ondelete=ondelete,
- deferrable=deferrable,
- initially=initially,
- match=match,
- **dialect_kw
- )
+ name=name,
+ onupdate=onupdate,
+ ondelete=ondelete,
+ deferrable=deferrable,
+ initially=initially,
+ match=match,
+ **dialect_kw
+ )
t1.append_constraint(f)
return f
def _unique_constraint(self, name, source, local_cols, schema=None, **kw):
t = sa_schema.Table(source, self._metadata(),
- *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
- schema=schema)
+ *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
+ schema=schema)
kw['name'] = name
uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
# TODO: need event tests to ensure the event
@@ -118,7 +120,7 @@ class Operations(object):
def _check_constraint(self, name, source, condition, schema=None, **kw):
t = sa_schema.Table(source, self._metadata(),
- sa_schema.Column('x', Integer), schema=schema)
+ sa_schema.Column('x', Integer), schema=schema)
ck = sa_schema.CheckConstraint(condition, name=name, **kw)
t.append_constraint(ck)
return ck
@@ -201,17 +203,17 @@ class Operations(object):
@util._with_legacy_names([('name', 'new_column_name')])
def alter_column(self, table_name, column_name,
- nullable=None,
- server_default=False,
- new_column_name=None,
- type_=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_autoincrement=None,
- schema=None
- ):
+ nullable=None,
+ server_default=False,
+ new_column_name=None,
+ type_=None,
+ autoincrement=None,
+ existing_type=None,
+ existing_server_default=False,
+ existing_nullable=None,
+ existing_autoincrement=None,
+ schema=None
+ ):
"""Issue an "alter column" instruction using the
current migration context.
@@ -291,9 +293,10 @@ class Operations(object):
"""
compiler = self.impl.dialect.statement_compiler(
- self.impl.dialect,
- None
- )
+ self.impl.dialect,
+ None
+ )
+
def _count_constraint(constraint):
return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and \
(not constraint._create_rule or
@@ -301,31 +304,31 @@ class Operations(object):
if existing_type and type_:
t = self._table(table_name,
- sa_schema.Column(column_name, existing_type),
- schema=schema
- )
+ sa_schema.Column(column_name, existing_type),
+ schema=schema
+ )
for constraint in t.constraints:
if _count_constraint(constraint):
self.impl.drop_constraint(constraint)
self.impl.alter_column(table_name, column_name,
- nullable=nullable,
- server_default=server_default,
- name=new_column_name,
- type_=type_,
- schema=schema,
- autoincrement=autoincrement,
- existing_type=existing_type,
- existing_server_default=existing_server_default,
- existing_nullable=existing_nullable,
- existing_autoincrement=existing_autoincrement
- )
+ nullable=nullable,
+ server_default=server_default,
+ name=new_column_name,
+ type_=type_,
+ schema=schema,
+ autoincrement=autoincrement,
+ existing_type=existing_type,
+ existing_server_default=existing_server_default,
+ existing_nullable=existing_nullable,
+ existing_autoincrement=existing_autoincrement
+ )
if type_:
t = self._table(table_name,
- sa_schema.Column(column_name, type_),
- schema=schema
- )
+ sa_schema.Column(column_name, type_),
+ schema=schema
+ )
for constraint in t.constraints:
if _count_constraint(constraint):
self.impl.add_constraint(constraint)
@@ -374,7 +377,7 @@ class Operations(object):
return conv(name)
else:
raise NotImplementedError(
- "op.f() feature requires SQLAlchemy 0.9.4 or greater.")
+ "op.f() feature requires SQLAlchemy 0.9.4 or greater.")
def add_column(self, table_name, column, schema=None):
"""Issue an "add column" instruction using the current
@@ -481,7 +484,6 @@ class Operations(object):
**kw
)
-
def create_primary_key(self, name, table_name, cols, schema=None):
"""Issue a "create primary key" instruction using the current
migration context.
@@ -518,10 +520,9 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._primary_key_constraint(name, table_name, cols,
- schema)
- )
-
+ self._primary_key_constraint(name, table_name, cols,
+ schema)
+ )
def create_foreign_key(self, name, source, referent, local_cols,
remote_cols, onupdate=None, ondelete=None,
@@ -573,13 +574,13 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._foreign_key_constraint(name, source, referent,
- local_cols, remote_cols,
- onupdate=onupdate, ondelete=ondelete,
- deferrable=deferrable, source_schema=source_schema,
- referent_schema=referent_schema,
- initially=initially, match=match, **dialect_kw)
- )
+ self._foreign_key_constraint(name, source, referent,
+ local_cols, remote_cols,
+ onupdate=onupdate, ondelete=ondelete,
+ deferrable=deferrable, source_schema=source_schema,
+ referent_schema=referent_schema,
+ initially=initially, match=match, **dialect_kw)
+ )
def create_unique_constraint(self, name, source, local_cols,
schema=None, **kw):
@@ -621,9 +622,9 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._unique_constraint(name, source, local_cols,
- schema=schema, **kw)
- )
+ self._unique_constraint(name, source, local_cols,
+ schema=schema, **kw)
+ )
def create_check_constraint(self, name, source, condition,
schema=None, **kw):
@@ -841,7 +842,7 @@ class Operations(object):
t = self._table(table_name, schema=schema)
types = {
'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
- [], [], name=name),
+ [], [], name=name),
'primary': sa_schema.PrimaryKeyConstraint,
'unique': sa_schema.UniqueConstraint,
'check': lambda name: sa_schema.CheckConstraint("", name=name),
@@ -851,7 +852,7 @@ class Operations(object):
const = types[type_]
except KeyError:
raise TypeError("'type' can be one of %s" %
- ", ".join(sorted(repr(x) for x in types)))
+ ", ".join(sorted(repr(x) for x in types)))
const = const(name=name)
t.append_constraint(const)
@@ -1038,7 +1039,7 @@ class Operations(object):
:meth:`sqlalchemy.engine.Connection.execution_options`.
"""
self.migration_context.impl.execute(sql,
- execution_options=execution_options)
+ execution_options=execution_options)
def get_bind(self):
"""Return the current 'bind'.
@@ -1051,4 +1052,3 @@ class Operations(object):
"""
return self.migration_context.impl.bind
-
diff --git a/alembic/script.py b/alembic/script.py
index ed44f71..a97fc9c 100644
--- a/alembic/script.py
+++ b/alembic/script.py
@@ -12,7 +12,9 @@ _slug_re = re.compile(r'\w+')
_default_file_template = "%(rev)s_%(slug)s"
_relative_destination = re.compile(r'(?:\+|-)\d+')
+
class ScriptDirectory(object):
+
"""Provides operations upon an Alembic script directory.
This object is useful to get information as to current revisions,
@@ -31,9 +33,10 @@ class ScriptDirectory(object):
"""
+
def __init__(self, dir, file_template=_default_file_template,
- truncate_slug_length=40,
- sourceless=False):
+ truncate_slug_length=40,
+ sourceless=False):
self.dir = dir
self.versions = os.path.join(self.dir, 'versions')
self.file_template = file_template
@@ -42,8 +45,8 @@ class ScriptDirectory(object):
if not os.access(dir, os.F_OK):
raise util.CommandError("Path doesn't exist: %r. Please use "
- "the 'init' command to create a new "
- "scripts folder." % dir)
+ "the 'init' command to create a new "
+ "scripts folder." % dir)
@classmethod
def from_config(cls, config):
@@ -62,13 +65,13 @@ class ScriptDirectory(object):
if truncate_slug_length is not None:
truncate_slug_length = int(truncate_slug_length)
return ScriptDirectory(
- util.coerce_resource_to_filename(script_location),
- file_template=config.get_main_option(
- 'file_template',
- _default_file_template),
- truncate_slug_length=truncate_slug_length,
- sourceless=config.get_main_option("sourceless") == "true"
- )
+ util.coerce_resource_to_filename(script_location),
+ file_template=config.get_main_option(
+ 'file_template',
+ _default_file_template),
+ truncate_slug_length=truncate_slug_length,
+ sourceless=config.get_main_option("sourceless") == "true"
+ )
def walk_revisions(self, base="base", head="head"):
"""Iterate through all revisions.
@@ -108,11 +111,11 @@ class ScriptDirectory(object):
raise util.CommandError("No such revision '%s'" % id_)
elif len(revs) > 1:
raise util.CommandError(
- "Multiple revisions start "
- "with '%s', %s..." % (
- id_,
- ", ".join("'%s'" % r for r in revs[0:3])
- ))
+ "Multiple revisions start "
+ "with '%s', %s..." % (
+ id_,
+ ", ".join("'%s'" % r for r in revs[0:3])
+ ))
else:
return self._revision_map[revs[0]]
@@ -148,7 +151,7 @@ class ScriptDirectory(object):
revs = revs[-relative:]
if len(revs) != abs(relative):
raise util.CommandError("Relative revision %s didn't "
- "produce %d migrations" % (upper, abs(relative)))
+ "produce %d migrations" % (upper, abs(relative)))
return iter(revs)
elif lower is not None and _relative_destination.match(lower):
relative = int(lower)
@@ -156,7 +159,7 @@ class ScriptDirectory(object):
revs = revs[0:-relative]
if len(revs) != abs(relative):
raise util.CommandError("Relative revision %s didn't "
- "produce %d migrations" % (lower, abs(relative)))
+ "produce %d migrations" % (lower, abs(relative)))
return iter(revs)
else:
return self._iterate_revisions(upper, lower)
@@ -165,12 +168,12 @@ class ScriptDirectory(object):
lower = self.get_revision(lower)
upper = self.get_revision(upper)
orig = lower.revision if lower else 'base', \
- upper.revision if upper else 'base'
+ upper.revision if upper else 'base'
script = upper
while script != lower:
if script is None and lower is not None:
raise util.CommandError(
- "Revision %s is not an ancestor of %s" % orig)
+ "Revision %s is not an ancestor of %s" % orig)
yield script
downrev = script.down_revision
script = self._revision_map[downrev]
@@ -181,7 +184,7 @@ class ScriptDirectory(object):
(script.module.upgrade, script.down_revision, script.revision,
script.doc)
for script in reversed(list(revs))
- ]
+ ]
def _downgrade_revs(self, destination, current_rev):
revs = self.iterate_revisions(current_rev, destination)
@@ -189,7 +192,7 @@ class ScriptDirectory(object):
(script.module.downgrade, script.revision, script.down_revision,
script.doc)
for script in revs
- ]
+ ]
def run_env(self):
"""Run the script environment.
@@ -216,14 +219,14 @@ class ScriptDirectory(object):
continue
if script.revision in map_:
util.warn("Revision %s is present more than once" %
- script.revision)
+ script.revision)
map_[script.revision] = script
for rev in map_.values():
if rev.down_revision is None:
continue
if rev.down_revision not in map_:
util.warn("Revision %s referenced from %s is not present"
- % (rev.down_revision, rev))
+ % (rev.down_revision, rev))
rev.down_revision = None
else:
map_[rev.down_revision].add_nextrev(rev.revision)
@@ -260,10 +263,10 @@ class ScriptDirectory(object):
current_heads = self.get_heads()
if len(current_heads) > 1:
raise util.CommandError('Only a single head is supported. The '
- 'script directory has multiple heads (due to branching), which '
- 'must be resolved by manually editing the revision files to '
- 'form a linear sequence. Run `alembic branches` to see the '
- 'divergence(s).')
+ 'script directory has multiple heads (due to branching), which '
+ 'must be resolved by manually editing the revision files to '
+ 'form a linear sequence. Run `alembic branches` to see the '
+ 'divergence(s).')
if current_heads:
return current_heads[0]
@@ -303,18 +306,18 @@ class ScriptDirectory(object):
"""
for script in self._revision_map.values():
if script and script.down_revision is None \
- and script.revision in self._revision_map:
+ and script.revision in self._revision_map:
return script.revision
else:
return None
def _generate_template(self, src, dest, **kw):
util.status("Generating %s" % os.path.abspath(dest),
- util.template_to_file,
- src,
- dest,
- **kw
- )
+ util.template_to_file,
+ src,
+ dest,
+ **kw
+ )
def _copy_file(self, src, dest):
util.status("Generating %s" % os.path.abspath(dest),
@@ -357,13 +360,14 @@ class ScriptDirectory(object):
self._revision_map[script.revision] = script
if script.down_revision:
self._revision_map[script.down_revision].\
- add_nextrev(script.revision)
+ add_nextrev(script.revision)
return script
else:
return None
class Script(object):
+
"""Represent a single revision file in a ``versions/`` directory.
The :class:`.Script` instance is returned by methods
@@ -455,11 +459,11 @@ class Script(object):
def __str__(self):
return "%s -> %s%s%s, %s" % (
- self.down_revision,
- self.revision,
- " (head)" if self.is_head else "",
- " (branchpoint)" if self.is_branch_point else "",
- self.doc)
+ self.down_revision,
+ self.revision,
+ " (head)" if self.is_head else "",
+ " (branchpoint)" if self.is_branch_point else "",
+ self.doc)
@classmethod
def _from_path(cls, scriptdir, path):
@@ -502,11 +506,11 @@ class Script(object):
m = _legacy_rev.match(filename)
if not m:
raise util.CommandError(
- "Could not determine revision id from filename %s. "
- "Be sure the 'revision' variable is "
- "declared inside the script (please see 'Upgrading "
- "from Alembic 0.1 to 0.2' in the documentation)."
- % filename)
+ "Could not determine revision id from filename %s. "
+ "Be sure the 'revision' variable is "
+ "declared inside the script (please see 'Upgrading "
+ "from Alembic 0.1 to 0.2' in the documentation)."
+ % filename)
else:
revision = m.group(1)
else:
diff --git a/alembic/templates/generic/env.py b/alembic/templates/generic/env.py
index 712b616..fccd445 100644
--- a/alembic/templates/generic/env.py
+++ b/alembic/templates/generic/env.py
@@ -22,6 +22,7 @@ target_metadata = None
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
+
def run_migrations_offline():
"""Run migrations in 'offline' mode.
@@ -40,6 +41,7 @@ def run_migrations_offline():
with context.begin_transaction():
context.run_migrations()
+
def run_migrations_online():
"""Run migrations in 'online' mode.
@@ -48,15 +50,15 @@ def run_migrations_online():
"""
engine = engine_from_config(
- config.get_section(config.config_ini_section),
- prefix='sqlalchemy.',
- poolclass=pool.NullPool)
+ config.get_section(config.config_ini_section),
+ prefix='sqlalchemy.',
+ poolclass=pool.NullPool)
connection = engine.connect()
context.configure(
- connection=connection,
- target_metadata=target_metadata
- )
+ connection=connection,
+ target_metadata=target_metadata
+ )
try:
with context.begin_transaction():
@@ -68,4 +70,3 @@ if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
-
diff --git a/alembic/templates/multidb/env.py b/alembic/templates/multidb/env.py
index e3511de..ab37199 100644
--- a/alembic/templates/multidb/env.py
+++ b/alembic/templates/multidb/env.py
@@ -39,6 +39,7 @@ target_metadata = {}
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
+
def run_migrations_offline():
"""Run migrations in 'offline' mode.
@@ -58,7 +59,7 @@ def run_migrations_offline():
for name in re.split(r',\s*', db_names):
engines[name] = rec = {}
rec['url'] = context.config.get_section_option(name,
- "sqlalchemy.url")
+ "sqlalchemy.url")
for name, rec in engines.items():
logger.info("Migrating database %s" % name)
@@ -66,10 +67,11 @@ def run_migrations_offline():
logger.info("Writing output to %s" % file_)
with open(file_, 'w') as buffer:
context.configure(url=rec['url'], output_buffer=buffer,
- target_metadata=target_metadata.get(name))
+ target_metadata=target_metadata.get(name))
with context.begin_transaction():
context.run_migrations(engine_name=name)
+
def run_migrations_online():
"""Run migrations in 'online' mode.
@@ -85,9 +87,9 @@ def run_migrations_online():
for name in re.split(r',\s*', db_names):
engines[name] = rec = {}
rec['engine'] = engine_from_config(
- context.config.get_section(name),
- prefix='sqlalchemy.',
- poolclass=pool.NullPool)
+ context.config.get_section(name),
+ prefix='sqlalchemy.',
+ poolclass=pool.NullPool)
for name, rec in engines.items():
engine = rec['engine']
@@ -102,11 +104,11 @@ def run_migrations_online():
for name, rec in engines.items():
logger.info("Migrating database %s" % name)
context.configure(
- connection=rec['connection'],
- upgrade_token="%s_upgrades" % name,
- downgrade_token="%s_downgrades" % name,
- target_metadata=target_metadata.get(name)
- )
+ connection=rec['connection'],
+ upgrade_token="%s_upgrades" % name,
+ downgrade_token="%s_downgrades" % name,
+ target_metadata=target_metadata.get(name)
+ )
context.run_migrations(engine_name=name)
if USE_TWOPHASE:
diff --git a/alembic/templates/pylons/env.py b/alembic/templates/pylons/env.py
index 36c3fca..3329428 100644
--- a/alembic/templates/pylons/env.py
+++ b/alembic/templates/pylons/env.py
@@ -46,7 +46,7 @@ def run_migrations_offline():
"""
context.configure(
- url=meta.engine.url, target_metadata=target_metadata)
+ url=meta.engine.url, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
@@ -70,9 +70,9 @@ def run_migrations_online():
)
context.configure(
- connection=connection,
- target_metadata=target_metadata
- )
+ connection=connection,
+ target_metadata=target_metadata
+ )
try:
with context.begin_transaction():
diff --git a/alembic/util.py b/alembic/util.py
index 8c02d57..e0d62eb 100644
--- a/alembic/util.py
+++ b/alembic/util.py
@@ -12,9 +12,11 @@ from sqlalchemy import __version__
from .compat import callable, exec_, load_module_py, load_module_pyc, binary_type
+
class CommandError(Exception):
pass
+
def _safe_int(value):
try:
return int(value)
@@ -28,7 +30,7 @@ sqla_092 = _vers >= (0, 9, 2)
sqla_094 = _vers >= (0, 9, 4)
if not sqla_07:
raise CommandError(
- "SQLAlchemy 0.7.3 or greater is required. ")
+ "SQLAlchemy 0.7.3 or greater is required. ")
from sqlalchemy.util import format_argspec_plus, update_wrapper
from sqlalchemy.util.compat import inspect_getfullargspec
@@ -41,7 +43,7 @@ try:
import termios
import struct
ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ,
- struct.pack('HHHH', 0, 0, 0, 0))
+ struct.pack('HHHH', 0, 0, 0, 0))
_h, TERMWIDTH, _hp, _wp = struct.unpack('HHHH', ioctl)
if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty
TERMWIDTH = None
@@ -55,6 +57,7 @@ def template_to_file(template_file, dest, **kw):
Template(filename=template_file).render(**kw)
)
+
def create_module_class_proxy(cls, globals_, locals_):
"""Create module level proxy functions for the
methods on a given class.
@@ -97,18 +100,18 @@ def create_module_class_proxy(cls, globals_, locals_):
defaulted_vals = ()
apply_kw = inspect.formatargspec(
- name_args, spec[1], spec[2],
- defaulted_vals,
- formatvalue=lambda x: '=' + x)
+ name_args, spec[1], spec[2],
+ defaulted_vals,
+ formatvalue=lambda x: '=' + x)
def _name_error(name):
raise NameError(
- "Can't invoke function '%s', as the proxy object has "\
- "not yet been "
- "established for the Alembic '%s' class. "
- "Try placing this code inside a callable." % (
- name, cls.__name__
- ))
+ "Can't invoke function '%s', as the proxy object has "
+ "not yet been "
+ "established for the Alembic '%s' class. "
+ "Try placing this code inside a callable." % (
+ name, cls.__name__
+ ))
globals_['_name_error'] = _name_error
func_text = textwrap.dedent("""\
@@ -137,6 +140,7 @@ def create_module_class_proxy(cls, globals_, locals_):
else:
attr_names.add(methname)
+
def write_outstream(stream, *text):
encoding = getattr(stream, 'encoding', 'ascii') or 'ascii'
for t in text:
@@ -151,6 +155,7 @@ def write_outstream(stream, *text):
# as the exception is "ignored" (noisily) in TextIOWrapper.
break
+
def coerce_resource_to_filename(fname):
"""Interpret a filename as either a filesystem location or as a package resource.
@@ -163,6 +168,7 @@ def coerce_resource_to_filename(fname):
fname = pkg_resources.resource_filename(*fname.split(':'))
return fname
+
def status(_statmsg, fn, *arg, **kw):
msg(_statmsg + " ...", False)
try:
@@ -173,24 +179,29 @@ def status(_statmsg, fn, *arg, **kw):
write_outstream(sys.stdout, " FAILED\n")
raise
+
def err(message):
log.error(message)
msg("FAILED: %s" % message)
sys.exit(-1)
+
def obfuscate_url_pw(u):
u = url.make_url(u)
if u.password:
u.password = 'XXXXX'
return str(u)
+
def asbool(value):
return value is not None and \
value.lower() == 'true'
+
def warn(msg):
warnings.warn(msg)
+
def msg(msg, newline=True):
if TERMWIDTH is None:
write_outstream(sys.stdout, msg)
@@ -204,6 +215,7 @@ def msg(msg, newline=True):
write_outstream(sys.stdout, " ", line, "\n")
write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else ""))
+
def load_python_file(dir_, filename):
"""Load a file from the given path as a Python module."""
@@ -223,6 +235,7 @@ def load_python_file(dir_, filename):
del sys.modules[module_id]
return module
+
def simple_pyc_file_from_path(path):
"""Given a python source path, return the so-called
"sourceless" .pyc or .pyo path.
@@ -238,6 +251,7 @@ def simple_pyc_file_from_path(path):
else:
return path + "c" # e.g. .pyc
+
def pyc_file_from_path(path):
"""Given a python source path, locate the .pyc.
@@ -253,11 +267,14 @@ def pyc_file_from_path(path):
else:
return simple_pyc_file_from_path(path)
+
def rev_id():
val = int(uuid.uuid4()) % 100000000000000
return hex(val)[2:-1]
+
class memoized_property(object):
+
"""A read-only @property that is only evaluated once."""
def __init__(self, fget, doc=None):
@@ -278,7 +295,7 @@ class immutabledict(dict):
raise TypeError("%s object is immutable" % self.__class__.__name__)
__delitem__ = __setitem__ = __setattr__ = \
- clear = pop = popitem = setdefault = \
+ clear = pop = popitem = setdefault = \
update = _immutable
def __new__(cls, *args):
@@ -332,7 +349,7 @@ def _with_legacy_names(translations):
return fn(*arg, **kw)
code = 'lambda %(args)s: %(target)s(%(apply_kw)s)' % (
- metadata)
+ metadata)
decorated = eval(code, {"target": go})
decorated.__defaults__ = getattr(fn, '__func__', fn).__defaults__
update_wrapper(decorated, fn)
@@ -346,6 +363,3 @@ def _with_legacy_names(translations):
return decorated
return decorate
-
-
-
diff --git a/setup.py b/setup.py
index f736cb0..cf9542e 100644
--- a/setup.py
+++ b/setup.py
@@ -34,14 +34,14 @@ setup(name='alembic',
description="A database migration tool for SQLAlchemy.",
long_description=open(readme).read(),
classifiers=[
- 'Development Status :: 4 - Beta',
- 'Environment :: Console',
- 'Intended Audience :: Developers',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: Implementation :: CPython',
- 'Programming Language :: Python :: Implementation :: PyPy',
- 'Topic :: Database :: Front-Ends',
+ 'Development Status :: 4 - Beta',
+ 'Environment :: Console',
+ 'Intended Audience :: Developers',
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: Implementation :: CPython',
+ 'Programming Language :: Python :: Implementation :: PyPy',
+ 'Topic :: Database :: Front-Ends',
],
keywords='SQLAlchemy migrations',
author='Mike Bayer',
@@ -50,11 +50,11 @@ setup(name='alembic',
license='MIT',
packages=find_packages('.', exclude=['examples*', 'test*']),
include_package_data=True,
- tests_require = ['nose >= 0.11', 'mock'],
- test_suite = "nose.collector",
+ tests_require=['nose >= 0.11', 'mock'],
+ test_suite="nose.collector",
zip_safe=False,
install_requires=requires,
- entry_points = {
- 'console_scripts': [ 'alembic = alembic.config:main' ],
+ entry_points={
+ 'console_scripts': ['alembic = alembic.config:main'],
}
-)
+ )
diff --git a/tests/__init__.py b/tests/__init__.py
index ba8c0eb..9b5944f 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -37,8 +37,8 @@ else:
import mock
except ImportError:
raise ImportError(
- "Alembic's test suite requires the "
- "'mock' library as of 0.6.1.")
+ "Alembic's test suite requires the "
+ "'mock' library as of 0.6.1.")
def sqlite_db():
@@ -48,14 +48,18 @@ def sqlite_db():
dir_ = os.path.join(staging_directory, 'scripts')
return create_engine('sqlite:///%s/foo.db' % dir_)
+
def capture_db():
buf = []
+
def dump(sql, *multiparams, **params):
buf.append(str(sql.compile(dialect=engine.dialect)))
engine = create_engine("postgresql://", strategy="mock", executor=dump)
return engine, buf
_engs = {}
+
+
def db_for_dialect(name):
if name in _engs:
return _engs[name]
@@ -82,18 +86,21 @@ def requires_08(fn, *arg, **kw):
raise SkipTest("SQLAlchemy 0.8.0b2 or greater required")
return fn(*arg, **kw)
+
@decorator
def requires_09(fn, *arg, **kw):
if not util.sqla_09:
raise SkipTest("SQLAlchemy 0.9 or greater required")
return fn(*arg, **kw)
+
@decorator
def requires_092(fn, *arg, **kw):
if not util.sqla_092:
raise SkipTest("SQLAlchemy 0.9.2 or greater required")
return fn(*arg, **kw)
+
@decorator
def requires_094(fn, *arg, **kw):
if not util.sqla_094:
@@ -101,6 +108,8 @@ def requires_094(fn, *arg, **kw):
return fn(*arg, **kw)
_dialects = {}
+
+
def _get_dialect(name):
if name is None or name == 'default':
return default.DefaultDialect()
@@ -114,14 +123,16 @@ def _get_dialect(name):
d.implicit_returning = True
return d
+
def assert_compiled(element, assert_string, dialect=None):
dialect = _get_dialect(dialect)
eq_(
- text_type(element.compile(dialect=dialect)).\
- replace("\n", "").replace("\t", ""),
+ text_type(element.compile(dialect=dialect)).
+ replace("\n", "").replace("\t", ""),
assert_string.replace("\n", "").replace("\t", "")
)
+
@contextmanager
def capture_context_buffer(**kw):
if kw.pop('bytes_io', False):
@@ -130,10 +141,11 @@ def capture_context_buffer(**kw):
buf = io.StringIO()
kw.update({
- 'dialect_name': "sqlite",
- 'output_buffer': buf
+ 'dialect_name': "sqlite",
+ 'output_buffer': buf
})
conf = EnvironmentContext.configure
+
def configure(*arg, **opt):
opt.update(**kw)
return conf(*arg, **opt)
@@ -141,6 +153,7 @@ def capture_context_buffer(**kw):
with mock.patch.object(EnvironmentContext, "configure", configure):
yield buf
+
def eq_ignore_whitespace(a, b, msg=None):
a = re.sub(r'^\s+?|\n', "", a)
a = re.sub(r' {2,}', " ", a)
@@ -148,18 +161,22 @@ def eq_ignore_whitespace(a, b, msg=None):
b = re.sub(r' {2,}', " ", b)
assert a == b, msg or "%r != %r" % (a, b)
+
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
+
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
+
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
+
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
@@ -168,9 +185,12 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
print(text_type(e))
+
def op_fixture(dialect='default', as_sql=False, naming_convention=None):
impl = _impls[dialect]
+
class Impl(impl):
+
def __init__(self, dialect, as_sql):
self.assertion = []
self.dialect = dialect
@@ -179,6 +199,7 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None):
# be more like a real connection
# as tests get more involved
self.connection = None
+
def _exec(self, construct, *args, **kw):
if isinstance(construct, string_types):
construct = text(construct)
@@ -193,11 +214,12 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None):
if naming_convention:
if not util.sqla_092:
raise SkipTest(
- "naming_convention feature requires "
- "sqla 0.9.2 or greater")
+ "naming_convention feature requires "
+ "sqla 0.9.2 or greater")
opts['target_metadata'] = MetaData(naming_convention=naming_convention)
class ctx(MigrationContext):
+
def __init__(self, dialect='default', as_sql=False):
self.dialect = _get_dialect(dialect)
self.impl = Impl(self.dialect, as_sql)
@@ -222,12 +244,14 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None):
alembic.op._proxy = Operations(context)
return context
+
def script_file_fixture(txt):
dir_ = os.path.join(staging_directory, 'scripts')
path = os.path.join(dir_, "script.py.mako")
with open(path, 'w') as f:
f.write(txt)
+
def env_file_fixture(txt):
dir_ = os.path.join(staging_directory, 'scripts')
txt = """
@@ -244,6 +268,7 @@ config = context.config
with open(path, 'w') as f:
f.write(txt)
+
def _sqlite_testing_config(sourceless=False):
dir_ = os.path.join(staging_directory, 'scripts')
return _write_config_file("""
@@ -313,12 +338,14 @@ datefmt = %%H:%%M:%%S
""" % (dir_, dialect, directives))
+
def _write_config_file(text):
cfg = _testing_config()
with open(cfg.config_file_name, 'w') as f:
f.write(text)
return cfg
+
def _testing_config():
from alembic.config import Config
if not os.access(staging_directory, os.F_OK):
@@ -350,6 +377,7 @@ def staging_env(create=True, template="generic", sourceless=False):
sc = script.ScriptDirectory.from_config(cfg)
return sc
+
def clear_staging_env():
shutil.rmtree(staging_directory, True)
@@ -370,13 +398,14 @@ def write_script(scriptdir, rev_id, content, encoding='ascii', sourceless=False)
old = scriptdir._revision_map[script.revision]
if old.down_revision != script.down_revision:
raise Exception("Can't change down_revision "
- "on a refresh operation.")
+ "on a refresh operation.")
scriptdir._revision_map[script.revision] = script
script.nextrev = old.nextrev
if sourceless:
make_sourceless(path)
+
def make_sourceless(path):
# note that if -O is set, you'd see pyo files here,
# the pyc util function looks at sys.flags.optimize to handle this
@@ -391,6 +420,7 @@ def make_sourceless(path):
shutil.copyfile(pyc_path, simple_pyc_path)
os.unlink(path)
+
def three_rev_fixture(cfg):
a = util.rev_id()
b = util.rev_id()
diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py
index 2f7a4a1..0885477 100644
--- a/tests/test_autogen_indexes.py
+++ b/tests/test_autogen_indexes.py
@@ -14,6 +14,7 @@ py3k = sys.version_info >= (3, )
from .test_autogenerate import AutogenFixtureTest
+
class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
reports_unique_constraints = True
@@ -22,17 +23,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
m2 = MetaData()
Table('user', m1,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False, index=True),
- Column('a1', String(10), server_default="x")
- )
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50), nullable=False, index=True),
+ Column('a1', String(10), server_default="x")
+ )
Table('user', m2,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', String(10), server_default="x"),
- UniqueConstraint("name", name="uq_user_name")
- )
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50), nullable=False),
+ Column('a1', String(10), server_default="x"),
+ UniqueConstraint("name", name="uq_user_name")
+ )
diffs = self._fixture(m1, m2)
@@ -46,21 +47,20 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
eq_(diffs[0][0], "remove_index")
eq_(diffs[0][1].name, "ix_user_name")
-
def test_add_unique_constraint(self):
m1 = MetaData()
m2 = MetaData()
Table('address', m1,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('qpr', String(10), index=True),
- )
+ Column('id', Integer, primary_key=True),
+ Column('email_address', String(100), nullable=False),
+ Column('qpr', String(10), index=True),
+ )
Table('address', m2,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('qpr', String(10), index=True),
- UniqueConstraint("email_address", name="uq_email_address")
- )
+ Column('id', Integer, primary_key=True),
+ Column('email_address', String(100), nullable=False),
+ Column('qpr', String(10), index=True),
+ UniqueConstraint("email_address", name="uq_email_address")
+ )
diffs = self._fixture(m1, m2)
@@ -70,29 +70,28 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
else:
eq_(diffs, [])
-
def test_index_becomes_unique(self):
m1 = MetaData()
m2 = MetaData()
Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- UniqueConstraint('order_id', 'user_id',
- name='order_order_id_user_id_unique'
- ),
- Index('order_user_id_amount_idx', 'user_id', 'amount')
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ UniqueConstraint('order_id', 'user_id',
+ name='order_order_id_user_id_unique'
+ ),
+ Index('order_user_id_amount_idx', 'user_id', 'amount')
+ )
Table('order', m2,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- UniqueConstraint('order_id', 'user_id',
- name='order_order_id_user_id_unique'
- ),
- Index('order_user_id_amount_idx', 'user_id', 'amount', unique=True),
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ UniqueConstraint('order_id', 'user_id',
+ name='order_order_id_user_id_unique'
+ ),
+ Index('order_user_id_amount_idx', 'user_id', 'amount', unique=True),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_index")
@@ -103,21 +102,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
eq_(diffs[1][1].name, "order_user_id_amount_idx")
eq_(diffs[1][1].unique, True)
-
-
def test_mismatch_db_named_col_flag(self):
m1 = MetaData()
m2 = MetaData()
Table('item', m1,
- Column('x', Integer),
- UniqueConstraint('x', name="db_generated_name")
- )
+ Column('x', Integer),
+ UniqueConstraint('x', name="db_generated_name")
+ )
# test mismatch between unique=True and
# named uq constraint
Table('item', m2,
- Column('x', Integer, unique=True)
- )
+ Column('x', Integer, unique=True)
+ )
diffs = self._fixture(m1, m2)
@@ -127,10 +124,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
m1 = MetaData()
m2 = MetaData()
Table('extra', m2,
- Column('foo', Integer, index=True),
- Column('bar', Integer),
- Index('newtable_idx', 'bar')
- )
+ Column('foo', Integer, index=True),
+ Column('bar', Integer),
+ Index('newtable_idx', 'bar')
+ )
diffs = self._fixture(m1, m2)
@@ -142,20 +139,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
eq_(diffs[2][0], "add_index")
eq_(diffs[2][1].name, "newtable_idx")
-
def test_named_cols_changed(self):
m1 = MetaData()
m2 = MetaData()
Table('col_change', m1,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x', name="nochange")
- )
+ Column('x', Integer),
+ Column('y', Integer),
+ UniqueConstraint('x', name="nochange")
+ )
Table('col_change', m2,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x', 'y', name="nochange")
- )
+ Column('x', Integer),
+ Column('y', Integer),
+ UniqueConstraint('x', 'y', name="nochange")
+ )
diffs = self._fixture(m1, m2)
@@ -173,72 +169,68 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
m2 = MetaData()
Table('nothing_changed', m1,
- Column('x', String(20), unique=True, index=True)
- )
+ Column('x', String(20), unique=True, index=True)
+ )
Table('nothing_changed', m2,
- Column('x', String(20), unique=True, index=True)
- )
+ Column('x', String(20), unique=True, index=True)
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
-
def test_nothing_changed_two(self):
m1 = MetaData()
m2 = MetaData()
Table('nothing_changed', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20), unique=True),
- mysql_engine='InnoDB'
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20), unique=True),
+ mysql_engine='InnoDB'
+ )
Table('nothing_changed_related', m1,
- Column('id1', Integer),
- Column('id2', Integer),
- ForeignKeyConstraint(['id1', 'id2'],
- ['nothing_changed.id1', 'nothing_changed.id2']),
- mysql_engine='InnoDB'
- )
+ Column('id1', Integer),
+ Column('id2', Integer),
+ ForeignKeyConstraint(['id1', 'id2'],
+ ['nothing_changed.id1', 'nothing_changed.id2']),
+ mysql_engine='InnoDB'
+ )
Table('nothing_changed', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20), unique=True),
- mysql_engine='InnoDB'
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20), unique=True),
+ mysql_engine='InnoDB'
+ )
Table('nothing_changed_related', m2,
- Column('id1', Integer),
- Column('id2', Integer),
- ForeignKeyConstraint(['id1', 'id2'],
- ['nothing_changed.id1', 'nothing_changed.id2']),
- mysql_engine='InnoDB'
- )
-
+ Column('id1', Integer),
+ Column('id2', Integer),
+ ForeignKeyConstraint(['id1', 'id2'],
+ ['nothing_changed.id1', 'nothing_changed.id2']),
+ mysql_engine='InnoDB'
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
-
-
def test_nothing_changed_index_named_as_column(self):
m1 = MetaData()
m2 = MetaData()
Table('nothing_changed', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- Index('x', 'x')
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20)),
+ Index('x', 'x')
+ )
Table('nothing_changed', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- Index('x', 'x')
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20)),
+ Index('x', 'x')
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -248,28 +240,28 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
m2 = MetaData()
Table("nothing_changed", m1,
- Column('id', Integer, primary_key=True),
- Column('other_id',
- ForeignKey('nc2.id',
+ Column('id', Integer, primary_key=True),
+ Column('other_id',
+ ForeignKey('nc2.id',
name='fk_my_table_other_table'
),
- nullable=False),
- Column('foo', Integer),
- mysql_engine='InnoDB')
+ nullable=False),
+ Column('foo', Integer),
+ mysql_engine='InnoDB')
Table('nc2', m1,
- Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ Column('id', Integer, primary_key=True),
+ mysql_engine='InnoDB')
Table("nothing_changed", m2,
- Column('id', Integer, primary_key=True),
- Column('other_id', ForeignKey('nc2.id',
- name='fk_my_table_other_table'),
- nullable=False),
- Column('foo', Integer),
- mysql_engine='InnoDB')
+ Column('id', Integer, primary_key=True),
+ Column('other_id', ForeignKey('nc2.id',
+ name='fk_my_table_other_table'),
+ nullable=False),
+ Column('foo', Integer),
+ mysql_engine='InnoDB')
Table('nc2', m2,
- Column('id', Integer, primary_key=True),
- mysql_engine='InnoDB')
+ Column('id', Integer, primary_key=True),
+ mysql_engine='InnoDB')
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -278,18 +270,18 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
m2 = MetaData()
Table('new_idx', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20)),
+ )
idx = Index('x', 'x')
Table('new_idx', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- idx
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20)),
+ idx
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [('add_index', idx)])
@@ -300,17 +292,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
idx = Index('x', 'x')
Table('new_idx', m1,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20)),
- idx
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20)),
+ idx
+ )
Table('new_idx', m2,
- Column('id1', Integer, primary_key=True),
- Column('id2', Integer, primary_key=True),
- Column('x', String(20))
- )
+ Column('id1', Integer, primary_key=True),
+ Column('id2', Integer, primary_key=True),
+ Column('x', String(20))
+ )
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], 'remove_index')
@@ -319,38 +311,36 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
m1 = MetaData()
m2 = MetaData()
Table('col_change', m1,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x')
- )
+ Column('x', Integer),
+ Column('y', Integer),
+ UniqueConstraint('x')
+ )
Table('col_change', m2,
- Column('x', Integer),
- Column('y', Integer),
- UniqueConstraint('x', 'y')
- )
+ Column('x', Integer),
+ Column('y', Integer),
+ UniqueConstraint('x', 'y')
+ )
diffs = self._fixture(m1, m2)
diffs = set((cmd,
- ('x' in obj.name) if obj.name is not None else False)
+ ('x' in obj.name) if obj.name is not None else False)
for cmd, obj in diffs)
if self.reports_unnamed_constraints:
assert ("remove_constraint", True) in diffs
assert ("add_constraint", False) in diffs
-
-
def test_remove_named_unique_index(self):
m1 = MetaData()
m2 = MetaData()
Table('remove_idx', m1,
- Column('x', Integer),
- Index('xidx', 'x', unique=True)
- )
+ Column('x', Integer),
+ Index('xidx', 'x', unique=True)
+ )
Table('remove_idx', m2,
- Column('x', Integer),
- )
+ Column('x', Integer),
+ )
diffs = self._fixture(m1, m2)
@@ -360,18 +350,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
else:
eq_(diffs, [])
-
def test_remove_named_unique_constraint(self):
m1 = MetaData()
m2 = MetaData()
Table('remove_idx', m1,
- Column('x', Integer),
- UniqueConstraint('x', name='xidx')
- )
+ Column('x', Integer),
+ UniqueConstraint('x', name='xidx')
+ )
Table('remove_idx', m2,
- Column('x', Integer),
- )
+ Column('x', Integer),
+ )
diffs = self._fixture(m1, m2)
@@ -437,7 +426,6 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
eq_(diffs, [])
-
class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
reports_unnamed_constraints = True
@@ -450,7 +438,7 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m2 = MetaData()
Table('add_ix', m1, Column('x', String(50)), schema="test_schema")
Table('add_ix', m2, Column('x', String(50)),
- Index('ix_1', 'x'), schema="test_schema")
+ Index('ix_1', 'x'), schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs[0][0], "add_index")
@@ -460,9 +448,9 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m1 = MetaData()
m2 = MetaData()
Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'),
- schema="test_schema")
+ schema="test_schema")
Table('add_ix', m2, Column('x', String(50)),
- Index('ix_1', 'x'), schema="test_schema")
+ Index('ix_1', 'x'), schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs, [])
@@ -472,7 +460,7 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m2 = MetaData()
Table('add_uq', m1, Column('x', String(50)), schema="test_schema")
Table('add_uq', m2, Column('x', String(50)),
- UniqueConstraint('x', name='ix_1'), schema="test_schema")
+ UniqueConstraint('x', name='ix_1'), schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs[0][0], "add_constraint")
@@ -482,11 +470,11 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
m1 = MetaData()
m2 = MetaData()
Table('add_uq', m1, Column('x', String(50)),
- UniqueConstraint('x', name='ix_1'),
- schema="test_schema")
+ UniqueConstraint('x', name='ix_1'),
+ schema="test_schema")
Table('add_uq', m2, Column('x', String(50)),
- UniqueConstraint('x', name='ix_1'),
- schema="test_schema")
+ UniqueConstraint('x', name='ix_1'),
+ schema="test_schema")
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(diffs, [])
@@ -511,7 +499,7 @@ class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest):
def test_removed_idx_index_named_as_column(self):
try:
super(MySQLUniqueIndexTest,
- self).test_removed_idx_index_named_as_column()
+ self).test_removed_idx_index_named_as_column()
except IndexError:
assert True
else:
@@ -521,6 +509,7 @@ class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest):
def _get_bind(cls):
return db_for_dialect('mysql')
+
class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest):
reports_unique_constraints = False
@@ -536,13 +525,13 @@ class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest):
def test_unique_not_reported(self):
m1 = MetaData()
Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- UniqueConstraint('order_id', 'user_id',
- name='order_order_id_user_id_unique'
- )
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ UniqueConstraint('order_id', 'user_id',
+ name='order_order_id_user_id_unique'
+ )
+ )
diffs = self._fixture(m1, m1)
eq_(diffs, [])
@@ -550,19 +539,19 @@ class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest):
def test_remove_unique_index_not_reported(self):
m1 = MetaData()
Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- Index('oid_ix', 'order_id', 'user_id',
- unique=True
- )
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ Index('oid_ix', 'order_id', 'user_id',
+ unique=True
+ )
+ )
m2 = MetaData()
Table('order', m2,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs, [])
@@ -570,23 +559,24 @@ class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest):
def test_remove_plain_index_is_reported(self):
m1 = MetaData()
Table('order', m1,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- Index('oid_ix', 'order_id', 'user_id')
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ Index('oid_ix', 'order_id', 'user_id')
+ )
m2 = MetaData()
Table('order', m2,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True),
- Column('user_id', Integer),
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True),
+ Column('user_id', Integer),
+ )
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], 'remove_index')
class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
+
"""this test suite simulates the condition where:
a. the dialect doesn't report unique constraints
@@ -612,8 +602,8 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
def get_indexes(self, connection, tablename, **kw):
indexes = _get_indexes(self, connection, tablename, **kw)
for uq in _get_unique_constraints(
- self, connection, tablename, **kw
- ):
+ self, connection, tablename, **kw
+ ):
uq['unique'] = True
indexes.append(uq)
return indexes
@@ -621,4 +611,3 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
eng.dialect.get_unique_constraints = unimpl
eng.dialect.get_indexes = get_indexes
return eng
-
diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py
index d253410..901e9f2 100644
--- a/tests/test_autogen_render.py
+++ b/tests/test_autogen_render.py
@@ -18,7 +18,9 @@ from . import eq_, eq_ignore_whitespace, requires_092, requires_09, requires_094
py3k = sys.version_info >= (3, )
+
class AutogenRenderTest(TestCase):
+
"""test individual directives"""
@classmethod
@@ -38,17 +40,16 @@ class AutogenRenderTest(TestCase):
'dialect': postgresql.dialect()
}
-
def test_render_add_index(self):
"""
autogenerate.render._add_index
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ )
idx = Index('test_active_code_idx', t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._add_index(idx, self.autogen_context),
@@ -62,11 +63,11 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ schema='CamelSchema'
+ )
idx = Index('test_active_code_idx', t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._add_index(idx, self.autogen_context),
@@ -79,24 +80,24 @@ class AutogenRenderTest(TestCase):
m = MetaData()
t = Table('t', m,
- Column('x', String),
- Column('y', String)
- )
+ Column('x', String),
+ Column('y', String)
+ )
idx = Index('foo_idx', t.c.x, t.c.y,
- postgresql_where=(t.c.y == 'something'))
+ postgresql_where=(t.c.y == 'something'))
if compat.sqla_08:
eq_ignore_whitespace(
autogenerate.render._add_index(idx, autogen_context),
"""op.create_index('foo_idx', 't', ['x', 'y'], unique=False, """
- """postgresql_where=sa.text("t.y = 'something'"))"""
+ """postgresql_where=sa.text("t.y = 'something'"))"""
)
else:
eq_ignore_whitespace(
autogenerate.render._add_index(idx, autogen_context),
"""op.create_index('foo_idx', 't', ['x', 'y'], unique=False, """
- """postgresql_where=sa.text('t.y = %(y_1)s'))"""
+ """postgresql_where=sa.text('t.y = %(y_1)s'))"""
)
# def test_render_add_index_func(self):
@@ -122,10 +123,10 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ )
idx = Index('test_active_code_idx', t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._drop_index(idx, self.autogen_context),
@@ -138,16 +139,16 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ schema='CamelSchema'
+ )
idx = Index('test_active_code_idx', t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._drop_index(idx, self.autogen_context),
"op.drop_index('test_active_code_idx', " +
- "table_name='test', schema='CamelSchema')"
+ "table_name='test', schema='CamelSchema')"
)
def test_add_unique_constraint(self):
@@ -156,10 +157,10 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ )
uq = UniqueConstraint(t.c.code, name='uq_test_code')
eq_ignore_whitespace(
autogenerate.render._add_unique_constraint(uq, self.autogen_context),
@@ -172,11 +173,11 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ schema='CamelSchema'
+ )
uq = UniqueConstraint(t.c.code, name='uq_test_code')
eq_ignore_whitespace(
autogenerate.render._add_unique_constraint(uq, self.autogen_context),
@@ -189,10 +190,10 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ )
uq = UniqueConstraint(t.c.code, name='uq_test_code')
eq_ignore_whitespace(
autogenerate.render._drop_constraint(uq, self.autogen_context),
@@ -205,11 +206,11 @@ class AutogenRenderTest(TestCase):
"""
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ schema='CamelSchema'
+ )
uq = UniqueConstraint(t.c.code, name='uq_test_code')
eq_ignore_whitespace(
autogenerate.render._drop_constraint(uq, self.autogen_context),
@@ -219,14 +220,14 @@ class AutogenRenderTest(TestCase):
def test_render_table_upgrade(self):
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('name', Unicode(255)),
- Column("address_id", Integer, ForeignKey("address.id")),
- Column("timestamp", DATETIME, server_default="NOW()"),
- Column("amount", Numeric(5, 2)),
- UniqueConstraint("name", name="uq_name"),
- UniqueConstraint("timestamp"),
- )
+ Column('id', Integer, primary_key=True),
+ Column('name', Unicode(255)),
+ Column("address_id", Integer, ForeignKey("address.id")),
+ Column("timestamp", DATETIME, server_default="NOW()"),
+ Column("amount", Numeric(5, 2)),
+ UniqueConstraint("name", name="uq_name"),
+ UniqueConstraint("timestamp"),
+ )
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('test',"
@@ -234,8 +235,8 @@ class AutogenRenderTest(TestCase):
"sa.Column('name', sa.Unicode(length=255), nullable=True),"
"sa.Column('address_id', sa.Integer(), nullable=True),"
"sa.Column('timestamp', sa.DATETIME(), "
- "server_default='NOW()', "
- "nullable=True),"
+ "server_default='NOW()', "
+ "nullable=True),"
"sa.Column('amount', sa.Numeric(precision=5, scale=2), nullable=True),"
"sa.ForeignKeyConstraint(['address_id'], ['address.id'], ),"
"sa.PrimaryKeyConstraint('id'),"
@@ -247,10 +248,10 @@ class AutogenRenderTest(TestCase):
def test_render_table_w_schema(self):
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('address.id')),
- schema='foo'
- )
+ Column('id', Integer, primary_key=True),
+ Column('q', Integer, ForeignKey('address.id')),
+ schema='foo'
+ )
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('test',"
@@ -299,9 +300,9 @@ class AutogenRenderTest(TestCase):
def test_render_table_w_fk_schema(self):
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('foo.address.id')),
- )
+ Column('id', Integer, primary_key=True),
+ Column('q', Integer, ForeignKey('foo.address.id')),
+ )
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('test',"
@@ -315,9 +316,9 @@ class AutogenRenderTest(TestCase):
def test_render_table_w_metadata_schema(self):
m = MetaData(schema="foo")
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('address.id')),
- )
+ Column('id', Integer, primary_key=True),
+ Column('q', Integer, ForeignKey('address.id')),
+ )
eq_ignore_whitespace(
re.sub(r"u'", "'", autogenerate.render._add_table(t, self.autogen_context)),
"op.create_table('test',"
@@ -332,9 +333,9 @@ class AutogenRenderTest(TestCase):
def test_render_table_w_metadata_schema_override(self):
m = MetaData(schema="foo")
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('bar.address.id')),
- )
+ Column('id', Integer, primary_key=True),
+ Column('q', Integer, ForeignKey('bar.address.id')),
+ )
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('test',"
@@ -349,10 +350,10 @@ class AutogenRenderTest(TestCase):
def test_render_addtl_args(self):
m = MetaData()
t = Table('test', m,
- Column('id', Integer, primary_key=True),
- Column('q', Integer, ForeignKey('bar.address.id')),
- sqlite_autoincrement=True, mysql_engine="InnoDB"
- )
+ Column('id', Integer, primary_key=True),
+ Column('q', Integer, ForeignKey('bar.address.id')),
+ sqlite_autoincrement=True, mysql_engine="InnoDB"
+ )
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('test',"
@@ -366,7 +367,7 @@ class AutogenRenderTest(TestCase):
def test_render_drop_table(self):
eq_(
autogenerate.render._drop_table(Table("sometable", MetaData()),
- self.autogen_context),
+ self.autogen_context),
"op.drop_table('sometable')"
)
@@ -407,26 +408,26 @@ class AutogenRenderTest(TestCase):
def test_render_add_column(self):
eq_(
autogenerate.render._add_column(
- None, "foo", Column("x", Integer, server_default="5"),
- self.autogen_context),
+ None, "foo", Column("x", Integer, server_default="5"),
+ self.autogen_context),
"op.add_column('foo', sa.Column('x', sa.Integer(), "
- "server_default='5', nullable=True))"
+ "server_default='5', nullable=True))"
)
def test_render_add_column_w_schema(self):
eq_(
autogenerate.render._add_column(
- "foo", "bar", Column("x", Integer, server_default="5"),
- self.autogen_context),
+ "foo", "bar", Column("x", Integer, server_default="5"),
+ self.autogen_context),
"op.add_column('bar', sa.Column('x', sa.Integer(), "
- "server_default='5', nullable=True), schema='foo')"
+ "server_default='5', nullable=True), schema='foo')"
)
def test_render_drop_column(self):
eq_(
autogenerate.render._drop_column(
- None, "foo", Column("x", Integer, server_default="5"),
- self.autogen_context),
+ None, "foo", Column("x", Integer, server_default="5"),
+ self.autogen_context),
"op.drop_column('foo', 'x')"
)
@@ -434,8 +435,8 @@ class AutogenRenderTest(TestCase):
def test_render_drop_column_w_schema(self):
eq_(
autogenerate.render._drop_column(
- "foo", "bar", Column("x", Integer, server_default="5"),
- self.autogen_context),
+ "foo", "bar", Column("x", Integer, server_default="5"),
+ self.autogen_context),
"op.drop_column('bar', 'x', schema='foo')"
)
@@ -444,35 +445,35 @@ class AutogenRenderTest(TestCase):
eq_(
autogenerate.render._render_server_default(
"nextval('group_to_perm_group_to_perm_id_seq'::regclass)",
- self.autogen_context),
+ self.autogen_context),
'"nextval(\'group_to_perm_group_to_perm_id_seq\'::regclass)"'
)
def test_render_col_with_server_default(self):
c = Column('updated_at', TIMESTAMP(),
- server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)',
- nullable=False)
+ server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)',
+ nullable=False)
result = autogenerate.render._render_column(
- c, self.autogen_context
- )
+ c, self.autogen_context
+ )
eq_(
result,
'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
- 'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', '
- 'nullable=False)'
+ 'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', '
+ 'nullable=False)'
)
def test_render_col_autoinc_false_mysql(self):
c = Column('some_key', Integer, primary_key=True, autoincrement=False)
Table('some_table', MetaData(), c)
result = autogenerate.render._render_column(
- c, self.autogen_context
- )
+ c, self.autogen_context
+ )
eq_(
result,
'sa.Column(\'some_key\', sa.Integer(), '
- 'autoincrement=False, '
- 'nullable=False)'
+ 'autoincrement=False, '
+ 'nullable=False)'
)
def test_render_custom(self):
@@ -493,14 +494,14 @@ class AutogenRenderTest(TestCase):
}}
t = Table('t', MetaData(),
- Column('x', Integer),
- Column('y', Integer),
- PrimaryKeyConstraint('x'),
- ForeignKeyConstraint(['x'], ['y'])
- )
+ Column('x', Integer),
+ Column('y', Integer),
+ PrimaryKeyConstraint('x'),
+ ForeignKeyConstraint(['x'], ['y'])
+ )
result = autogenerate.render._add_table(
- t, autogen_context
- )
+ t, autogen_context
+ )
eq_(
result, """sa.create_table('t',
col(x),
@@ -510,32 +511,32 @@ render:primary_key\n)"""
def test_render_modify_type(self):
eq_ignore_whitespace(
autogenerate.render._modify_col(
- "sometable", "somecolumn",
- self.autogen_context,
- type_=CHAR(10), existing_type=CHAR(20)),
+ "sometable", "somecolumn",
+ self.autogen_context,
+ type_=CHAR(10), existing_type=CHAR(20)),
"op.alter_column('sometable', 'somecolumn', "
- "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))"
+ "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))"
)
def test_render_modify_type_w_schema(self):
eq_ignore_whitespace(
autogenerate.render._modify_col(
- "sometable", "somecolumn",
- self.autogen_context,
- type_=CHAR(10), existing_type=CHAR(20),
- schema='foo'),
+ "sometable", "somecolumn",
+ self.autogen_context,
+ type_=CHAR(10), existing_type=CHAR(20),
+ schema='foo'),
"op.alter_column('sometable', 'somecolumn', "
- "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), "
- "schema='foo')"
+ "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), "
+ "schema='foo')"
)
def test_render_modify_nullable(self):
eq_ignore_whitespace(
autogenerate.render._modify_col(
- "sometable", "somecolumn",
- self.autogen_context,
- existing_type=Integer(),
- nullable=True),
+ "sometable", "somecolumn",
+ self.autogen_context,
+ existing_type=Integer(),
+ nullable=True),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.Integer(), nullable=True)"
)
@@ -543,10 +544,10 @@ render:primary_key\n)"""
def test_render_modify_nullable_w_schema(self):
eq_ignore_whitespace(
autogenerate.render._modify_col(
- "sometable", "somecolumn",
- self.autogen_context,
- existing_type=Integer(),
- nullable=True, schema='foo'),
+ "sometable", "somecolumn",
+ self.autogen_context,
+ existing_type=Integer(),
+ nullable=True, schema='foo'),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.Integer(), nullable=True, schema='foo')"
)
@@ -594,13 +595,13 @@ render:primary_key\n)"""
m = MetaData()
Table('t', m, Column('c', Integer))
t2 = Table('t2', m, Column('c_rem', Integer,
- ForeignKey('t.c', name="fk1", use_alter=True)))
+ ForeignKey('t.c', name="fk1", use_alter=True)))
const = list(t2.foreign_keys)[0].constraint
eq_ignore_whitespace(
autogenerate.render._render_constraint(const, self.autogen_context),
"sa.ForeignKeyConstraint(['c_rem'], ['t.c'], "
- "name='fk1', use_alter=True)"
+ "name='fk1', use_alter=True)"
)
def test_render_fk_constraint_w_metadata_schema(self):
@@ -626,7 +627,6 @@ render:primary_key\n)"""
"sa.CheckConstraint('im a constraint', name='cc1')"
)
-
def test_render_check_constraint_sqlexpr(self):
c = column('c')
five = literal_column('5')
@@ -653,29 +653,27 @@ render:primary_key\n)"""
def test_render_modify_nullable_w_default(self):
eq_ignore_whitespace(
autogenerate.render._modify_col(
- "sometable", "somecolumn",
- self.autogen_context,
- existing_type=Integer(),
- existing_server_default="5",
- nullable=True),
+ "sometable", "somecolumn",
+ self.autogen_context,
+ existing_type=Integer(),
+ existing_server_default="5",
+ nullable=True),
"op.alter_column('sometable', 'somecolumn', "
"existing_type=sa.Integer(), nullable=True, "
"existing_server_default='5')"
)
-
-
def test_render_enum(self):
eq_ignore_whitespace(
autogenerate.render._repr_type(
- Enum("one", "two", "three", name="myenum"),
- self.autogen_context),
+ Enum("one", "two", "three", name="myenum"),
+ self.autogen_context),
"sa.Enum('one', 'two', 'three', name='myenum')"
)
eq_ignore_whitespace(
autogenerate.render._repr_type(
- Enum("one", "two", "three"),
- self.autogen_context),
+ Enum("one", "two", "three"),
+ self.autogen_context),
"sa.Enum('one', 'two', 'three')"
)
@@ -696,7 +694,9 @@ render:primary_key\n)"""
def test_repr_user_type_user_prefix_None(self):
from sqlalchemy.types import UserDefinedType
+
class MyType(UserDefinedType):
+
def get_col_spec(self):
return "MYTYPE"
@@ -717,7 +717,9 @@ render:primary_key\n)"""
def test_repr_user_type_user_prefix_present(self):
from sqlalchemy.types import UserDefinedType
+
class MyType(UserDefinedType):
+
def get_col_spec(self):
return "MYTYPE"
@@ -755,9 +757,10 @@ render:primary_key\n)"""
"mysql.VARCHAR(charset='utf8', national=True, length=20)"
)
eq_(autogen_context['imports'],
- set(['from sqlalchemy.dialects import mysql'])
+ set(['from sqlalchemy.dialects import mysql'])
)
+
class RenderNamingConventionTest(TestCase):
@classmethod
@@ -771,30 +774,29 @@ class RenderNamingConventionTest(TestCase):
'dialect': postgresql.dialect()
}
-
def setUp(self):
convention = {
- "ix": 'ix_%(custom)s_%(column_0_label)s',
- "uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s",
- "ck": "ck_%(custom)s_%(table_name)s",
- "fk": "fk_%(custom)s_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
- "pk": "pk_%(custom)s_%(table_name)s",
- "custom": lambda const, table: "ct"
+ "ix": 'ix_%(custom)s_%(column_0_label)s',
+ "uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s",
+ "ck": "ck_%(custom)s_%(table_name)s",
+ "fk": "fk_%(custom)s_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+ "pk": "pk_%(custom)s_%(table_name)s",
+ "custom": lambda const, table: "ct"
}
self.metadata = MetaData(
- naming_convention=convention
- )
+ naming_convention=convention
+ )
def test_schema_type_boolean(self):
t = Table('t', self.metadata, Column('c', Boolean(name='xyz')))
eq_ignore_whitespace(
autogenerate.render._add_column(
- None, "t", t.c.c,
- self.autogen_context),
+ None, "t", t.c.c,
+ self.autogen_context),
"op.add_column('t', "
- "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))"
+ "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))"
)
def test_explicit_unique_constraint(self):
@@ -819,10 +821,10 @@ class RenderNamingConventionTest(TestCase):
def test_render_add_index(self):
t = Table('test', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ )
idx = Index(None, t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._add_index(idx, self.autogen_context),
@@ -832,10 +834,10 @@ class RenderNamingConventionTest(TestCase):
def test_render_drop_index(self):
t = Table('test', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ )
idx = Index(None, t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._drop_index(idx, self.autogen_context),
@@ -844,11 +846,11 @@ class RenderNamingConventionTest(TestCase):
def test_render_add_index_schema(self):
t = Table('test', self.metadata,
- Column('id', Integer, primary_key=True),
- Column('active', Boolean()),
- Column('code', String(255)),
- schema='CamelSchema'
- )
+ Column('id', Integer, primary_key=True),
+ Column('active', Boolean()),
+ Column('code', String(255)),
+ schema='CamelSchema'
+ )
idx = Index(None, t.c.active, t.c.code)
eq_ignore_whitespace(
autogenerate.render._add_index(idx, self.autogen_context),
@@ -856,14 +858,13 @@ class RenderNamingConventionTest(TestCase):
"['active', 'code'], unique=False, schema='CamelSchema')"
)
-
def test_implicit_unique_constraint(self):
t = Table('t', self.metadata, Column('c', Integer, unique=True))
uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0]
eq_ignore_whitespace(
autogenerate.render._render_unique_constraint(uq,
- self.autogen_context
- ),
+ self.autogen_context
+ ),
"sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))"
)
@@ -872,7 +873,7 @@ class RenderNamingConventionTest(TestCase):
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('t',sa.Column('c', sa.Integer(), nullable=False),"
- "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))"
+ "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))"
)
def test_inline_ck_constraint(self):
@@ -880,7 +881,7 @@ class RenderNamingConventionTest(TestCase):
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('t',sa.Column('c', sa.Integer(), nullable=True),"
- "sa.CheckConstraint('c > 5', name=op.f('ck_ct_t')))"
+ "sa.CheckConstraint('c > 5', name=op.f('ck_ct_t')))"
)
def test_inline_fk(self):
@@ -888,7 +889,7 @@ class RenderNamingConventionTest(TestCase):
eq_ignore_whitespace(
autogenerate.render._add_table(t, self.autogen_context),
"op.create_table('t',sa.Column('c', sa.Integer(), nullable=True),"
- "sa.ForeignKeyConstraint(['c'], ['q.id'], name=op.f('fk_ct_t_c_q')))"
+ "sa.ForeignKeyConstraint(['c'], ['q.id'], name=op.f('fk_ct_t_c_q')))"
)
def test_render_check_constraint_renamed(self):
diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py
index cc9e118..f52aebb 100644
--- a/tests/test_autogenerate.py
+++ b/tests/test_autogenerate.py
@@ -14,11 +14,13 @@ from sqlalchemy.engine.reflection import Inspector
from alembic import autogenerate
from alembic.migration import MigrationContext
from . import staging_env, sqlite_db, clear_staging_env, eq_, \
- db_for_dialect
+ db_for_dialect
py3k = sys.version_info >= (3, )
names_in_this_test = set()
+
+
def _default_include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name in names_in_this_test
@@ -29,11 +31,15 @@ _default_object_filters = [
_default_include_object
]
from sqlalchemy import event
+
+
@event.listens_for(Table, "after_parent_attach")
def new_table(table, parent):
names_in_this_test.add(table.name)
+
class AutogenTest(object):
+
@classmethod
def _get_bind(cls):
return sqlite_db()
@@ -66,14 +72,16 @@ class AutogenTest(object):
'connection': connection,
'dialect': connection.dialect,
'context': context
- }
+ }
@classmethod
def teardown_class(cls):
cls.m1.drop_all(cls.bind)
clear_staging_env()
+
class AutogenFixtureTest(object):
+
def _fixture(self, m1, m2, include_schemas=False):
self.metadata, model_metadata = m1, m2
self.metadata.create_all(self.bind)
@@ -98,13 +106,13 @@ class AutogenFixtureTest(object):
'connection': connection,
'dialect': connection.dialect,
'context': context
- }
+ }
diffs = []
autogenerate._produce_net_changes(connection, model_metadata, diffs,
autogen_context,
object_filters=_default_object_filters,
include_schemas=include_schemas
- )
+ )
return diffs
reports_unnamed_constraints = False
@@ -124,6 +132,7 @@ class AutogenFixtureTest(object):
class AutogenCrossSchemaTest(AutogenTest, TestCase):
+
@classmethod
def _get_bind(cls):
cls.test_schema_name = "test_schema"
@@ -133,19 +142,19 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
def _get_db_schema(cls):
m = MetaData()
Table('t1', m,
- Column('x', Integer)
- )
+ Column('x', Integer)
+ )
Table('t2', m,
- Column('y', Integer),
- schema=cls.test_schema_name
- )
+ Column('y', Integer),
+ schema=cls.test_schema_name
+ )
Table('t6', m,
- Column('u', Integer)
- )
+ Column('u', Integer)
+ )
Table('t7', m,
- Column('v', Integer),
- schema=cls.test_schema_name
- )
+ Column('v', Integer),
+ schema=cls.test_schema_name
+ )
return m
@@ -153,25 +162,26 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
def _get_model_schema(cls):
m = MetaData()
Table('t3', m,
- Column('q', Integer)
- )
+ Column('q', Integer)
+ )
Table('t4', m,
- Column('z', Integer),
- schema=cls.test_schema_name
- )
+ Column('z', Integer),
+ schema=cls.test_schema_name
+ )
Table('t6', m,
- Column('u', Integer)
- )
+ Column('u', Integer)
+ )
Table('t7', m,
- Column('v', Integer),
- schema=cls.test_schema_name
- )
+ Column('v', Integer),
+ schema=cls.test_schema_name
+ )
return m
def test_default_schema_omitted_upgrade(self):
metadata = self.m2
connection = self.context.bind
diffs = []
+
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t3"
@@ -189,6 +199,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
metadata = self.m2
connection = self.context.bind
diffs = []
+
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t4"
@@ -206,6 +217,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
metadata = self.m2
connection = self.context.bind
diffs = []
+
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t1"
@@ -223,6 +235,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
metadata = self.m2
connection = self.context.bind
diffs = []
+
def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name == "t2"
@@ -238,6 +251,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
class AutogenDefaultSchemaTest(AutogenFixtureTest, TestCase):
+
@classmethod
def _get_bind(cls):
cls.test_schema_name = "test_schema"
@@ -285,7 +299,6 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestCase):
Table('a', m2, Column('x', String(50)), schema=default_schema)
Table('a', m2, Column('y', String(50)), schema="test_schema")
-
diffs = self._fixture(m1, m2, include_schemas=True)
eq_(len(diffs), 1)
eq_(diffs[0][0], "add_table")
@@ -303,28 +316,28 @@ class ModelOne(object):
m = MetaData(schema=schema)
Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50)),
- Column('a1', Text),
- Column("pw", String(50))
- )
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('a1', Text),
+ Column("pw", String(50))
+ )
Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- )
+ Column('id', Integer, primary_key=True),
+ Column('email_address', String(100), nullable=False),
+ )
Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column("amount", Numeric(8, 2), nullable=False,
- server_default="0"),
- CheckConstraint('amount >= 0', name='ck_order_amount')
- )
+ Column('order_id', Integer, primary_key=True),
+ Column("amount", Numeric(8, 2), nullable=False,
+ server_default="0"),
+ CheckConstraint('amount >= 0', name='ck_order_amount')
+ )
Table('extra', m,
- Column("x", CHAR),
- Column('uid', Integer, ForeignKey('user.id'))
- )
+ Column("x", CHAR),
+ Column('uid', Integer, ForeignKey('user.id'))
+ )
return m
@@ -335,35 +348,34 @@ class ModelOne(object):
m = MetaData(schema=schema)
Table('user', m,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('a1', Text, server_default="x")
- )
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50), nullable=False),
+ Column('a1', Text, server_default="x")
+ )
Table('address', m,
- Column('id', Integer, primary_key=True),
- Column('email_address', String(100), nullable=False),
- Column('street', String(50)),
- )
+ Column('id', Integer, primary_key=True),
+ Column('email_address', String(100), nullable=False),
+ Column('street', String(50)),
+ )
Table('order', m,
- Column('order_id', Integer, primary_key=True),
- Column('amount', Numeric(10, 2), nullable=True,
- server_default="0"),
- Column('user_id', Integer, ForeignKey('user.id')),
- CheckConstraint('amount > -1', name='ck_order_amount'),
- )
+ Column('order_id', Integer, primary_key=True),
+ Column('amount', Numeric(10, 2), nullable=True,
+ server_default="0"),
+ Column('user_id', Integer, ForeignKey('user.id')),
+ CheckConstraint('amount > -1', name='ck_order_amount'),
+ )
Table('item', m,
- Column('id', Integer, primary_key=True),
- Column('description', String(100)),
- Column('order_id', Integer, ForeignKey('order.order_id')),
- CheckConstraint('len(description) > 5')
- )
+ Column('id', Integer, primary_key=True),
+ Column('description', String(100)),
+ Column('order_id', Integer, ForeignKey('order.order_id')),
+ CheckConstraint('len(description) > 5')
+ )
return m
-
class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
def test_diffs(self):
@@ -375,7 +387,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
autogenerate._produce_net_changes(connection, metadata, diffs,
self.autogen_context,
object_filters=_default_object_filters,
- )
+ )
eq_(
diffs[0],
@@ -415,7 +427,6 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
eq_(diffs[7][0][5], True)
eq_(diffs[7][0][6], False)
-
def test_render_nothing(self):
context = MigrationContext.configure(
connection=self.bind.connect(),
@@ -431,11 +442,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
autogenerate._produce_migration_diffs(context, template_args, set())
eq_(re.sub(r"u'", "'", template_args['upgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
pass
### end Alembic commands ###""")
eq_(re.sub(r"u'", "'", template_args['downgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
pass
### end Alembic commands ###""")
@@ -446,7 +457,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
autogenerate._produce_migration_diffs(self.context, template_args, set())
eq_(re.sub(r"u'", "'", template_args['upgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(length=100), nullable=True),
@@ -474,7 +485,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
### end Alembic commands ###""")
eq_(re.sub(r"u'", "'", template_args['downgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
op.alter_column('user', 'name',
existing_type=sa.VARCHAR(length=50),
nullable=True)
@@ -506,7 +517,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
'compare_server_default': True,
'target_metadata': self.m2,
'include_symbol': lambda name, schema=None:
- name in ('address', 'order'),
+ name in ('address', 'order'),
'upgrade_token': "upgrades",
'downgrade_token': "downgrades",
'alembic_module_prefix': 'op.',
@@ -517,7 +528,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
autogenerate._produce_migration_diffs(context, template_args, set())
template_args['upgrades'] = template_args['upgrades'].replace("u'", "'")
template_args['downgrades'] = template_args['downgrades'].\
- replace("u'", "'")
+ replace("u'", "'")
assert "alter_column('user'" not in template_args['upgrades']
assert "alter_column('user'" not in template_args['downgrades']
assert "alter_column('order'" in template_args['upgrades']
@@ -559,7 +570,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
template_args['upgrades'] = template_args['upgrades'].replace("u'", "'")
template_args['downgrades'] = template_args['downgrades'].\
- replace("u'", "'")
+ replace("u'", "'")
assert "op.create_table('item'" not in template_args['upgrades']
assert "op.create_table('item'" not in template_args['downgrades']
@@ -573,19 +584,19 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
def test_skip_null_type_comparison_reflected(self):
diff = []
autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", NULLTYPE),
- Column("somecol", Integer()),
- diff, self.autogen_context
- )
+ Column("somecol", NULLTYPE),
+ Column("somecol", Integer()),
+ diff, self.autogen_context
+ )
assert not diff
def test_skip_null_type_comparison_local(self):
diff = []
autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", Integer()),
- Column("somecol", NULLTYPE),
- diff, self.autogen_context
- )
+ Column("somecol", Integer()),
+ Column("somecol", NULLTYPE),
+ diff, self.autogen_context
+ )
assert not diff
def test_affinity_typedec(self):
@@ -600,10 +611,10 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
diff = []
autogenerate.compare._compare_type(None, "sometable", "somecol",
- Column("somecol", Integer, nullable=True),
- Column("somecol", MyType()),
- diff, self.autogen_context
- )
+ Column("somecol", Integer, nullable=True),
+ Column("somecol", MyType()),
+ diff, self.autogen_context
+ )
assert not diff
def test_dont_barf_on_already_reflected(self):
@@ -613,17 +624,17 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
autogenerate.compare._compare_tables(
OrderedSet([(None, 'extra'), (None, 'user')]),
OrderedSet(), [], inspector,
- MetaData(), diffs, self.autogen_context
+ MetaData(), diffs, self.autogen_context
)
eq_(
[(rec[0], rec[1].name) for rec in diffs],
[('remove_table', 'extra'), ('remove_table', 'user')]
)
+
class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase):
schema = "test_schema"
-
@classmethod
def _get_bind(cls):
return db_for_dialect('postgresql')
@@ -693,14 +704,14 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase):
)
template_args = {}
autogenerate._produce_migration_diffs(context, template_args, set(),
- include_symbol=lambda name, schema: False
- )
+ include_symbol=lambda name, schema: False
+ )
eq_(re.sub(r"u'", "'", template_args['upgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
pass
### end Alembic commands ###""")
eq_(re.sub(r"u'", "'", template_args['downgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
pass
### end Alembic commands ###""")
@@ -709,13 +720,13 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase):
template_args = {}
autogenerate._produce_migration_diffs(
- self.context, template_args, set(),
- include_object=_default_include_object,
- include_schemas=True
- )
+ self.context, template_args, set(),
+ include_object=_default_include_object,
+ include_schemas=True
+ )
eq_(re.sub(r"u'", "'", template_args['upgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(length=100), nullable=True),
@@ -747,7 +758,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase):
### end Alembic commands ###""" % {"schema": self.schema})
eq_(re.sub(r"u'", "'", template_args['downgrades']),
-"""### commands auto generated by Alembic - please adjust! ###
+ """### commands auto generated by Alembic - please adjust! ###
op.alter_column('user', 'name',
existing_type=sa.VARCHAR(length=50),
nullable=True,
@@ -776,10 +787,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase):
### end Alembic commands ###""" % {"schema": self.schema})
-
-
-
class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase):
+
@classmethod
def _get_db_schema(cls):
m = MetaData()
@@ -804,7 +813,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase):
diffs = []
autogenerate._produce_net_changes(self.context.bind, self.m2,
- diffs, self.autogen_context)
+ diffs, self.autogen_context)
first_table = self.m2.tables['sometable']
first_column = first_table.columns['id']
@@ -827,7 +836,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase):
diffs = []
autogenerate._produce_net_changes(self.context.bind, self.m2,
- diffs, self.autogen_context)
+ diffs, self.autogen_context)
eq_(diffs, [])
@@ -838,21 +847,22 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase):
diffs = []
autogenerate._produce_net_changes(self.context.bind, self.m2,
- diffs, self.autogen_context)
+ diffs, self.autogen_context)
eq_(diffs[0][0][0], 'modify_type')
eq_(diffs[1][0][0], 'modify_type')
class AutogenKeyTest(AutogenTest, TestCase):
+
@classmethod
def _get_db_schema(cls):
m = MetaData()
Table('someothertable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer, key="somekey"),
- )
+ Column('id', Integer, primary_key=True),
+ Column('value', Integer, key="somekey"),
+ )
return m
@classmethod
@@ -860,17 +870,18 @@ class AutogenKeyTest(AutogenTest, TestCase):
m = MetaData()
Table('sometable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer, key="someotherkey"),
- )
+ Column('id', Integer, primary_key=True),
+ Column('value', Integer, key="someotherkey"),
+ )
Table('someothertable', m,
- Column('id', Integer, primary_key=True),
- Column('value', Integer, key="somekey"),
- Column("othervalue", Integer, key="otherkey")
- )
+ Column('id', Integer, primary_key=True),
+ Column('value', Integer, key="somekey"),
+ Column("othervalue", Integer, key="otherkey")
+ )
return m
symbols = ['someothertable', 'sometable']
+
def test_autogen(self):
metadata = self.m2
connection = self.context.bind
@@ -886,7 +897,9 @@ class AutogenKeyTest(AutogenTest, TestCase):
eq_(diffs[1][0], "add_column")
eq_(diffs[1][3].key, "otherkey")
+
class AutogenerateDiffOrderTest(AutogenTest, TestCase):
+
@classmethod
def _get_db_schema(cls):
return MetaData()
@@ -895,12 +908,12 @@ class AutogenerateDiffOrderTest(AutogenTest, TestCase):
def _get_model_schema(cls):
m = MetaData()
Table('parent', m,
- Column('id', Integer, primary_key=True)
- )
+ Column('id', Integer, primary_key=True)
+ )
Table('child', m,
- Column('parent_id', Integer, ForeignKey('parent.id')),
- )
+ Column('parent_id', Integer, ForeignKey('parent.id')),
+ )
return m
@@ -925,6 +938,7 @@ class AutogenerateDiffOrderTest(AutogenTest, TestCase):
class CompareMetadataTest(ModelOne, AutogenTest, TestCase):
+
def test_compare_metadata(self):
metadata = self.m2
@@ -1035,6 +1049,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestCase):
eq_(diffs[2][1][5], False)
eq_(diffs[2][1][6], True)
+
class PGCompareMetaData(ModelOne, AutogenTest, TestCase):
schema = "test_schema"
diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py
index cc56731..13029c7 100644
--- a/tests/test_bulk_insert.py
+++ b/tests/test_bulk_insert.py
@@ -8,107 +8,121 @@ from sqlalchemy.types import TypeEngine
from . import op_fixture, eq_, assert_raises_message
+
def _table_fixture(dialect, as_sql):
context = op_fixture(dialect, as_sql)
t1 = table("ins_table",
- column('id', Integer),
- column('v1', String()),
- column('v2', String()),
- )
+ column('id', Integer),
+ column('v1', String()),
+ column('v2', String()),
+ )
return context, t1
+
def _big_t_table_fixture(dialect, as_sql):
context = op_fixture(dialect, as_sql)
t1 = Table("ins_table", MetaData(),
- Column('id', Integer, primary_key=True),
- Column('v1', String()),
- Column('v2', String()),
- )
+ Column('id', Integer, primary_key=True),
+ Column('v1', String()),
+ Column('v2', String()),
+ )
return context, t1
+
def _test_bulk_insert(dialect, as_sql):
context, t1 = _table_fixture(dialect, as_sql)
op.bulk_insert(t1, [
- {'id':1, 'v1':'row v1', 'v2':'row v5'},
- {'id':2, 'v1':'row v2', 'v2':'row v6'},
- {'id':3, 'v1':'row v3', 'v2':'row v7'},
- {'id':4, 'v1':'row v4', 'v2':'row v8'},
+ {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
+ {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
+ {'id': 3, 'v1': 'row v3', 'v2': 'row v7'},
+ {'id': 4, 'v1': 'row v4', 'v2': 'row v8'},
])
return context
+
def _test_bulk_insert_single(dialect, as_sql):
context, t1 = _table_fixture(dialect, as_sql)
op.bulk_insert(t1, [
- {'id':1, 'v1':'row v1', 'v2':'row v5'},
+ {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
])
return context
+
def _test_bulk_insert_single_bigt(dialect, as_sql):
context, t1 = _big_t_table_fixture(dialect, as_sql)
op.bulk_insert(t1, [
- {'id':1, 'v1':'row v1', 'v2':'row v5'},
+ {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
])
return context
+
def test_bulk_insert():
context = _test_bulk_insert('default', False)
context.assert_(
'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
)
+
def test_bulk_insert_wrong_cols():
context = op_fixture('postgresql')
t1 = table("ins_table",
- column('id', Integer),
- column('v1', String()),
- column('v2', String()),
- )
+ column('id', Integer),
+ column('v1', String()),
+ column('v2', String()),
+ )
op.bulk_insert(t1, [
- {'v1':'row v1', },
+ {'v1': 'row v1', },
])
context.assert_(
'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
)
+
def test_bulk_insert_no_rows():
context, t1 = _table_fixture('default', False)
op.bulk_insert(t1, [])
context.assert_()
+
def test_bulk_insert_pg():
context = _test_bulk_insert('postgresql', False)
context.assert_(
'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
)
+
def test_bulk_insert_pg_single():
context = _test_bulk_insert_single('postgresql', False)
context.assert_(
'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
)
+
def test_bulk_insert_pg_single_as_sql():
context = _test_bulk_insert_single('postgresql', True)
context.assert_(
"INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')"
)
+
def test_bulk_insert_pg_single_big_t_as_sql():
context = _test_bulk_insert_single_bigt('postgresql', True)
context.assert_(
"INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')"
)
+
def test_bulk_insert_mssql():
context = _test_bulk_insert('mssql', False)
context.assert_(
'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
)
+
def test_bulk_insert_inline_literal_as_sql():
context = op_fixture('postgresql', True)
@@ -136,6 +150,7 @@ def test_bulk_insert_as_sql():
"INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')"
)
+
def test_bulk_insert_as_sql_pg():
context = _test_bulk_insert('postgresql', True)
context.assert_(
@@ -145,6 +160,7 @@ def test_bulk_insert_as_sql_pg():
"INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')"
)
+
def test_bulk_insert_as_sql_mssql():
context = _test_bulk_insert('mssql', True)
# SQL server requires IDENTITY_INSERT
@@ -159,12 +175,13 @@ def test_bulk_insert_as_sql_mssql():
'SET IDENTITY_INSERT ins_table OFF'
)
+
def test_invalid_format():
context, t1 = _table_fixture("sqlite", False)
assert_raises_message(
TypeError,
"List expected",
- op.bulk_insert, t1, {"id":5}
+ op.bulk_insert, t1, {"id": 5}
)
assert_raises_message(
@@ -173,7 +190,9 @@ def test_invalid_format():
op.bulk_insert, t1, [(5, )]
)
+
class RoundTripTest(TestCase):
+
def setUp(self):
from sqlalchemy import create_engine
from alembic.migration import MigrationContext
@@ -188,17 +207,18 @@ class RoundTripTest(TestCase):
context = MigrationContext.configure(self.conn)
self.op = op.Operations(context)
self.t1 = table('foo',
- column('id'),
- column('data'),
- column('x')
- )
+ column('id'),
+ column('data'),
+ column('x')
+ )
+
def tearDown(self):
self.conn.close()
def test_single_insert_round_trip(self):
self.op.bulk_insert(self.t1,
- [{'data':"d1", "x":"x1"}]
- )
+ [{'data': "d1", "x": "x1"}]
+ )
eq_(
self.conn.execute("select id, data, x from foo").fetchall(),
@@ -209,9 +229,9 @@ class RoundTripTest(TestCase):
def test_bulk_insert_round_trip(self):
self.op.bulk_insert(self.t1, [
- {'data':"d1", "x":"x1"},
- {'data':"d2", "x":"x2"},
- {'data':"d3", "x":"x3"},
+ {'data': "d1", "x": "x1"},
+ {'data': "d2", "x": "x2"},
+ {'data': "d3", "x": "x3"},
])
eq_(
@@ -241,4 +261,3 @@ class RoundTripTest(TestCase):
(2, "d2"),
]
)
-
diff --git a/tests/test_command.py b/tests/test_command.py
index 53a9538..b550471 100644
--- a/tests/test_command.py
+++ b/tests/test_command.py
@@ -7,7 +7,6 @@ from io import TextIOWrapper, BytesIO
from alembic.script import ScriptDirectory
-
class StdoutCommandTest(unittest.TestCase):
@classmethod
diff --git a/tests/test_config.py b/tests/test_config.py
index 6164eb9..cd56d13 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -9,6 +9,7 @@ from . import Mock, call
from . import eq_, capture_db, assert_raises_message
+
def test_config_no_file_main_option():
cfg = config.Config()
cfg.set_main_option("url", "postgresql://foo/bar")
@@ -35,6 +36,7 @@ def test_standalone_op():
op.alter_column("t", "c", nullable=True)
eq_(buf, ['ALTER TABLE t ALTER COLUMN c DROP NOT NULL'])
+
def test_no_script_error():
cfg = config.Config()
assert_raises_message(
@@ -72,4 +74,3 @@ class OutputEncodingTest(unittest.TestCase):
stdout.mock_calls,
[call.write('m?il x y'), call.write('\n')]
)
-
diff --git a/tests/test_environment.py b/tests/test_environment.py
index cc5ccb8..ad47cf9 100644
--- a/tests/test_environment.py
+++ b/tests/test_environment.py
@@ -8,7 +8,9 @@ from . import Mock, call, _no_sql_testing_config, staging_env, clear_staging_env
from . import eq_, is_
+
class EnvironmentTest(unittest.TestCase):
+
def setUp(self):
staging_env()
self.cfg = _no_sql_testing_config()
diff --git a/tests/test_mssql.py b/tests/test_mssql.py
index 3205959..396692f 100644
--- a/tests/test_mssql.py
+++ b/tests/test_mssql.py
@@ -11,6 +11,7 @@ from . import op_fixture, capture_context_buffer, \
class FullEnvironmentTests(TestCase):
+
@classmethod
def setup_class(cls):
env = staging_env()
@@ -41,13 +42,14 @@ class FullEnvironmentTests(TestCase):
command.upgrade(self.cfg, self.a, sql=True)
assert "BYE" in buf.getvalue()
+
class OpTest(TestCase):
+
def test_add_column(self):
context = op_fixture('mssql')
op.add_column('t1', Column('c1', Integer, nullable=False))
context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
-
def test_add_column_with_default(self):
context = op_fixture("mssql")
op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12"))
@@ -78,8 +80,8 @@ class OpTest(TestCase):
context = op_fixture('mssql')
from sqlalchemy import Boolean
op.alter_column('tests', 'col',
- existing_type=Boolean(),
- nullable=False)
+ existing_type=Boolean(),
+ nullable=False)
context.assert_('ALTER TABLE tests ALTER COLUMN col BIT NOT NULL')
def test_drop_index(self):
@@ -95,7 +97,6 @@ class OpTest(TestCase):
context.assert_contains("exec('alter table t1 drop constraint ' + @const_name)")
context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
-
def test_alter_column_drop_default(self):
context = op_fixture('mssql')
op.alter_column("t", "c", server_default=None)
@@ -186,7 +187,7 @@ class OpTest(TestCase):
def test_alter_do_everything(self):
context = op_fixture('mssql')
op.alter_column("t", "c", new_column_name="c2", nullable=True,
- type_=Integer, server_default="5")
+ type_=Integer, server_default="5")
context.assert_(
'ALTER TABLE t ALTER COLUMN c INTEGER NULL',
"ALTER TABLE t ADD DEFAULT '5' FOR c",
@@ -199,7 +200,7 @@ class OpTest(TestCase):
context.assert_contains("EXEC sp_rename 't1', t2")
# TODO: when we add schema support
- #def test_alter_column_rename_mssql_schema(self):
+ # def test_alter_column_rename_mssql_schema(self):
# context = op_fixture('mssql')
# op.alter_column("t", "c", name="x", schema="y")
# context.assert_(
diff --git a/tests/test_mysql.py b/tests/test_mysql.py
index 16b171c..1ad1453 100644
--- a/tests/test_mysql.py
+++ b/tests/test_mysql.py
@@ -7,7 +7,9 @@ from . import op_fixture, assert_raises_message, db_for_dialect, \
staging_env, clear_staging_env
from alembic.migration import MigrationContext
+
class MySQLOpTest(TestCase):
+
def test_rename_column(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', new_column_name="c2", existing_type=Integer)
@@ -18,7 +20,7 @@ class MySQLOpTest(TestCase):
def test_rename_column_quotes_needed_one(self):
context = op_fixture('mysql')
op.alter_column('MyTable', 'ColumnOne', new_column_name="ColumnTwo",
- existing_type=Integer)
+ existing_type=Integer)
context.assert_(
'ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL'
)
@@ -26,7 +28,7 @@ class MySQLOpTest(TestCase):
def test_rename_column_quotes_needed_two(self):
context = op_fixture('mysql')
op.alter_column('my table', 'column one', new_column_name="column two",
- existing_type=Integer)
+ existing_type=Integer)
context.assert_(
'ALTER TABLE `my table` CHANGE `column one` `column two` INTEGER NULL'
)
@@ -34,7 +36,7 @@ class MySQLOpTest(TestCase):
def test_rename_column_serv_default(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', new_column_name="c2", existing_type=Integer,
- existing_server_default="q")
+ existing_server_default="q")
context.assert_(
"ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'"
)
@@ -42,7 +44,7 @@ class MySQLOpTest(TestCase):
def test_rename_column_serv_compiled_default(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', existing_type=Integer,
- server_default=func.utc_thing(func.current_timestamp()))
+ server_default=func.utc_thing(func.current_timestamp()))
# this is not a valid MySQL default but the point is to just
# test SQL expression rendering
context.assert_(
@@ -52,7 +54,7 @@ class MySQLOpTest(TestCase):
def test_rename_column_autoincrement(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', new_column_name="c2", existing_type=Integer,
- existing_autoincrement=True)
+ existing_autoincrement=True)
context.assert_(
'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT'
)
@@ -60,7 +62,7 @@ class MySQLOpTest(TestCase):
def test_col_add_autoincrement(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', existing_type=Integer,
- autoincrement=True)
+ autoincrement=True)
context.assert_(
'ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT'
)
@@ -68,18 +70,17 @@ class MySQLOpTest(TestCase):
def test_col_remove_autoincrement(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', existing_type=Integer,
- existing_autoincrement=True,
- autoincrement=False)
+ existing_autoincrement=True,
+ autoincrement=False)
context.assert_(
'ALTER TABLE t1 MODIFY c1 INTEGER NULL'
)
-
def test_col_dont_remove_server_default(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', existing_type=Integer,
- existing_server_default='1',
- server_default=False)
+ existing_server_default='1',
+ server_default=False)
context.assert_()
@@ -90,8 +91,6 @@ class MySQLOpTest(TestCase):
'ALTER TABLE t ALTER COLUMN c DROP DEFAULT'
)
-
-
def test_alter_column_modify_default(self):
context = op_fixture('mysql')
# notice we dont need the existing type on this one...
@@ -110,7 +109,7 @@ class MySQLOpTest(TestCase):
def test_col_not_nullable_existing_serv_default(self):
context = op_fixture('mysql')
op.alter_column('t1', 'c1', nullable=False, existing_type=Integer,
- existing_server_default='5')
+ existing_server_default='5')
context.assert_(
"ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT '5'"
)
@@ -191,7 +190,9 @@ class MySQLOpTest(TestCase):
op.drop_constraint, "f1", "t1"
)
+
class MySQLDefaultCompareTest(TestCase):
+
@classmethod
def setup_class(cls):
cls.bind = db_for_dialect("mysql")
@@ -209,7 +210,7 @@ class MySQLDefaultCompareTest(TestCase):
'connection': connection,
'dialect': connection.dialect,
'context': context
- }
+ }
@classmethod
def teardown_class(cls):
@@ -228,11 +229,11 @@ class MySQLDefaultCompareTest(TestCase):
alternate = txt
expected = False
t = Table("test", self.metadata,
- Column("somecol", type_, server_default=text(txt) if txt else None)
- )
+ Column("somecol", type_, server_default=text(txt) if txt else None)
+ )
t2 = Table("test", MetaData(),
- Column("somecol", type_, server_default=text(alternate))
- )
+ Column("somecol", type_, server_default=text(alternate))
+ )
assert self._compare_default(
t, t2, t2.c.somecol, alternate
) is expected
@@ -263,4 +264,3 @@ class MySQLDefaultCompareTest(TestCase):
TIMESTAMP(),
None, "CURRENT_TIMESTAMP",
)
-
diff --git a/tests/test_offline_environment.py b/tests/test_offline_environment.py
index 7026e8c..9623bcc 100644
--- a/tests/test_offline_environment.py
+++ b/tests/test_offline_environment.py
@@ -9,6 +9,7 @@ from . import clear_staging_env, staging_env, \
class OfflineEnvironmentTest(TestCase):
+
def setUp(self):
env = staging_env()
self.cfg = _no_sql_testing_config()
@@ -33,7 +34,6 @@ assert context.requires_connection()
command.upgrade(self.cfg, a)
command.downgrade(self.cfg, a)
-
def test_starting_rev_post_context(self):
env_file_fixture("""
context.configure(dialect_name='sqlite', starting_rev='x')
diff --git a/tests/test_op.py b/tests/test_op.py
index eaa0d5d..8c4e964 100644
--- a/tests/test_op.py
+++ b/tests/test_op.py
@@ -1,7 +1,7 @@
"""Test against the builders in the op.* module."""
from sqlalchemy import Integer, Column, ForeignKey, \
- Table, String, Boolean, MetaData, CheckConstraint
+ Table, String, Boolean, MetaData, CheckConstraint
from sqlalchemy.sql import column, func, text
from sqlalchemy import event
@@ -9,6 +9,7 @@ from alembic import op
from . import op_fixture, assert_raises_message, requires_094, eq_
from . import mock
+
@event.listens_for(Table, "after_parent_attach")
def _add_cols(table, metadata):
if table.name == "tbl_with_auto_appended_column":
@@ -20,16 +21,19 @@ def test_rename_table():
op.rename_table('t1', 't2')
context.assert_("ALTER TABLE t1 RENAME TO t2")
+
def test_rename_table_schema():
context = op_fixture()
op.rename_table('t1', 't2', schema="foo")
context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2")
+
def test_rename_table_postgresql():
context = op_fixture("postgresql")
op.rename_table('t1', 't2')
context.assert_("ALTER TABLE t1 RENAME TO t2")
+
def test_rename_table_schema_postgresql():
context = op_fixture("postgresql")
op.rename_table('t1', 't2', schema="foo")
@@ -76,6 +80,7 @@ def test_create_index_postgresql_expressions():
"CREATE INDEX geocoded ON locations (lower(coordinates)) "
"WHERE locations.coordinates != Null")
+
def test_create_index_postgresql_where():
context = op_fixture("postgresql")
op.create_index(
@@ -84,31 +89,36 @@ def test_create_index_postgresql_where():
['coordinates'],
postgresql_where=text("locations.coordinates != Null"))
context.assert_(
- "CREATE INDEX geocoded ON locations (coordinates) "
- "WHERE locations.coordinates != Null")
+ "CREATE INDEX geocoded ON locations (coordinates) "
+ "WHERE locations.coordinates != Null")
+
def test_add_column():
context = op_fixture()
op.add_column('t1', Column('c1', Integer, nullable=False))
context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL")
+
def test_add_column_schema():
context = op_fixture()
op.add_column('t1', Column('c1', Integer, nullable=False), schema="foo")
context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL")
+
def test_add_column_with_default():
context = op_fixture()
op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12"))
context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL")
+
def test_add_column_schema_with_default():
context = op_fixture()
op.add_column('t1',
- Column('c1', Integer, nullable=False, server_default="12"),
- schema='foo')
+ Column('c1', Integer, nullable=False, server_default="12"),
+ schema='foo')
context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL")
+
def test_add_column_fk():
context = op_fixture()
op.add_column('t1', Column('c1', Integer, ForeignKey('c2.id'), nullable=False))
@@ -117,16 +127,18 @@ def test_add_column_fk():
"ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)"
)
+
def test_add_column_schema_fk():
context = op_fixture()
op.add_column('t1',
- Column('c1', Integer, ForeignKey('c2.id'), nullable=False),
- schema='foo')
+ Column('c1', Integer, ForeignKey('c2.id'), nullable=False),
+ schema='foo')
context.assert_(
"ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
"ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)"
)
+
def test_add_column_schema_type():
"""Test that a schema type generates its constraints...."""
context = op_fixture()
@@ -146,6 +158,7 @@ def test_add_column_schema_schema_type():
'ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))'
)
+
def test_add_column_schema_type_checks_rule():
"""Test that a schema type doesn't generate a
constraint based on check rule."""
@@ -155,6 +168,7 @@ def test_add_column_schema_type_checks_rule():
'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
)
+
def test_add_column_fk_self_referential():
context = op_fixture()
op.add_column('t1', Column('c1', Integer, ForeignKey('t1.c2'), nullable=False))
@@ -163,44 +177,50 @@ def test_add_column_fk_self_referential():
"ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)"
)
+
def test_add_column_schema_fk_self_referential():
context = op_fixture()
op.add_column('t1',
- Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False),
- schema='foo')
+ Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False),
+ schema='foo')
context.assert_(
"ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
"ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)"
)
+
def test_add_column_fk_schema():
context = op_fixture()
op.add_column('t1', Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False))
context.assert_(
- 'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL',
- 'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
+ 'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL',
+ 'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
)
+
def test_add_column_schema_fk_schema():
context = op_fixture()
op.add_column('t1',
- Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False),
- schema='foo')
+ Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False),
+ schema='foo')
context.assert_(
- 'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL',
- 'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
+ 'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL',
+ 'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
)
+
def test_drop_column():
context = op_fixture()
op.drop_column('t1', 'c1')
context.assert_("ALTER TABLE t1 DROP COLUMN c1")
+
def test_drop_column_schema():
context = op_fixture()
op.drop_column('t1', 'c1', schema='foo')
context.assert_("ALTER TABLE foo.t1 DROP COLUMN c1")
+
def test_alter_column_nullable():
context = op_fixture()
op.alter_column("t", "c", nullable=True)
@@ -210,6 +230,7 @@ def test_alter_column_nullable():
"ALTER TABLE t ALTER COLUMN c DROP NOT NULL"
)
+
def test_alter_column_schema_nullable():
context = op_fixture()
op.alter_column("t", "c", nullable=True, schema='foo')
@@ -219,6 +240,7 @@ def test_alter_column_schema_nullable():
"ALTER TABLE foo.t ALTER COLUMN c DROP NOT NULL"
)
+
def test_alter_column_not_nullable():
context = op_fixture()
op.alter_column("t", "c", nullable=False)
@@ -228,6 +250,7 @@ def test_alter_column_not_nullable():
"ALTER TABLE t ALTER COLUMN c SET NOT NULL"
)
+
def test_alter_column_schema_not_nullable():
context = op_fixture()
op.alter_column("t", "c", nullable=False, schema='foo')
@@ -237,6 +260,7 @@ def test_alter_column_schema_not_nullable():
"ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL"
)
+
def test_alter_column_rename():
context = op_fixture()
op.alter_column("t", "c", new_column_name="x")
@@ -244,6 +268,7 @@ def test_alter_column_rename():
"ALTER TABLE t RENAME c TO x"
)
+
def test_alter_column_schema_rename():
context = op_fixture()
op.alter_column("t", "c", new_column_name="x", schema='foo')
@@ -251,6 +276,7 @@ def test_alter_column_schema_rename():
"ALTER TABLE foo.t RENAME c TO x"
)
+
def test_alter_column_type():
context = op_fixture()
op.alter_column("t", "c", type_=String(50))
@@ -258,6 +284,7 @@ def test_alter_column_type():
'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)'
)
+
def test_alter_column_schema_type():
context = op_fixture()
op.alter_column("t", "c", type_=String(50), schema='foo')
@@ -265,6 +292,7 @@ def test_alter_column_schema_type():
'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)'
)
+
def test_alter_column_set_default():
context = op_fixture()
op.alter_column("t", "c", server_default="q")
@@ -272,6 +300,7 @@ def test_alter_column_set_default():
"ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'"
)
+
def test_alter_column_schema_set_default():
context = op_fixture()
op.alter_column("t", "c", server_default="q", schema='foo')
@@ -279,23 +308,26 @@ def test_alter_column_schema_set_default():
"ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'"
)
+
def test_alter_column_set_compiled_default():
context = op_fixture()
op.alter_column("t", "c",
- server_default=func.utc_thing(func.current_timestamp()))
+ server_default=func.utc_thing(func.current_timestamp()))
context.assert_(
"ALTER TABLE t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
)
+
def test_alter_column_schema_set_compiled_default():
context = op_fixture()
op.alter_column("t", "c",
- server_default=func.utc_thing(func.current_timestamp()),
- schema='foo')
+ server_default=func.utc_thing(func.current_timestamp()),
+ schema='foo')
context.assert_(
"ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
)
+
def test_alter_column_drop_default():
context = op_fixture()
op.alter_column("t", "c", server_default=None)
@@ -303,6 +335,7 @@ def test_alter_column_drop_default():
'ALTER TABLE t ALTER COLUMN c DROP DEFAULT'
)
+
def test_alter_column_schema_drop_default():
context = op_fixture()
op.alter_column("t", "c", server_default=None, schema='foo')
@@ -319,6 +352,7 @@ def test_alter_column_schema_type_unnamed():
'ALTER TABLE t ADD CHECK (c IN (0, 1))'
)
+
def test_alter_column_schema_schema_type_unnamed():
context = op_fixture('mssql')
op.alter_column("t", "c", type_=Boolean(), schema='foo')
@@ -327,6 +361,7 @@ def test_alter_column_schema_schema_type_unnamed():
'ALTER TABLE foo.t ADD CHECK (c IN (0, 1))'
)
+
def test_alter_column_schema_type_named():
context = op_fixture('mssql')
op.alter_column("t", "c", type_=Boolean(name="xyz"))
@@ -335,6 +370,7 @@ def test_alter_column_schema_type_named():
'ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))'
)
+
def test_alter_column_schema_schema_type_named():
context = op_fixture('mssql')
op.alter_column("t", "c", type_=Boolean(name="xyz"), schema='foo')
@@ -343,6 +379,7 @@ def test_alter_column_schema_schema_type_named():
'ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))'
)
+
def test_alter_column_schema_type_existing_type():
context = op_fixture('mssql')
op.alter_column("t", "c", type_=String(10), existing_type=Boolean(name="xyz"))
@@ -351,15 +388,17 @@ def test_alter_column_schema_type_existing_type():
'ALTER TABLE t ALTER COLUMN c VARCHAR(10)'
)
+
def test_alter_column_schema_schema_type_existing_type():
context = op_fixture('mssql')
op.alter_column("t", "c", type_=String(10),
- existing_type=Boolean(name="xyz"), schema='foo')
+ existing_type=Boolean(name="xyz"), schema='foo')
context.assert_(
'ALTER TABLE foo.t DROP CONSTRAINT xyz',
'ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)'
)
+
def test_alter_column_schema_type_existing_type_no_const():
context = op_fixture('postgresql')
op.alter_column("t", "c", type_=String(10), existing_type=Boolean())
@@ -367,14 +406,16 @@ def test_alter_column_schema_type_existing_type_no_const():
'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)'
)
+
def test_alter_column_schema_schema_type_existing_type_no_const():
context = op_fixture('postgresql')
op.alter_column("t", "c", type_=String(10), existing_type=Boolean(),
- schema='foo')
+ schema='foo')
context.assert_(
'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)'
)
+
def test_alter_column_schema_type_existing_type_no_new_type():
context = op_fixture('postgresql')
op.alter_column("t", "c", nullable=False, existing_type=Boolean())
@@ -382,94 +423,104 @@ def test_alter_column_schema_type_existing_type_no_new_type():
'ALTER TABLE t ALTER COLUMN c SET NOT NULL'
)
+
def test_alter_column_schema_schema_type_existing_type_no_new_type():
context = op_fixture('postgresql')
op.alter_column("t", "c", nullable=False, existing_type=Boolean(),
- schema='foo')
+ schema='foo')
context.assert_(
'ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL'
)
+
def test_add_foreign_key():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'])
+ ['foo', 'bar'], ['bat', 'hoho'])
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES t2 (bat, hoho)"
+ "REFERENCES t2 (bat, hoho)"
)
+
def test_add_foreign_key_schema():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- source_schema='foo2', referent_schema='bar2')
+ ['foo', 'bar'], ['bat', 'hoho'],
+ source_schema='foo2', referent_schema='bar2')
context.assert_(
"ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES bar2.t2 (bat, hoho)"
+ "REFERENCES bar2.t2 (bat, hoho)"
)
+
def test_add_foreign_key_onupdate():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- onupdate='CASCADE')
+ ['foo', 'bar'], ['bat', 'hoho'],
+ onupdate='CASCADE')
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES t2 (bat, hoho) ON UPDATE CASCADE"
+ "REFERENCES t2 (bat, hoho) ON UPDATE CASCADE"
)
+
def test_add_foreign_key_ondelete():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- ondelete='CASCADE')
+ ['foo', 'bar'], ['bat', 'hoho'],
+ ondelete='CASCADE')
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES t2 (bat, hoho) ON DELETE CASCADE"
+ "REFERENCES t2 (bat, hoho) ON DELETE CASCADE"
)
+
def test_add_foreign_key_deferrable():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- deferrable=True)
+ ['foo', 'bar'], ['bat', 'hoho'],
+ deferrable=True)
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES t2 (bat, hoho) DEFERRABLE"
+ "REFERENCES t2 (bat, hoho) DEFERRABLE"
)
+
def test_add_foreign_key_initially():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- initially='INITIAL')
+ ['foo', 'bar'], ['bat', 'hoho'],
+ initially='INITIAL')
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES t2 (bat, hoho) INITIALLY INITIAL"
+ "REFERENCES t2 (bat, hoho) INITIALLY INITIAL"
)
+
def test_add_foreign_key_match():
context = op_fixture()
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- match='SIMPLE')
+ ['foo', 'bar'], ['bat', 'hoho'],
+ match='SIMPLE')
context.assert_(
"ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
- "REFERENCES t2 (bat, hoho) MATCH SIMPLE"
+ "REFERENCES t2 (bat, hoho) MATCH SIMPLE"
)
+
def test_add_foreign_key_dialect_kw():
context = op_fixture()
with mock.patch("alembic.operations.sa_schema.ForeignKeyConstraint") as fkc:
op.create_foreign_key('fk_test', 't1', 't2',
- ['foo', 'bar'], ['bat', 'hoho'],
- foobar_arg='xyz')
+ ['foo', 'bar'], ['bat', 'hoho'],
+ foobar_arg='xyz')
eq_(fkc.mock_calls[0],
- mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'],
- onupdate=None, ondelete=None, name='fk_test',
- foobar_arg='xyz',
- deferrable=None, initially=None, match=None))
+ mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'],
+ onupdate=None, ondelete=None, name='fk_test',
+ foobar_arg='xyz',
+ deferrable=None, initially=None, match=None))
+
def test_add_foreign_key_self_referential():
context = op_fixture()
@@ -479,6 +530,7 @@ def test_add_foreign_key_self_referential():
"FOREIGN KEY(foo) REFERENCES t1 (bar)"
)
+
def test_add_primary_key_constraint():
context = op_fixture()
op.create_primary_key("pk_test", "t1", ["foo", "bar"])
@@ -486,6 +538,7 @@ def test_add_primary_key_constraint():
"ALTER TABLE t1 ADD CONSTRAINT pk_test PRIMARY KEY (foo, bar)"
)
+
def test_add_primary_key_constraint_schema():
context = op_fixture()
op.create_primary_key("pk_test", "t1", ["foo"], schema="bar")
@@ -506,6 +559,7 @@ def test_add_check_constraint():
"CHECK (len(name) > 5)"
)
+
def test_add_check_constraint_schema():
context = op_fixture()
op.create_check_constraint(
@@ -519,6 +573,7 @@ def test_add_check_constraint_schema():
"CHECK (len(name) > 5)"
)
+
def test_add_unique_constraint():
context = op_fixture()
op.create_unique_constraint('uk_test', 't1', ['foo', 'bar'])
@@ -526,6 +581,7 @@ def test_add_unique_constraint():
"ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
)
+
def test_add_unique_constraint_schema():
context = op_fixture()
op.create_unique_constraint('uk_test', 't1', ['foo', 'bar'], schema='foo')
@@ -541,6 +597,7 @@ def test_drop_constraint():
"ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat"
)
+
def test_drop_constraint_schema():
context = op_fixture()
op.drop_constraint('foo_bar_bat', 't1', schema='foo')
@@ -548,6 +605,7 @@ def test_drop_constraint_schema():
"ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat"
)
+
def test_create_index():
context = op_fixture()
op.create_index('ik_test', 't1', ['foo', 'bar'])
@@ -564,10 +622,11 @@ def test_create_index_table_col_event():
"CREATE INDEX ik_test ON tbl_with_auto_appended_column (foo, bar)"
)
+
def test_add_unique_constraint_col_event():
context = op_fixture()
op.create_unique_constraint('ik_test',
- 'tbl_with_auto_appended_column', ['foo', 'bar'])
+ 'tbl_with_auto_appended_column', ['foo', 'bar'])
context.assert_(
"ALTER TABLE tbl_with_auto_appended_column "
"ADD CONSTRAINT ik_test UNIQUE (foo, bar)"
@@ -581,6 +640,7 @@ def test_create_index_schema():
"CREATE INDEX ik_test ON foo.t1 (foo, bar)"
)
+
def test_drop_index():
context = op_fixture()
op.drop_index('ik_test')
@@ -588,6 +648,7 @@ def test_drop_index():
"DROP INDEX ik_test"
)
+
def test_drop_index_schema():
context = op_fixture()
op.drop_index('ik_test', schema='foo')
@@ -595,6 +656,7 @@ def test_drop_index_schema():
"DROP INDEX foo.ik_test"
)
+
def test_drop_table():
context = op_fixture()
op.drop_table('tb_test')
@@ -602,6 +664,7 @@ def test_drop_table():
"DROP TABLE tb_test"
)
+
def test_drop_table_schema():
context = op_fixture()
op.drop_table('tb_test', schema='foo')
@@ -609,6 +672,7 @@ def test_drop_table_schema():
"DROP TABLE foo.tb_test"
)
+
def test_create_table_selfref():
context = op_fixture()
op.create_table(
@@ -618,12 +682,13 @@ def test_create_table_selfref():
)
context.assert_(
"CREATE TABLE some_table ("
- "id INTEGER NOT NULL, "
- "st_id INTEGER, "
- "PRIMARY KEY (id), "
- "FOREIGN KEY(st_id) REFERENCES some_table (id))"
+ "id INTEGER NOT NULL, "
+ "st_id INTEGER, "
+ "PRIMARY KEY (id), "
+ "FOREIGN KEY(st_id) REFERENCES some_table (id))"
)
+
def test_create_table_fk_and_schema():
context = op_fixture()
op.create_table(
@@ -634,12 +699,13 @@ def test_create_table_fk_and_schema():
)
context.assert_(
"CREATE TABLE schema.some_table ("
- "id INTEGER NOT NULL, "
- "foo_id INTEGER, "
- "PRIMARY KEY (id), "
- "FOREIGN KEY(foo_id) REFERENCES foo (id))"
+ "id INTEGER NOT NULL, "
+ "foo_id INTEGER, "
+ "PRIMARY KEY (id), "
+ "FOREIGN KEY(foo_id) REFERENCES foo (id))"
)
+
def test_create_table_no_pk():
context = op_fixture()
op.create_table(
@@ -652,6 +718,7 @@ def test_create_table_no_pk():
"CREATE TABLE some_table (x INTEGER, y INTEGER, z INTEGER)"
)
+
def test_create_table_two_fk():
context = op_fixture()
op.create_table(
@@ -662,38 +729,40 @@ def test_create_table_two_fk():
)
context.assert_(
"CREATE TABLE some_table ("
- "id INTEGER NOT NULL, "
- "foo_id INTEGER, "
- "foo_bar INTEGER, "
- "PRIMARY KEY (id), "
- "FOREIGN KEY(foo_id) REFERENCES foo (id), "
- "FOREIGN KEY(foo_bar) REFERENCES foo (bar))"
+ "id INTEGER NOT NULL, "
+ "foo_id INTEGER, "
+ "foo_bar INTEGER, "
+ "PRIMARY KEY (id), "
+ "FOREIGN KEY(foo_id) REFERENCES foo (id), "
+ "FOREIGN KEY(foo_bar) REFERENCES foo (bar))"
)
+
def test_inline_literal():
context = op_fixture()
from sqlalchemy.sql import table, column
from sqlalchemy import String, Integer
account = table('account',
- column('name', String),
- column('id', Integer)
- )
+ column('name', String),
+ column('id', Integer)
+ )
op.execute(
- account.update().\
- where(account.c.name == op.inline_literal('account 1')).\
- values({'name': op.inline_literal('account 2')})
- )
+ account.update().
+ where(account.c.name == op.inline_literal('account 1')).
+ values({'name': op.inline_literal('account 2')})
+ )
op.execute(
- account.update().\
- where(account.c.id == op.inline_literal(1)).\
- values({'id': op.inline_literal(2)})
- )
+ account.update().
+ where(account.c.id == op.inline_literal(1)).
+ values({'id': op.inline_literal(2)})
+ )
context.assert_(
"UPDATE account SET name='account 2' WHERE account.name = 'account 1'",
"UPDATE account SET id=2 WHERE account.id = 1"
)
+
def test_cant_op():
if hasattr(op, '_proxy'):
del op._proxy
@@ -733,4 +802,3 @@ def test_naming_changes():
r"Unknown arguments: badarg\d, badarg\d",
op.alter_column, "t", "c", badarg1="x", badarg2="y"
)
-
diff --git a/tests/test_op_naming_convention.py b/tests/test_op_naming_convention.py
index b0b5b76..3f80ecf 100644
--- a/tests/test_op_naming_convention.py
+++ b/tests/test_op_naming_convention.py
@@ -1,16 +1,17 @@
from sqlalchemy import Integer, Column, ForeignKey, \
- Table, String, Boolean, MetaData, CheckConstraint
+ Table, String, Boolean, MetaData, CheckConstraint
from sqlalchemy.sql import column, func, text
from sqlalchemy import event
from alembic import op
from . import op_fixture, assert_raises_message, requires_094
+
@requires_094
def test_add_check_constraint():
context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"
- })
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
op.create_check_constraint(
"foo",
"user_table",
@@ -21,11 +22,12 @@ def test_add_check_constraint():
"CHECK (len(name) > 5)"
)
+
@requires_094
def test_add_check_constraint_name_is_none():
context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_foo"
- })
+ "ck": "ck_%(table_name)s_foo"
+ })
op.create_check_constraint(
None,
"user_table",
@@ -36,11 +38,12 @@ def test_add_check_constraint_name_is_none():
"CHECK (len(name) > 5)"
)
+
@requires_094
def test_add_unique_constraint_name_is_none():
context = op_fixture(naming_convention={
- "uq": "uq_%(table_name)s_foo"
- })
+ "uq": "uq_%(table_name)s_foo"
+ })
op.create_unique_constraint(
None,
"user_table",
@@ -54,8 +57,8 @@ def test_add_unique_constraint_name_is_none():
@requires_094
def test_add_index_name_is_none():
context = op_fixture(naming_convention={
- "ix": "ix_%(table_name)s_foo"
- })
+ "ix": "ix_%(table_name)s_foo"
+ })
op.create_index(
None,
"user_table",
@@ -66,7 +69,6 @@ def test_add_index_name_is_none():
)
-
@requires_094
def test_add_check_constraint_already_named_from_schema():
m1 = MetaData(naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
@@ -74,7 +76,7 @@ def test_add_check_constraint_already_named_from_schema():
Table('t', m1, Column('x'), ck)
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
op.create_table(
"some_table",
@@ -85,10 +87,11 @@ def test_add_check_constraint_already_named_from_schema():
"(x INTEGER CONSTRAINT ck_t_cc1 CHECK (im a constraint))"
)
+
@requires_094
def test_add_check_constraint_inline_on_table():
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
op.create_table(
"some_table",
Column('x', Integer),
@@ -99,10 +102,11 @@ def test_add_check_constraint_inline_on_table():
"(x INTEGER, CONSTRAINT ck_some_table_cc1 CHECK (im a constraint))"
)
+
@requires_094
def test_add_check_constraint_inline_on_table_w_f():
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
op.create_table(
"some_table",
Column('x', Integer),
@@ -113,10 +117,11 @@ def test_add_check_constraint_inline_on_table_w_f():
"(x INTEGER, CONSTRAINT ck_some_table_cc1 CHECK (im a constraint))"
)
+
@requires_094
def test_add_check_constraint_inline_on_column():
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
op.create_table(
"some_table",
Column('x', Integer, CheckConstraint("im a constraint", name="cc1"))
@@ -126,10 +131,11 @@ def test_add_check_constraint_inline_on_column():
"(x INTEGER CONSTRAINT ck_some_table_cc1 CHECK (im a constraint))"
)
+
@requires_094
def test_add_check_constraint_inline_on_column_w_f():
context = op_fixture(
- naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+ naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
op.create_table(
"some_table",
Column('x', Integer, CheckConstraint("im a constraint", name=op.f("ck_q_cc1")))
@@ -143,8 +149,8 @@ def test_add_check_constraint_inline_on_column_w_f():
@requires_094
def test_add_column_schema_type():
context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"
- })
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
op.add_column('t1', Column('c1', Boolean(name='foo'), nullable=False))
context.assert_(
'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
@@ -155,12 +161,10 @@ def test_add_column_schema_type():
@requires_094
def test_add_column_schema_type_w_f():
context = op_fixture(naming_convention={
- "ck": "ck_%(table_name)s_%(constraint_name)s"
- })
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
op.add_column('t1', Column('c1', Boolean(name=op.f('foo')), nullable=False))
context.assert_(
'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
'ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))'
)
-
-
diff --git a/tests/test_oracle.py b/tests/test_oracle.py
index d443a71..781a1ab 100644
--- a/tests/test_oracle.py
+++ b/tests/test_oracle.py
@@ -11,6 +11,7 @@ from . import op_fixture, capture_context_buffer, \
class FullEnvironmentTests(TestCase):
+
@classmethod
def setup_class(cls):
env = staging_env()
@@ -40,13 +41,14 @@ class FullEnvironmentTests(TestCase):
command.upgrade(self.cfg, self.a, sql=True)
assert "BYE" in buf.getvalue()
+
class OpTest(TestCase):
+
def test_add_column(self):
context = op_fixture('oracle')
op.add_column('t1', Column('c1', Integer, nullable=False))
context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
-
def test_add_column_with_default(self):
context = op_fixture("oracle")
op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12"))
@@ -147,10 +149,9 @@ class OpTest(TestCase):
)
# TODO: when we add schema support
- #def test_alter_column_rename_oracle_schema(self):
+ # def test_alter_column_rename_oracle_schema(self):
# context = op_fixture('oracle')
# op.alter_column("t", "c", name="x", schema="y")
# context.assert_(
# 'ALTER TABLE y.t RENAME COLUMN c TO c2'
# )
-
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index 2e0965e..4cd160b 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -1,7 +1,7 @@
from unittest import TestCase
from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
- String, Interval
+ String, Interval
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.schema import DefaultClause
from sqlalchemy.engine.reflection import Inspector
@@ -13,10 +13,12 @@ from alembic import command, util
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
from . import db_for_dialect, eq_, staging_env, \
- clear_staging_env, _no_sql_testing_config,\
- capture_context_buffer, requires_09, write_script
+ clear_staging_env, _no_sql_testing_config,\
+ capture_context_buffer, requires_09, write_script
+
class PGOfflineEnumTest(TestCase):
+
def setUp(self):
staging_env()
self.cfg = cfg = _no_sql_testing_config()
@@ -29,7 +31,6 @@ class PGOfflineEnumTest(TestCase):
def tearDown(self):
clear_staging_env()
-
def _inline_enum_script(self):
write_script(self.script, self.rid, """
revision = '%s'
@@ -103,6 +104,7 @@ def downgrade():
class PostgresqlInlineLiteralTest(TestCase):
+
@classmethod
def setup_class(cls):
cls.bind = db_for_dialect("postgresql")
@@ -144,7 +146,9 @@ class PostgresqlInlineLiteralTest(TestCase):
1,
)
+
class PostgresqlDefaultCompareTest(TestCase):
+
@classmethod
def setup_class(cls):
cls.bind = db_for_dialect("postgresql")
@@ -180,19 +184,19 @@ class PostgresqlDefaultCompareTest(TestCase):
alternate = orig_default
t1 = Table("test", self.metadata,
- Column("somecol", type_, server_default=orig_default))
+ Column("somecol", type_, server_default=orig_default))
t2 = Table("test", MetaData(),
- Column("somecol", type_, server_default=alternate))
+ Column("somecol", type_, server_default=alternate))
t1.create(self.bind)
insp = Inspector.from_engine(self.bind)
cols = insp.get_columns(t1.name)
insp_col = Column("somecol", cols[0]['type'],
- server_default=text(cols[0]['default']))
+ server_default=text(cols[0]['default']))
diffs = []
_compare_server_default(None, "test", "somecol", insp_col,
- t2.c.somecol, diffs, self.autogen_context)
+ t2.c.somecol, diffs, self.autogen_context)
eq_(bool(diffs), diff_expected)
def _compare_default(
@@ -284,11 +288,11 @@ class PostgresqlDefaultCompareTest(TestCase):
def test_primary_key_skip(self):
"""Test that SERIAL cols are just skipped"""
t1 = Table("sometable", self.metadata,
- Column("id", Integer, primary_key=True)
- )
+ Column("id", Integer, primary_key=True)
+ )
t2 = Table("sometable", MetaData(),
- Column("id", Integer, primary_key=True)
- )
+ Column("id", Integer, primary_key=True)
+ )
assert not self._compare_default(
t1, t2, t2.c.id, ""
)
diff --git a/tests/test_revision_create.py b/tests/test_revision_create.py
index 5bf12cf..cbe2a6e 100644
--- a/tests/test_revision_create.py
+++ b/tests/test_revision_create.py
@@ -10,7 +10,9 @@ import datetime
env, abc, def_ = None, None, None
+
class GeneralOrderedTests(unittest.TestCase):
+
def test_001_environment(self):
assert_set = set(['env.py', 'script.py.mako', 'README'])
eq_(
@@ -76,28 +78,26 @@ class GeneralOrderedTests(unittest.TestCase):
def test_008_long_name(self):
rid = util.rev_id()
env.generate_revision(rid,
- "this is a really long name with "
- "lots of characters and also "
- "I'd like it to\nhave\nnewlines")
+ "this is a really long name with "
+ "lots of characters and also "
+ "I'd like it to\nhave\nnewlines")
assert os.access(
os.path.join(env.dir, 'versions',
- '%s_this_is_a_really_long_name_with_lots_of_.py' % rid),
- os.F_OK)
-
+ '%s_this_is_a_really_long_name_with_lots_of_.py' % rid),
+ os.F_OK)
def test_009_long_name_configurable(self):
env.truncate_slug_length = 60
rid = util.rev_id()
env.generate_revision(rid,
- "this is a really long name with "
- "lots of characters and also "
- "I'd like it to\nhave\nnewlines")
+ "this is a really long name with "
+ "lots of characters and also "
+ "I'd like it to\nhave\nnewlines")
assert os.access(
os.path.join(env.dir, 'versions',
- '%s_this_is_a_really_long_name_with_lots_'
- 'of_characters_and_also_.py' % rid),
- os.F_OK)
-
+ '%s_this_is_a_really_long_name_with_lots_'
+ 'of_characters_and_also_.py' % rid),
+ os.F_OK)
@classmethod
def setup_class(cls):
@@ -108,7 +108,9 @@ class GeneralOrderedTests(unittest.TestCase):
def teardown_class(cls):
clear_staging_env()
+
class ScriptNamingTest(unittest.TestCase):
+
@classmethod
def setup_class(cls):
_testing_config()
@@ -119,12 +121,12 @@ class ScriptNamingTest(unittest.TestCase):
def test_args(self):
script = ScriptDirectory(
- staging_directory,
- file_template="%(rev)s_%(slug)s_"
- "%(year)s_%(month)s_"
- "%(day)s_%(hour)s_"
- "%(minute)s_%(second)s"
- )
+ staging_directory,
+ file_template="%(rev)s_%(slug)s_"
+ "%(year)s_%(month)s_"
+ "%(day)s_%(hour)s_"
+ "%(minute)s_%(second)s"
+ )
create_date = datetime.datetime(2012, 7, 25, 15, 8, 5)
eq_(
script._rev_path("12345", "this is a message", create_date),
@@ -134,6 +136,7 @@ class ScriptNamingTest(unittest.TestCase):
class TemplateArgsTest(unittest.TestCase):
+
def setUp(self):
staging_env()
self.cfg = _no_sql_testing_config(
@@ -153,7 +156,7 @@ class TemplateArgsTest(unittest.TestCase):
template_args=template_args
)
env.configure(dialect_name="sqlite",
- template_args={"y": "y2", "q": "q1"})
+ template_args={"y": "y2", "q": "q1"})
eq_(
template_args,
{"x": "x1", "y": "y2", "z": "z1", "q": "q1"}
@@ -206,4 +209,3 @@ down_revision = ${repr(down_revision)}
with open(rev.path) as f:
text = f.read()
assert "somearg: somevalue" in text
-
diff --git a/tests/test_revision_paths.py b/tests/test_revision_paths.py
index 5a02189..15da250 100644
--- a/tests/test_revision_paths.py
+++ b/tests/test_revision_paths.py
@@ -6,6 +6,7 @@ env = None
a, b, c, d, e = None, None, None, None, None
cfg = None
+
def setup():
global env
env = staging_env()
@@ -16,6 +17,7 @@ def setup():
d = env.generate_revision(util.rev_id(), 'c->d', refresh=True)
e = env.generate_revision(util.rev_id(), 'd->e', refresh=True)
+
def teardown():
clear_staging_env()
@@ -39,6 +41,7 @@ def test_upgrade_path():
]
)
+
def test_relative_upgrade_path():
eq_(
env._upgrade_revs("+2", a.revision),
@@ -64,6 +67,7 @@ def test_relative_upgrade_path():
]
)
+
def test_invalid_relative_upgrade_path():
assert_raises_message(
util.CommandError,
@@ -77,6 +81,7 @@ def test_invalid_relative_upgrade_path():
env._upgrade_revs, "+5", b.revision
)
+
def test_downgrade_path():
eq_(
@@ -96,6 +101,7 @@ def test_downgrade_path():
]
)
+
def test_relative_downgrade_path():
eq_(
env._downgrade_revs("-1", c.revision),
@@ -113,6 +119,7 @@ def test_relative_downgrade_path():
]
)
+
def test_invalid_relative_downgrade_path():
assert_raises_message(
util.CommandError,
@@ -126,6 +133,7 @@ def test_invalid_relative_downgrade_path():
env._downgrade_revs, "+2", b.revision
)
+
def test_invalid_move_rev_to_none():
assert_raises_message(
util.CommandError,
@@ -133,10 +141,10 @@ def test_invalid_move_rev_to_none():
env._downgrade_revs, b.revision[0:3], None
)
+
def test_invalid_move_higher_to_lower():
assert_raises_message(
- util.CommandError,
+ util.CommandError,
"Revision %s is not an ancestor of %s" % (c.revision, b.revision),
env._downgrade_revs, c.revision[0:4], b.revision
)
-
diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py
index 7aae797..ba64df7 100644
--- a/tests/test_sql_script.py
+++ b/tests/test_sql_script.py
@@ -14,6 +14,7 @@ import re
cfg = None
a, b, c = None, None, None
+
class ThreeRevTest(unittest.TestCase):
def setUp(self):
@@ -32,11 +33,11 @@ class ThreeRevTest(unittest.TestCase):
with capture_context_buffer(transactional_ddl=True) as buf:
command.upgrade(cfg, c, sql=True)
assert re.match(
- (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % a) +
- (r".*%s" % b) +
- (r".*%s.*?COMMIT;.*$" % c),
+ (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % a) +
+ (r".*%s" % b) +
+ (r".*%s.*?COMMIT;.*$" % c),
- buf.getvalue(), re.S)
+ buf.getvalue(), re.S)
def test_begin_commit_nontransactional_ddl(self):
with capture_context_buffer(transactional_ddl=False) as buf:
@@ -48,11 +49,11 @@ class ThreeRevTest(unittest.TestCase):
with capture_context_buffer(transaction_per_migration=True) as buf:
command.upgrade(cfg, c, sql=True)
assert re.match(
- (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % a) +
- (r"BEGIN;.*?%s.*?COMMIT;.*" % b) +
- (r"BEGIN;.*?%s.*?COMMIT;.*$" % c),
+ (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % a) +
+ (r"BEGIN;.*?%s.*?COMMIT;.*" % b) +
+ (r"BEGIN;.*?%s.*?COMMIT;.*$" % c),
- buf.getvalue(), re.S)
+ buf.getvalue(), re.S)
def test_version_from_none_insert(self):
with capture_context_buffer() as buf:
@@ -99,6 +100,7 @@ class ThreeRevTest(unittest.TestCase):
class EncodingTest(unittest.TestCase):
+
def setUp(self):
global cfg, env, a
env = staging_env()
@@ -128,8 +130,8 @@ def downgrade():
def test_encode(self):
with capture_context_buffer(
- bytes_io=True,
- output_encoding='utf-8'
- ) as buf:
+ bytes_io=True,
+ output_encoding='utf-8'
+ ) as buf:
command.upgrade(cfg, a, sql=True)
assert "« S’il vous plaît…".encode("utf-8") in buf.getvalue()
diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py
index 9ceb78e..ea9411d 100644
--- a/tests/test_sqlite.py
+++ b/tests/test_sqlite.py
@@ -1,8 +1,9 @@
from tests import op_fixture, assert_raises_message
from alembic import op
-from sqlalchemy import Integer, Column, Boolean
+from sqlalchemy import Integer, Column, Boolean
from sqlalchemy.sql import column
+
def test_add_column():
context = op_fixture('sqlite')
op.add_column('t1', Column('c1', Integer))
@@ -10,6 +11,7 @@ def test_add_column():
'ALTER TABLE t1 ADD COLUMN c1 INTEGER'
)
+
def test_add_column_implicit_constraint():
context = op_fixture('sqlite')
op.add_column('t1', Column('c1', Boolean))
@@ -17,6 +19,7 @@ def test_add_column_implicit_constraint():
'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN'
)
+
def test_add_explicit_constraint():
context = op_fixture('sqlite')
assert_raises_message(
@@ -28,6 +31,7 @@ def test_add_explicit_constraint():
column('name') > 5
)
+
def test_drop_explicit_constraint():
context = op_fixture('sqlite')
assert_raises_message(
@@ -37,4 +41,3 @@ def test_drop_explicit_constraint():
"foo",
"sometable",
)
-
diff --git a/tests/test_version_table.py b/tests/test_version_table.py
index 3a0a54d..98dec50 100644
--- a/tests/test_version_table.py
+++ b/tests/test_version_table.py
@@ -8,6 +8,7 @@ from alembic.util import CommandError
version_table = Table('version_table', MetaData(),
Column('version_num', String(32), nullable=False))
+
class TestMigrationContext(unittest.TestCase):
_bind = []
diff --git a/tests/test_versioning.py b/tests/test_versioning.py
index 68440fc..7c59a12 100644
--- a/tests/test_versioning.py
+++ b/tests/test_versioning.py
@@ -7,6 +7,7 @@ from . import clear_staging_env, staging_env, \
_sqlite_testing_config, sqlite_db, eq_, write_script, \
assert_raises_message
+
class VersioningTest(unittest.TestCase):
sourceless = False
@@ -62,7 +63,6 @@ class VersioningTest(unittest.TestCase):
""" % (c, b), sourceless=self.sourceless)
-
def test_002_upgrade(self):
command.upgrade(self.cfg, c)
db = sqlite_db()
@@ -94,7 +94,6 @@ class VersioningTest(unittest.TestCase):
def test_006_upgrade_again(self):
command.upgrade(self.cfg, b)
-
# TODO: test some invalid movements
@classmethod
@@ -106,7 +105,9 @@ class VersioningTest(unittest.TestCase):
def teardown_class(cls):
clear_staging_env()
+
class VersionNameTemplateTest(unittest.TestCase):
+
def setUp(self):
self.env = staging_env()
self.cfg = _sqlite_testing_config()
@@ -188,7 +189,9 @@ class VersionNameTemplateTest(unittest.TestCase):
class SourcelessVersioningTest(VersioningTest):
sourceless = True
+
class SourcelessNeedsFlagTest(unittest.TestCase):
+
def setUp(self):
self.env = staging_env(sourceless=False)
self.cfg = _sqlite_testing_config()