diff options
42 files changed, 1213 insertions, 1157 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 5d1e848..db5fe12 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -31,22 +31,23 @@ def compare_metadata(context, metadata): from sqlalchemy.schema import SchemaItem from sqlalchemy.types import TypeEngine from sqlalchemy import (create_engine, MetaData, Column, - Integer, String, Table) + Integer, String, Table, text) import pprint engine = create_engine("sqlite://") - engine.execute(''' - create table foo ( - id integer not null primary key, - old_data varchar, - x integer - )''') - - engine.execute(''' - create table bar ( - data varchar - )''') + with engine.begin() as conn: + conn.execute(text(''' + create table foo ( + id integer not null primary key, + old_data varchar, + x integer + )''')) + + conn.execute(text(''' + create table bar ( + data varchar + )''')) metadata = MetaData() Table('foo', metadata, diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 3674c67..923fd8b 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -140,7 +140,10 @@ class DefaultImpl(with_metaclass(ImplMeta)): conn = self.connection if execution_options: conn = conn.execution_options(**execution_options) - return conn.execute(construct, *multiparams, **params) + if params: + multiparams += (params,) + + return conn.execute(construct, multiparams) def execute(self, sql, execution_options=None): self._exec(sql, execution_options) @@ -316,7 +319,7 @@ class DefaultImpl(with_metaclass(ImplMeta)): if self.as_sql: for row in rows: self._exec( - table.insert(inline=True).values( + sqla_compat._insert_inline(table).values( **dict( ( k, @@ -338,10 +341,14 @@ class DefaultImpl(with_metaclass(ImplMeta)): table._autoincrement_column = None if rows: if multiinsert: - self._exec(table.insert(inline=True), multiparams=rows) + self._exec( + sqla_compat._insert_inline(table), multiparams=rows + ) else: for row in rows: - self._exec(table.insert(inline=True).values(**row)) + self._exec( + sqla_compat._insert_inline(table).values(**row) + ) def _tokenize_column_type(self, column): definition = self.dialect.type_compiler.process(column.type).lower() diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index 81e3b99..f029193 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -5,7 +5,6 @@ from sqlalchemy import Index from sqlalchemy import MetaData from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema as sql_schema -from sqlalchemy import select from sqlalchemy import Table from sqlalchemy import types as sqltypes from sqlalchemy.events import SchemaEventTarget @@ -14,9 +13,12 @@ from sqlalchemy.util import topological from ..util import exc from ..util.sqla_compat import _columns_for_constraint +from ..util.sqla_compat import _ensure_scope_for_ddl from ..util.sqla_compat import _fk_is_self_referential +from ..util.sqla_compat import _insert_inline from ..util.sqla_compat import _is_type_bound from ..util.sqla_compat import _remove_column_from_collection +from ..util.sqla_compat import _select class BatchOperationsImpl(object): @@ -76,44 +78,44 @@ class BatchOperationsImpl(object): def flush(self): should_recreate = self._should_recreate() - if not should_recreate: - for opname, arg, kw in self.batch: - fn = getattr(self.operations.impl, opname) - fn(*arg, **kw) - else: - if self.naming_convention: - m1 = MetaData(naming_convention=self.naming_convention) + with _ensure_scope_for_ddl(self.impl.connection): + if not should_recreate: + for opname, arg, kw in self.batch: + fn = getattr(self.operations.impl, opname) + fn(*arg, **kw) else: - m1 = MetaData() + if self.naming_convention: + m1 = MetaData(naming_convention=self.naming_convention) + else: + m1 = MetaData() - if self.copy_from is not None: - existing_table = self.copy_from - reflected = False - else: - existing_table = Table( - self.table_name, - m1, - schema=self.schema, - autoload=True, - autoload_with=self.operations.get_bind(), - *self.reflect_args, - **self.reflect_kwargs + if self.copy_from is not None: + existing_table = self.copy_from + reflected = False + else: + existing_table = Table( + self.table_name, + m1, + schema=self.schema, + autoload_with=self.operations.get_bind(), + *self.reflect_args, + **self.reflect_kwargs + ) + reflected = True + + batch_impl = ApplyBatchImpl( + self.impl, + existing_table, + self.table_args, + self.table_kwargs, + reflected, + partial_reordering=self.partial_reordering, ) - reflected = True - - batch_impl = ApplyBatchImpl( - self.impl, - existing_table, - self.table_args, - self.table_kwargs, - reflected, - partial_reordering=self.partial_reordering, - ) - for opname, arg, kw in self.batch: - fn = getattr(batch_impl, opname) - fn(*arg, **kw) + for opname, arg, kw in self.batch: + fn = getattr(batch_impl, opname) + fn(*arg, **kw) - batch_impl._create(self.impl) + batch_impl._create(self.impl) def alter_column(self, *arg, **kw): self.batch.append(("alter_column", arg, kw)) @@ -362,14 +364,14 @@ class ApplyBatchImpl(object): try: op_impl._exec( - self.new_table.insert(inline=True).from_select( + _insert_inline(self.new_table).from_select( list( k for k, transfer in self.column_transfers.items() if "expr" in transfer ), - select( - [ + _select( + *[ transfer["expr"] for transfer in self.column_transfers.values() if "expr" in transfer diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index d90b5e6..5e8aa4f 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -3,6 +3,7 @@ from sqlalchemy.types import Integer from sqlalchemy.types import NULLTYPE from .. import util +from ..util.compat import raise_ from ..util.compat import string_types @@ -113,10 +114,13 @@ class SchemaObjects(object): } try: const = types[type_] - except KeyError: - raise TypeError( - "'type' can be one of %s" - % ", ".join(sorted(repr(x) for x in types)) + except KeyError as ke: + raise_( + TypeError( + "'type' can be one of %s" + % ", ".join(sorted(repr(x) for x in types)) + ), + from_=ke, ) else: const = const(name=name) diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 5c8590d..48bb842 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -8,7 +8,7 @@ from sqlalchemy import MetaData from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import String from sqlalchemy import Table -from sqlalchemy.engine import Connection +from sqlalchemy.engine import Engine from sqlalchemy.engine import url as sqla_url from sqlalchemy.engine.strategies import MockEngineStrategy @@ -31,15 +31,18 @@ class _ProxyTransaction(object): def rollback(self): self._proxied_transaction.rollback() + self.migration_context._transaction = None def commit(self): self._proxied_transaction.commit() + self.migration_context._transaction = None def __enter__(self): return self def __exit__(self, type_, value, traceback): self._proxied_transaction.__exit__(type_, value, traceback) + self.migration_context._transaction = None class MigrationContext(object): @@ -105,8 +108,13 @@ class MigrationContext(object): if as_sql: self.connection = self._stdout_connection(connection) assert self.connection is not None + self._in_external_transaction = False else: self.connection = connection + self._in_external_transaction = ( + sqla_compat._get_connection_in_transaction(connection) + ) + self._migrations_fn = opts.get("fn") self.as_sql = as_sql @@ -199,12 +207,11 @@ class MigrationContext(object): dialect_opts = {} if connection: - if not isinstance(connection, Connection): - util.warn( + if isinstance(connection, Engine): + raise util.CommandError( "'connection' argument to configure() is expected " "to be a sqlalchemy.engine.Connection instance, " "got %r" % connection, - stacklevel=3, ) dialect = connection.dialect @@ -268,19 +275,27 @@ class MigrationContext(object): """ _in_connection_transaction = self._in_connection_transaction() - if self.impl.transactional_ddl: - if self.as_sql: - self.impl.emit_commit() + if self.impl.transactional_ddl and self.as_sql: + self.impl.emit_commit() - elif _in_connection_transaction: - assert self._transaction is not None + elif _in_connection_transaction: + assert self._transaction is not None - self._transaction.commit() - self._transaction = None + self._transaction.commit() + self._transaction = None if not self.as_sql: current_level = self.connection.get_isolation_level() - self.connection.execution_options(isolation_level="AUTOCOMMIT") + base_connection = self.connection + + # in 1.3 and 1.4 non-future mode, the connection gets switched + # out. we can use the base connection with the new mode + # except that it will not know it's in "autocommit" and will + # emit deprecation warnings when an autocommit action takes + # place. + self.connection = ( + self.impl.connection + ) = base_connection.execution_options(isolation_level="AUTOCOMMIT") try: yield finally: @@ -288,13 +303,13 @@ class MigrationContext(object): self.connection.execution_options( isolation_level=current_level ) + self.connection = self.impl.connection = base_connection - if self.impl.transactional_ddl: - if self.as_sql: - self.impl.emit_begin() + if self.impl.transactional_ddl and self.as_sql: + self.impl.emit_begin() - elif _in_connection_transaction: - self._transaction = self.bind.begin() + elif _in_connection_transaction: + self._transaction = self.connection.begin() def begin_transaction(self, _per_migration=False): """Begin a logical transaction for migration operations. @@ -337,23 +352,50 @@ class MigrationContext(object): :meth:`.MigrationContext.autocommit_block` """ - transaction_now = _per_migration == self._transaction_per_migration - if not transaction_now: + @contextmanager + def do_nothing(): + yield - @contextmanager - def do_nothing(): - yield + if self._in_external_transaction: + return do_nothing() + if self.impl.transactional_ddl: + transaction_now = _per_migration == self._transaction_per_migration + else: + transaction_now = _per_migration is True + + if not transaction_now: return do_nothing() elif not self.impl.transactional_ddl: + assert _per_migration - @contextmanager - def do_nothing(): - yield - - return do_nothing() + if self.as_sql: + return do_nothing() + else: + # track our own notion of a "transaction block", which must be + # committed when complete. Don't rely upon whether or not the + # SQLAlchemy connection reports as "in transaction"; this + # because SQLAlchemy future connection features autobegin + # behavior, so it may already be in a transaction from our + # emitting of queries like "has_version_table", etc. While we + # could track these operations as well, that leaves open the + # possibility of new operations or other things happening in + # the user environment that still may be triggering + # "autobegin". + + in_transaction = self._transaction is not None + + if in_transaction: + return do_nothing() + else: + self._transaction = ( + sqla_compat._safe_begin_connection_transaction( + self.connection + ) + ) + return _ProxyTransaction(self) elif self.as_sql: @contextmanager @@ -364,7 +406,9 @@ class MigrationContext(object): return begin_commit() else: - self._transaction = self.bind.begin() + self._transaction = sqla_compat._safe_begin_connection_transaction( + self.connection + ) return _ProxyTransaction(self) def get_current_revision(self): @@ -439,9 +483,10 @@ class MigrationContext(object): ) def _ensure_version_table(self, purge=False): - self._version.create(self.connection, checkfirst=True) - if purge: - self.connection.execute(self._version.delete()) + with sqla_compat._ensure_scope_for_ddl(self.connection): + self._version.create(self.connection, checkfirst=True) + if purge: + self.connection.execute(self._version.delete()) def _has_version_table(self): return sqla_compat._connectable_has_table( @@ -504,12 +549,9 @@ class MigrationContext(object): head_maintainer = HeadMaintainer(self, heads) - starting_in_transaction = ( - not self.as_sql and self._in_connection_transaction() - ) - for step in self._migrations_fn(heads, self): with self.begin_transaction(_per_migration=True): + if self.as_sql and not head_maintainer.heads: # for offline mode, include a CREATE TABLE from # the base @@ -535,18 +577,6 @@ class MigrationContext(object): run_args=kw, ) - if ( - not starting_in_transaction - and not self.as_sql - and not self.impl.transactional_ddl - and self._in_connection_transaction() - ): - raise util.CommandError( - 'Migration "%s" has left an uncommitted ' - "transaction opened; transactional_ddl is False so " - "Alembic is not committing transactions" % step - ) - if self.as_sql and not head_maintainer.heads: self._version.drop(self.connection) diff --git a/alembic/script/base.py b/alembic/script/base.py index fea9e87..363895c 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -171,7 +171,7 @@ class ScriptDirectory(object): "ancestor/descendant revisions along the same branch" ) ancestor = ancestor % {"start": start, "end": end} - compat.raise_from_cause(util.CommandError(ancestor)) + compat.raise_(util.CommandError(ancestor), from_=rna) except revision.MultipleHeads as mh: if not multiple_heads: multiple_heads = ( @@ -185,15 +185,15 @@ class ScriptDirectory(object): "head_arg": end or mh.argument, "heads": util.format_as_comma(mh.heads), } - compat.raise_from_cause(util.CommandError(multiple_heads)) + compat.raise_(util.CommandError(multiple_heads), from_=mh) except revision.ResolutionError as re: if resolution is None: resolution = "Can't locate revision identified by '%s'" % ( re.argument ) - compat.raise_from_cause(util.CommandError(resolution)) + compat.raise_(util.CommandError(resolution), from_=re) except revision.RevisionError as err: - compat.raise_from_cause(util.CommandError(err.args[0])) + compat.raise_(util.CommandError(err.args[0]), from_=err) def walk_revisions(self, base="base", head="heads"): """Iterate through all revisions. @@ -571,7 +571,7 @@ class ScriptDirectory(object): try: Script.verify_rev_id(revid) except revision.RevisionError as err: - compat.raise_from_cause(util.CommandError(err.args[0])) + compat.raise_(util.CommandError(err.args[0]), from_=err) with self._catch_revision_errors( multiple_heads=( @@ -659,7 +659,7 @@ class ScriptDirectory(object): try: script = Script._from_path(self, path) except revision.RevisionError as err: - compat.raise_from_cause(util.CommandError(err.args[0])) + compat.raise_(util.CommandError(err.args[0]), from_=err) if branch_labels and not script.branch_labels: raise util.CommandError( "Version %s specified branch_labels %s, however the " diff --git a/alembic/script/revision.py b/alembic/script/revision.py index 683d322..c75d1c0 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -422,9 +422,12 @@ class RevisionMap(object): except KeyError: try: nonbranch_rev = self._revision_for_ident(branch_label) - except ResolutionError: - raise ResolutionError( - "No such branch: '%s'" % branch_label, branch_label + except ResolutionError as re: + util.raise_( + ResolutionError( + "No such branch: '%s'" % branch_label, branch_label + ), + from_=re, ) else: return nonbranch_rev diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py index 7d0843b..d6d1d38 100644 --- a/alembic/script/write_hooks.py +++ b/alembic/script/write_hooks.py @@ -39,9 +39,10 @@ def _invoke(name, revision, options): """ try: hook = _registry[name] - except KeyError: - compat.raise_from_cause( - util.CommandError("No formatter with name '%s' registered" % name) + except KeyError as ke: + compat.raise_( + util.CommandError("No formatter with name '%s' registered" % name), + from_=ke, ) else: return hook(revision, options) @@ -65,12 +66,13 @@ def _run_hooks(path, hook_config): opts["_hook_name"] = name try: type_ = opts["type"] - except KeyError: - compat.raise_from_cause( + except KeyError as ke: + compat.raise_( util.CommandError( "Key %s.type is required for post write hook %r" % (name, name) - ) + ), + from_=ke, ) else: util.status( @@ -89,12 +91,13 @@ def console_scripts(path, options): try: entrypoint_name = options["entrypoint"] - except KeyError: - compat.raise_from_cause( + except KeyError as ke: + compat.raise_( util.CommandError( "Key %s.entrypoint is required for post write hook %r" % (options["_hook_name"], options["_hook_name"]) - ) + ), + from_=ke, ) iter_ = pkg_resources.iter_entry_points("console_scripts", entrypoint_name) impl = next(iter_) diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index 23c0f19..5f497a6 100644 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -1,10 +1,12 @@ from sqlalchemy.testing import config # noqa -from sqlalchemy.testing import exclusions # noqa from sqlalchemy.testing import emits_warning # noqa from sqlalchemy.testing import engines # noqa +from sqlalchemy.testing import exclusions # noqa from sqlalchemy.testing import mock # noqa from sqlalchemy.testing import provide_metadata # noqa from sqlalchemy.testing import uses_deprecated # noqa +from sqlalchemy.testing.config import combinations # noqa +from sqlalchemy.testing.config import fixture # noqa from sqlalchemy.testing.config import requirements as requires # noqa from alembic import util # noqa @@ -13,12 +15,14 @@ from .assertions import assert_raises_message # noqa from .assertions import emits_python_deprecation_warning # noqa from .assertions import eq_ # noqa from .assertions import eq_ignore_whitespace # noqa +from .assertions import expect_raises # noqa +from .assertions import expect_raises_message # noqa +from .assertions import expect_sqlalchemy_deprecated # noqa +from .assertions import expect_sqlalchemy_deprecated_20 # noqa from .assertions import is_ # noqa from .assertions import is_false # noqa from .assertions import is_not_ # noqa from .assertions import is_true # noqa from .assertions import ne_ # noqa -from .fixture_functions import combinations # noqa -from .fixture_functions import fixture # noqa from .fixtures import TestBase # noqa from .util import resolve_lambda # noqa diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py index b09e09f..6d39f4c 100644 --- a/alembic/testing/assertions.py +++ b/alembic/testing/assertions.py @@ -1,7 +1,10 @@ from __future__ import absolute_import +import contextlib import re +import sys +from sqlalchemy import exc as sa_exc from sqlalchemy import util from sqlalchemy.engine import default from sqlalchemy.testing.assertions import _expect_warnings @@ -17,27 +20,92 @@ from ..util import sqla_compat from ..util.compat import py3k +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if not util.py3k: + return + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + def assert_raises(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): + + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer(object): + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + ec = _ErrorContainer() + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] try: - callable_(*args, **kw) + yield ec success = False - except except_cls: + except except_cls as err: + ec.error = err success = True + if msg is not None: + assert re.search( + msg, util.text_type(err), re.UNICODE + ), "%r !~ %s" % (msg, err) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(util.text_type(err).encode("utf-8")) # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" -def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): - try: - callable_(*args, **kwargs) - assert False, "Callable did not raise an exception" - except except_cls as e: - assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( - msg, - e, - ) - print(util.text_type(e).encode("utf-8")) +def expect_raises(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) def eq_ignore_whitespace(a, b, msg=None): @@ -106,3 +174,11 @@ def emits_python_deprecation_warning(*messages): return fn(*args, **kw) return decorate + + +def expect_sqlalchemy_deprecated(*messages, **kw): + return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw) + + +def expect_sqlalchemy_deprecated_20(*messages, **kw): + return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw) diff --git a/alembic/testing/env.py b/alembic/testing/env.py index 473c73e..62b74ec 100644 --- a/alembic/testing/env.py +++ b/alembic/testing/env.py @@ -4,9 +4,10 @@ import os import shutil import textwrap -from sqlalchemy.testing import engines +from sqlalchemy.testing import config from sqlalchemy.testing import provision +from . import util as testing_util from .. import util from ..script import Script from ..script import ScriptDirectory @@ -93,25 +94,28 @@ config = context.config f.write(txt) -def _sqlite_file_db(tempname="foo.db"): +def _sqlite_file_db(tempname="foo.db", future=False): dir_ = os.path.join(_get_staging_directory(), "scripts") url = "sqlite:///%s/%s" % (dir_, tempname) - return engines.testing_engine(url=url) + return testing_util.testing_engine(url=url, future=future) -def _sqlite_testing_config(sourceless=False): +def _sqlite_testing_config(sourceless=False, future=False): dir_ = os.path.join(_get_staging_directory(), "scripts") url = "sqlite:///%s/foo.db" % dir_ + sqlalchemy_future = future or ("future" in config.db.__class__.__module__) + return _write_config_file( """ [alembic] script_location = %s sqlalchemy.url = %s sourceless = %s +%s [loggers] -keys = root +keys = root,sqlalchemy [handlers] keys = console @@ -121,6 +125,11 @@ level = WARN handlers = console qualname = +[logger_sqlalchemy] +level = DEBUG +handlers = +qualname = sqlalchemy.engine + [handler_console] class = StreamHandler args = (sys.stderr,) @@ -134,12 +143,19 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ - % (dir_, url, "true" if sourceless else "false") + % ( + dir_, + url, + "true" if sourceless else "false", + "sqlalchemy.future = true" if sqlalchemy_future else "", + ) ) def _multi_dir_testing_config(sourceless=False, extra_version_location=""): dir_ = os.path.join(_get_staging_directory(), "scripts") + sqlalchemy_future = "future" in config.db.__class__.__module__ + url = "sqlite:///%s/foo.db" % dir_ return _write_config_file( @@ -147,6 +163,7 @@ def _multi_dir_testing_config(sourceless=False, extra_version_location=""): [alembic] script_location = %s sqlalchemy.url = %s +sqlalchemy.future = %s sourceless = %s version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s @@ -177,6 +194,7 @@ datefmt = %%H:%%M:%%S % ( dir_, url, + "true" if sqlalchemy_future else "false", "true" if sourceless else "false", extra_version_location, ) @@ -463,6 +481,8 @@ def _multidb_testing_config(engines): dir_ = os.path.join(_get_staging_directory(), "scripts") + sqlalchemy_future = "future" in config.db.__class__.__module__ + databases = ", ".join(engines.keys()) engines = "\n\n".join( "[%s]\n" "sqlalchemy.url = %s" % (key, value.url) @@ -474,7 +494,7 @@ def _multidb_testing_config(engines): [alembic] script_location = %s sourceless = false - +sqlalchemy.future = %s databases = %s %s @@ -502,5 +522,5 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ - % (dir_, databases, engines) + % (dir_, "true" if sqlalchemy_future else "false", databases, engines) ) diff --git a/alembic/testing/fixture_functions.py b/alembic/testing/fixture_functions.py deleted file mode 100644 index 2640693..0000000 --- a/alembic/testing/fixture_functions.py +++ /dev/null @@ -1,79 +0,0 @@ -_fixture_functions = None # installed by plugin_base - - -def combinations(*comb, **kw): - r"""Deliver multiple versions of a test based on positional combinations. - - This is a facade over pytest.mark.parametrize. - - - :param \*comb: argument combinations. These are tuples that will be passed - positionally to the decorated function. - - :param argnames: optional list of argument names. These are the names - of the arguments in the test function that correspond to the entries - in each argument tuple. pytest.mark.parametrize requires this, however - the combinations function will derive it automatically if not present - by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the - first argument is "self" which is discarded. - - :param id\_: optional id template. This is a string template that - describes how the "id" for each parameter set should be defined, if any. - The number of characters in the template should match the number of - entries in each argument tuple. Each character describes how the - corresponding entry in the argument tuple should be handled, as far as - whether or not it is included in the arguments passed to the function, as - well as if it is included in the tokens used to create the id of the - parameter set. - - If omitted, the argument combinations are passed to parametrize as is. If - passed, each argument combination is turned into a pytest.param() object, - mapping the elements of the argument tuple to produce an id based on a - character value in the same position within the string template using the - following scheme:: - - i - the given argument is a string that is part of the id only, don't - pass it as an argument - - n - the given argument should be passed and it should be added to the - id by calling the .__name__ attribute - - r - the given argument should be passed and it should be added to the - id by calling repr() - - s - the given argument should be passed and it should be added to the - id by calling str() - - a - (argument) the given argument should be passed and it should not - be used to generated the id - - e.g.:: - - @testing.combinations( - (operator.eq, "eq"), - (operator.ne, "ne"), - (operator.gt, "gt"), - (operator.lt, "lt"), - id_="na" - ) - def test_operator(self, opfunc, name): - pass - - The above combination will call ``.__name__`` on the first member of - each tuple and use that as the "id" to pytest.param(). - - - """ - return _fixture_functions.combinations(*comb, **kw) - - -def fixture(*arg, **kw): - return _fixture_functions.fixture(*arg, **kw) - - -def get_current_test_name(): - return _fixture_functions.get_current_test_name() - - -def skip_test(msg): - raise _fixture_functions.skip_test_exception(msg) diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index dd1d2f1..d5d45ac 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -8,11 +8,13 @@ from sqlalchemy import inspect from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.testing import config from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import eq_ -from sqlalchemy.testing.fixtures import TestBase # noqa +from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest +from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase import alembic from .assertions import _get_dialect @@ -20,16 +22,53 @@ from ..environment import EnvironmentContext from ..migration import MigrationContext from ..operations import Operations from ..util import compat +from ..util import sqla_compat from ..util.compat import configparser from ..util.compat import string_types from ..util.compat import text_type from ..util.sqla_compat import create_mock_engine from ..util.sqla_compat import sqla_14 + testing_config = configparser.ConfigParser() testing_config.read(["test.cfg"]) +class TestBase(SQLAlchemyTestBase): + is_sqlalchemy_future = False + + @testing.fixture() + def ops_context(self, migration_context): + with migration_context.begin_transaction(_per_migration=True): + yield Operations(migration_context) + + @testing.fixture + def migration_context(self, connection): + return MigrationContext.configure( + connection, opts=dict(transaction_per_migration=True) + ) + + @testing.fixture + def connection(self): + with config.db.connect() as conn: + yield conn + + +class TablesTest(TestBase, SQLAlchemyTablesTest): + pass + + +if sqla_14: + from sqlalchemy.testing.fixtures import FutureEngineMixin +else: + + class FutureEngineMixin(object): + __requires__ = ("sqlalchemy_14",) + + +FutureEngineMixin.is_sqlalchemy_future = True + + def capture_db(dialect="postgresql://"): buf = [] @@ -205,7 +244,8 @@ class AlterColRoundTripFixture(object): ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2) def tearDown(self): - self.metadata.drop_all(self.conn) + with self.conn.begin(): + self.metadata.drop_all(self.conn) self.conn.close() def _run_alter_col(self, from_, to_, compare=None): @@ -218,26 +258,27 @@ class AlterColRoundTripFixture(object): ) t = Table("x", self.metadata, column) - t.create(self.conn) - insp = inspect(self.conn) - old_col = insp.get_columns("x")[0] - - # TODO: conditional comment support - self.op.alter_column( - "x", - column.name, - existing_type=column.type, - existing_server_default=column.server_default - if column.server_default is not None - else False, - existing_nullable=True if column.nullable else False, - # existing_comment=column.comment, - nullable=to_.get("nullable", None), - # modify_comment=False, - server_default=to_.get("server_default", False), - new_column_name=to_.get("name", None), - type_=to_.get("type", None), - ) + with sqla_compat._ensure_scope_for_ddl(self.conn): + t.create(self.conn) + insp = inspect(self.conn) + old_col = insp.get_columns("x")[0] + + # TODO: conditional comment support + self.op.alter_column( + "x", + column.name, + existing_type=column.type, + existing_server_default=column.server_default + if column.server_default is not None + else False, + existing_nullable=True if column.nullable else False, + # existing_comment=column.comment, + nullable=to_.get("nullable", None), + # modify_comment=False, + server_default=to_.get("server_default", False), + new_column_name=to_.get("name", None), + type_=to_.get("type", None), + ) insp = inspect(self.conn) new_col = insp.get_columns("x")[0] diff --git a/alembic/testing/plugin/bootstrap.py b/alembic/testing/plugin/bootstrap.py index 8200ec1..d4a2c55 100644 --- a/alembic/testing/plugin/bootstrap.py +++ b/alembic/testing/plugin/bootstrap.py @@ -1,35 +1,4 @@ """ Bootstrapper for test framework plugins. -This is vendored from SQLAlchemy so that we can use local overrides -for plugin_base.py and pytestplugin.py. - """ - - -import os -import sys - - -bootstrap_file = locals()["bootstrap_file"] -to_bootstrap = locals()["to_bootstrap"] - - -def load_file_as_module(name): - path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) - if sys.version_info >= (3, 3): - from importlib import machinery - - mod = machinery.SourceFileLoader(name, path).load_module() - else: - import imp - - mod = imp.load_source(name, path) - return mod - - -if to_bootstrap == "pytest": - sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") - sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") -else: - raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py deleted file mode 100644 index 2d5e95a..0000000 --- a/alembic/testing/plugin/plugin_base.py +++ /dev/null @@ -1,125 +0,0 @@ -"""vendored plugin_base functions from the most recent SQLAlchemy versions. - -Alembic tests need to run on older versions of SQLAlchemy that don't -necessarily have all the latest testing fixtures. - -""" -from __future__ import absolute_import - -import abc -import sys - -from sqlalchemy.testing.plugin.plugin_base import * # noqa -from sqlalchemy.testing.plugin.plugin_base import post -from sqlalchemy.testing.plugin.plugin_base import post_begin as sqla_post_begin -from sqlalchemy.testing.plugin.plugin_base import stop_test_class as sqla_stc - -py3k = sys.version_info >= (3, 0) - - -if py3k: - - ABC = abc.ABC -else: - - class ABC(object): - __metaclass__ = abc.ABCMeta - - -def post_begin(): - sqla_post_begin() - - import warnings - - try: - import pytest - except ImportError: - pass - else: - warnings.filterwarnings( - "once", category=pytest.PytestDeprecationWarning - ) - - from sqlalchemy import exc - - if hasattr(exc, "RemovedIn20Warning"): - warnings.filterwarnings( - "error", - category=exc.RemovedIn20Warning, - message=".*Engine.execute", - ) - warnings.filterwarnings( - "error", - category=exc.RemovedIn20Warning, - message=".*Passing a string", - ) - - -# override selected SQLAlchemy pytest hooks with vendored functionality -def stop_test_class(cls): - sqla_stc(cls) - import os - from alembic.testing.env import _get_staging_directory - - assert not os.path.exists(_get_staging_directory()), ( - "staging directory %s was not cleaned up" % _get_staging_directory() - ) - - -def want_class(name, cls): - from sqlalchemy.testing import config - from sqlalchemy.testing import fixtures - - if not issubclass(cls, fixtures.TestBase): - return False - elif name.startswith("_"): - return False - elif ( - config.options.backend_only - and not getattr(cls, "__backend__", False) - and not getattr(cls, "__sparse_backend__", False) - ): - return False - else: - return True - - -@post -def _init_symbols(options, file_config): - from sqlalchemy.testing import config - from alembic.testing import fixture_functions as alembic_config - - config._fixture_functions = ( - alembic_config._fixture_functions - ) = _fixture_fn_class() - - -class FixtureFunctions(ABC): - @abc.abstractmethod - def skip_test_exception(self, *arg, **kw): - raise NotImplementedError() - - @abc.abstractmethod - def combinations(self, *args, **kw): - raise NotImplementedError() - - @abc.abstractmethod - def param_ident(self, *args, **kw): - raise NotImplementedError() - - @abc.abstractmethod - def fixture(self, *arg, **kw): - raise NotImplementedError() - - def get_current_test_name(self): - raise NotImplementedError() - - -_fixture_fn_class = None - - -def set_fixture_functions(fixture_fn_class): - from sqlalchemy.testing.plugin import plugin_base - - global _fixture_fn_class - _fixture_fn_class = plugin_base._fixture_fn_class = fixture_fn_class diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py deleted file mode 100644 index 6b76a17..0000000 --- a/alembic/testing/plugin/pytestplugin.py +++ /dev/null @@ -1,314 +0,0 @@ -"""vendored pytestplugin functions from the most recent SQLAlchemy versions. - -Alembic tests need to run on older versions of SQLAlchemy that don't -necessarily have all the latest testing fixtures. - -""" -try: - # installed by bootstrap.py - import sqla_plugin_base as plugin_base -except ImportError: - # assume we're a package, use traditional import - from . import plugin_base - -from functools import update_wrapper -import inspect -import itertools -import operator -import os -import re -import sys - -import pytest -from sqlalchemy.testing.plugin.pytestplugin import * # noqa -from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc - -py3k = sys.version_info.major >= 3 - -if py3k: - from typing import TYPE_CHECKING -else: - TYPE_CHECKING = False - -if TYPE_CHECKING: - from typing import Sequence - - -# override selected SQLAlchemy pytest hooks with vendored functionality -def pytest_configure(config): - spc(config) - - plugin_base.set_fixture_functions(PytestFixtureFunctions) - - -def pytest_pycollect_makeitem(collector, name, obj): - - if inspect.isclass(obj) and plugin_base.want_class(name, obj): - ctor = getattr(pytest.Class, "from_parent", pytest.Class) - - return [ - ctor(name=parametrize_cls.__name__, parent=collector) - for parametrize_cls in _parametrize_cls(collector.module, obj) - ] - elif ( - inspect.isfunction(obj) - and isinstance(collector, pytest.Instance) - and plugin_base.want_method(collector.cls, obj) - ): - # None means, fall back to default logic, which includes - # method-level parametrize - return None - else: - # empty list means skip this item - return [] - - -_current_class = None - - -def _parametrize_cls(module, cls): - """implement a class-based version of pytest parametrize.""" - - if "_sa_parametrize" not in cls.__dict__: - return [cls] - - _sa_parametrize = cls._sa_parametrize - classes = [] - for full_param_set in itertools.product( - *[params for argname, params in _sa_parametrize] - ): - cls_variables = {} - - for argname, param in zip( - [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set - ): - if not argname: - raise TypeError("need argnames for class-based combinations") - argname_split = re.split(r",\s*", argname) - for arg, val in zip(argname_split, param.values): - cls_variables[arg] = val - parametrized_name = "_".join( - # token is a string, but in py2k py.test is giving us a unicode, - # so call str() on it. - str(re.sub(r"\W", "", token)) - for param in full_param_set - for token in param.id.split("-") - ) - name = "%s_%s" % (cls.__name__, parametrized_name) - newcls = type.__new__(type, name, (cls,), cls_variables) - setattr(module, name, newcls) - classes.append(newcls) - return classes - - -def getargspec(fn): - if sys.version_info.major == 3: - return inspect.getfullargspec(fn) - else: - return inspect.getargspec(fn) - - -def _pytest_fn_decorator(target): - """Port of langhelpers.decorator with pytest-specific tricks.""" - # from sqlalchemy rel_1_3_14 - - from sqlalchemy.util.langhelpers import format_argspec_plus - from sqlalchemy.util.compat import inspect_getfullargspec - - def _exec_code_in_env(code, env, fn_name): - exec(code, env) - return env[fn_name] - - def decorate(fn, add_positional_parameters=()): - - spec = inspect_getfullargspec(fn) - if add_positional_parameters: - spec.args.extend(add_positional_parameters) - - metadata = dict(target="target", fn="__fn", name=fn.__name__) - metadata.update(format_argspec_plus(spec, grouped=False)) - code = ( - """\ -def %(name)s(%(args)s): - return %(target)s(%(fn)s, %(apply_kw)s) -""" - % metadata - ) - decorated = _exec_code_in_env( - code, {"target": target, "__fn": fn}, fn.__name__ - ) - if not add_positional_parameters: - decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ - decorated.__wrapped__ = fn - return update_wrapper(decorated, fn) - else: - # this is the pytest hacky part. don't do a full update wrapper - # because pytest is really being sneaky about finding the args - # for the wrapped function - decorated.__module__ = fn.__module__ - decorated.__name__ = fn.__name__ - return decorated - - return decorate - - -class PytestFixtureFunctions(plugin_base.FixtureFunctions): - def skip_test_exception(self, *arg, **kw): - return pytest.skip.Exception(*arg, **kw) - - _combination_id_fns = { - "i": lambda obj: obj, - "r": repr, - "s": str, - "n": operator.attrgetter("__name__"), - } - - def combinations(self, *arg_sets, **kw): - """Facade for pytest.mark.parametrize. - - Automatically derives argument names from the callable which in our - case is always a method on a class with positional arguments. - - ids for parameter sets are derived using an optional template. - - """ - # from sqlalchemy rel_1_3_14 - from alembic.testing import exclusions - - if sys.version_info.major == 3: - if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"): - arg_sets = list(arg_sets[0]) - else: - if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"): - arg_sets = list(arg_sets[0]) - - argnames = kw.pop("argnames", None) - - def _filter_exclusions(args): - result = [] - gathered_exclusions = [] - for a in args: - if isinstance(a, exclusions.compound): - gathered_exclusions.append(a) - else: - result.append(a) - - return result, gathered_exclusions - - id_ = kw.pop("id_", None) - - tobuild_pytest_params = [] - has_exclusions = False - if id_: - _combination_id_fns = self._combination_id_fns - - # because itemgetter is not consistent for one argument vs. - # multiple, make it multiple in all cases and use a slice - # to omit the first argument - _arg_getter = operator.itemgetter( - 0, - *[ - idx - for idx, char in enumerate(id_) - if char in ("n", "r", "s", "a") - ] - ) - fns = [ - (operator.itemgetter(idx), _combination_id_fns[char]) - for idx, char in enumerate(id_) - if char in _combination_id_fns - ] - - for arg in arg_sets: - if not isinstance(arg, tuple): - arg = (arg,) - - fn_params, param_exclusions = _filter_exclusions(arg) - - parameters = _arg_getter(fn_params)[1:] - - if param_exclusions: - has_exclusions = True - - tobuild_pytest_params.append( - ( - parameters, - param_exclusions, - "-".join( - comb_fn(getter(arg)) for getter, comb_fn in fns - ), - ) - ) - - else: - - for arg in arg_sets: - if not isinstance(arg, tuple): - arg = (arg,) - - fn_params, param_exclusions = _filter_exclusions(arg) - - if param_exclusions: - has_exclusions = True - - tobuild_pytest_params.append( - (fn_params, param_exclusions, None) - ) - - pytest_params = [] - for parameters, param_exclusions, id_ in tobuild_pytest_params: - if has_exclusions: - parameters += (param_exclusions,) - - param = pytest.param(*parameters, id=id_) - pytest_params.append(param) - - def decorate(fn): - if inspect.isclass(fn): - if has_exclusions: - raise NotImplementedError( - "exclusions not supported for class level combinations" - ) - if "_sa_parametrize" not in fn.__dict__: - fn._sa_parametrize = [] - fn._sa_parametrize.append((argnames, pytest_params)) - return fn - else: - if argnames is None: - _argnames = getargspec(fn).args[1:] # type: Sequence(str) - else: - _argnames = re.split( - r", *", argnames - ) # type: Sequence(str) - - if has_exclusions: - _argnames += ["_exclusions"] - - @_pytest_fn_decorator - def check_exclusions(fn, *args, **kw): - _exclusions = args[-1] - if _exclusions: - exlu = exclusions.compound().add(*_exclusions) - fn = exlu(fn) - return fn(*args[0:-1], **kw) - - def process_metadata(spec): - spec.args.append("_exclusions") - - fn = check_exclusions( - fn, add_positional_parameters=("_exclusions",) - ) - - return pytest.mark.parametrize(_argnames, pytest_params)(fn) - - return decorate - - def param_ident(self, *parameters): - ident = parameters[0] - return pytest.param(*parameters[1:], id=ident) - - def fixture(self, *arg, **kw): - return pytest.fixture(*arg, **kw) - - def get_current_test_name(self): - return os.environ.get("PYTEST_CURRENT_TEST") diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 456002d..2de8a71 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -44,6 +44,15 @@ class SuiteRequirements(Requirements): return exclusions.skip_if(doesnt_have_check_uq_constraints) @property + def sequences(self): + """Target database must support SEQUENCEs.""" + + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) + + @property def foreign_key_match(self): return exclusions.open() diff --git a/alembic/testing/util.py b/alembic/testing/util.py index 3e76645..ccabf9c 100644 --- a/alembic/testing/util.py +++ b/alembic/testing/util.py @@ -95,3 +95,13 @@ def metadata_fixture(ddl="function"): return fixture_functions.fixture(scope=ddl)(run_ddl) return decorate + + +def testing_engine(url=None, options=None, future=False): + from sqlalchemy.testing import config + from sqlalchemy.testing.engines import testing_engine + + if not future: + future = getattr(config._current.options, "future_engine", False) + kw = {"future": future} if future else {} + return testing_engine(url, options, **kw) diff --git a/alembic/testing/warnings.py b/alembic/testing/warnings.py new file mode 100644 index 0000000..0182032 --- /dev/null +++ b/alembic/testing/warnings.py @@ -0,0 +1,48 @@ +# testing/warnings.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import warnings + +from sqlalchemy import exc as sa_exc + + +def setup_filters(): + """Set global warning behavior for the test suite.""" + + warnings.resetwarnings() + + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) + + # some selected deprecations... + warnings.filterwarnings("error", category=DeprecationWarning) + try: + import pytest + except ImportError: + pass + else: + warnings.filterwarnings( + "once", category=pytest.PytestDeprecationWarning + ) + + if hasattr(sa_exc, "RemovedIn20Warning"): + for msg in [ + # + # Core execution - need to remove this after SQLAlchemy + # repairs it in provisioning + # + r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept " + "parameters as a single dictionary or a single sequence of " + "dictionaries only.", + ]: + warnings.filterwarnings( + "ignore", + message=msg, + category=sa_exc.RemovedIn20Warning, + ) diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 141ba45..cfdab49 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -1,4 +1,4 @@ -from .compat import raise_from_cause # noqa +from .compat import raise_ # noqa from .exc import CommandError from .langhelpers import _with_legacy_names # noqa from .langhelpers import asbool # noqa diff --git a/alembic/util/compat.py b/alembic/util/compat.py index f5a04ef..c8919b6 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -261,34 +261,54 @@ def with_metaclass(meta, base=object): if py3k: - def reraise(tp, value, tb=None, cause=None): - if cause is not None: - value.__cause__ = cause - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value + def raise_( + exception, with_traceback=None, replace_context=None, from_=False + ): + r"""implement "raise" with cause support. + + :param exception: exception to raise + :param with_traceback: will call exception.with_traceback() + :param replace_context: an as-yet-unsupported feature. This is + an exception object which we are "replacing", e.g., it's our + "cause" but we don't want it printed. Basically just what + ``__suppress_context__`` does but we don't want to suppress + the enclosing context, if any. So for now we make it the + cause. + :param from\_: the cause. this actually sets the cause and doesn't + hope to hide it someday. - def raise_from_cause(exception, exc_info=None): - if exc_info is None: - exc_info = sys.exc_info() - exc_type, exc_value, exc_tb = exc_info - reraise(type(exception), exception, tb=exc_tb, cause=exc_value) + """ + if with_traceback is not None: + exception = exception.with_traceback(with_traceback) + + if from_ is not False: + exception.__cause__ = from_ + elif replace_context is not None: + # no good solution here, we would like to have the exception + # have only the context of replace_context.__context__ so that the + # intermediary exception does not change, but we can't figure + # that out. + exception.__cause__ = replace_context + + try: + raise exception + finally: + # credit to + # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/ + # as the __traceback__ object creates a cycle + del exception, replace_context, from_, with_traceback else: exec( - "def reraise(tp, value, tb=None, cause=None):\n" - " raise tp, value, tb\n" + "def raise_(exception, with_traceback=None, replace_context=None, " + "from_=False):\n" + " if with_traceback:\n" + " raise type(exception), exception, with_traceback\n" + " else:\n" + " raise exception\n" ) - def raise_from_cause(exception, exc_info=None): - # not as nice as that of Py3K, but at least preserves - # the code line where the issue occurred - if exc_info is None: - exc_info = sys.exc_info() - exc_type, exc_value, exc_tb = exc_info - reraise(type(exception), exception, tb=exc_tb) - # produce a wrapper that allows encoded text to stream # into a given buffer, but doesn't close it. diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index bb9c8f5..cc07f4b 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -7,6 +7,7 @@ from .compat import callable from .compat import collections_abc from .compat import exec_ from .compat import inspect_getargspec +from .compat import raise_ from .compat import string_types from .compat import with_metaclass @@ -74,13 +75,16 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): def _create_method_proxy(cls, name, globals_, locals_): fn = getattr(cls, name) - def _name_error(name): - raise NameError( - "Can't invoke function '%s', as the proxy object has " - "not yet been " - "established for the Alembic '%s' class. " - "Try placing this code inside a callable." - % (name, cls.__name__) + def _name_error(name, from_): + raise_( + NameError( + "Can't invoke function '%s', as the proxy object has " + "not yet been " + "established for the Alembic '%s' class. " + "Try placing this code inside a callable." + % (name, cls.__name__) + ), + from_=from_, ) globals_["_name_error"] = _name_error @@ -142,8 +146,8 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)): %(translate)s try: p = _proxy - except NameError: - _name_error('%(name)s') + except NameError as ne: + _name_error('%(name)s', ne) return _proxy.%(name)s(%(apply_kw)s) e """ diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py index b65df2c..8dfb9db 100644 --- a/alembic/util/pyfiles.py +++ b/alembic/util/pyfiles.py @@ -10,6 +10,7 @@ from .compat import has_pep3147 from .compat import load_module_py from .compat import load_module_pyc from .compat import py3k +from .compat import raise_ from .exc import CommandError @@ -82,7 +83,7 @@ def edit(path): try: editor.edit(path) except Exception as exc: - raise CommandError("Error executing editor (%s)" % (exc,)) + raise_(CommandError("Error executing editor (%s)" % (exc,)), from_=exc) def load_python_file(dir_, filename): diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 159d0f0..29c2519 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -1,3 +1,4 @@ +import contextlib import re from sqlalchemy import __version__ @@ -64,6 +65,46 @@ else: AUTOINCREMENT_DEFAULT = "auto" +@contextlib.contextmanager +def _ensure_scope_for_ddl(connection): + try: + in_transaction = connection.in_transaction + except AttributeError: + # catch for MockConnection + yield + else: + if not in_transaction(): + with connection.begin(): + yield + else: + yield + + +def _safe_begin_connection_transaction(connection): + transaction = _get_connection_transaction(connection) + if transaction: + return transaction + else: + return connection.begin() + + +def _get_connection_in_transaction(connection): + try: + in_transaction = connection.in_transaction + except AttributeError: + # catch for MockConnection + return False + else: + return in_transaction() + + +def _get_connection_transaction(connection): + if sqla_14: + return connection.get_transaction() + else: + return connection._Connection__transaction + + def _create_url(*arg, **kw): if hasattr(url.URL, "create"): return url.URL.create(*arg, **kw) @@ -314,8 +355,16 @@ def _mariadb_normalized_version_info(mysql_dialect): return mysql_dialect._mariadb_normalized_version_info +def _insert_inline(table): + if sqla_14: + return table.insert().inline() + else: + return table.insert(inline=True) + + if sqla_14: from sqlalchemy import create_mock_engine + from sqlalchemy import select as _select else: from sqlalchemy import create_engine @@ -323,3 +372,6 @@ else: return create_engine( "postgresql://", strategy="mock", executor=executor ) + + def _select(*columns): + return sql.select(list(columns)) diff --git a/docs/build/autogenerate.rst b/docs/build/autogenerate.rst index 7072aa3..46dde52 100644 --- a/docs/build/autogenerate.rst +++ b/docs/build/autogenerate.rst @@ -465,7 +465,7 @@ are being used:: Above, ``inspected_column`` is a :class:`sqlalchemy.schema.Column` as returned by -:meth:`sqlalchemy.engine.reflection.Inspector.reflecttable`, whereas +:meth:`sqlalchemy.engine.reflection.Inspector.reflect_table`, whereas ``metadata_column`` is a :class:`sqlalchemy.schema.Column` from the local model environment. A return value of ``None`` indicates that default type comparison to proceed. diff --git a/docs/build/unreleased/autocommit.rst b/docs/build/unreleased/autocommit.rst new file mode 100644 index 0000000..39d6098 --- /dev/null +++ b/docs/build/unreleased/autocommit.rst @@ -0,0 +1,21 @@ +.. change:: + :tags: change, environment + + To accommodate SQLAlchemy 1.4 and 2.0, the migration model now no longer + assumes that the SQLAlchemy Connection will autocommit an individual + operation. This essentially means that for databases that use + non-transactional DDL (pysqlite current driver behavior, MySQL), there is + still a BEGIN/COMMIT block that will surround each individual migration. + Databases that support transactional DDL should continue to have the + same flow, either per migration or per-entire run, depending on the + value of the :paramref:`.Environment.configure.transaction_per_migration` + flag. + + +.. change:: + :tags: change, environment + + It now raises a :class:`.CommandError` if a ``sqlalchemy.engine.Engine`` + is passed to the :meth:`.MigrationContext.configure` method instead of + a ``sqlalchemy.engine.Connection`` object. Previously, this would + be a warning only.
\ No newline at end of file @@ -1,3 +1,66 @@ +[metadata] + +name = alembic + +# version comes from setup.py; setuptools +# can't read the "attr:" here without importing +# until version 47.0.0 which is too recent + + +description = A database migration tool for SQLAlchemy. +long_description = file: README.rst +long_description_content_type = text/x-rst +url=https://alembic.sqlalchemy.org +author = Mike Bayer +author_email = mike_mp@zzzcomputing.com +license = MIT +license_file = LICENSE + + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + Environment :: Console + License :: OSI Approved :: MIT License + Operating System :: OS Independent + Programming Language :: Python + Programming Language :: Python :: 2 + Programming Language :: Python :: 2.7 + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: Implementation :: CPython + Programming Language :: Python :: Implementation :: PyPy + Topic :: Database :: Front-Ends + +[options] +packages = find: +include_package_data = true +zip_safe = false +python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.* +package_dir = + =. + +install_requires = + SQLAlchemy>=1.3.0 + Mako + python-editor>=0.3 + python-dateutil + +[options.packages.find] +exclude = + test* + examples* + +[options.exclude_package_data] +'' = test* + +[options.entry_points] +console_scripts = + alembic = alembic.config:main + [egg_info] tag_build=dev @@ -40,8 +103,9 @@ default=sqlite:///:memory: sqlite=sqlite:///:memory: sqlite_file=sqlite:///querytest.db postgresql=postgresql://scott:tiger@127.0.0.1:5432/test -mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8 -mssql=mssql+pyodbc://scott:tiger@ms_2008 +mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 +mariadb = mariadb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 +mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server oracle=oracle://scott:tiger@127.0.0.1:1521 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 @@ -49,7 +113,7 @@ oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 [tool:pytest] -addopts= --tb native -v -r fxX -p no:warnings -p no:logging --maxfail=25 +addopts= --tb native -v -r sfxX -p no:warnings -p no:logging --maxfail=25 python_files=tests/test_*.py @@ -2,7 +2,6 @@ import os import re import sys -from setuptools import find_packages from setuptools import setup from setuptools.command.test import test as TestCommand @@ -16,16 +15,6 @@ VERSION = ( v.close() -readme = os.path.join(os.path.dirname(__file__), "README.rst") - -requires = [ - "SQLAlchemy>=1.3.0", - "Mako", - "python-editor>=0.3", - "python-dateutil", -] - - class UseTox(TestCommand): RED = 31 RESET_SEQ = "\033[0m" @@ -42,40 +31,6 @@ class UseTox(TestCommand): setup( - name="alembic", version=VERSION, - description="A database migration tool for SQLAlchemy.", - long_description=open(readme).read(), - python_requires=( - ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" - ), - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Console", - "License :: OSI Approved :: MIT License", - "Intended Audience :: Developers", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Topic :: Database :: Front-Ends", - ], - keywords="SQLAlchemy migrations", - author="Mike Bayer", - author_email="mike@zzzcomputing.com", - url="https://alembic.sqlalchemy.org", - project_urls={"Issue Tracker": "https://github.com/sqlalchemy/alembic/"}, - license="MIT", - packages=find_packages(".", exclude=["examples*", "test*"]), - include_package_data=True, cmdclass={"test": UseTox}, - zip_safe=False, - install_requires=requires, - entry_points={"console_scripts": ["alembic = alembic.config:main"]}, ) diff --git a/tests/conftest.py b/tests/conftest.py index a83dff5..325bb45 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,8 @@ import os import pytest +os.environ["SQLALCHEMY_WARN_20"] = "true" + pytest.register_assert_rewrite("sqlalchemy.testing.assertions") @@ -32,4 +34,12 @@ with open(bootstrap_file) as f: code = compile(f.read(), "bootstrap.py", "exec") to_bootstrap = "pytest" exec(code, globals(), locals()) - from pytestplugin import * # noqa + from sqlalchemy.testing.plugin.pytestplugin import * # noqa + + wrap_pytest_sessionstart = pytest_sessionstart # noqa + + def pytest_sessionstart(session): + wrap_pytest_sessionstart(session) + from alembic.testing import warnings + + warnings.setup_filters() diff --git a/tests/requirements.py b/tests/requirements.py index 830d4de..8c81889 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -271,3 +271,9 @@ class DefaultRequirements(SuiteRequirements): @property def supports_identity_on_null(self): return self.identity_columns + exclusions.only_on(["oracle"]) + + @property + def legacy_engine(self): + return exclusions.only_if( + lambda config: not getattr(config.db, "_is_future", False) + ) diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index 94546ff..943e61a 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -14,9 +14,9 @@ from sqlalchemy import UniqueConstraint from alembic.testing import assertions from alembic.testing import config -from alembic.testing import engines from alembic.testing import eq_ from alembic.testing import TestBase +from alembic.testing import util from alembic.testing.env import staging_env from alembic.util import sqla_compat from ._autogen_fixtures import AutogenFixtureTest @@ -29,7 +29,7 @@ class NoUqReflection(object): def setUp(self): staging_env() - self.bind = eng = engines.testing_engine() + self.bind = eng = util.testing_engine() def unimpl(*arg, **kw): raise NotImplementedError() @@ -1508,7 +1508,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase): class TruncatedIdxTest(AutogenFixtureTest, TestBase): def setUp(self): - self.bind = engines.testing_engine() + self.bind = util.testing_engine() self.bind.dialect.max_identifier_length = 30 def test_idx_matches_long(self): diff --git a/tests/test_batch.py b/tests/test_batch.py index c9785f2..23ab364 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -24,7 +24,6 @@ from sqlalchemy.dialects import sqlite as sqlite_dialect from sqlalchemy.schema import CreateIndex from sqlalchemy.schema import CreateTable from sqlalchemy.sql import column -from sqlalchemy.sql import select from sqlalchemy.sql import text from alembic.ddl import sqlite @@ -40,6 +39,7 @@ from alembic.testing import mock from alembic.testing import TestBase from alembic.testing.fixtures import op_fixture from alembic.util import exc as alembic_exc +from alembic.util.sqla_compat import _select from alembic.util.sqla_compat import sqla_14 @@ -851,8 +851,10 @@ class BatchApplyTest(TestBase): class BatchAPITest(TestBase): @contextmanager def _fixture(self, schema=None): + migration_context = mock.Mock( - opts={}, impl=mock.MagicMock(__dialect__="sqlite") + opts={}, + impl=mock.MagicMock(__dialect__="sqlite", connection=object()), ) op = Operations(migration_context) batch = op.batch_alter_table( @@ -1256,90 +1258,105 @@ class BatchRoundTripTest(TestBase): Column("x", Integer), mysql_engine="InnoDB", ) - t1.create(self.conn) + with self.conn.begin(): + t1.create(self.conn) - self.conn.execute( - t1.insert(), - [ - {"id": 1, "data": "d1", "x": 5}, - {"id": 2, "data": "22", "x": 6}, - {"id": 3, "data": "8.5", "x": 7}, - {"id": 4, "data": "9.46", "x": 8}, - {"id": 5, "data": "d5", "x": 9}, - ], - ) + self.conn.execute( + t1.insert(), + [ + {"id": 1, "data": "d1", "x": 5}, + {"id": 2, "data": "22", "x": 6}, + {"id": 3, "data": "8.5", "x": 7}, + {"id": 4, "data": "9.46", "x": 8}, + {"id": 5, "data": "d5", "x": 9}, + ], + ) context = MigrationContext.configure(self.conn) self.op = Operations(context) @contextmanager def _sqlite_referential_integrity(self): - self.conn.execute("PRAGMA foreign_keys=ON") + self.conn.exec_driver_sql("PRAGMA foreign_keys=ON") try: yield finally: - self.conn.execute("PRAGMA foreign_keys=OFF") + self.conn.exec_driver_sql("PRAGMA foreign_keys=OFF") + + # as these tests are typically intentional fails, clean out + # tables left over + m = MetaData() + m.reflect(self.conn) + with self.conn.begin(): + m.drop_all(self.conn) def _no_pk_fixture(self): - nopk = Table( - "nopk", - self.metadata, - Column("a", Integer), - Column("b", Integer), - Column("c", Integer), - mysql_engine="InnoDB", - ) - nopk.create(self.conn) - self.conn.execute( - nopk.insert(), [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}] - ) - return nopk + with self.conn.begin(): + nopk = Table( + "nopk", + self.metadata, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), + mysql_engine="InnoDB", + ) + nopk.create(self.conn) + self.conn.execute( + nopk.insert(), + [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}], + ) + return nopk def _table_w_index_fixture(self): - t = Table( - "t_w_ix", - self.metadata, - Column("id", Integer, primary_key=True), - Column("thing", Integer), - Column("data", String(20)), - ) - Index("ix_thing", t.c.thing) - t.create(self.conn) - return t + with self.conn.begin(): + t = Table( + "t_w_ix", + self.metadata, + Column("id", Integer, primary_key=True), + Column("thing", Integer), + Column("data", String(20)), + ) + Index("ix_thing", t.c.thing) + t.create(self.conn) + return t def _boolean_fixture(self): - t = Table( - "hasbool", - self.metadata, - Column("x", Boolean(create_constraint=True, name="ck1")), - Column("y", Integer), - ) - t.create(self.conn) + with self.conn.begin(): + t = Table( + "hasbool", + self.metadata, + Column("x", Boolean(create_constraint=True, name="ck1")), + Column("y", Integer), + ) + t.create(self.conn) def _timestamp_fixture(self): - t = Table("hasts", self.metadata, Column("x", DateTime())) - t.create(self.conn) - return t + with self.conn.begin(): + t = Table("hasts", self.metadata, Column("x", DateTime())) + t.create(self.conn) + return t def _datetime_server_default_fixture(self): return func.datetime("now", "localtime") def _timestamp_w_expr_default_fixture(self): - t = Table( - "hasts", - self.metadata, - Column( - "x", - DateTime(), - server_default=self._datetime_server_default_fixture(), - nullable=False, - ), - ) - t.create(self.conn) - return t + with self.conn.begin(): + t = Table( + "hasts", + self.metadata, + Column( + "x", + DateTime(), + server_default=self._datetime_server_default_fixture(), + nullable=False, + ), + ) + t.create(self.conn) + return t def _int_to_boolean_fixture(self): - t = Table("hasbool", self.metadata, Column("x", Integer)) - t.create(self.conn) + with self.conn.begin(): + t = Table("hasbool", self.metadata, Column("x", Integer)) + t.create(self.conn) def test_change_type_boolean_to_int(self): self._boolean_fixture() @@ -1365,15 +1382,16 @@ class BatchRoundTripTest(TestBase): import datetime - self.conn.execute( - t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)} - ) + with self.conn.begin(): + self.conn.execute( + t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)} + ) with self.op.batch_alter_table("hasts") as batch_op: batch_op.alter_column("x", type_=DateTime()) eq_( - self.conn.execute(select([t.c.x])).fetchall(), + self.conn.execute(_select(t.c.x)).fetchall(), [(datetime.datetime(2012, 5, 18, 15, 32, 5),)], ) @@ -1388,10 +1406,14 @@ class BatchRoundTripTest(TestBase): server_default=self._datetime_server_default_fixture(), ) - self.conn.execute(t.insert()) - - row = self.conn.execute(select([t.c.x])).fetchone() - assert row["x"] is not None + with self.conn.begin(): + self.conn.execute(t.insert()) + res = self.conn.execute(_select(t.c.x)) + if sqla_14: + assert res.scalar_one_or_none() is not None + else: + row = res.fetchone() + assert row["x"] is not None def test_drop_col_schematype(self): self._boolean_fixture() @@ -1429,19 +1451,18 @@ class BatchRoundTripTest(TestBase): ) def tearDown(self): - self.metadata.drop_all(self.conn) + in_t = getattr(self.conn, "in_transaction", lambda: False) + if in_t(): + self.conn.rollback() + with self.conn.begin(): + self.metadata.drop_all(self.conn) self.conn.close() def _assert_data(self, data, tablename="foo"): - eq_( - [ - dict(row) - for row in self.conn.execute( - text("select * from %s" % tablename) - ) - ], - data, - ) + res = self.conn.execute(text("select * from %s" % tablename)) + if sqla_14: + res = res.mappings() + eq_([dict(row) for row in res], data) def test_ix_existing(self): self._table_w_index_fixture() @@ -1486,8 +1507,9 @@ class BatchRoundTripTest(TestBase): Column("foo_id", Integer, ForeignKey("foo.id")), mysql_engine="InnoDB", ) - bar.create(self.conn) - self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3}) + with self.conn.begin(): + bar.create(self.conn) + self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3}) with self.op.batch_alter_table("foo", recreate=recreate) as batch_op: batch_op.alter_column( @@ -1532,9 +1554,14 @@ class BatchRoundTripTest(TestBase): Column("data", String(50)), mysql_engine="InnoDB", ) - bar.create(self.conn) - self.conn.execute(bar.insert(), {"id": 1, "data": "x", "bar_id": None}) - self.conn.execute(bar.insert(), {"id": 2, "data": "y", "bar_id": 1}) + with self.conn.begin(): + bar.create(self.conn) + self.conn.execute( + bar.insert(), {"id": 1, "data": "x", "bar_id": None} + ) + self.conn.execute( + bar.insert(), {"id": 2, "data": "y", "bar_id": 1} + ) with self.op.batch_alter_table("bar", recreate=recreate) as batch_op: batch_op.alter_column( @@ -1649,8 +1676,9 @@ class BatchRoundTripTest(TestBase): Column("foo_id", Integer, ForeignKey("foo.id")), mysql_engine="InnoDB", ) - bar.create(self.conn) - self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3}) + with self.conn.begin(): + bar.create(self.conn) + self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3}) naming_convention = { "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s" @@ -1773,9 +1801,10 @@ class BatchRoundTripTest(TestBase): Column("flag", Boolean(create_constraint=True)), mysql_engine="InnoDB", ) - bar.create(self.conn) - self.conn.execute(bar.insert(), {"id": 1, "flag": True}) - self.conn.execute(bar.insert(), {"id": 2, "flag": False}) + with self.conn.begin(): + bar.create(self.conn) + self.conn.execute(bar.insert(), {"id": 1, "flag": True}) + self.conn.execute(bar.insert(), {"id": 2, "flag": False}) with self.op.batch_alter_table("bar") as batch_op: batch_op.alter_column( @@ -1795,15 +1824,16 @@ class BatchRoundTripTest(TestBase): Column("flag", Boolean(create_constraint=False)), mysql_engine="InnoDB", ) - bar.create(self.conn) - self.conn.execute(bar.insert(), {"id": 1, "flag": True}) - self.conn.execute(bar.insert(), {"id": 2, "flag": False}) - self.conn.execute( - # override Boolean type which as of 1.1 coerces numerics - # to 1/0 - text("insert into bar (id, flag) values (:id, :flag)"), - {"id": 3, "flag": 5}, - ) + with self.conn.begin(): + bar.create(self.conn) + self.conn.execute(bar.insert(), {"id": 1, "flag": True}) + self.conn.execute(bar.insert(), {"id": 2, "flag": False}) + self.conn.execute( + # override Boolean type which as of 1.1 coerces numerics + # to 1/0 + text("insert into bar (id, flag) values (:id, :flag)"), + {"id": 3, "flag": 5}, + ) with self.op.batch_alter_table( "bar", @@ -2042,7 +2072,8 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest): ), Column("y", Integer), ) - t.create(self.conn) + with self.conn.begin(): + t.create(self.conn) def _datetime_server_default_fixture(self): return func.current_timestamp() diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py index aedf6e9..09c641a 100644 --- a/tests/test_bulk_insert.py +++ b/tests/test_bulk_insert.py @@ -237,23 +237,28 @@ class RoundTripTest(TestBase): def setUp(self): self.conn = config.db.connect() - self.conn.execute( - text( - """ - create table foo( - id integer primary key, - data varchar(50), - x integer - ) - """ + with self.conn.begin(): + self.conn.execute( + text( + """ + create table foo( + id integer primary key, + data varchar(50), + x integer + ) + """ + ) ) - ) context = MigrationContext.configure(self.conn) self.op = op.Operations(context) self.t1 = table("foo", column("id"), column("data"), column("x")) + self.trans = self.conn.begin() + def tearDown(self): - self.conn.execute(text("drop table foo")) + self.trans.rollback() + with self.conn.begin(): + self.conn.execute(text("drop table foo")) self.conn.close() def test_single_insert_round_trip(self): diff --git a/tests/test_command.py b/tests/test_command.py index a616e9c..8350e82 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -399,7 +399,7 @@ finally: r2 = command.revision(self.cfg) db = _sqlite_file_db() command.upgrade(self.cfg, "head") - with db.connect() as conn: + with db.begin() as conn: conn.execute( text("insert into alembic_version values ('%s')" % r2.revision) ) @@ -681,7 +681,7 @@ class StampMultipleHeadsTest(TestBase, _StampTest): command.stamp(self.cfg, [self.a]) eng = _sqlite_file_db() - with eng.connect() as conn: + with eng.begin() as conn: result = conn.execute( text("update alembic_version set version_num='fake'") ) diff --git a/tests/test_environment.py b/tests/test_environment.py index 4fb6bbe..63de6cd 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,6 +1,7 @@ #!coding: utf-8 from alembic import command from alembic import testing +from alembic import util from alembic.environment import EnvironmentContext from alembic.migration import MigrationContext from alembic.script import ScriptDirectory @@ -11,7 +12,7 @@ from alembic.testing import is_ from alembic.testing import is_false from alembic.testing import is_true from alembic.testing import mock -from alembic.testing.assertions import expect_warnings +from alembic.testing.assertions import expect_raises_message from alembic.testing.env import _no_sql_testing_config from alembic.testing.env import _sqlite_file_db from alembic.testing.env import clear_staging_env @@ -94,10 +95,11 @@ def upgrade(): command.upgrade(self.cfg, "arev", sql=True) assert "do some SQL thing with a % percent sign %" in buf.getvalue() + @config.requirements.legacy_engine @testing.uses_deprecated( r"The Engine.execute\(\) function/method is considered legacy" ) - def test_warning_on_passing_engine(self): + def test_error_on_passing_engine(self): env = self._fixture() engine = _sqlite_file_db() @@ -131,18 +133,15 @@ def downgrade(): migration_fn(rev, context) return env.script._upgrade_revs(a_rev, rev) - with expect_warnings( + with expect_raises_message( + util.CommandError, r"'connection' argument to configure\(\) is " - r"expected to be a sqlalchemy.engine.Connection " + r"expected to be a sqlalchemy.engine.Connection ", ): env.configure( connection=engine, fn=upgrade, transactional_ddl=False ) - env.run_migrations() - - eq_(migration_fn.mock_calls, [mock.call((), env._migration_context)]) - class MigrationTransactionTest(TestBase): __backend__ = True @@ -238,7 +237,7 @@ class MigrationTransactionTest(TestBase): with context.begin_transaction(): is_false(self.conn.in_transaction()) with context.begin_transaction(_per_migration=True): - is_false(self.conn.in_transaction()) + is_true(self.conn.in_transaction()) is_false(self.conn.in_transaction()) is_false(self.conn.in_transaction()) @@ -264,7 +263,7 @@ class MigrationTransactionTest(TestBase): with context.begin_transaction(): is_false(self.conn.in_transaction()) with context.begin_transaction(_per_migration=True): - is_false(self.conn.in_transaction()) + is_true(self.conn.in_transaction()) is_false(self.conn.in_transaction()) is_false(self.conn.in_transaction()) @@ -334,18 +333,12 @@ class MigrationTransactionTest(TestBase): with context.begin_transaction(): is_false(self.conn.in_transaction()) with context.begin_transaction(_per_migration=True): - if context.impl.transactional_ddl: - is_true(self.conn.in_transaction()) - else: - is_false(self.conn.in_transaction()) + is_true(self.conn.in_transaction()) with context.autocommit_block(): is_false(self.conn.in_transaction()) - if context.impl.transactional_ddl: - is_true(self.conn.in_transaction()) - else: - is_false(self.conn.in_transaction()) + is_true(self.conn.in_transaction()) is_false(self.conn.in_transaction()) is_false(self.conn.in_transaction()) diff --git a/tests/test_impl.py b/tests/test_impl.py new file mode 100644 index 0000000..8a73b87 --- /dev/null +++ b/tests/test_impl.py @@ -0,0 +1,45 @@ +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import Table +from sqlalchemy.sql import text + +from alembic import testing +from alembic.testing import eq_ +from alembic.testing.fixtures import FutureEngineMixin +from alembic.testing.fixtures import TablesTest + + +class ImplTest(TablesTest): + __only_on__ = "sqlite" + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", metadata, Column("x", Integer), Column("y", Integer) + ) + + @testing.fixture + def impl(self, migration_context): + with migration_context.begin_transaction(_per_migration=True): + yield migration_context.impl + + def test_execute_params(self, impl): + result = impl._exec(text("select :my_param"), params={"my_param": 5}) + eq_(result.scalar(), 5) + + def test_execute_multiparams(self, impl): + some_table = self.tables.some_table + impl._exec( + some_table.insert(), + multiparams=[{"x": 1, "y": 2}, {"x": 2, "y": 3}, {"x": 5, "y": 7}], + ) + eq_( + impl._exec( + some_table.select().order_by(some_table.c.x) + ).fetchall(), + [(1, 2), (2, 3), (5, 7)], + ) + + +class FutureImplTest(FutureEngineMixin, ImplTest): + pass diff --git a/tests/test_mysql.py b/tests/test_mysql.py index caef197..ba43e3a 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -594,10 +594,11 @@ class MySQLDefaultCompareTest(TestBase): clear_staging_env() def setUp(self): - self.metadata = MetaData(self.bind) + self.metadata = MetaData() def tearDown(self): - self.metadata.drop_all() + with config.db.begin() as conn: + self.metadata.drop_all(conn) def _compare_default_roundtrip(self, type_, txt, alternate=None): if alternate: diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 08f70d8..10f17d4 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -35,7 +35,6 @@ from alembic.autogenerate.compare import _compare_server_default from alembic.autogenerate.compare import _compare_tables from alembic.autogenerate.compare import _render_server_default_for_compare from alembic.migration import MigrationContext -from alembic.operations import Operations from alembic.operations import ops from alembic.script import ScriptDirectory from alembic.testing import assert_raises_message @@ -50,6 +49,7 @@ from alembic.testing.env import staging_env from alembic.testing.env import write_script from alembic.testing.fixtures import capture_context_buffer from alembic.testing.fixtures import op_fixture +from alembic.testing.fixtures import TablesTest from alembic.testing.fixtures import TestBase from alembic.util import sqla_compat @@ -436,11 +436,12 @@ class PGAutocommitBlockTest(TestBase): with self.conn.begin(): self.conn.execute(text("DROP TYPE mood")) - def test_alter_enum(self): - context = MigrationContext.configure(connection=self.conn) - with context.begin_transaction(_per_migration=True): - with context.autocommit_block(): - context.execute(text("ALTER TYPE mood ADD VALUE 'soso'")) + def test_alter_enum(self, migration_context): + with migration_context.begin_transaction(_per_migration=True): + with migration_context.autocommit_block(): + migration_context.execute( + text("ALTER TYPE mood ADD VALUE 'soso'") + ) class PGOfflineEnumTest(TestBase): @@ -546,58 +547,38 @@ def downgrade(): assert "DROP TYPE pgenum" in buf.getvalue() -class PostgresqlInlineLiteralTest(TestBase): +class PostgresqlInlineLiteralTest(TablesTest): __only_on__ = "postgresql" __backend__ = True @classmethod - def setup_class(cls): - cls.bind = config.db - with config.db.connect() as conn: - conn.execute( - text( - """ - create table tab ( - col varchar(50) - ) - """ - ) - ) - conn.execute( - text( - """ + def define_tables(cls, metadata): + Table("tab", metadata, Column("col", String(50))) + + @classmethod + def insert_data(cls, connection): + connection.execute( + text( + """ insert into tab (col) values ('old data 1'), ('old data 2.1'), ('old data 3') """ - ) ) + ) - @classmethod - def teardown_class(cls): - with cls.bind.connect() as conn: - conn.execute(text("drop table tab")) - - def setUp(self): - self.conn = self.bind.connect() - ctx = MigrationContext.configure(self.conn) - self.op = Operations(ctx) - - def tearDown(self): - self.conn.close() - - def test_inline_percent(self): + def test_inline_percent(self, connection, ops_context): # TODO: here's the issue, you need to escape this. tab = table("tab", column("col")) - self.op.execute( + ops_context.execute( tab.update() - .where(tab.c.col.like(self.op.inline_literal("%.%"))) - .values(col=self.op.inline_literal("new data")), + .where(tab.c.col.like(ops_context.inline_literal("%.%"))) + .values(col=ops_context.inline_literal("new data")), execution_options={"no_parameters": True}, ) eq_( - self.conn.execute( + connection.execute( text("select count(*) from tab where col='new data'") ).scalar(), 1, @@ -618,7 +599,7 @@ class PostgresqlDefaultCompareTest(TestBase): ) def setUp(self): - self.metadata = MetaData(self.bind) + self.metadata = MetaData() self.autogen_context = api.AutogenContext(self.migration_context) @classmethod @@ -626,7 +607,8 @@ class PostgresqlDefaultCompareTest(TestBase): clear_staging_env() def tearDown(self): - self.metadata.drop_all() + with config.db.begin() as conn: + self.metadata.drop_all(conn) def _compare_default_roundtrip( self, type_, orig_default, alternate=None, diff_expected=None diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py index 17bf037..e1b094f 100644 --- a/tests/test_script_consumption.py +++ b/tests/test_script_consumption.py @@ -5,12 +5,16 @@ import os import re import textwrap +import sqlalchemy as sa + from alembic import command +from alembic import testing from alembic import util from alembic.environment import EnvironmentContext from alembic.script import Script from alembic.script import ScriptDirectory from alembic.testing import assert_raises_message +from alembic.testing import config from alembic.testing import eq_ from alembic.testing import mock from alembic.testing.env import _no_sql_testing_config @@ -22,31 +26,81 @@ from alembic.testing.env import staging_env from alembic.testing.env import three_rev_fixture from alembic.testing.env import write_script from alembic.testing.fixtures import capture_context_buffer +from alembic.testing.fixtures import FutureEngineMixin from alembic.testing.fixtures import TestBase from alembic.util import compat -class ApplyVersionsFunctionalTest(TestBase): +class PatchEnvironment(object): + @contextmanager + def _patch_environment(self, transactional_ddl, transaction_per_migration): + conf = EnvironmentContext.configure + + conn = [None] + + def configure(*arg, **opt): + opt.update( + transactional_ddl=transactional_ddl, + transaction_per_migration=transaction_per_migration, + ) + conn[0] = opt["connection"] + return conf(*arg, **opt) + + with mock.patch.object(EnvironmentContext, "configure", configure): + yield + + # it's no longer possible for the conn to be in a transaction + # assuming normal env.py as context.begin_transaction() + # will always run a real DB transaction, no longer uses autocommit + # mode + assert not conn[0].in_transaction() + + +@testing.combinations( + ( + False, + True, + ), + ( + True, + False, + ), + ( + True, + True, + ), + argnames="transactional_ddl,transaction_per_migration", + id_="rr", +) +class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase): __only_on__ = "sqlite" sourceless = False + future = False + transactional_ddl = False + transaction_per_migration = True def setUp(self): - self.bind = _sqlite_file_db() + self.bind = _sqlite_file_db(future=self.future) self.env = staging_env(sourceless=self.sourceless) - self.cfg = _sqlite_testing_config(sourceless=self.sourceless) + self.cfg = _sqlite_testing_config( + sourceless=self.sourceless, future=self.future + ) def tearDown(self): clear_staging_env() def test_steps(self): - self._test_001_revisions() - self._test_002_upgrade() - self._test_003_downgrade() - self._test_004_downgrade() - self._test_005_upgrade() - self._test_006_upgrade_again() - self._test_007_stamp_upgrade() + with self._patch_environment( + self.transactional_ddl, self.transaction_per_migration + ): + self._test_001_revisions() + self._test_002_upgrade() + self._test_003_downgrade() + self._test_004_downgrade() + self._test_005_upgrade() + self._test_006_upgrade_again() + self._test_007_stamp_upgrade() def _test_001_revisions(self): self.a = a = util.rev_id() @@ -166,22 +220,39 @@ class ApplyVersionsFunctionalTest(TestBase): assert not db.dialect.has_table(db.connect(), "bat") +# class level combinations can't do the skips for SQLAlchemy 1.3 +# so we have a separate class +@testing.combinations( + ( + False, + True, + ), + ( + True, + False, + ), + ( + True, + True, + ), + argnames="transactional_ddl,transaction_per_migration", + id_="rr", +) +class FutureApplyVersionsTest(FutureEngineMixin, ApplyVersionsFunctionalTest): + future = True + + class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest): sourceless = "simple" -class NewFangledSourcelessEnvOnlyApplyVersionsTest( - ApplyVersionsFunctionalTest -): - sourceless = "pep3147_envonly" - - __requires__ = ("pep3147",) - - -class NewFangledSourcelessEverythingApplyVersionsTest( - ApplyVersionsFunctionalTest -): - sourceless = "pep3147_everything" +@testing.combinations( + ("pep3147_envonly",), + ("pep3147_everything",), + argnames="sourceless", + id_="r", +) +class NewFangledSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest): __requires__ = ("pep3147",) @@ -313,13 +384,17 @@ class OfflineTransactionalDDLTest(TestBase): ) -class OnlineTransactionalDDLTest(TestBase): +class OnlineTransactionalDDLTest(PatchEnvironment, TestBase): def tearDown(self): clear_staging_env() - def _opened_transaction_fixture(self): + def _opened_transaction_fixture(self, future=False): self.env = staging_env() - self.cfg = _sqlite_testing_config() + + if future: + self.cfg = _sqlite_testing_config(future=future) + else: + self.cfg = _sqlite_testing_config() script = ScriptDirectory.from_config(self.cfg) a = util.rev_id() @@ -358,6 +433,8 @@ from alembic import op def upgrade(): conn = op.get_bind() + # this should fail for a SQLAlchemy 2.0 connection b.c. there is + # already a transaction. trans = conn.begin() @@ -391,59 +468,89 @@ def downgrade(): ) return a, b, c - @contextmanager - def _patch_environment(self, transactional_ddl, transaction_per_migration): - conf = EnvironmentContext.configure - - def configure(*arg, **opt): - opt.update( - transactional_ddl=transactional_ddl, - transaction_per_migration=transaction_per_migration, - ) - return conf(*arg, **opt) - - with mock.patch.object(EnvironmentContext, "configure", configure): - yield + # these tests might not be supported anymore; the connection is always + # going to be in a transaction now even on 1.3. - def test_raise_when_rev_leaves_open_transaction(self): - a, b, c = self._opened_transaction_fixture() + @testing.combinations((False,), (True, config.requirements.sqlalchemy_14)) + def test_raise_when_rev_leaves_open_transaction(self, future): + a, b, c = self._opened_transaction_fixture(future) with self._patch_environment( transactional_ddl=False, transaction_per_migration=False ): - assert_raises_message( - util.CommandError, - r'Migration "upgrade .*, rev b" has left an uncommitted ' - r"transaction opened; transactional_ddl is False so Alembic " - r"is not committing transactions", - command.upgrade, - self.cfg, - c, - ) + if future: + with testing.expect_raises_message( + sa.exc.InvalidRequestError, + "a transaction is already begun", + ): + command.upgrade(self.cfg, c) + elif config.requirements.sqlalchemy_14.enabled: + if self.is_sqlalchemy_future: + with testing.expect_raises_message( + sa.exc.InvalidRequestError, + r"a transaction is already begun for this connection", + ): + command.upgrade(self.cfg, c) + else: + with testing.expect_sqlalchemy_deprecated_20( + r"Calling .begin\(\) when a transaction " + "is already begun" + ): + command.upgrade(self.cfg, c) + else: + command.upgrade(self.cfg, c) - def test_raise_when_rev_leaves_open_transaction_tpm(self): - a, b, c = self._opened_transaction_fixture() + @testing.combinations((False,), (True, config.requirements.sqlalchemy_14)) + def test_raise_when_rev_leaves_open_transaction_tpm(self, future): + a, b, c = self._opened_transaction_fixture(future) with self._patch_environment( transactional_ddl=False, transaction_per_migration=True ): - assert_raises_message( - util.CommandError, - r'Migration "upgrade .*, rev b" has left an uncommitted ' - r"transaction opened; transactional_ddl is False so Alembic " - r"is not committing transactions", - command.upgrade, - self.cfg, - c, - ) + if future: + with testing.expect_raises_message( + sa.exc.InvalidRequestError, + "a transaction is already begun", + ): + command.upgrade(self.cfg, c) + elif config.requirements.sqlalchemy_14.enabled: + if self.is_sqlalchemy_future: + with testing.expect_raises_message( + sa.exc.InvalidRequestError, + r"a transaction is already begun for this connection", + ): + command.upgrade(self.cfg, c) + else: + with testing.expect_sqlalchemy_deprecated_20( + r"Calling .begin\(\) when a transaction is " + "already begun" + ): + command.upgrade(self.cfg, c) + else: + command.upgrade(self.cfg, c) - def test_noerr_rev_leaves_open_transaction_transactional_ddl(self): + @testing.combinations((False,), (True, config.requirements.sqlalchemy_14)) + def test_noerr_rev_leaves_open_transaction_transactional_ddl(self, future): a, b, c = self._opened_transaction_fixture() with self._patch_environment( transactional_ddl=True, transaction_per_migration=False ): - command.upgrade(self.cfg, c) + if config.requirements.sqlalchemy_14.enabled: + if self.is_sqlalchemy_future: + with testing.expect_raises_message( + sa.exc.InvalidRequestError, + r"a transaction is already begun for this connection", + ): + command.upgrade(self.cfg, c) + else: + with testing.expect_sqlalchemy_deprecated_20( + r"Calling .begin\(\) when a transaction " + "is already begun" + ): + command.upgrade(self.cfg, c) + else: + command.upgrade(self.cfg, c) def test_noerr_transaction_opened_externally(self): a, b, c = self._opened_transaction_fixture() @@ -477,6 +584,12 @@ run_migrations_online() command.stamp(self.cfg, c) +class FutureOnlineTransactionalDDLTest( + FutureEngineMixin, OnlineTransactionalDDLTest +): + pass + + class EncodingTest(TestBase): def setUp(self): self.env = staging_env() diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 3ea1975..946f69f 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -96,7 +96,7 @@ class SQLiteDefaultCompareTest(TestBase): ) def setUp(self): - self.metadata = MetaData(self.bind) + self.metadata = MetaData() self.autogen_context = api.AutogenContext(self.migration_context) @classmethod @@ -104,7 +104,7 @@ class SQLiteDefaultCompareTest(TestBase): clear_staging_env() def tearDown(self): - self.metadata.drop_all() + self.metadata.drop_all(config.db) def _compare_default_roundtrip( self, type_, orig_default, alternate=None, diff_expected=None diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 1801346..5ad3c21 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -39,7 +39,8 @@ class TestMigrationContext(TestBase): def tearDown(self): self.transaction.rollback() - version_table.drop(self.connection, checkfirst=True) + with self.connection.begin(): + version_table.drop(self.connection, checkfirst=True) self.connection.close() def make_one(self, **kwargs): @@ -182,11 +183,16 @@ class UpdateRevTest(TestBase): self.context = migration.MigrationContext.configure( connection=self.connection, opts={"version_table": "version_table"} ) - version_table.create(self.connection) + with self.connection.begin(): + version_table.create(self.connection) self.updater = migration.HeadMaintainer(self.context, ()) def tearDown(self): - version_table.drop(self.connection, checkfirst=True) + in_t = getattr(self.connection, "in_transaction", lambda: False) + if in_t(): + self.connection.rollback() + with self.connection.begin(): + version_table.drop(self.connection, checkfirst=True) self.connection.close() def _assert_heads(self, heads): @@ -194,145 +200,176 @@ class UpdateRevTest(TestBase): eq_(self.updater.heads, set(heads)) def test_update_none_to_single(self): - self.updater.update_to_step(_up(None, "a", True)) - self._assert_heads(("a",)) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self._assert_heads(("a",)) def test_update_single_to_single(self): - self.updater.update_to_step(_up(None, "a", True)) - self.updater.update_to_step(_up("a", "b")) - self._assert_heads(("b",)) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_up("a", "b")) + self._assert_heads(("b",)) def test_update_single_to_none(self): - self.updater.update_to_step(_up(None, "a", True)) - self.updater.update_to_step(_down("a", None, True)) - self._assert_heads(()) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_down("a", None, True)) + self._assert_heads(()) def test_add_branches(self): - self.updater.update_to_step(_up(None, "a", True)) - self.updater.update_to_step(_up("a", "b")) - self.updater.update_to_step(_up(None, "c", True)) - self._assert_heads(("b", "c")) - self.updater.update_to_step(_up("c", "d")) - self.updater.update_to_step(_up("d", "e1")) - self.updater.update_to_step(_up("d", "e2", True)) - self._assert_heads(("b", "e1", "e2")) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_up("a", "b")) + self.updater.update_to_step(_up(None, "c", True)) + self._assert_heads(("b", "c")) + self.updater.update_to_step(_up("c", "d")) + self.updater.update_to_step(_up("d", "e1")) + self.updater.update_to_step(_up("d", "e2", True)) + self._assert_heads(("b", "e1", "e2")) def test_teardown_branches(self): - self.updater.update_to_step(_up(None, "d1", True)) - self.updater.update_to_step(_up(None, "d2", True)) - self._assert_heads(("d1", "d2")) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "d1", True)) + self.updater.update_to_step(_up(None, "d2", True)) + self._assert_heads(("d1", "d2")) - self.updater.update_to_step(_down("d1", "c")) - self._assert_heads(("c", "d2")) + self.updater.update_to_step(_down("d1", "c")) + self._assert_heads(("c", "d2")) - self.updater.update_to_step(_down("d2", "c", True)) + self.updater.update_to_step(_down("d2", "c", True)) - self._assert_heads(("c",)) - self.updater.update_to_step(_down("c", "b")) - self._assert_heads(("b",)) + self._assert_heads(("c",)) + self.updater.update_to_step(_down("c", "b")) + self._assert_heads(("b",)) def test_resolve_merges(self): - self.updater.update_to_step(_up(None, "a", True)) - self.updater.update_to_step(_up("a", "b")) - self.updater.update_to_step(_up("b", "c1")) - self.updater.update_to_step(_up("b", "c2", True)) - self.updater.update_to_step(_up("c1", "d1")) - self.updater.update_to_step(_up("c2", "d2")) - self._assert_heads(("d1", "d2")) - self.updater.update_to_step(_up(("d1", "d2"), "e")) - self._assert_heads(("e",)) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self.updater.update_to_step(_up("a", "b")) + self.updater.update_to_step(_up("b", "c1")) + self.updater.update_to_step(_up("b", "c2", True)) + self.updater.update_to_step(_up("c1", "d1")) + self.updater.update_to_step(_up("c2", "d2")) + self._assert_heads(("d1", "d2")) + self.updater.update_to_step(_up(("d1", "d2"), "e")) + self._assert_heads(("e",)) def test_unresolve_merges(self): - self.updater.update_to_step(_up(None, "e", True)) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "e", True)) - self.updater.update_to_step(_down("e", ("d1", "d2"))) - self._assert_heads(("d2", "d1")) + self.updater.update_to_step(_down("e", ("d1", "d2"))) + self._assert_heads(("d2", "d1")) - self.updater.update_to_step(_down("d2", "c2")) - self._assert_heads(("c2", "d1")) + self.updater.update_to_step(_down("d2", "c2")) + self._assert_heads(("c2", "d1")) def test_update_no_match(self): - self.updater.update_to_step(_up(None, "a", True)) - self.updater.heads.add("x") - assert_raises_message( - CommandError, - "Online migration expected to match one row when updating " - "'x' to 'b' in 'version_table'; 0 found", - self.updater.update_to_step, - _up("x", "b"), - ) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self.updater.heads.add("x") + assert_raises_message( + CommandError, + "Online migration expected to match one row when updating " + "'x' to 'b' in 'version_table'; 0 found", + self.updater.update_to_step, + _up("x", "b"), + ) def test_update_no_match_no_sane_rowcount(self): - self.updater.update_to_step(_up(None, "a", True)) - self.updater.heads.add("x") - with mock.patch.object( - self.connection.dialect, "supports_sane_rowcount", False - ): - self.updater.update_to_step(_up("x", "b")) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + self.updater.heads.add("x") + with mock.patch.object( + self.connection.dialect, "supports_sane_rowcount", False + ): + self.updater.update_to_step(_up("x", "b")) def test_update_multi_match(self): - self.connection.execute(version_table.insert(), version_num="a") - self.connection.execute(version_table.insert(), version_num="a") - - self.updater.heads.add("a") - assert_raises_message( - CommandError, - "Online migration expected to match one row when updating " - "'a' to 'b' in 'version_table'; 2 found", - self.updater.update_to_step, - _up("a", "b"), - ) + with self.connection.begin(): + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + + self.updater.heads.add("a") + assert_raises_message( + CommandError, + "Online migration expected to match one row when updating " + "'a' to 'b' in 'version_table'; 2 found", + self.updater.update_to_step, + _up("a", "b"), + ) def test_update_multi_match_no_sane_rowcount(self): - self.connection.execute(version_table.insert(), version_num="a") - self.connection.execute(version_table.insert(), version_num="a") - - self.updater.heads.add("a") - with mock.patch.object( - self.connection.dialect, "supports_sane_rowcount", False - ): - self.updater.update_to_step(_up("a", "b")) + with self.connection.begin(): + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + + self.updater.heads.add("a") + with mock.patch.object( + self.connection.dialect, "supports_sane_rowcount", False + ): + self.updater.update_to_step(_up("a", "b")) def test_delete_no_match(self): - self.updater.update_to_step(_up(None, "a", True)) - - self.updater.heads.add("x") - assert_raises_message( - CommandError, - "Online migration expected to match one row when " - "deleting 'x' in 'version_table'; 0 found", - self.updater.update_to_step, - _down("x", None, True), - ) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) + + self.updater.heads.add("x") + assert_raises_message( + CommandError, + "Online migration expected to match one row when " + "deleting 'x' in 'version_table'; 0 found", + self.updater.update_to_step, + _down("x", None, True), + ) def test_delete_no_matchno_sane_rowcount(self): - self.updater.update_to_step(_up(None, "a", True)) + with self.connection.begin(): + self.updater.update_to_step(_up(None, "a", True)) - self.updater.heads.add("x") - with mock.patch.object( - self.connection.dialect, "supports_sane_rowcount", False - ): - self.updater.update_to_step(_down("x", None, True)) + self.updater.heads.add("x") + with mock.patch.object( + self.connection.dialect, "supports_sane_rowcount", False + ): + self.updater.update_to_step(_down("x", None, True)) def test_delete_multi_match(self): - self.connection.execute(version_table.insert(), version_num="a") - self.connection.execute(version_table.insert(), version_num="a") - - self.updater.heads.add("a") - assert_raises_message( - CommandError, - "Online migration expected to match one row when " - "deleting 'a' in 'version_table'; 2 found", - self.updater.update_to_step, - _down("a", None, True), - ) + with self.connection.begin(): + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + + self.updater.heads.add("a") + assert_raises_message( + CommandError, + "Online migration expected to match one row when " + "deleting 'a' in 'version_table'; 2 found", + self.updater.update_to_step, + _down("a", None, True), + ) def test_delete_multi_match_no_sane_rowcount(self): - self.connection.execute(version_table.insert(), version_num="a") - self.connection.execute(version_table.insert(), version_num="a") - - self.updater.heads.add("a") - with mock.patch.object( - self.connection.dialect, "supports_sane_rowcount", False - ): - self.updater.update_to_step(_down("a", None, True)) + with self.connection.begin(): + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + self.connection.execute( + version_table.insert(), dict(version_num="a") + ) + + self.updater.heads.add("a") + with mock.patch.object( + self.connection.dialect, "supports_sane_rowcount", False + ): + self.updater.update_to_step(_down("a", None, True)) @@ -1,6 +1,6 @@ [tox] -envlist = py +envlist = py-sqlalchemy SQLA_REPO = {env:SQLA_REPO:git+https://github.com/sqlalchemy/sqlalchemy.git} @@ -10,8 +10,8 @@ cov_args=--cov=alembic --cov-report term --cov-report xml deps=pytest>4.6 pytest-xdist mock - sqla13: {[tox]SQLA_REPO}@rel_1_3 - sqlamaster: {[tox]SQLA_REPO}@master + sqla13: {[tox]SQLA_REPO}@rel_1_3#egg=sqlalchemy + sqlamaster: {[tox]SQLA_REPO}@master#egg=sqlalchemy postgresql: psycopg2 mysql: mysqlclient mysql: pymysql @@ -19,6 +19,10 @@ deps=pytest>4.6 oracle: cx_oracle>=7;python_version>="3" mssql: pymssql cov: pytest-cov + sqlalchemy: sqlalchemy>=1.3.0 + mako + python-editor>=0.3 + python-dateutil @@ -39,6 +43,9 @@ setenv= mssql: MSSQL={env:TOX_MSSQL:--db pymssql} pyoptimize: PYTHONOPTIMIZE=1 pyoptimize: LIMITTESTS="tests/test_script_consumption.py" + future: SQLALCHEMY_TESTING_FUTURE_ENGINE=1 + SQLALCHEMY_WARN_20=1 + # tox as of 2.0 blocks all environment variables from the # outside, unless they are here (or in TOX_TESTENV_PASSENV, @@ -64,4 +71,4 @@ deps= black==20.8b1 commands = flake8 ./alembic/ ./tests/ setup.py docs/build/conf.py {posargs} - black --check . + black --check setup.py tests alembic |