summaryrefslogtreecommitdiff
path: root/alembic/testing/assertions.py
blob: 1c24066b808f57a51f57b3933e506dad4cefa9fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from __future__ import annotations

import contextlib
import re
import sys
from typing import Any
from typing import Dict

from sqlalchemy import exc as sa_exc
from sqlalchemy.engine import default
from sqlalchemy.testing.assertions import _expect_warnings
from sqlalchemy.testing.assertions import eq_  # noqa
from sqlalchemy.testing.assertions import is_  # noqa
from sqlalchemy.testing.assertions import is_false  # noqa
from sqlalchemy.testing.assertions import is_not_  # noqa
from sqlalchemy.testing.assertions import is_true  # noqa
from sqlalchemy.testing.assertions import ne_  # noqa
from sqlalchemy.util import decorator

from ..util import sqla_compat


def _assert_proper_exception_context(exception):
    """assert that any exception we're catching does not have a __context__
    without a __cause__, and that __suppress_context__ is never set.

    Python 3 will report nested as exceptions as "during the handling of
    error X, error Y occurred". That's not what we want to do.  we want
    these exceptions in a cause chain.

    """

    if (
        exception.__context__ is not exception.__cause__
        and not exception.__suppress_context__
    ):
        assert False, (
            "Exception %r was correctly raised but did not set a cause, "
            "within context %r as its cause."
            % (exception, exception.__context__)
        )


def assert_raises(except_cls, callable_, *args, **kw):
    return _assert_raises(except_cls, callable_, args, kw, check_context=True)


def assert_raises_context_ok(except_cls, callable_, *args, **kw):
    return _assert_raises(except_cls, callable_, args, kw)


def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
    return _assert_raises(
        except_cls, callable_, args, kwargs, msg=msg, check_context=True
    )


def assert_raises_message_context_ok(
    except_cls, msg, callable_, *args, **kwargs
):
    return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)


def _assert_raises(
    except_cls, callable_, args, kwargs, msg=None, check_context=False
):

    with _expect_raises(except_cls, msg, check_context) as ec:
        callable_(*args, **kwargs)
    return ec.error


class _ErrorContainer:
    error: Any = None


@contextlib.contextmanager
def _expect_raises(except_cls, msg=None, check_context=False):
    ec = _ErrorContainer()
    if check_context:
        are_we_already_in_a_traceback = sys.exc_info()[0]
    try:
        yield ec
        success = False
    except except_cls as err:
        ec.error = err
        success = True
        if msg is not None:
            assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
        if check_context and not are_we_already_in_a_traceback:
            _assert_proper_exception_context(err)
        print(str(err).encode("utf-8"))

    # assert outside the block so it works for AssertionError too !
    assert success, "Callable did not raise an exception"


def expect_raises(except_cls, check_context=True):
    return _expect_raises(except_cls, check_context=check_context)


def expect_raises_message(except_cls, msg, check_context=True):
    return _expect_raises(except_cls, msg=msg, check_context=check_context)


def eq_ignore_whitespace(a, b, msg=None):

    a = re.sub(r"^\s+?|\n", "", a)
    a = re.sub(r" {2,}", " ", a)
    b = re.sub(r"^\s+?|\n", "", b)
    b = re.sub(r" {2,}", " ", b)

    assert a == b, msg or "%r != %r" % (a, b)


_dialect_mods: Dict[Any, Any] = {}


def _get_dialect(name):
    if name is None or name == "default":
        return default.DefaultDialect()
    else:

        d = sqla_compat._create_url(name).get_dialect()()

        if name == "postgresql":
            d.implicit_returning = True
        elif name == "mssql":
            d.legacy_schema_aliasing = False
        return d


def expect_warnings(*messages, **kw):
    """Context manager which expects one or more warnings.

    With no arguments, squelches all SAWarnings emitted via
    sqlalchemy.util.warn and sqlalchemy.util.warn_limited.   Otherwise
    pass string expressions that will match selected warnings via regex;
    all non-matching warnings are sent through.

    The expect version **asserts** that the warnings were in fact seen.

    Note that the test suite sets SAWarning warnings to raise exceptions.

    """
    return _expect_warnings(Warning, messages, **kw)


def emits_python_deprecation_warning(*messages):
    """Decorator form of expect_warnings().

    Note that emits_warning does **not** assert that the warnings
    were in fact seen.

    """

    @decorator
    def decorate(fn, *args, **kw):
        with _expect_warnings(DeprecationWarning, assert_=False, *messages):
            return fn(*args, **kw)

    return decorate


def expect_sqlalchemy_deprecated(*messages, **kw):
    return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)


def expect_sqlalchemy_deprecated_20(*messages, **kw):
    return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)