summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/provision.py4
-rw-r--r--lib/sqlalchemy/testing/__init__.py2
-rw-r--r--lib/sqlalchemy/testing/assertions.py59
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py19
-rw-r--r--lib/sqlalchemy/testing/warnings.py32
5 files changed, 78 insertions, 38 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py
index 289dda4b6..29926ee3d 100644
--- a/lib/sqlalchemy/dialects/postgresql/provision.py
+++ b/lib/sqlalchemy/dialects/postgresql/provision.py
@@ -19,10 +19,6 @@ def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
- try:
- _pg_drop_db(cfg, conn, ident)
- except Exception:
- pass
if not template_db:
template_db = conn.exec_driver_sql(
"select current_database()"
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 87208d3f4..fd6ddf593 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -12,6 +12,8 @@ from .assertions import assert_raises
from .assertions import assert_raises_context_ok
from .assertions import assert_raises_message
from .assertions import assert_raises_message_context_ok
+from .assertions import assert_warns
+from .assertions import assert_warns_message
from .assertions import AssertsCompiledSQL
from .assertions import AssertsExecutionResults
from .assertions import ComparesTables
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index f268b6fc3..2a00f1c14 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -139,13 +139,15 @@ def _expect_warnings(
exc_cls,
messages,
regex=True,
+ search_msg=False,
assert_=True,
raise_on_any_unexpected=False,
+ squelch_other_warnings=False,
):
global _FILTERS, _SEEN, _EXC_CLS
- if regex:
+ if regex or search_msg:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
else:
filters = list(messages)
@@ -183,19 +185,23 @@ def _expect_warnings(
exception = None
if not exception or not issubclass(exception, _EXC_CLS):
- return real_warn(msg, *arg, **kw)
+ if not squelch_other_warnings:
+ return real_warn(msg, *arg, **kw)
if not filters and not raise_on_any_unexpected:
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or (
- not regex and filter_ == msg
+ if (
+ (search_msg and filter_.search(msg))
+ or (regex and filter_.match(msg))
+ or (not regex and filter_ == msg)
):
seen.discard(filter_)
break
else:
- real_warn(msg, *arg, **kw)
+ if not squelch_other_warnings:
+ real_warn(msg, *arg, **kw)
with mock.patch("warnings.warn", our_warn):
try:
@@ -343,6 +349,40 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
)
+def assert_warns(except_cls, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+
+ """
+ with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True):
+ return callable_(*args, **kwargs)
+
+
+def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+ Also uses regex.search() to match the given message to the error string
+ rather than regex.match().
+
+ """
+ with _expect_warnings(
+ except_cls,
+ [msg],
+ search_msg=True,
+ regex=False,
+ squelch_other_warnings=True,
+ ):
+ return callable_(*args, **kwargs)
+
+
def assert_raises_message_context_ok(
except_cls, msg, callable_, *args, **kwargs
):
@@ -364,6 +404,15 @@ class _ErrorContainer:
@contextlib.contextmanager
def _expect_raises(except_cls, msg=None, check_context=False):
+ if (
+ isinstance(except_cls, type)
+ and issubclass(except_cls, Warning)
+ or isinstance(except_cls, Warning)
+ ):
+ raise TypeError(
+ "Use expect_warnings for warnings, not "
+ "expect_raises / assert_raises"
+ )
ec = _ErrorContainer()
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 7a62ad008..2ae6730bb 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -13,16 +13,10 @@ import itertools
import operator
import os
import re
+import uuid
import pytest
-try:
- import xdist # noqa
-
- has_xdist = True
-except ImportError:
- has_xdist = False
-
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
@@ -75,6 +69,9 @@ def pytest_addoption(parser):
def pytest_configure(config):
+ if config.pluginmanager.hasplugin("xdist"):
+ config.pluginmanager.register(XDistHooks())
+
if hasattr(config, "workerinput"):
plugin_base.restore_important_follower_config(config.workerinput)
plugin_base.configure_follower(config.workerinput["follower_ident"])
@@ -148,10 +145,8 @@ def pytest_collection_finish(session):
collect_types.init_types_collection(filter_filename=_filter)
-if has_xdist:
- import uuid
-
- def pytest_configure_node(node):
+class XDistHooks:
+ def pytest_configure_node(self, node):
from sqlalchemy.testing import provision
from sqlalchemy.testing import asyncio
@@ -166,7 +161,7 @@ if has_xdist:
provision.create_follower_db, node.workerinput["follower_ident"]
)
- def pytest_testnodedown(node, error):
+ def pytest_testnodedown(self, node, error):
from sqlalchemy.testing import provision
from sqlalchemy.testing import asyncio
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index 34b23d675..1c2039602 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -11,8 +11,13 @@ from .. import exc as sa_exc
from ..util.langhelpers import _warnings_warn
-class SATestSuiteWarning(sa_exc.SAWarning):
- """warning for a condition detected during tests that is non-fatal"""
+class SATestSuiteWarning(Warning):
+ """warning for a condition detected during tests that is non-fatal
+
+ Currently outside of SAWarning so that we can work around tools like
+ Alembic doing the wrong thing with warnings.
+
+ """
def warn_test_suite(message):
@@ -22,28 +27,21 @@ def warn_test_suite(message):
def setup_filters():
"""Set global warning behavior for the test suite."""
+ # TODO: at this point we can use the normal pytest warnings plugin,
+ # if we decide the test suite can be linked to pytest only
+
+ origin = r"^(?:test|sqlalchemy)\..*"
+
warnings.filterwarnings(
"ignore", category=sa_exc.SAPendingDeprecationWarning
)
warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
warnings.filterwarnings("error", category=sa_exc.SAWarning)
- warnings.filterwarnings("always", category=SATestSuiteWarning)
- # some selected deprecations...
- warnings.filterwarnings("error", category=DeprecationWarning)
- warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=r".*StopIteration"
- )
- warnings.filterwarnings(
- "ignore",
- category=DeprecationWarning,
- message=r".*inspect.get.*argspec",
- )
+ warnings.filterwarnings("always", category=SATestSuiteWarning)
warnings.filterwarnings(
- "ignore",
- category=DeprecationWarning,
- message="The loop argument is deprecated",
+ "error", category=DeprecationWarning, module=origin
)
try:
@@ -52,7 +50,7 @@ def setup_filters():
pass
else:
warnings.filterwarnings(
- "once", category=pytest.PytestDeprecationWarning
+ "once", category=pytest.PytestDeprecationWarning, module=origin
)