summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/autogenerate/api.py25
-rw-r--r--alembic/ddl/impl.py15
-rw-r--r--alembic/operations/batch.py78
-rw-r--r--alembic/operations/schemaobj.py12
-rw-r--r--alembic/runtime/migration.py124
-rw-r--r--alembic/script/base.py12
-rw-r--r--alembic/script/revision.py9
-rw-r--r--alembic/script/write_hooks.py21
-rw-r--r--alembic/testing/__init__.py10
-rw-r--r--alembic/testing/assertions.py100
-rw-r--r--alembic/testing/env.py36
-rw-r--r--alembic/testing/fixture_functions.py79
-rw-r--r--alembic/testing/fixtures.py85
-rw-r--r--alembic/testing/plugin/bootstrap.py31
-rw-r--r--alembic/testing/plugin/plugin_base.py125
-rw-r--r--alembic/testing/plugin/pytestplugin.py314
-rw-r--r--alembic/testing/requirements.py9
-rw-r--r--alembic/testing/util.py10
-rw-r--r--alembic/testing/warnings.py48
-rw-r--r--alembic/util/__init__.py2
-rw-r--r--alembic/util/compat.py62
-rw-r--r--alembic/util/langhelpers.py22
-rw-r--r--alembic/util/pyfiles.py3
-rw-r--r--alembic/util/sqla_compat.py52
-rw-r--r--docs/build/autogenerate.rst2
-rw-r--r--docs/build/unreleased/autocommit.rst21
-rw-r--r--setup.cfg70
-rw-r--r--setup.py45
-rwxr-xr-xtests/conftest.py12
-rw-r--r--tests/requirements.py6
-rw-r--r--tests/test_autogen_indexes.py6
-rw-r--r--tests/test_batch.py231
-rw-r--r--tests/test_bulk_insert.py27
-rw-r--r--tests/test_command.py4
-rw-r--r--tests/test_environment.py29
-rw-r--r--tests/test_impl.py45
-rw-r--r--tests/test_mysql.py5
-rw-r--r--tests/test_postgresql.py68
-rw-r--r--tests/test_script_consumption.py237
-rw-r--r--tests/test_sqlite.py4
-rw-r--r--tests/test_version_table.py259
-rw-r--r--tox.ini15
42 files changed, 1213 insertions, 1157 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 5d1e848..db5fe12 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -31,22 +31,23 @@ def compare_metadata(context, metadata):
from sqlalchemy.schema import SchemaItem
from sqlalchemy.types import TypeEngine
from sqlalchemy import (create_engine, MetaData, Column,
- Integer, String, Table)
+ Integer, String, Table, text)
import pprint
engine = create_engine("sqlite://")
- engine.execute('''
- create table foo (
- id integer not null primary key,
- old_data varchar,
- x integer
- )''')
-
- engine.execute('''
- create table bar (
- data varchar
- )''')
+ with engine.begin() as conn:
+ conn.execute(text('''
+ create table foo (
+ id integer not null primary key,
+ old_data varchar,
+ x integer
+ )'''))
+
+ conn.execute(text('''
+ create table bar (
+ data varchar
+ )'''))
metadata = MetaData()
Table('foo', metadata,
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 3674c67..923fd8b 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -140,7 +140,10 @@ class DefaultImpl(with_metaclass(ImplMeta)):
conn = self.connection
if execution_options:
conn = conn.execution_options(**execution_options)
- return conn.execute(construct, *multiparams, **params)
+ if params:
+ multiparams += (params,)
+
+ return conn.execute(construct, multiparams)
def execute(self, sql, execution_options=None):
self._exec(sql, execution_options)
@@ -316,7 +319,7 @@ class DefaultImpl(with_metaclass(ImplMeta)):
if self.as_sql:
for row in rows:
self._exec(
- table.insert(inline=True).values(
+ sqla_compat._insert_inline(table).values(
**dict(
(
k,
@@ -338,10 +341,14 @@ class DefaultImpl(with_metaclass(ImplMeta)):
table._autoincrement_column = None
if rows:
if multiinsert:
- self._exec(table.insert(inline=True), multiparams=rows)
+ self._exec(
+ sqla_compat._insert_inline(table), multiparams=rows
+ )
else:
for row in rows:
- self._exec(table.insert(inline=True).values(**row))
+ self._exec(
+ sqla_compat._insert_inline(table).values(**row)
+ )
def _tokenize_column_type(self, column):
definition = self.dialect.type_compiler.process(column.type).lower()
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
index 81e3b99..f029193 100644
--- a/alembic/operations/batch.py
+++ b/alembic/operations/batch.py
@@ -5,7 +5,6 @@ from sqlalchemy import Index
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema as sql_schema
-from sqlalchemy import select
from sqlalchemy import Table
from sqlalchemy import types as sqltypes
from sqlalchemy.events import SchemaEventTarget
@@ -14,9 +13,12 @@ from sqlalchemy.util import topological
from ..util import exc
from ..util.sqla_compat import _columns_for_constraint
+from ..util.sqla_compat import _ensure_scope_for_ddl
from ..util.sqla_compat import _fk_is_self_referential
+from ..util.sqla_compat import _insert_inline
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import _remove_column_from_collection
+from ..util.sqla_compat import _select
class BatchOperationsImpl(object):
@@ -76,44 +78,44 @@ class BatchOperationsImpl(object):
def flush(self):
should_recreate = self._should_recreate()
- if not should_recreate:
- for opname, arg, kw in self.batch:
- fn = getattr(self.operations.impl, opname)
- fn(*arg, **kw)
- else:
- if self.naming_convention:
- m1 = MetaData(naming_convention=self.naming_convention)
+ with _ensure_scope_for_ddl(self.impl.connection):
+ if not should_recreate:
+ for opname, arg, kw in self.batch:
+ fn = getattr(self.operations.impl, opname)
+ fn(*arg, **kw)
else:
- m1 = MetaData()
+ if self.naming_convention:
+ m1 = MetaData(naming_convention=self.naming_convention)
+ else:
+ m1 = MetaData()
- if self.copy_from is not None:
- existing_table = self.copy_from
- reflected = False
- else:
- existing_table = Table(
- self.table_name,
- m1,
- schema=self.schema,
- autoload=True,
- autoload_with=self.operations.get_bind(),
- *self.reflect_args,
- **self.reflect_kwargs
+ if self.copy_from is not None:
+ existing_table = self.copy_from
+ reflected = False
+ else:
+ existing_table = Table(
+ self.table_name,
+ m1,
+ schema=self.schema,
+ autoload_with=self.operations.get_bind(),
+ *self.reflect_args,
+ **self.reflect_kwargs
+ )
+ reflected = True
+
+ batch_impl = ApplyBatchImpl(
+ self.impl,
+ existing_table,
+ self.table_args,
+ self.table_kwargs,
+ reflected,
+ partial_reordering=self.partial_reordering,
)
- reflected = True
-
- batch_impl = ApplyBatchImpl(
- self.impl,
- existing_table,
- self.table_args,
- self.table_kwargs,
- reflected,
- partial_reordering=self.partial_reordering,
- )
- for opname, arg, kw in self.batch:
- fn = getattr(batch_impl, opname)
- fn(*arg, **kw)
+ for opname, arg, kw in self.batch:
+ fn = getattr(batch_impl, opname)
+ fn(*arg, **kw)
- batch_impl._create(self.impl)
+ batch_impl._create(self.impl)
def alter_column(self, *arg, **kw):
self.batch.append(("alter_column", arg, kw))
@@ -362,14 +364,14 @@ class ApplyBatchImpl(object):
try:
op_impl._exec(
- self.new_table.insert(inline=True).from_select(
+ _insert_inline(self.new_table).from_select(
list(
k
for k, transfer in self.column_transfers.items()
if "expr" in transfer
),
- select(
- [
+ _select(
+ *[
transfer["expr"]
for transfer in self.column_transfers.values()
if "expr" in transfer
diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py
index d90b5e6..5e8aa4f 100644
--- a/alembic/operations/schemaobj.py
+++ b/alembic/operations/schemaobj.py
@@ -3,6 +3,7 @@ from sqlalchemy.types import Integer
from sqlalchemy.types import NULLTYPE
from .. import util
+from ..util.compat import raise_
from ..util.compat import string_types
@@ -113,10 +114,13 @@ class SchemaObjects(object):
}
try:
const = types[type_]
- except KeyError:
- raise TypeError(
- "'type' can be one of %s"
- % ", ".join(sorted(repr(x) for x in types))
+ except KeyError as ke:
+ raise_(
+ TypeError(
+ "'type' can be one of %s"
+ % ", ".join(sorted(repr(x) for x in types))
+ ),
+ from_=ke,
)
else:
const = const(name=name)
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index 5c8590d..48bb842 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -8,7 +8,7 @@ from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import String
from sqlalchemy import Table
-from sqlalchemy.engine import Connection
+from sqlalchemy.engine import Engine
from sqlalchemy.engine import url as sqla_url
from sqlalchemy.engine.strategies import MockEngineStrategy
@@ -31,15 +31,18 @@ class _ProxyTransaction(object):
def rollback(self):
self._proxied_transaction.rollback()
+ self.migration_context._transaction = None
def commit(self):
self._proxied_transaction.commit()
+ self.migration_context._transaction = None
def __enter__(self):
return self
def __exit__(self, type_, value, traceback):
self._proxied_transaction.__exit__(type_, value, traceback)
+ self.migration_context._transaction = None
class MigrationContext(object):
@@ -105,8 +108,13 @@ class MigrationContext(object):
if as_sql:
self.connection = self._stdout_connection(connection)
assert self.connection is not None
+ self._in_external_transaction = False
else:
self.connection = connection
+ self._in_external_transaction = (
+ sqla_compat._get_connection_in_transaction(connection)
+ )
+
self._migrations_fn = opts.get("fn")
self.as_sql = as_sql
@@ -199,12 +207,11 @@ class MigrationContext(object):
dialect_opts = {}
if connection:
- if not isinstance(connection, Connection):
- util.warn(
+ if isinstance(connection, Engine):
+ raise util.CommandError(
"'connection' argument to configure() is expected "
"to be a sqlalchemy.engine.Connection instance, "
"got %r" % connection,
- stacklevel=3,
)
dialect = connection.dialect
@@ -268,19 +275,27 @@ class MigrationContext(object):
"""
_in_connection_transaction = self._in_connection_transaction()
- if self.impl.transactional_ddl:
- if self.as_sql:
- self.impl.emit_commit()
+ if self.impl.transactional_ddl and self.as_sql:
+ self.impl.emit_commit()
- elif _in_connection_transaction:
- assert self._transaction is not None
+ elif _in_connection_transaction:
+ assert self._transaction is not None
- self._transaction.commit()
- self._transaction = None
+ self._transaction.commit()
+ self._transaction = None
if not self.as_sql:
current_level = self.connection.get_isolation_level()
- self.connection.execution_options(isolation_level="AUTOCOMMIT")
+ base_connection = self.connection
+
+ # in 1.3 and 1.4 non-future mode, the connection gets switched
+ # out. we can use the base connection with the new mode
+ # except that it will not know it's in "autocommit" and will
+ # emit deprecation warnings when an autocommit action takes
+ # place.
+ self.connection = (
+ self.impl.connection
+ ) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
try:
yield
finally:
@@ -288,13 +303,13 @@ class MigrationContext(object):
self.connection.execution_options(
isolation_level=current_level
)
+ self.connection = self.impl.connection = base_connection
- if self.impl.transactional_ddl:
- if self.as_sql:
- self.impl.emit_begin()
+ if self.impl.transactional_ddl and self.as_sql:
+ self.impl.emit_begin()
- elif _in_connection_transaction:
- self._transaction = self.bind.begin()
+ elif _in_connection_transaction:
+ self._transaction = self.connection.begin()
def begin_transaction(self, _per_migration=False):
"""Begin a logical transaction for migration operations.
@@ -337,23 +352,50 @@ class MigrationContext(object):
:meth:`.MigrationContext.autocommit_block`
"""
- transaction_now = _per_migration == self._transaction_per_migration
- if not transaction_now:
+ @contextmanager
+ def do_nothing():
+ yield
- @contextmanager
- def do_nothing():
- yield
+ if self._in_external_transaction:
+ return do_nothing()
+ if self.impl.transactional_ddl:
+ transaction_now = _per_migration == self._transaction_per_migration
+ else:
+ transaction_now = _per_migration is True
+
+ if not transaction_now:
return do_nothing()
elif not self.impl.transactional_ddl:
+ assert _per_migration
- @contextmanager
- def do_nothing():
- yield
-
- return do_nothing()
+ if self.as_sql:
+ return do_nothing()
+ else:
+ # track our own notion of a "transaction block", which must be
+ # committed when complete. Don't rely upon whether or not the
+ # SQLAlchemy connection reports as "in transaction"; this
+ # because SQLAlchemy future connection features autobegin
+ # behavior, so it may already be in a transaction from our
+ # emitting of queries like "has_version_table", etc. While we
+ # could track these operations as well, that leaves open the
+ # possibility of new operations or other things happening in
+ # the user environment that still may be triggering
+ # "autobegin".
+
+ in_transaction = self._transaction is not None
+
+ if in_transaction:
+ return do_nothing()
+ else:
+ self._transaction = (
+ sqla_compat._safe_begin_connection_transaction(
+ self.connection
+ )
+ )
+ return _ProxyTransaction(self)
elif self.as_sql:
@contextmanager
@@ -364,7 +406,9 @@ class MigrationContext(object):
return begin_commit()
else:
- self._transaction = self.bind.begin()
+ self._transaction = sqla_compat._safe_begin_connection_transaction(
+ self.connection
+ )
return _ProxyTransaction(self)
def get_current_revision(self):
@@ -439,9 +483,10 @@ class MigrationContext(object):
)
def _ensure_version_table(self, purge=False):
- self._version.create(self.connection, checkfirst=True)
- if purge:
- self.connection.execute(self._version.delete())
+ with sqla_compat._ensure_scope_for_ddl(self.connection):
+ self._version.create(self.connection, checkfirst=True)
+ if purge:
+ self.connection.execute(self._version.delete())
def _has_version_table(self):
return sqla_compat._connectable_has_table(
@@ -504,12 +549,9 @@ class MigrationContext(object):
head_maintainer = HeadMaintainer(self, heads)
- starting_in_transaction = (
- not self.as_sql and self._in_connection_transaction()
- )
-
for step in self._migrations_fn(heads, self):
with self.begin_transaction(_per_migration=True):
+
if self.as_sql and not head_maintainer.heads:
# for offline mode, include a CREATE TABLE from
# the base
@@ -535,18 +577,6 @@ class MigrationContext(object):
run_args=kw,
)
- if (
- not starting_in_transaction
- and not self.as_sql
- and not self.impl.transactional_ddl
- and self._in_connection_transaction()
- ):
- raise util.CommandError(
- 'Migration "%s" has left an uncommitted '
- "transaction opened; transactional_ddl is False so "
- "Alembic is not committing transactions" % step
- )
-
if self.as_sql and not head_maintainer.heads:
self._version.drop(self.connection)
diff --git a/alembic/script/base.py b/alembic/script/base.py
index fea9e87..363895c 100644
--- a/alembic/script/base.py
+++ b/alembic/script/base.py
@@ -171,7 +171,7 @@ class ScriptDirectory(object):
"ancestor/descendant revisions along the same branch"
)
ancestor = ancestor % {"start": start, "end": end}
- compat.raise_from_cause(util.CommandError(ancestor))
+ compat.raise_(util.CommandError(ancestor), from_=rna)
except revision.MultipleHeads as mh:
if not multiple_heads:
multiple_heads = (
@@ -185,15 +185,15 @@ class ScriptDirectory(object):
"head_arg": end or mh.argument,
"heads": util.format_as_comma(mh.heads),
}
- compat.raise_from_cause(util.CommandError(multiple_heads))
+ compat.raise_(util.CommandError(multiple_heads), from_=mh)
except revision.ResolutionError as re:
if resolution is None:
resolution = "Can't locate revision identified by '%s'" % (
re.argument
)
- compat.raise_from_cause(util.CommandError(resolution))
+ compat.raise_(util.CommandError(resolution), from_=re)
except revision.RevisionError as err:
- compat.raise_from_cause(util.CommandError(err.args[0]))
+ compat.raise_(util.CommandError(err.args[0]), from_=err)
def walk_revisions(self, base="base", head="heads"):
"""Iterate through all revisions.
@@ -571,7 +571,7 @@ class ScriptDirectory(object):
try:
Script.verify_rev_id(revid)
except revision.RevisionError as err:
- compat.raise_from_cause(util.CommandError(err.args[0]))
+ compat.raise_(util.CommandError(err.args[0]), from_=err)
with self._catch_revision_errors(
multiple_heads=(
@@ -659,7 +659,7 @@ class ScriptDirectory(object):
try:
script = Script._from_path(self, path)
except revision.RevisionError as err:
- compat.raise_from_cause(util.CommandError(err.args[0]))
+ compat.raise_(util.CommandError(err.args[0]), from_=err)
if branch_labels and not script.branch_labels:
raise util.CommandError(
"Version %s specified branch_labels %s, however the "
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
index 683d322..c75d1c0 100644
--- a/alembic/script/revision.py
+++ b/alembic/script/revision.py
@@ -422,9 +422,12 @@ class RevisionMap(object):
except KeyError:
try:
nonbranch_rev = self._revision_for_ident(branch_label)
- except ResolutionError:
- raise ResolutionError(
- "No such branch: '%s'" % branch_label, branch_label
+ except ResolutionError as re:
+ util.raise_(
+ ResolutionError(
+ "No such branch: '%s'" % branch_label, branch_label
+ ),
+ from_=re,
)
else:
return nonbranch_rev
diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py
index 7d0843b..d6d1d38 100644
--- a/alembic/script/write_hooks.py
+++ b/alembic/script/write_hooks.py
@@ -39,9 +39,10 @@ def _invoke(name, revision, options):
"""
try:
hook = _registry[name]
- except KeyError:
- compat.raise_from_cause(
- util.CommandError("No formatter with name '%s' registered" % name)
+ except KeyError as ke:
+ compat.raise_(
+ util.CommandError("No formatter with name '%s' registered" % name),
+ from_=ke,
)
else:
return hook(revision, options)
@@ -65,12 +66,13 @@ def _run_hooks(path, hook_config):
opts["_hook_name"] = name
try:
type_ = opts["type"]
- except KeyError:
- compat.raise_from_cause(
+ except KeyError as ke:
+ compat.raise_(
util.CommandError(
"Key %s.type is required for post write hook %r"
% (name, name)
- )
+ ),
+ from_=ke,
)
else:
util.status(
@@ -89,12 +91,13 @@ def console_scripts(path, options):
try:
entrypoint_name = options["entrypoint"]
- except KeyError:
- compat.raise_from_cause(
+ except KeyError as ke:
+ compat.raise_(
util.CommandError(
"Key %s.entrypoint is required for post write hook %r"
% (options["_hook_name"], options["_hook_name"])
- )
+ ),
+ from_=ke,
)
iter_ = pkg_resources.iter_entry_points("console_scripts", entrypoint_name)
impl = next(iter_)
diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py
index 23c0f19..5f497a6 100644
--- a/alembic/testing/__init__.py
+++ b/alembic/testing/__init__.py
@@ -1,10 +1,12 @@
from sqlalchemy.testing import config # noqa
-from sqlalchemy.testing import exclusions # noqa
from sqlalchemy.testing import emits_warning # noqa
from sqlalchemy.testing import engines # noqa
+from sqlalchemy.testing import exclusions # noqa
from sqlalchemy.testing import mock # noqa
from sqlalchemy.testing import provide_metadata # noqa
from sqlalchemy.testing import uses_deprecated # noqa
+from sqlalchemy.testing.config import combinations # noqa
+from sqlalchemy.testing.config import fixture # noqa
from sqlalchemy.testing.config import requirements as requires # noqa
from alembic import util # noqa
@@ -13,12 +15,14 @@ from .assertions import assert_raises_message # noqa
from .assertions import emits_python_deprecation_warning # noqa
from .assertions import eq_ # noqa
from .assertions import eq_ignore_whitespace # noqa
+from .assertions import expect_raises # noqa
+from .assertions import expect_raises_message # noqa
+from .assertions import expect_sqlalchemy_deprecated # noqa
+from .assertions import expect_sqlalchemy_deprecated_20 # noqa
from .assertions import is_ # noqa
from .assertions import is_false # noqa
from .assertions import is_not_ # noqa
from .assertions import is_true # noqa
from .assertions import ne_ # noqa
-from .fixture_functions import combinations # noqa
-from .fixture_functions import fixture # noqa
from .fixtures import TestBase # noqa
from .util import resolve_lambda # noqa
diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py
index b09e09f..6d39f4c 100644
--- a/alembic/testing/assertions.py
+++ b/alembic/testing/assertions.py
@@ -1,7 +1,10 @@
from __future__ import absolute_import
+import contextlib
import re
+import sys
+from sqlalchemy import exc as sa_exc
from sqlalchemy import util
from sqlalchemy.engine import default
from sqlalchemy.testing.assertions import _expect_warnings
@@ -17,27 +20,92 @@ from ..util import sqla_compat
from ..util.compat import py3k
+def _assert_proper_exception_context(exception):
+ """assert that any exception we're catching does not have a __context__
+ without a __cause__, and that __suppress_context__ is never set.
+
+ Python 3 will report nested as exceptions as "during the handling of
+ error X, error Y occurred". That's not what we want to do. we want
+ these exceptions in a cause chain.
+
+ """
+
+ if not util.py3k:
+ return
+
+ if (
+ exception.__context__ is not exception.__cause__
+ and not exception.__suppress_context__
+ ):
+ assert False, (
+ "Exception %r was correctly raised but did not set a cause, "
+ "within context %r as its cause."
+ % (exception, exception.__context__)
+ )
+
+
def assert_raises(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ return _assert_raises(
+ except_cls, callable_, args, kwargs, msg=msg, check_context=True
+ )
+
+
+def assert_raises_message_context_ok(
+ except_cls, msg, callable_, *args, **kwargs
+):
+ return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+ except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
+
+ with _expect_raises(except_cls, msg, check_context) as ec:
+ callable_(*args, **kwargs)
+ return ec.error
+
+
+class _ErrorContainer(object):
+ error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+ ec = _ErrorContainer()
+ if check_context:
+ are_we_already_in_a_traceback = sys.exc_info()[0]
try:
- callable_(*args, **kw)
+ yield ec
success = False
- except except_cls:
+ except except_cls as err:
+ ec.error = err
success = True
+ if msg is not None:
+ assert re.search(
+ msg, util.text_type(err), re.UNICODE
+ ), "%r !~ %s" % (msg, err)
+ if check_context and not are_we_already_in_a_traceback:
+ _assert_proper_exception_context(err)
+ print(util.text_type(err).encode("utf-8"))
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
-def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
- try:
- callable_(*args, **kwargs)
- assert False, "Callable did not raise an exception"
- except except_cls as e:
- assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
- msg,
- e,
- )
- print(util.text_type(e).encode("utf-8"))
+def expect_raises(except_cls, check_context=True):
+ return _expect_raises(except_cls, check_context=check_context)
+
+
+def expect_raises_message(except_cls, msg, check_context=True):
+ return _expect_raises(except_cls, msg=msg, check_context=check_context)
def eq_ignore_whitespace(a, b, msg=None):
@@ -106,3 +174,11 @@ def emits_python_deprecation_warning(*messages):
return fn(*args, **kw)
return decorate
+
+
+def expect_sqlalchemy_deprecated(*messages, **kw):
+ return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
+
+
+def expect_sqlalchemy_deprecated_20(*messages, **kw):
+ return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)
diff --git a/alembic/testing/env.py b/alembic/testing/env.py
index 473c73e..62b74ec 100644
--- a/alembic/testing/env.py
+++ b/alembic/testing/env.py
@@ -4,9 +4,10 @@ import os
import shutil
import textwrap
-from sqlalchemy.testing import engines
+from sqlalchemy.testing import config
from sqlalchemy.testing import provision
+from . import util as testing_util
from .. import util
from ..script import Script
from ..script import ScriptDirectory
@@ -93,25 +94,28 @@ config = context.config
f.write(txt)
-def _sqlite_file_db(tempname="foo.db"):
+def _sqlite_file_db(tempname="foo.db", future=False):
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/%s" % (dir_, tempname)
- return engines.testing_engine(url=url)
+ return testing_util.testing_engine(url=url, future=future)
-def _sqlite_testing_config(sourceless=False):
+def _sqlite_testing_config(sourceless=False, future=False):
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/foo.db" % dir_
+ sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
+
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s
sourceless = %s
+%s
[loggers]
-keys = root
+keys = root,sqlalchemy
[handlers]
keys = console
@@ -121,6 +125,11 @@ level = WARN
handlers = console
qualname =
+[logger_sqlalchemy]
+level = DEBUG
+handlers =
+qualname = sqlalchemy.engine
+
[handler_console]
class = StreamHandler
args = (sys.stderr,)
@@ -134,12 +143,19 @@ keys = generic
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
- % (dir_, url, "true" if sourceless else "false")
+ % (
+ dir_,
+ url,
+ "true" if sourceless else "false",
+ "sqlalchemy.future = true" if sqlalchemy_future else "",
+ )
)
def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
dir_ = os.path.join(_get_staging_directory(), "scripts")
+ sqlalchemy_future = "future" in config.db.__class__.__module__
+
url = "sqlite:///%s/foo.db" % dir_
return _write_config_file(
@@ -147,6 +163,7 @@ def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
[alembic]
script_location = %s
sqlalchemy.url = %s
+sqlalchemy.future = %s
sourceless = %s
version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
@@ -177,6 +194,7 @@ datefmt = %%H:%%M:%%S
% (
dir_,
url,
+ "true" if sqlalchemy_future else "false",
"true" if sourceless else "false",
extra_version_location,
)
@@ -463,6 +481,8 @@ def _multidb_testing_config(engines):
dir_ = os.path.join(_get_staging_directory(), "scripts")
+ sqlalchemy_future = "future" in config.db.__class__.__module__
+
databases = ", ".join(engines.keys())
engines = "\n\n".join(
"[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
@@ -474,7 +494,7 @@ def _multidb_testing_config(engines):
[alembic]
script_location = %s
sourceless = false
-
+sqlalchemy.future = %s
databases = %s
%s
@@ -502,5 +522,5 @@ keys = generic
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
- % (dir_, databases, engines)
+ % (dir_, "true" if sqlalchemy_future else "false", databases, engines)
)
diff --git a/alembic/testing/fixture_functions.py b/alembic/testing/fixture_functions.py
deleted file mode 100644
index 2640693..0000000
--- a/alembic/testing/fixture_functions.py
+++ /dev/null
@@ -1,79 +0,0 @@
-_fixture_functions = None # installed by plugin_base
-
-
-def combinations(*comb, **kw):
- r"""Deliver multiple versions of a test based on positional combinations.
-
- This is a facade over pytest.mark.parametrize.
-
-
- :param \*comb: argument combinations. These are tuples that will be passed
- positionally to the decorated function.
-
- :param argnames: optional list of argument names. These are the names
- of the arguments in the test function that correspond to the entries
- in each argument tuple. pytest.mark.parametrize requires this, however
- the combinations function will derive it automatically if not present
- by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
- first argument is "self" which is discarded.
-
- :param id\_: optional id template. This is a string template that
- describes how the "id" for each parameter set should be defined, if any.
- The number of characters in the template should match the number of
- entries in each argument tuple. Each character describes how the
- corresponding entry in the argument tuple should be handled, as far as
- whether or not it is included in the arguments passed to the function, as
- well as if it is included in the tokens used to create the id of the
- parameter set.
-
- If omitted, the argument combinations are passed to parametrize as is. If
- passed, each argument combination is turned into a pytest.param() object,
- mapping the elements of the argument tuple to produce an id based on a
- character value in the same position within the string template using the
- following scheme::
-
- i - the given argument is a string that is part of the id only, don't
- pass it as an argument
-
- n - the given argument should be passed and it should be added to the
- id by calling the .__name__ attribute
-
- r - the given argument should be passed and it should be added to the
- id by calling repr()
-
- s - the given argument should be passed and it should be added to the
- id by calling str()
-
- a - (argument) the given argument should be passed and it should not
- be used to generated the id
-
- e.g.::
-
- @testing.combinations(
- (operator.eq, "eq"),
- (operator.ne, "ne"),
- (operator.gt, "gt"),
- (operator.lt, "lt"),
- id_="na"
- )
- def test_operator(self, opfunc, name):
- pass
-
- The above combination will call ``.__name__`` on the first member of
- each tuple and use that as the "id" to pytest.param().
-
-
- """
- return _fixture_functions.combinations(*comb, **kw)
-
-
-def fixture(*arg, **kw):
- return _fixture_functions.fixture(*arg, **kw)
-
-
-def get_current_test_name():
- return _fixture_functions.get_current_test_name()
-
-
-def skip_test(msg):
- raise _fixture_functions.skip_test_exception(msg)
diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py
index dd1d2f1..d5d45ac 100644
--- a/alembic/testing/fixtures.py
+++ b/alembic/testing/fixtures.py
@@ -8,11 +8,13 @@ from sqlalchemy import inspect
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
+from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.testing import config
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import eq_
-from sqlalchemy.testing.fixtures import TestBase # noqa
+from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
+from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
import alembic
from .assertions import _get_dialect
@@ -20,16 +22,53 @@ from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
from ..util import compat
+from ..util import sqla_compat
from ..util.compat import configparser
from ..util.compat import string_types
from ..util.compat import text_type
from ..util.sqla_compat import create_mock_engine
from ..util.sqla_compat import sqla_14
+
testing_config = configparser.ConfigParser()
testing_config.read(["test.cfg"])
+class TestBase(SQLAlchemyTestBase):
+ is_sqlalchemy_future = False
+
+ @testing.fixture()
+ def ops_context(self, migration_context):
+ with migration_context.begin_transaction(_per_migration=True):
+ yield Operations(migration_context)
+
+ @testing.fixture
+ def migration_context(self, connection):
+ return MigrationContext.configure(
+ connection, opts=dict(transaction_per_migration=True)
+ )
+
+ @testing.fixture
+ def connection(self):
+ with config.db.connect() as conn:
+ yield conn
+
+
+class TablesTest(TestBase, SQLAlchemyTablesTest):
+ pass
+
+
+if sqla_14:
+ from sqlalchemy.testing.fixtures import FutureEngineMixin
+else:
+
+ class FutureEngineMixin(object):
+ __requires__ = ("sqlalchemy_14",)
+
+
+FutureEngineMixin.is_sqlalchemy_future = True
+
+
def capture_db(dialect="postgresql://"):
buf = []
@@ -205,7 +244,8 @@ class AlterColRoundTripFixture(object):
), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
def tearDown(self):
- self.metadata.drop_all(self.conn)
+ with self.conn.begin():
+ self.metadata.drop_all(self.conn)
self.conn.close()
def _run_alter_col(self, from_, to_, compare=None):
@@ -218,26 +258,27 @@ class AlterColRoundTripFixture(object):
)
t = Table("x", self.metadata, column)
- t.create(self.conn)
- insp = inspect(self.conn)
- old_col = insp.get_columns("x")[0]
-
- # TODO: conditional comment support
- self.op.alter_column(
- "x",
- column.name,
- existing_type=column.type,
- existing_server_default=column.server_default
- if column.server_default is not None
- else False,
- existing_nullable=True if column.nullable else False,
- # existing_comment=column.comment,
- nullable=to_.get("nullable", None),
- # modify_comment=False,
- server_default=to_.get("server_default", False),
- new_column_name=to_.get("name", None),
- type_=to_.get("type", None),
- )
+ with sqla_compat._ensure_scope_for_ddl(self.conn):
+ t.create(self.conn)
+ insp = inspect(self.conn)
+ old_col = insp.get_columns("x")[0]
+
+ # TODO: conditional comment support
+ self.op.alter_column(
+ "x",
+ column.name,
+ existing_type=column.type,
+ existing_server_default=column.server_default
+ if column.server_default is not None
+ else False,
+ existing_nullable=True if column.nullable else False,
+ # existing_comment=column.comment,
+ nullable=to_.get("nullable", None),
+ # modify_comment=False,
+ server_default=to_.get("server_default", False),
+ new_column_name=to_.get("name", None),
+ type_=to_.get("type", None),
+ )
insp = inspect(self.conn)
new_col = insp.get_columns("x")[0]
diff --git a/alembic/testing/plugin/bootstrap.py b/alembic/testing/plugin/bootstrap.py
index 8200ec1..d4a2c55 100644
--- a/alembic/testing/plugin/bootstrap.py
+++ b/alembic/testing/plugin/bootstrap.py
@@ -1,35 +1,4 @@
"""
Bootstrapper for test framework plugins.
-This is vendored from SQLAlchemy so that we can use local overrides
-for plugin_base.py and pytestplugin.py.
-
"""
-
-
-import os
-import sys
-
-
-bootstrap_file = locals()["bootstrap_file"]
-to_bootstrap = locals()["to_bootstrap"]
-
-
-def load_file_as_module(name):
- path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
- if sys.version_info >= (3, 3):
- from importlib import machinery
-
- mod = machinery.SourceFileLoader(name, path).load_module()
- else:
- import imp
-
- mod = imp.load_source(name, path)
- return mod
-
-
-if to_bootstrap == "pytest":
- sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
- sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
-else:
- raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py
deleted file mode 100644
index 2d5e95a..0000000
--- a/alembic/testing/plugin/plugin_base.py
+++ /dev/null
@@ -1,125 +0,0 @@
-"""vendored plugin_base functions from the most recent SQLAlchemy versions.
-
-Alembic tests need to run on older versions of SQLAlchemy that don't
-necessarily have all the latest testing fixtures.
-
-"""
-from __future__ import absolute_import
-
-import abc
-import sys
-
-from sqlalchemy.testing.plugin.plugin_base import * # noqa
-from sqlalchemy.testing.plugin.plugin_base import post
-from sqlalchemy.testing.plugin.plugin_base import post_begin as sqla_post_begin
-from sqlalchemy.testing.plugin.plugin_base import stop_test_class as sqla_stc
-
-py3k = sys.version_info >= (3, 0)
-
-
-if py3k:
-
- ABC = abc.ABC
-else:
-
- class ABC(object):
- __metaclass__ = abc.ABCMeta
-
-
-def post_begin():
- sqla_post_begin()
-
- import warnings
-
- try:
- import pytest
- except ImportError:
- pass
- else:
- warnings.filterwarnings(
- "once", category=pytest.PytestDeprecationWarning
- )
-
- from sqlalchemy import exc
-
- if hasattr(exc, "RemovedIn20Warning"):
- warnings.filterwarnings(
- "error",
- category=exc.RemovedIn20Warning,
- message=".*Engine.execute",
- )
- warnings.filterwarnings(
- "error",
- category=exc.RemovedIn20Warning,
- message=".*Passing a string",
- )
-
-
-# override selected SQLAlchemy pytest hooks with vendored functionality
-def stop_test_class(cls):
- sqla_stc(cls)
- import os
- from alembic.testing.env import _get_staging_directory
-
- assert not os.path.exists(_get_staging_directory()), (
- "staging directory %s was not cleaned up" % _get_staging_directory()
- )
-
-
-def want_class(name, cls):
- from sqlalchemy.testing import config
- from sqlalchemy.testing import fixtures
-
- if not issubclass(cls, fixtures.TestBase):
- return False
- elif name.startswith("_"):
- return False
- elif (
- config.options.backend_only
- and not getattr(cls, "__backend__", False)
- and not getattr(cls, "__sparse_backend__", False)
- ):
- return False
- else:
- return True
-
-
-@post
-def _init_symbols(options, file_config):
- from sqlalchemy.testing import config
- from alembic.testing import fixture_functions as alembic_config
-
- config._fixture_functions = (
- alembic_config._fixture_functions
- ) = _fixture_fn_class()
-
-
-class FixtureFunctions(ABC):
- @abc.abstractmethod
- def skip_test_exception(self, *arg, **kw):
- raise NotImplementedError()
-
- @abc.abstractmethod
- def combinations(self, *args, **kw):
- raise NotImplementedError()
-
- @abc.abstractmethod
- def param_ident(self, *args, **kw):
- raise NotImplementedError()
-
- @abc.abstractmethod
- def fixture(self, *arg, **kw):
- raise NotImplementedError()
-
- def get_current_test_name(self):
- raise NotImplementedError()
-
-
-_fixture_fn_class = None
-
-
-def set_fixture_functions(fixture_fn_class):
- from sqlalchemy.testing.plugin import plugin_base
-
- global _fixture_fn_class
- _fixture_fn_class = plugin_base._fixture_fn_class = fixture_fn_class
diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py
deleted file mode 100644
index 6b76a17..0000000
--- a/alembic/testing/plugin/pytestplugin.py
+++ /dev/null
@@ -1,314 +0,0 @@
-"""vendored pytestplugin functions from the most recent SQLAlchemy versions.
-
-Alembic tests need to run on older versions of SQLAlchemy that don't
-necessarily have all the latest testing fixtures.
-
-"""
-try:
- # installed by bootstrap.py
- import sqla_plugin_base as plugin_base
-except ImportError:
- # assume we're a package, use traditional import
- from . import plugin_base
-
-from functools import update_wrapper
-import inspect
-import itertools
-import operator
-import os
-import re
-import sys
-
-import pytest
-from sqlalchemy.testing.plugin.pytestplugin import * # noqa
-from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
-
-py3k = sys.version_info.major >= 3
-
-if py3k:
- from typing import TYPE_CHECKING
-else:
- TYPE_CHECKING = False
-
-if TYPE_CHECKING:
- from typing import Sequence
-
-
-# override selected SQLAlchemy pytest hooks with vendored functionality
-def pytest_configure(config):
- spc(config)
-
- plugin_base.set_fixture_functions(PytestFixtureFunctions)
-
-
-def pytest_pycollect_makeitem(collector, name, obj):
-
- if inspect.isclass(obj) and plugin_base.want_class(name, obj):
- ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-
- return [
- ctor(name=parametrize_cls.__name__, parent=collector)
- for parametrize_cls in _parametrize_cls(collector.module, obj)
- ]
- elif (
- inspect.isfunction(obj)
- and isinstance(collector, pytest.Instance)
- and plugin_base.want_method(collector.cls, obj)
- ):
- # None means, fall back to default logic, which includes
- # method-level parametrize
- return None
- else:
- # empty list means skip this item
- return []
-
-
-_current_class = None
-
-
-def _parametrize_cls(module, cls):
- """implement a class-based version of pytest parametrize."""
-
- if "_sa_parametrize" not in cls.__dict__:
- return [cls]
-
- _sa_parametrize = cls._sa_parametrize
- classes = []
- for full_param_set in itertools.product(
- *[params for argname, params in _sa_parametrize]
- ):
- cls_variables = {}
-
- for argname, param in zip(
- [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
- ):
- if not argname:
- raise TypeError("need argnames for class-based combinations")
- argname_split = re.split(r",\s*", argname)
- for arg, val in zip(argname_split, param.values):
- cls_variables[arg] = val
- parametrized_name = "_".join(
- # token is a string, but in py2k py.test is giving us a unicode,
- # so call str() on it.
- str(re.sub(r"\W", "", token))
- for param in full_param_set
- for token in param.id.split("-")
- )
- name = "%s_%s" % (cls.__name__, parametrized_name)
- newcls = type.__new__(type, name, (cls,), cls_variables)
- setattr(module, name, newcls)
- classes.append(newcls)
- return classes
-
-
-def getargspec(fn):
- if sys.version_info.major == 3:
- return inspect.getfullargspec(fn)
- else:
- return inspect.getargspec(fn)
-
-
-def _pytest_fn_decorator(target):
- """Port of langhelpers.decorator with pytest-specific tricks."""
- # from sqlalchemy rel_1_3_14
-
- from sqlalchemy.util.langhelpers import format_argspec_plus
- from sqlalchemy.util.compat import inspect_getfullargspec
-
- def _exec_code_in_env(code, env, fn_name):
- exec(code, env)
- return env[fn_name]
-
- def decorate(fn, add_positional_parameters=()):
-
- spec = inspect_getfullargspec(fn)
- if add_positional_parameters:
- spec.args.extend(add_positional_parameters)
-
- metadata = dict(target="target", fn="__fn", name=fn.__name__)
- metadata.update(format_argspec_plus(spec, grouped=False))
- code = (
- """\
-def %(name)s(%(args)s):
- return %(target)s(%(fn)s, %(apply_kw)s)
-"""
- % metadata
- )
- decorated = _exec_code_in_env(
- code, {"target": target, "__fn": fn}, fn.__name__
- )
- if not add_positional_parameters:
- decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
- decorated.__wrapped__ = fn
- return update_wrapper(decorated, fn)
- else:
- # this is the pytest hacky part. don't do a full update wrapper
- # because pytest is really being sneaky about finding the args
- # for the wrapped function
- decorated.__module__ = fn.__module__
- decorated.__name__ = fn.__name__
- return decorated
-
- return decorate
-
-
-class PytestFixtureFunctions(plugin_base.FixtureFunctions):
- def skip_test_exception(self, *arg, **kw):
- return pytest.skip.Exception(*arg, **kw)
-
- _combination_id_fns = {
- "i": lambda obj: obj,
- "r": repr,
- "s": str,
- "n": operator.attrgetter("__name__"),
- }
-
- def combinations(self, *arg_sets, **kw):
- """Facade for pytest.mark.parametrize.
-
- Automatically derives argument names from the callable which in our
- case is always a method on a class with positional arguments.
-
- ids for parameter sets are derived using an optional template.
-
- """
- # from sqlalchemy rel_1_3_14
- from alembic.testing import exclusions
-
- if sys.version_info.major == 3:
- if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
- arg_sets = list(arg_sets[0])
- else:
- if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
- arg_sets = list(arg_sets[0])
-
- argnames = kw.pop("argnames", None)
-
- def _filter_exclusions(args):
- result = []
- gathered_exclusions = []
- for a in args:
- if isinstance(a, exclusions.compound):
- gathered_exclusions.append(a)
- else:
- result.append(a)
-
- return result, gathered_exclusions
-
- id_ = kw.pop("id_", None)
-
- tobuild_pytest_params = []
- has_exclusions = False
- if id_:
- _combination_id_fns = self._combination_id_fns
-
- # because itemgetter is not consistent for one argument vs.
- # multiple, make it multiple in all cases and use a slice
- # to omit the first argument
- _arg_getter = operator.itemgetter(
- 0,
- *[
- idx
- for idx, char in enumerate(id_)
- if char in ("n", "r", "s", "a")
- ]
- )
- fns = [
- (operator.itemgetter(idx), _combination_id_fns[char])
- for idx, char in enumerate(id_)
- if char in _combination_id_fns
- ]
-
- for arg in arg_sets:
- if not isinstance(arg, tuple):
- arg = (arg,)
-
- fn_params, param_exclusions = _filter_exclusions(arg)
-
- parameters = _arg_getter(fn_params)[1:]
-
- if param_exclusions:
- has_exclusions = True
-
- tobuild_pytest_params.append(
- (
- parameters,
- param_exclusions,
- "-".join(
- comb_fn(getter(arg)) for getter, comb_fn in fns
- ),
- )
- )
-
- else:
-
- for arg in arg_sets:
- if not isinstance(arg, tuple):
- arg = (arg,)
-
- fn_params, param_exclusions = _filter_exclusions(arg)
-
- if param_exclusions:
- has_exclusions = True
-
- tobuild_pytest_params.append(
- (fn_params, param_exclusions, None)
- )
-
- pytest_params = []
- for parameters, param_exclusions, id_ in tobuild_pytest_params:
- if has_exclusions:
- parameters += (param_exclusions,)
-
- param = pytest.param(*parameters, id=id_)
- pytest_params.append(param)
-
- def decorate(fn):
- if inspect.isclass(fn):
- if has_exclusions:
- raise NotImplementedError(
- "exclusions not supported for class level combinations"
- )
- if "_sa_parametrize" not in fn.__dict__:
- fn._sa_parametrize = []
- fn._sa_parametrize.append((argnames, pytest_params))
- return fn
- else:
- if argnames is None:
- _argnames = getargspec(fn).args[1:] # type: Sequence(str)
- else:
- _argnames = re.split(
- r", *", argnames
- ) # type: Sequence(str)
-
- if has_exclusions:
- _argnames += ["_exclusions"]
-
- @_pytest_fn_decorator
- def check_exclusions(fn, *args, **kw):
- _exclusions = args[-1]
- if _exclusions:
- exlu = exclusions.compound().add(*_exclusions)
- fn = exlu(fn)
- return fn(*args[0:-1], **kw)
-
- def process_metadata(spec):
- spec.args.append("_exclusions")
-
- fn = check_exclusions(
- fn, add_positional_parameters=("_exclusions",)
- )
-
- return pytest.mark.parametrize(_argnames, pytest_params)(fn)
-
- return decorate
-
- def param_ident(self, *parameters):
- ident = parameters[0]
- return pytest.param(*parameters[1:], id=ident)
-
- def fixture(self, *arg, **kw):
- return pytest.fixture(*arg, **kw)
-
- def get_current_test_name(self):
- return os.environ.get("PYTEST_CURRENT_TEST")
diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py
index 456002d..2de8a71 100644
--- a/alembic/testing/requirements.py
+++ b/alembic/testing/requirements.py
@@ -44,6 +44,15 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(doesnt_have_check_uq_constraints)
@property
+ def sequences(self):
+ """Target database must support SEQUENCEs."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
+
+ @property
def foreign_key_match(self):
return exclusions.open()
diff --git a/alembic/testing/util.py b/alembic/testing/util.py
index 3e76645..ccabf9c 100644
--- a/alembic/testing/util.py
+++ b/alembic/testing/util.py
@@ -95,3 +95,13 @@ def metadata_fixture(ddl="function"):
return fixture_functions.fixture(scope=ddl)(run_ddl)
return decorate
+
+
+def testing_engine(url=None, options=None, future=False):
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing.engines import testing_engine
+
+ if not future:
+ future = getattr(config._current.options, "future_engine", False)
+ kw = {"future": future} if future else {}
+ return testing_engine(url, options, **kw)
diff --git a/alembic/testing/warnings.py b/alembic/testing/warnings.py
new file mode 100644
index 0000000..0182032
--- /dev/null
+++ b/alembic/testing/warnings.py
@@ -0,0 +1,48 @@
+# testing/warnings.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import warnings
+
+from sqlalchemy import exc as sa_exc
+
+
+def setup_filters():
+ """Set global warning behavior for the test suite."""
+
+ warnings.resetwarnings()
+
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
+
+ # some selected deprecations...
+ warnings.filterwarnings("error", category=DeprecationWarning)
+ try:
+ import pytest
+ except ImportError:
+ pass
+ else:
+ warnings.filterwarnings(
+ "once", category=pytest.PytestDeprecationWarning
+ )
+
+ if hasattr(sa_exc, "RemovedIn20Warning"):
+ for msg in [
+ #
+ # Core execution - need to remove this after SQLAlchemy
+ # repairs it in provisioning
+ #
+ r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept "
+ "parameters as a single dictionary or a single sequence of "
+ "dictionaries only.",
+ ]:
+ warnings.filterwarnings(
+ "ignore",
+ message=msg,
+ category=sa_exc.RemovedIn20Warning,
+ )
diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py
index 141ba45..cfdab49 100644
--- a/alembic/util/__init__.py
+++ b/alembic/util/__init__.py
@@ -1,4 +1,4 @@
-from .compat import raise_from_cause # noqa
+from .compat import raise_ # noqa
from .exc import CommandError
from .langhelpers import _with_legacy_names # noqa
from .langhelpers import asbool # noqa
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index f5a04ef..c8919b6 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -261,34 +261,54 @@ def with_metaclass(meta, base=object):
if py3k:
- def reraise(tp, value, tb=None, cause=None):
- if cause is not None:
- value.__cause__ = cause
- if value.__traceback__ is not tb:
- raise value.with_traceback(tb)
- raise value
+ def raise_(
+ exception, with_traceback=None, replace_context=None, from_=False
+ ):
+ r"""implement "raise" with cause support.
+
+ :param exception: exception to raise
+ :param with_traceback: will call exception.with_traceback()
+ :param replace_context: an as-yet-unsupported feature. This is
+ an exception object which we are "replacing", e.g., it's our
+ "cause" but we don't want it printed. Basically just what
+ ``__suppress_context__`` does but we don't want to suppress
+ the enclosing context, if any. So for now we make it the
+ cause.
+ :param from\_: the cause. this actually sets the cause and doesn't
+ hope to hide it someday.
- def raise_from_cause(exception, exc_info=None):
- if exc_info is None:
- exc_info = sys.exc_info()
- exc_type, exc_value, exc_tb = exc_info
- reraise(type(exception), exception, tb=exc_tb, cause=exc_value)
+ """
+ if with_traceback is not None:
+ exception = exception.with_traceback(with_traceback)
+
+ if from_ is not False:
+ exception.__cause__ = from_
+ elif replace_context is not None:
+ # no good solution here, we would like to have the exception
+ # have only the context of replace_context.__context__ so that the
+ # intermediary exception does not change, but we can't figure
+ # that out.
+ exception.__cause__ = replace_context
+
+ try:
+ raise exception
+ finally:
+ # credit to
+ # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
+ # as the __traceback__ object creates a cycle
+ del exception, replace_context, from_, with_traceback
else:
exec(
- "def reraise(tp, value, tb=None, cause=None):\n"
- " raise tp, value, tb\n"
+ "def raise_(exception, with_traceback=None, replace_context=None, "
+ "from_=False):\n"
+ " if with_traceback:\n"
+ " raise type(exception), exception, with_traceback\n"
+ " else:\n"
+ " raise exception\n"
)
- def raise_from_cause(exception, exc_info=None):
- # not as nice as that of Py3K, but at least preserves
- # the code line where the issue occurred
- if exc_info is None:
- exc_info = sys.exc_info()
- exc_type, exc_value, exc_tb = exc_info
- reraise(type(exception), exception, tb=exc_tb)
-
# produce a wrapper that allows encoded text to stream
# into a given buffer, but doesn't close it.
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index bb9c8f5..cc07f4b 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -7,6 +7,7 @@ from .compat import callable
from .compat import collections_abc
from .compat import exec_
from .compat import inspect_getargspec
+from .compat import raise_
from .compat import string_types
from .compat import with_metaclass
@@ -74,13 +75,16 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
def _create_method_proxy(cls, name, globals_, locals_):
fn = getattr(cls, name)
- def _name_error(name):
- raise NameError(
- "Can't invoke function '%s', as the proxy object has "
- "not yet been "
- "established for the Alembic '%s' class. "
- "Try placing this code inside a callable."
- % (name, cls.__name__)
+ def _name_error(name, from_):
+ raise_(
+ NameError(
+ "Can't invoke function '%s', as the proxy object has "
+ "not yet been "
+ "established for the Alembic '%s' class. "
+ "Try placing this code inside a callable."
+ % (name, cls.__name__)
+ ),
+ from_=from_,
)
globals_["_name_error"] = _name_error
@@ -142,8 +146,8 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
%(translate)s
try:
p = _proxy
- except NameError:
- _name_error('%(name)s')
+ except NameError as ne:
+ _name_error('%(name)s', ne)
return _proxy.%(name)s(%(apply_kw)s)
e
"""
diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py
index b65df2c..8dfb9db 100644
--- a/alembic/util/pyfiles.py
+++ b/alembic/util/pyfiles.py
@@ -10,6 +10,7 @@ from .compat import has_pep3147
from .compat import load_module_py
from .compat import load_module_pyc
from .compat import py3k
+from .compat import raise_
from .exc import CommandError
@@ -82,7 +83,7 @@ def edit(path):
try:
editor.edit(path)
except Exception as exc:
- raise CommandError("Error executing editor (%s)" % (exc,))
+ raise_(CommandError("Error executing editor (%s)" % (exc,)), from_=exc)
def load_python_file(dir_, filename):
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index 159d0f0..29c2519 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -1,3 +1,4 @@
+import contextlib
import re
from sqlalchemy import __version__
@@ -64,6 +65,46 @@ else:
AUTOINCREMENT_DEFAULT = "auto"
+@contextlib.contextmanager
+def _ensure_scope_for_ddl(connection):
+ try:
+ in_transaction = connection.in_transaction
+ except AttributeError:
+ # catch for MockConnection
+ yield
+ else:
+ if not in_transaction():
+ with connection.begin():
+ yield
+ else:
+ yield
+
+
+def _safe_begin_connection_transaction(connection):
+ transaction = _get_connection_transaction(connection)
+ if transaction:
+ return transaction
+ else:
+ return connection.begin()
+
+
+def _get_connection_in_transaction(connection):
+ try:
+ in_transaction = connection.in_transaction
+ except AttributeError:
+ # catch for MockConnection
+ return False
+ else:
+ return in_transaction()
+
+
+def _get_connection_transaction(connection):
+ if sqla_14:
+ return connection.get_transaction()
+ else:
+ return connection._Connection__transaction
+
+
def _create_url(*arg, **kw):
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
@@ -314,8 +355,16 @@ def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
+def _insert_inline(table):
+ if sqla_14:
+ return table.insert().inline()
+ else:
+ return table.insert(inline=True)
+
+
if sqla_14:
from sqlalchemy import create_mock_engine
+ from sqlalchemy import select as _select
else:
from sqlalchemy import create_engine
@@ -323,3 +372,6 @@ else:
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
+
+ def _select(*columns):
+ return sql.select(list(columns))
diff --git a/docs/build/autogenerate.rst b/docs/build/autogenerate.rst
index 7072aa3..46dde52 100644
--- a/docs/build/autogenerate.rst
+++ b/docs/build/autogenerate.rst
@@ -465,7 +465,7 @@ are being used::
Above, ``inspected_column`` is a :class:`sqlalchemy.schema.Column` as
returned by
-:meth:`sqlalchemy.engine.reflection.Inspector.reflecttable`, whereas
+:meth:`sqlalchemy.engine.reflection.Inspector.reflect_table`, whereas
``metadata_column`` is a :class:`sqlalchemy.schema.Column` from the
local model environment. A return value of ``None`` indicates that default
type comparison to proceed.
diff --git a/docs/build/unreleased/autocommit.rst b/docs/build/unreleased/autocommit.rst
new file mode 100644
index 0000000..39d6098
--- /dev/null
+++ b/docs/build/unreleased/autocommit.rst
@@ -0,0 +1,21 @@
+.. change::
+ :tags: change, environment
+
+ To accommodate SQLAlchemy 1.4 and 2.0, the migration model now no longer
+ assumes that the SQLAlchemy Connection will autocommit an individual
+ operation. This essentially means that for databases that use
+ non-transactional DDL (pysqlite current driver behavior, MySQL), there is
+ still a BEGIN/COMMIT block that will surround each individual migration.
+ Databases that support transactional DDL should continue to have the
+ same flow, either per migration or per-entire run, depending on the
+ value of the :paramref:`.Environment.configure.transaction_per_migration`
+ flag.
+
+
+.. change::
+ :tags: change, environment
+
+ It now raises a :class:`.CommandError` if a ``sqlalchemy.engine.Engine``
+ is passed to the :meth:`.MigrationContext.configure` method instead of
+ a ``sqlalchemy.engine.Connection`` object. Previously, this would
+ be a warning only. \ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
index e41b58f..054c9fc 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,3 +1,66 @@
+[metadata]
+
+name = alembic
+
+# version comes from setup.py; setuptools
+# can't read the "attr:" here without importing
+# until version 47.0.0 which is too recent
+
+
+description = A database migration tool for SQLAlchemy.
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+url=https://alembic.sqlalchemy.org
+author = Mike Bayer
+author_email = mike_mp@zzzcomputing.com
+license = MIT
+license_file = LICENSE
+
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ Environment :: Console
+ License :: OSI Approved :: MIT License
+ Operating System :: OS Independent
+ Programming Language :: Python
+ Programming Language :: Python :: 2
+ Programming Language :: Python :: 2.7
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.6
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: Implementation :: CPython
+ Programming Language :: Python :: Implementation :: PyPy
+ Topic :: Database :: Front-Ends
+
+[options]
+packages = find:
+include_package_data = true
+zip_safe = false
+python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*
+package_dir =
+ =.
+
+install_requires =
+ SQLAlchemy>=1.3.0
+ Mako
+ python-editor>=0.3
+ python-dateutil
+
+[options.packages.find]
+exclude =
+ test*
+ examples*
+
+[options.exclude_package_data]
+'' = test*
+
+[options.entry_points]
+console_scripts =
+ alembic = alembic.config:main
+
[egg_info]
tag_build=dev
@@ -40,8 +103,9 @@ default=sqlite:///:memory:
sqlite=sqlite:///:memory:
sqlite_file=sqlite:///querytest.db
postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8
-mssql=mssql+pyodbc://scott:tiger@ms_2008
+mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+mariadb = mariadb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
oracle=oracle://scott:tiger@127.0.0.1:1521
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
@@ -49,7 +113,7 @@ oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
[tool:pytest]
-addopts= --tb native -v -r fxX -p no:warnings -p no:logging --maxfail=25
+addopts= --tb native -v -r sfxX -p no:warnings -p no:logging --maxfail=25
python_files=tests/test_*.py
diff --git a/setup.py b/setup.py
index 5437439..1ca5fd7 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,6 @@ import os
import re
import sys
-from setuptools import find_packages
from setuptools import setup
from setuptools.command.test import test as TestCommand
@@ -16,16 +15,6 @@ VERSION = (
v.close()
-readme = os.path.join(os.path.dirname(__file__), "README.rst")
-
-requires = [
- "SQLAlchemy>=1.3.0",
- "Mako",
- "python-editor>=0.3",
- "python-dateutil",
-]
-
-
class UseTox(TestCommand):
RED = 31
RESET_SEQ = "\033[0m"
@@ -42,40 +31,6 @@ class UseTox(TestCommand):
setup(
- name="alembic",
version=VERSION,
- description="A database migration tool for SQLAlchemy.",
- long_description=open(readme).read(),
- python_requires=(
- ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
- ),
- classifiers=[
- "Development Status :: 5 - Production/Stable",
- "Environment :: Console",
- "License :: OSI Approved :: MIT License",
- "Intended Audience :: Developers",
- "Programming Language :: Python",
- "Programming Language :: Python :: 2",
- "Programming Language :: Python :: 2.7",
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.6",
- "Programming Language :: Python :: 3.7",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: Implementation :: CPython",
- "Programming Language :: Python :: Implementation :: PyPy",
- "Topic :: Database :: Front-Ends",
- ],
- keywords="SQLAlchemy migrations",
- author="Mike Bayer",
- author_email="mike@zzzcomputing.com",
- url="https://alembic.sqlalchemy.org",
- project_urls={"Issue Tracker": "https://github.com/sqlalchemy/alembic/"},
- license="MIT",
- packages=find_packages(".", exclude=["examples*", "test*"]),
- include_package_data=True,
cmdclass={"test": UseTox},
- zip_safe=False,
- install_requires=requires,
- entry_points={"console_scripts": ["alembic = alembic.config:main"]},
)
diff --git a/tests/conftest.py b/tests/conftest.py
index a83dff5..325bb45 100755
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -10,6 +10,8 @@ import os
import pytest
+os.environ["SQLALCHEMY_WARN_20"] = "true"
+
pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
@@ -32,4 +34,12 @@ with open(bootstrap_file) as f:
code = compile(f.read(), "bootstrap.py", "exec")
to_bootstrap = "pytest"
exec(code, globals(), locals())
- from pytestplugin import * # noqa
+ from sqlalchemy.testing.plugin.pytestplugin import * # noqa
+
+ wrap_pytest_sessionstart = pytest_sessionstart # noqa
+
+ def pytest_sessionstart(session):
+ wrap_pytest_sessionstart(session)
+ from alembic.testing import warnings
+
+ warnings.setup_filters()
diff --git a/tests/requirements.py b/tests/requirements.py
index 830d4de..8c81889 100644
--- a/tests/requirements.py
+++ b/tests/requirements.py
@@ -271,3 +271,9 @@ class DefaultRequirements(SuiteRequirements):
@property
def supports_identity_on_null(self):
return self.identity_columns + exclusions.only_on(["oracle"])
+
+ @property
+ def legacy_engine(self):
+ return exclusions.only_if(
+ lambda config: not getattr(config.db, "_is_future", False)
+ )
diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py
index 94546ff..943e61a 100644
--- a/tests/test_autogen_indexes.py
+++ b/tests/test_autogen_indexes.py
@@ -14,9 +14,9 @@ from sqlalchemy import UniqueConstraint
from alembic.testing import assertions
from alembic.testing import config
-from alembic.testing import engines
from alembic.testing import eq_
from alembic.testing import TestBase
+from alembic.testing import util
from alembic.testing.env import staging_env
from alembic.util import sqla_compat
from ._autogen_fixtures import AutogenFixtureTest
@@ -29,7 +29,7 @@ class NoUqReflection(object):
def setUp(self):
staging_env()
- self.bind = eng = engines.testing_engine()
+ self.bind = eng = util.testing_engine()
def unimpl(*arg, **kw):
raise NotImplementedError()
@@ -1508,7 +1508,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
class TruncatedIdxTest(AutogenFixtureTest, TestBase):
def setUp(self):
- self.bind = engines.testing_engine()
+ self.bind = util.testing_engine()
self.bind.dialect.max_identifier_length = 30
def test_idx_matches_long(self):
diff --git a/tests/test_batch.py b/tests/test_batch.py
index c9785f2..23ab364 100644
--- a/tests/test_batch.py
+++ b/tests/test_batch.py
@@ -24,7 +24,6 @@ from sqlalchemy.dialects import sqlite as sqlite_dialect
from sqlalchemy.schema import CreateIndex
from sqlalchemy.schema import CreateTable
from sqlalchemy.sql import column
-from sqlalchemy.sql import select
from sqlalchemy.sql import text
from alembic.ddl import sqlite
@@ -40,6 +39,7 @@ from alembic.testing import mock
from alembic.testing import TestBase
from alembic.testing.fixtures import op_fixture
from alembic.util import exc as alembic_exc
+from alembic.util.sqla_compat import _select
from alembic.util.sqla_compat import sqla_14
@@ -851,8 +851,10 @@ class BatchApplyTest(TestBase):
class BatchAPITest(TestBase):
@contextmanager
def _fixture(self, schema=None):
+
migration_context = mock.Mock(
- opts={}, impl=mock.MagicMock(__dialect__="sqlite")
+ opts={},
+ impl=mock.MagicMock(__dialect__="sqlite", connection=object()),
)
op = Operations(migration_context)
batch = op.batch_alter_table(
@@ -1256,90 +1258,105 @@ class BatchRoundTripTest(TestBase):
Column("x", Integer),
mysql_engine="InnoDB",
)
- t1.create(self.conn)
+ with self.conn.begin():
+ t1.create(self.conn)
- self.conn.execute(
- t1.insert(),
- [
- {"id": 1, "data": "d1", "x": 5},
- {"id": 2, "data": "22", "x": 6},
- {"id": 3, "data": "8.5", "x": 7},
- {"id": 4, "data": "9.46", "x": 8},
- {"id": 5, "data": "d5", "x": 9},
- ],
- )
+ self.conn.execute(
+ t1.insert(),
+ [
+ {"id": 1, "data": "d1", "x": 5},
+ {"id": 2, "data": "22", "x": 6},
+ {"id": 3, "data": "8.5", "x": 7},
+ {"id": 4, "data": "9.46", "x": 8},
+ {"id": 5, "data": "d5", "x": 9},
+ ],
+ )
context = MigrationContext.configure(self.conn)
self.op = Operations(context)
@contextmanager
def _sqlite_referential_integrity(self):
- self.conn.execute("PRAGMA foreign_keys=ON")
+ self.conn.exec_driver_sql("PRAGMA foreign_keys=ON")
try:
yield
finally:
- self.conn.execute("PRAGMA foreign_keys=OFF")
+ self.conn.exec_driver_sql("PRAGMA foreign_keys=OFF")
+
+ # as these tests are typically intentional fails, clean out
+ # tables left over
+ m = MetaData()
+ m.reflect(self.conn)
+ with self.conn.begin():
+ m.drop_all(self.conn)
def _no_pk_fixture(self):
- nopk = Table(
- "nopk",
- self.metadata,
- Column("a", Integer),
- Column("b", Integer),
- Column("c", Integer),
- mysql_engine="InnoDB",
- )
- nopk.create(self.conn)
- self.conn.execute(
- nopk.insert(), [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}]
- )
- return nopk
+ with self.conn.begin():
+ nopk = Table(
+ "nopk",
+ self.metadata,
+ Column("a", Integer),
+ Column("b", Integer),
+ Column("c", Integer),
+ mysql_engine="InnoDB",
+ )
+ nopk.create(self.conn)
+ self.conn.execute(
+ nopk.insert(),
+ [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}],
+ )
+ return nopk
def _table_w_index_fixture(self):
- t = Table(
- "t_w_ix",
- self.metadata,
- Column("id", Integer, primary_key=True),
- Column("thing", Integer),
- Column("data", String(20)),
- )
- Index("ix_thing", t.c.thing)
- t.create(self.conn)
- return t
+ with self.conn.begin():
+ t = Table(
+ "t_w_ix",
+ self.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("thing", Integer),
+ Column("data", String(20)),
+ )
+ Index("ix_thing", t.c.thing)
+ t.create(self.conn)
+ return t
def _boolean_fixture(self):
- t = Table(
- "hasbool",
- self.metadata,
- Column("x", Boolean(create_constraint=True, name="ck1")),
- Column("y", Integer),
- )
- t.create(self.conn)
+ with self.conn.begin():
+ t = Table(
+ "hasbool",
+ self.metadata,
+ Column("x", Boolean(create_constraint=True, name="ck1")),
+ Column("y", Integer),
+ )
+ t.create(self.conn)
def _timestamp_fixture(self):
- t = Table("hasts", self.metadata, Column("x", DateTime()))
- t.create(self.conn)
- return t
+ with self.conn.begin():
+ t = Table("hasts", self.metadata, Column("x", DateTime()))
+ t.create(self.conn)
+ return t
def _datetime_server_default_fixture(self):
return func.datetime("now", "localtime")
def _timestamp_w_expr_default_fixture(self):
- t = Table(
- "hasts",
- self.metadata,
- Column(
- "x",
- DateTime(),
- server_default=self._datetime_server_default_fixture(),
- nullable=False,
- ),
- )
- t.create(self.conn)
- return t
+ with self.conn.begin():
+ t = Table(
+ "hasts",
+ self.metadata,
+ Column(
+ "x",
+ DateTime(),
+ server_default=self._datetime_server_default_fixture(),
+ nullable=False,
+ ),
+ )
+ t.create(self.conn)
+ return t
def _int_to_boolean_fixture(self):
- t = Table("hasbool", self.metadata, Column("x", Integer))
- t.create(self.conn)
+ with self.conn.begin():
+ t = Table("hasbool", self.metadata, Column("x", Integer))
+ t.create(self.conn)
def test_change_type_boolean_to_int(self):
self._boolean_fixture()
@@ -1365,15 +1382,16 @@ class BatchRoundTripTest(TestBase):
import datetime
- self.conn.execute(
- t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
- )
+ with self.conn.begin():
+ self.conn.execute(
+ t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
+ )
with self.op.batch_alter_table("hasts") as batch_op:
batch_op.alter_column("x", type_=DateTime())
eq_(
- self.conn.execute(select([t.c.x])).fetchall(),
+ self.conn.execute(_select(t.c.x)).fetchall(),
[(datetime.datetime(2012, 5, 18, 15, 32, 5),)],
)
@@ -1388,10 +1406,14 @@ class BatchRoundTripTest(TestBase):
server_default=self._datetime_server_default_fixture(),
)
- self.conn.execute(t.insert())
-
- row = self.conn.execute(select([t.c.x])).fetchone()
- assert row["x"] is not None
+ with self.conn.begin():
+ self.conn.execute(t.insert())
+ res = self.conn.execute(_select(t.c.x))
+ if sqla_14:
+ assert res.scalar_one_or_none() is not None
+ else:
+ row = res.fetchone()
+ assert row["x"] is not None
def test_drop_col_schematype(self):
self._boolean_fixture()
@@ -1429,19 +1451,18 @@ class BatchRoundTripTest(TestBase):
)
def tearDown(self):
- self.metadata.drop_all(self.conn)
+ in_t = getattr(self.conn, "in_transaction", lambda: False)
+ if in_t():
+ self.conn.rollback()
+ with self.conn.begin():
+ self.metadata.drop_all(self.conn)
self.conn.close()
def _assert_data(self, data, tablename="foo"):
- eq_(
- [
- dict(row)
- for row in self.conn.execute(
- text("select * from %s" % tablename)
- )
- ],
- data,
- )
+ res = self.conn.execute(text("select * from %s" % tablename))
+ if sqla_14:
+ res = res.mappings()
+ eq_([dict(row) for row in res], data)
def test_ix_existing(self):
self._table_w_index_fixture()
@@ -1486,8 +1507,9 @@ class BatchRoundTripTest(TestBase):
Column("foo_id", Integer, ForeignKey("foo.id")),
mysql_engine="InnoDB",
)
- bar.create(self.conn)
- self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
+ with self.conn.begin():
+ bar.create(self.conn)
+ self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
with self.op.batch_alter_table("foo", recreate=recreate) as batch_op:
batch_op.alter_column(
@@ -1532,9 +1554,14 @@ class BatchRoundTripTest(TestBase):
Column("data", String(50)),
mysql_engine="InnoDB",
)
- bar.create(self.conn)
- self.conn.execute(bar.insert(), {"id": 1, "data": "x", "bar_id": None})
- self.conn.execute(bar.insert(), {"id": 2, "data": "y", "bar_id": 1})
+ with self.conn.begin():
+ bar.create(self.conn)
+ self.conn.execute(
+ bar.insert(), {"id": 1, "data": "x", "bar_id": None}
+ )
+ self.conn.execute(
+ bar.insert(), {"id": 2, "data": "y", "bar_id": 1}
+ )
with self.op.batch_alter_table("bar", recreate=recreate) as batch_op:
batch_op.alter_column(
@@ -1649,8 +1676,9 @@ class BatchRoundTripTest(TestBase):
Column("foo_id", Integer, ForeignKey("foo.id")),
mysql_engine="InnoDB",
)
- bar.create(self.conn)
- self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
+ with self.conn.begin():
+ bar.create(self.conn)
+ self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
naming_convention = {
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"
@@ -1773,9 +1801,10 @@ class BatchRoundTripTest(TestBase):
Column("flag", Boolean(create_constraint=True)),
mysql_engine="InnoDB",
)
- bar.create(self.conn)
- self.conn.execute(bar.insert(), {"id": 1, "flag": True})
- self.conn.execute(bar.insert(), {"id": 2, "flag": False})
+ with self.conn.begin():
+ bar.create(self.conn)
+ self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+ self.conn.execute(bar.insert(), {"id": 2, "flag": False})
with self.op.batch_alter_table("bar") as batch_op:
batch_op.alter_column(
@@ -1795,15 +1824,16 @@ class BatchRoundTripTest(TestBase):
Column("flag", Boolean(create_constraint=False)),
mysql_engine="InnoDB",
)
- bar.create(self.conn)
- self.conn.execute(bar.insert(), {"id": 1, "flag": True})
- self.conn.execute(bar.insert(), {"id": 2, "flag": False})
- self.conn.execute(
- # override Boolean type which as of 1.1 coerces numerics
- # to 1/0
- text("insert into bar (id, flag) values (:id, :flag)"),
- {"id": 3, "flag": 5},
- )
+ with self.conn.begin():
+ bar.create(self.conn)
+ self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+ self.conn.execute(bar.insert(), {"id": 2, "flag": False})
+ self.conn.execute(
+ # override Boolean type which as of 1.1 coerces numerics
+ # to 1/0
+ text("insert into bar (id, flag) values (:id, :flag)"),
+ {"id": 3, "flag": 5},
+ )
with self.op.batch_alter_table(
"bar",
@@ -2042,7 +2072,8 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
),
Column("y", Integer),
)
- t.create(self.conn)
+ with self.conn.begin():
+ t.create(self.conn)
def _datetime_server_default_fixture(self):
return func.current_timestamp()
diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py
index aedf6e9..09c641a 100644
--- a/tests/test_bulk_insert.py
+++ b/tests/test_bulk_insert.py
@@ -237,23 +237,28 @@ class RoundTripTest(TestBase):
def setUp(self):
self.conn = config.db.connect()
- self.conn.execute(
- text(
- """
- create table foo(
- id integer primary key,
- data varchar(50),
- x integer
- )
- """
+ with self.conn.begin():
+ self.conn.execute(
+ text(
+ """
+ create table foo(
+ id integer primary key,
+ data varchar(50),
+ x integer
+ )
+ """
+ )
)
- )
context = MigrationContext.configure(self.conn)
self.op = op.Operations(context)
self.t1 = table("foo", column("id"), column("data"), column("x"))
+ self.trans = self.conn.begin()
+
def tearDown(self):
- self.conn.execute(text("drop table foo"))
+ self.trans.rollback()
+ with self.conn.begin():
+ self.conn.execute(text("drop table foo"))
self.conn.close()
def test_single_insert_round_trip(self):
diff --git a/tests/test_command.py b/tests/test_command.py
index a616e9c..8350e82 100644
--- a/tests/test_command.py
+++ b/tests/test_command.py
@@ -399,7 +399,7 @@ finally:
r2 = command.revision(self.cfg)
db = _sqlite_file_db()
command.upgrade(self.cfg, "head")
- with db.connect() as conn:
+ with db.begin() as conn:
conn.execute(
text("insert into alembic_version values ('%s')" % r2.revision)
)
@@ -681,7 +681,7 @@ class StampMultipleHeadsTest(TestBase, _StampTest):
command.stamp(self.cfg, [self.a])
eng = _sqlite_file_db()
- with eng.connect() as conn:
+ with eng.begin() as conn:
result = conn.execute(
text("update alembic_version set version_num='fake'")
)
diff --git a/tests/test_environment.py b/tests/test_environment.py
index 4fb6bbe..63de6cd 100644
--- a/tests/test_environment.py
+++ b/tests/test_environment.py
@@ -1,6 +1,7 @@
#!coding: utf-8
from alembic import command
from alembic import testing
+from alembic import util
from alembic.environment import EnvironmentContext
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
@@ -11,7 +12,7 @@ from alembic.testing import is_
from alembic.testing import is_false
from alembic.testing import is_true
from alembic.testing import mock
-from alembic.testing.assertions import expect_warnings
+from alembic.testing.assertions import expect_raises_message
from alembic.testing.env import _no_sql_testing_config
from alembic.testing.env import _sqlite_file_db
from alembic.testing.env import clear_staging_env
@@ -94,10 +95,11 @@ def upgrade():
command.upgrade(self.cfg, "arev", sql=True)
assert "do some SQL thing with a % percent sign %" in buf.getvalue()
+ @config.requirements.legacy_engine
@testing.uses_deprecated(
r"The Engine.execute\(\) function/method is considered legacy"
)
- def test_warning_on_passing_engine(self):
+ def test_error_on_passing_engine(self):
env = self._fixture()
engine = _sqlite_file_db()
@@ -131,18 +133,15 @@ def downgrade():
migration_fn(rev, context)
return env.script._upgrade_revs(a_rev, rev)
- with expect_warnings(
+ with expect_raises_message(
+ util.CommandError,
r"'connection' argument to configure\(\) is "
- r"expected to be a sqlalchemy.engine.Connection "
+ r"expected to be a sqlalchemy.engine.Connection ",
):
env.configure(
connection=engine, fn=upgrade, transactional_ddl=False
)
- env.run_migrations()
-
- eq_(migration_fn.mock_calls, [mock.call((), env._migration_context)])
-
class MigrationTransactionTest(TestBase):
__backend__ = True
@@ -238,7 +237,7 @@ class MigrationTransactionTest(TestBase):
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
- is_false(self.conn.in_transaction())
+ is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
@@ -264,7 +263,7 @@ class MigrationTransactionTest(TestBase):
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
- is_false(self.conn.in_transaction())
+ is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
@@ -334,18 +333,12 @@ class MigrationTransactionTest(TestBase):
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
- if context.impl.transactional_ddl:
- is_true(self.conn.in_transaction())
- else:
- is_false(self.conn.in_transaction())
+ is_true(self.conn.in_transaction())
with context.autocommit_block():
is_false(self.conn.in_transaction())
- if context.impl.transactional_ddl:
- is_true(self.conn.in_transaction())
- else:
- is_false(self.conn.in_transaction())
+ is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
diff --git a/tests/test_impl.py b/tests/test_impl.py
new file mode 100644
index 0000000..8a73b87
--- /dev/null
+++ b/tests/test_impl.py
@@ -0,0 +1,45 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import Table
+from sqlalchemy.sql import text
+
+from alembic import testing
+from alembic.testing import eq_
+from alembic.testing.fixtures import FutureEngineMixin
+from alembic.testing.fixtures import TablesTest
+
+
+class ImplTest(TablesTest):
+ __only_on__ = "sqlite"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table", metadata, Column("x", Integer), Column("y", Integer)
+ )
+
+ @testing.fixture
+ def impl(self, migration_context):
+ with migration_context.begin_transaction(_per_migration=True):
+ yield migration_context.impl
+
+ def test_execute_params(self, impl):
+ result = impl._exec(text("select :my_param"), params={"my_param": 5})
+ eq_(result.scalar(), 5)
+
+ def test_execute_multiparams(self, impl):
+ some_table = self.tables.some_table
+ impl._exec(
+ some_table.insert(),
+ multiparams=[{"x": 1, "y": 2}, {"x": 2, "y": 3}, {"x": 5, "y": 7}],
+ )
+ eq_(
+ impl._exec(
+ some_table.select().order_by(some_table.c.x)
+ ).fetchall(),
+ [(1, 2), (2, 3), (5, 7)],
+ )
+
+
+class FutureImplTest(FutureEngineMixin, ImplTest):
+ pass
diff --git a/tests/test_mysql.py b/tests/test_mysql.py
index caef197..ba43e3a 100644
--- a/tests/test_mysql.py
+++ b/tests/test_mysql.py
@@ -594,10 +594,11 @@ class MySQLDefaultCompareTest(TestBase):
clear_staging_env()
def setUp(self):
- self.metadata = MetaData(self.bind)
+ self.metadata = MetaData()
def tearDown(self):
- self.metadata.drop_all()
+ with config.db.begin() as conn:
+ self.metadata.drop_all(conn)
def _compare_default_roundtrip(self, type_, txt, alternate=None):
if alternate:
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index 08f70d8..10f17d4 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -35,7 +35,6 @@ from alembic.autogenerate.compare import _compare_server_default
from alembic.autogenerate.compare import _compare_tables
from alembic.autogenerate.compare import _render_server_default_for_compare
from alembic.migration import MigrationContext
-from alembic.operations import Operations
from alembic.operations import ops
from alembic.script import ScriptDirectory
from alembic.testing import assert_raises_message
@@ -50,6 +49,7 @@ from alembic.testing.env import staging_env
from alembic.testing.env import write_script
from alembic.testing.fixtures import capture_context_buffer
from alembic.testing.fixtures import op_fixture
+from alembic.testing.fixtures import TablesTest
from alembic.testing.fixtures import TestBase
from alembic.util import sqla_compat
@@ -436,11 +436,12 @@ class PGAutocommitBlockTest(TestBase):
with self.conn.begin():
self.conn.execute(text("DROP TYPE mood"))
- def test_alter_enum(self):
- context = MigrationContext.configure(connection=self.conn)
- with context.begin_transaction(_per_migration=True):
- with context.autocommit_block():
- context.execute(text("ALTER TYPE mood ADD VALUE 'soso'"))
+ def test_alter_enum(self, migration_context):
+ with migration_context.begin_transaction(_per_migration=True):
+ with migration_context.autocommit_block():
+ migration_context.execute(
+ text("ALTER TYPE mood ADD VALUE 'soso'")
+ )
class PGOfflineEnumTest(TestBase):
@@ -546,58 +547,38 @@ def downgrade():
assert "DROP TYPE pgenum" in buf.getvalue()
-class PostgresqlInlineLiteralTest(TestBase):
+class PostgresqlInlineLiteralTest(TablesTest):
__only_on__ = "postgresql"
__backend__ = True
@classmethod
- def setup_class(cls):
- cls.bind = config.db
- with config.db.connect() as conn:
- conn.execute(
- text(
- """
- create table tab (
- col varchar(50)
- )
- """
- )
- )
- conn.execute(
- text(
- """
+ def define_tables(cls, metadata):
+ Table("tab", metadata, Column("col", String(50)))
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ text(
+ """
insert into tab (col) values
('old data 1'),
('old data 2.1'),
('old data 3')
"""
- )
)
+ )
- @classmethod
- def teardown_class(cls):
- with cls.bind.connect() as conn:
- conn.execute(text("drop table tab"))
-
- def setUp(self):
- self.conn = self.bind.connect()
- ctx = MigrationContext.configure(self.conn)
- self.op = Operations(ctx)
-
- def tearDown(self):
- self.conn.close()
-
- def test_inline_percent(self):
+ def test_inline_percent(self, connection, ops_context):
# TODO: here's the issue, you need to escape this.
tab = table("tab", column("col"))
- self.op.execute(
+ ops_context.execute(
tab.update()
- .where(tab.c.col.like(self.op.inline_literal("%.%")))
- .values(col=self.op.inline_literal("new data")),
+ .where(tab.c.col.like(ops_context.inline_literal("%.%")))
+ .values(col=ops_context.inline_literal("new data")),
execution_options={"no_parameters": True},
)
eq_(
- self.conn.execute(
+ connection.execute(
text("select count(*) from tab where col='new data'")
).scalar(),
1,
@@ -618,7 +599,7 @@ class PostgresqlDefaultCompareTest(TestBase):
)
def setUp(self):
- self.metadata = MetaData(self.bind)
+ self.metadata = MetaData()
self.autogen_context = api.AutogenContext(self.migration_context)
@classmethod
@@ -626,7 +607,8 @@ class PostgresqlDefaultCompareTest(TestBase):
clear_staging_env()
def tearDown(self):
- self.metadata.drop_all()
+ with config.db.begin() as conn:
+ self.metadata.drop_all(conn)
def _compare_default_roundtrip(
self, type_, orig_default, alternate=None, diff_expected=None
diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py
index 17bf037..e1b094f 100644
--- a/tests/test_script_consumption.py
+++ b/tests/test_script_consumption.py
@@ -5,12 +5,16 @@ import os
import re
import textwrap
+import sqlalchemy as sa
+
from alembic import command
+from alembic import testing
from alembic import util
from alembic.environment import EnvironmentContext
from alembic.script import Script
from alembic.script import ScriptDirectory
from alembic.testing import assert_raises_message
+from alembic.testing import config
from alembic.testing import eq_
from alembic.testing import mock
from alembic.testing.env import _no_sql_testing_config
@@ -22,31 +26,81 @@ from alembic.testing.env import staging_env
from alembic.testing.env import three_rev_fixture
from alembic.testing.env import write_script
from alembic.testing.fixtures import capture_context_buffer
+from alembic.testing.fixtures import FutureEngineMixin
from alembic.testing.fixtures import TestBase
from alembic.util import compat
-class ApplyVersionsFunctionalTest(TestBase):
+class PatchEnvironment(object):
+ @contextmanager
+ def _patch_environment(self, transactional_ddl, transaction_per_migration):
+ conf = EnvironmentContext.configure
+
+ conn = [None]
+
+ def configure(*arg, **opt):
+ opt.update(
+ transactional_ddl=transactional_ddl,
+ transaction_per_migration=transaction_per_migration,
+ )
+ conn[0] = opt["connection"]
+ return conf(*arg, **opt)
+
+ with mock.patch.object(EnvironmentContext, "configure", configure):
+ yield
+
+ # it's no longer possible for the conn to be in a transaction
+ # assuming normal env.py as context.begin_transaction()
+ # will always run a real DB transaction, no longer uses autocommit
+ # mode
+ assert not conn[0].in_transaction()
+
+
+@testing.combinations(
+ (
+ False,
+ True,
+ ),
+ (
+ True,
+ False,
+ ),
+ (
+ True,
+ True,
+ ),
+ argnames="transactional_ddl,transaction_per_migration",
+ id_="rr",
+)
+class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
__only_on__ = "sqlite"
sourceless = False
+ future = False
+ transactional_ddl = False
+ transaction_per_migration = True
def setUp(self):
- self.bind = _sqlite_file_db()
+ self.bind = _sqlite_file_db(future=self.future)
self.env = staging_env(sourceless=self.sourceless)
- self.cfg = _sqlite_testing_config(sourceless=self.sourceless)
+ self.cfg = _sqlite_testing_config(
+ sourceless=self.sourceless, future=self.future
+ )
def tearDown(self):
clear_staging_env()
def test_steps(self):
- self._test_001_revisions()
- self._test_002_upgrade()
- self._test_003_downgrade()
- self._test_004_downgrade()
- self._test_005_upgrade()
- self._test_006_upgrade_again()
- self._test_007_stamp_upgrade()
+ with self._patch_environment(
+ self.transactional_ddl, self.transaction_per_migration
+ ):
+ self._test_001_revisions()
+ self._test_002_upgrade()
+ self._test_003_downgrade()
+ self._test_004_downgrade()
+ self._test_005_upgrade()
+ self._test_006_upgrade_again()
+ self._test_007_stamp_upgrade()
def _test_001_revisions(self):
self.a = a = util.rev_id()
@@ -166,22 +220,39 @@ class ApplyVersionsFunctionalTest(TestBase):
assert not db.dialect.has_table(db.connect(), "bat")
+# class level combinations can't do the skips for SQLAlchemy 1.3
+# so we have a separate class
+@testing.combinations(
+ (
+ False,
+ True,
+ ),
+ (
+ True,
+ False,
+ ),
+ (
+ True,
+ True,
+ ),
+ argnames="transactional_ddl,transaction_per_migration",
+ id_="rr",
+)
+class FutureApplyVersionsTest(FutureEngineMixin, ApplyVersionsFunctionalTest):
+ future = True
+
+
class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
sourceless = "simple"
-class NewFangledSourcelessEnvOnlyApplyVersionsTest(
- ApplyVersionsFunctionalTest
-):
- sourceless = "pep3147_envonly"
-
- __requires__ = ("pep3147",)
-
-
-class NewFangledSourcelessEverythingApplyVersionsTest(
- ApplyVersionsFunctionalTest
-):
- sourceless = "pep3147_everything"
+@testing.combinations(
+ ("pep3147_envonly",),
+ ("pep3147_everything",),
+ argnames="sourceless",
+ id_="r",
+)
+class NewFangledSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
__requires__ = ("pep3147",)
@@ -313,13 +384,17 @@ class OfflineTransactionalDDLTest(TestBase):
)
-class OnlineTransactionalDDLTest(TestBase):
+class OnlineTransactionalDDLTest(PatchEnvironment, TestBase):
def tearDown(self):
clear_staging_env()
- def _opened_transaction_fixture(self):
+ def _opened_transaction_fixture(self, future=False):
self.env = staging_env()
- self.cfg = _sqlite_testing_config()
+
+ if future:
+ self.cfg = _sqlite_testing_config(future=future)
+ else:
+ self.cfg = _sqlite_testing_config()
script = ScriptDirectory.from_config(self.cfg)
a = util.rev_id()
@@ -358,6 +433,8 @@ from alembic import op
def upgrade():
conn = op.get_bind()
+ # this should fail for a SQLAlchemy 2.0 connection b.c. there is
+ # already a transaction.
trans = conn.begin()
@@ -391,59 +468,89 @@ def downgrade():
)
return a, b, c
- @contextmanager
- def _patch_environment(self, transactional_ddl, transaction_per_migration):
- conf = EnvironmentContext.configure
-
- def configure(*arg, **opt):
- opt.update(
- transactional_ddl=transactional_ddl,
- transaction_per_migration=transaction_per_migration,
- )
- return conf(*arg, **opt)
-
- with mock.patch.object(EnvironmentContext, "configure", configure):
- yield
+ # these tests might not be supported anymore; the connection is always
+ # going to be in a transaction now even on 1.3.
- def test_raise_when_rev_leaves_open_transaction(self):
- a, b, c = self._opened_transaction_fixture()
+ @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
+ def test_raise_when_rev_leaves_open_transaction(self, future):
+ a, b, c = self._opened_transaction_fixture(future)
with self._patch_environment(
transactional_ddl=False, transaction_per_migration=False
):
- assert_raises_message(
- util.CommandError,
- r'Migration "upgrade .*, rev b" has left an uncommitted '
- r"transaction opened; transactional_ddl is False so Alembic "
- r"is not committing transactions",
- command.upgrade,
- self.cfg,
- c,
- )
+ if future:
+ with testing.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "a transaction is already begun",
+ ):
+ command.upgrade(self.cfg, c)
+ elif config.requirements.sqlalchemy_14.enabled:
+ if self.is_sqlalchemy_future:
+ with testing.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ r"a transaction is already begun for this connection",
+ ):
+ command.upgrade(self.cfg, c)
+ else:
+ with testing.expect_sqlalchemy_deprecated_20(
+ r"Calling .begin\(\) when a transaction "
+ "is already begun"
+ ):
+ command.upgrade(self.cfg, c)
+ else:
+ command.upgrade(self.cfg, c)
- def test_raise_when_rev_leaves_open_transaction_tpm(self):
- a, b, c = self._opened_transaction_fixture()
+ @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
+ def test_raise_when_rev_leaves_open_transaction_tpm(self, future):
+ a, b, c = self._opened_transaction_fixture(future)
with self._patch_environment(
transactional_ddl=False, transaction_per_migration=True
):
- assert_raises_message(
- util.CommandError,
- r'Migration "upgrade .*, rev b" has left an uncommitted '
- r"transaction opened; transactional_ddl is False so Alembic "
- r"is not committing transactions",
- command.upgrade,
- self.cfg,
- c,
- )
+ if future:
+ with testing.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "a transaction is already begun",
+ ):
+ command.upgrade(self.cfg, c)
+ elif config.requirements.sqlalchemy_14.enabled:
+ if self.is_sqlalchemy_future:
+ with testing.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ r"a transaction is already begun for this connection",
+ ):
+ command.upgrade(self.cfg, c)
+ else:
+ with testing.expect_sqlalchemy_deprecated_20(
+ r"Calling .begin\(\) when a transaction is "
+ "already begun"
+ ):
+ command.upgrade(self.cfg, c)
+ else:
+ command.upgrade(self.cfg, c)
- def test_noerr_rev_leaves_open_transaction_transactional_ddl(self):
+ @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
+ def test_noerr_rev_leaves_open_transaction_transactional_ddl(self, future):
a, b, c = self._opened_transaction_fixture()
with self._patch_environment(
transactional_ddl=True, transaction_per_migration=False
):
- command.upgrade(self.cfg, c)
+ if config.requirements.sqlalchemy_14.enabled:
+ if self.is_sqlalchemy_future:
+ with testing.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ r"a transaction is already begun for this connection",
+ ):
+ command.upgrade(self.cfg, c)
+ else:
+ with testing.expect_sqlalchemy_deprecated_20(
+ r"Calling .begin\(\) when a transaction "
+ "is already begun"
+ ):
+ command.upgrade(self.cfg, c)
+ else:
+ command.upgrade(self.cfg, c)
def test_noerr_transaction_opened_externally(self):
a, b, c = self._opened_transaction_fixture()
@@ -477,6 +584,12 @@ run_migrations_online()
command.stamp(self.cfg, c)
+class FutureOnlineTransactionalDDLTest(
+ FutureEngineMixin, OnlineTransactionalDDLTest
+):
+ pass
+
+
class EncodingTest(TestBase):
def setUp(self):
self.env = staging_env()
diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py
index 3ea1975..946f69f 100644
--- a/tests/test_sqlite.py
+++ b/tests/test_sqlite.py
@@ -96,7 +96,7 @@ class SQLiteDefaultCompareTest(TestBase):
)
def setUp(self):
- self.metadata = MetaData(self.bind)
+ self.metadata = MetaData()
self.autogen_context = api.AutogenContext(self.migration_context)
@classmethod
@@ -104,7 +104,7 @@ class SQLiteDefaultCompareTest(TestBase):
clear_staging_env()
def tearDown(self):
- self.metadata.drop_all()
+ self.metadata.drop_all(config.db)
def _compare_default_roundtrip(
self, type_, orig_default, alternate=None, diff_expected=None
diff --git a/tests/test_version_table.py b/tests/test_version_table.py
index 1801346..5ad3c21 100644
--- a/tests/test_version_table.py
+++ b/tests/test_version_table.py
@@ -39,7 +39,8 @@ class TestMigrationContext(TestBase):
def tearDown(self):
self.transaction.rollback()
- version_table.drop(self.connection, checkfirst=True)
+ with self.connection.begin():
+ version_table.drop(self.connection, checkfirst=True)
self.connection.close()
def make_one(self, **kwargs):
@@ -182,11 +183,16 @@ class UpdateRevTest(TestBase):
self.context = migration.MigrationContext.configure(
connection=self.connection, opts={"version_table": "version_table"}
)
- version_table.create(self.connection)
+ with self.connection.begin():
+ version_table.create(self.connection)
self.updater = migration.HeadMaintainer(self.context, ())
def tearDown(self):
- version_table.drop(self.connection, checkfirst=True)
+ in_t = getattr(self.connection, "in_transaction", lambda: False)
+ if in_t():
+ self.connection.rollback()
+ with self.connection.begin():
+ version_table.drop(self.connection, checkfirst=True)
self.connection.close()
def _assert_heads(self, heads):
@@ -194,145 +200,176 @@ class UpdateRevTest(TestBase):
eq_(self.updater.heads, set(heads))
def test_update_none_to_single(self):
- self.updater.update_to_step(_up(None, "a", True))
- self._assert_heads(("a",))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self._assert_heads(("a",))
def test_update_single_to_single(self):
- self.updater.update_to_step(_up(None, "a", True))
- self.updater.update_to_step(_up("a", "b"))
- self._assert_heads(("b",))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_up("a", "b"))
+ self._assert_heads(("b",))
def test_update_single_to_none(self):
- self.updater.update_to_step(_up(None, "a", True))
- self.updater.update_to_step(_down("a", None, True))
- self._assert_heads(())
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_down("a", None, True))
+ self._assert_heads(())
def test_add_branches(self):
- self.updater.update_to_step(_up(None, "a", True))
- self.updater.update_to_step(_up("a", "b"))
- self.updater.update_to_step(_up(None, "c", True))
- self._assert_heads(("b", "c"))
- self.updater.update_to_step(_up("c", "d"))
- self.updater.update_to_step(_up("d", "e1"))
- self.updater.update_to_step(_up("d", "e2", True))
- self._assert_heads(("b", "e1", "e2"))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_up("a", "b"))
+ self.updater.update_to_step(_up(None, "c", True))
+ self._assert_heads(("b", "c"))
+ self.updater.update_to_step(_up("c", "d"))
+ self.updater.update_to_step(_up("d", "e1"))
+ self.updater.update_to_step(_up("d", "e2", True))
+ self._assert_heads(("b", "e1", "e2"))
def test_teardown_branches(self):
- self.updater.update_to_step(_up(None, "d1", True))
- self.updater.update_to_step(_up(None, "d2", True))
- self._assert_heads(("d1", "d2"))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "d1", True))
+ self.updater.update_to_step(_up(None, "d2", True))
+ self._assert_heads(("d1", "d2"))
- self.updater.update_to_step(_down("d1", "c"))
- self._assert_heads(("c", "d2"))
+ self.updater.update_to_step(_down("d1", "c"))
+ self._assert_heads(("c", "d2"))
- self.updater.update_to_step(_down("d2", "c", True))
+ self.updater.update_to_step(_down("d2", "c", True))
- self._assert_heads(("c",))
- self.updater.update_to_step(_down("c", "b"))
- self._assert_heads(("b",))
+ self._assert_heads(("c",))
+ self.updater.update_to_step(_down("c", "b"))
+ self._assert_heads(("b",))
def test_resolve_merges(self):
- self.updater.update_to_step(_up(None, "a", True))
- self.updater.update_to_step(_up("a", "b"))
- self.updater.update_to_step(_up("b", "c1"))
- self.updater.update_to_step(_up("b", "c2", True))
- self.updater.update_to_step(_up("c1", "d1"))
- self.updater.update_to_step(_up("c2", "d2"))
- self._assert_heads(("d1", "d2"))
- self.updater.update_to_step(_up(("d1", "d2"), "e"))
- self._assert_heads(("e",))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.update_to_step(_up("a", "b"))
+ self.updater.update_to_step(_up("b", "c1"))
+ self.updater.update_to_step(_up("b", "c2", True))
+ self.updater.update_to_step(_up("c1", "d1"))
+ self.updater.update_to_step(_up("c2", "d2"))
+ self._assert_heads(("d1", "d2"))
+ self.updater.update_to_step(_up(("d1", "d2"), "e"))
+ self._assert_heads(("e",))
def test_unresolve_merges(self):
- self.updater.update_to_step(_up(None, "e", True))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "e", True))
- self.updater.update_to_step(_down("e", ("d1", "d2")))
- self._assert_heads(("d2", "d1"))
+ self.updater.update_to_step(_down("e", ("d1", "d2")))
+ self._assert_heads(("d2", "d1"))
- self.updater.update_to_step(_down("d2", "c2"))
- self._assert_heads(("c2", "d1"))
+ self.updater.update_to_step(_down("d2", "c2"))
+ self._assert_heads(("c2", "d1"))
def test_update_no_match(self):
- self.updater.update_to_step(_up(None, "a", True))
- self.updater.heads.add("x")
- assert_raises_message(
- CommandError,
- "Online migration expected to match one row when updating "
- "'x' to 'b' in 'version_table'; 0 found",
- self.updater.update_to_step,
- _up("x", "b"),
- )
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.heads.add("x")
+ assert_raises_message(
+ CommandError,
+ "Online migration expected to match one row when updating "
+ "'x' to 'b' in 'version_table'; 0 found",
+ self.updater.update_to_step,
+ _up("x", "b"),
+ )
def test_update_no_match_no_sane_rowcount(self):
- self.updater.update_to_step(_up(None, "a", True))
- self.updater.heads.add("x")
- with mock.patch.object(
- self.connection.dialect, "supports_sane_rowcount", False
- ):
- self.updater.update_to_step(_up("x", "b"))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+ self.updater.heads.add("x")
+ with mock.patch.object(
+ self.connection.dialect, "supports_sane_rowcount", False
+ ):
+ self.updater.update_to_step(_up("x", "b"))
def test_update_multi_match(self):
- self.connection.execute(version_table.insert(), version_num="a")
- self.connection.execute(version_table.insert(), version_num="a")
-
- self.updater.heads.add("a")
- assert_raises_message(
- CommandError,
- "Online migration expected to match one row when updating "
- "'a' to 'b' in 'version_table'; 2 found",
- self.updater.update_to_step,
- _up("a", "b"),
- )
+ with self.connection.begin():
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+
+ self.updater.heads.add("a")
+ assert_raises_message(
+ CommandError,
+ "Online migration expected to match one row when updating "
+ "'a' to 'b' in 'version_table'; 2 found",
+ self.updater.update_to_step,
+ _up("a", "b"),
+ )
def test_update_multi_match_no_sane_rowcount(self):
- self.connection.execute(version_table.insert(), version_num="a")
- self.connection.execute(version_table.insert(), version_num="a")
-
- self.updater.heads.add("a")
- with mock.patch.object(
- self.connection.dialect, "supports_sane_rowcount", False
- ):
- self.updater.update_to_step(_up("a", "b"))
+ with self.connection.begin():
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+
+ self.updater.heads.add("a")
+ with mock.patch.object(
+ self.connection.dialect, "supports_sane_rowcount", False
+ ):
+ self.updater.update_to_step(_up("a", "b"))
def test_delete_no_match(self):
- self.updater.update_to_step(_up(None, "a", True))
-
- self.updater.heads.add("x")
- assert_raises_message(
- CommandError,
- "Online migration expected to match one row when "
- "deleting 'x' in 'version_table'; 0 found",
- self.updater.update_to_step,
- _down("x", None, True),
- )
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
+
+ self.updater.heads.add("x")
+ assert_raises_message(
+ CommandError,
+ "Online migration expected to match one row when "
+ "deleting 'x' in 'version_table'; 0 found",
+ self.updater.update_to_step,
+ _down("x", None, True),
+ )
def test_delete_no_matchno_sane_rowcount(self):
- self.updater.update_to_step(_up(None, "a", True))
+ with self.connection.begin():
+ self.updater.update_to_step(_up(None, "a", True))
- self.updater.heads.add("x")
- with mock.patch.object(
- self.connection.dialect, "supports_sane_rowcount", False
- ):
- self.updater.update_to_step(_down("x", None, True))
+ self.updater.heads.add("x")
+ with mock.patch.object(
+ self.connection.dialect, "supports_sane_rowcount", False
+ ):
+ self.updater.update_to_step(_down("x", None, True))
def test_delete_multi_match(self):
- self.connection.execute(version_table.insert(), version_num="a")
- self.connection.execute(version_table.insert(), version_num="a")
-
- self.updater.heads.add("a")
- assert_raises_message(
- CommandError,
- "Online migration expected to match one row when "
- "deleting 'a' in 'version_table'; 2 found",
- self.updater.update_to_step,
- _down("a", None, True),
- )
+ with self.connection.begin():
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+
+ self.updater.heads.add("a")
+ assert_raises_message(
+ CommandError,
+ "Online migration expected to match one row when "
+ "deleting 'a' in 'version_table'; 2 found",
+ self.updater.update_to_step,
+ _down("a", None, True),
+ )
def test_delete_multi_match_no_sane_rowcount(self):
- self.connection.execute(version_table.insert(), version_num="a")
- self.connection.execute(version_table.insert(), version_num="a")
-
- self.updater.heads.add("a")
- with mock.patch.object(
- self.connection.dialect, "supports_sane_rowcount", False
- ):
- self.updater.update_to_step(_down("a", None, True))
+ with self.connection.begin():
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+ self.connection.execute(
+ version_table.insert(), dict(version_num="a")
+ )
+
+ self.updater.heads.add("a")
+ with mock.patch.object(
+ self.connection.dialect, "supports_sane_rowcount", False
+ ):
+ self.updater.update_to_step(_down("a", None, True))
diff --git a/tox.ini b/tox.ini
index 52f6842..a332635 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
[tox]
-envlist = py
+envlist = py-sqlalchemy
SQLA_REPO = {env:SQLA_REPO:git+https://github.com/sqlalchemy/sqlalchemy.git}
@@ -10,8 +10,8 @@ cov_args=--cov=alembic --cov-report term --cov-report xml
deps=pytest>4.6
pytest-xdist
mock
- sqla13: {[tox]SQLA_REPO}@rel_1_3
- sqlamaster: {[tox]SQLA_REPO}@master
+ sqla13: {[tox]SQLA_REPO}@rel_1_3#egg=sqlalchemy
+ sqlamaster: {[tox]SQLA_REPO}@master#egg=sqlalchemy
postgresql: psycopg2
mysql: mysqlclient
mysql: pymysql
@@ -19,6 +19,10 @@ deps=pytest>4.6
oracle: cx_oracle>=7;python_version>="3"
mssql: pymssql
cov: pytest-cov
+ sqlalchemy: sqlalchemy>=1.3.0
+ mako
+ python-editor>=0.3
+ python-dateutil
@@ -39,6 +43,9 @@ setenv=
mssql: MSSQL={env:TOX_MSSQL:--db pymssql}
pyoptimize: PYTHONOPTIMIZE=1
pyoptimize: LIMITTESTS="tests/test_script_consumption.py"
+ future: SQLALCHEMY_TESTING_FUTURE_ENGINE=1
+ SQLALCHEMY_WARN_20=1
+
# tox as of 2.0 blocks all environment variables from the
# outside, unless they are here (or in TOX_TESTENV_PASSENV,
@@ -64,4 +71,4 @@ deps=
black==20.8b1
commands =
flake8 ./alembic/ ./tests/ setup.py docs/build/conf.py {posargs}
- black --check .
+ black --check setup.py tests alembic