summaryrefslogtreecommitdiff
path: root/alembic/testing/util.py
blob: a82690d12655bd815560883cf0b438b945a87763 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# testing/util.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import annotations

import re
import types
from typing import Union


def flag_combinations(*combinations):
    """A facade around @testing.combinations() oriented towards boolean
    keyword-based arguments.

    Basically generates a nice looking identifier based on the keywords
    and also sets up the argument names.

    E.g.::

        @testing.flag_combinations(
            dict(lazy=False, passive=False),
            dict(lazy=True, passive=False),
            dict(lazy=False, passive=True),
            dict(lazy=False, passive=True, raiseload=True),
        )


    would result in::

        @testing.combinations(
            ('', False, False, False),
            ('lazy', True, False, False),
            ('lazy_passive', True, True, False),
            ('lazy_passive', True, True, True),
            id_='iaaa',
            argnames='lazy,passive,raiseload'
        )

    """
    from sqlalchemy.testing import config

    keys = set()

    for d in combinations:
        keys.update(d)

    keys = sorted(keys)

    return config.combinations(
        *[
            ("_".join(k for k in keys if d.get(k, False)),)
            + tuple(d.get(k, False) for k in keys)
            for d in combinations
        ],
        id_="i" + ("a" * len(keys)),
        argnames=",".join(keys),
    )


def resolve_lambda(__fn, **kw):
    """Given a no-arg lambda and a namespace, return a new lambda that
    has all the values filled in.

    This is used so that we can have module-level fixtures that
    refer to instance-level variables using lambdas.

    """

    glb = dict(__fn.__globals__)
    glb.update(kw)
    new_fn = types.FunctionType(__fn.__code__, glb)
    return new_fn()


def metadata_fixture(ddl="function"):
    """Provide MetaData for a pytest fixture."""

    from sqlalchemy.testing import config
    from . import fixture_functions

    def decorate(fn):
        def run_ddl(self):
            from sqlalchemy import schema

            metadata = self.metadata = schema.MetaData()
            try:
                result = fn(self, metadata)
                metadata.create_all(config.db)
                # TODO:
                # somehow get a per-function dml erase fixture here
                yield result
            finally:
                metadata.drop_all(config.db)

        return fixture_functions.fixture(scope=ddl)(run_ddl)

    return decorate


def _safe_int(value: str) -> Union[int, str]:
    try:
        return int(value)
    except:
        return value


def testing_engine(url=None, options=None, future=False):
    from sqlalchemy.testing import config
    from sqlalchemy.testing.engines import testing_engine
    from sqlalchemy import __version__

    _vers = tuple(
        [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
    )
    sqla_1x = _vers < (2,)

    if not future:
        future = getattr(config._current.options, "future_engine", False)

    if sqla_1x:
        kw = {"future": future} if future else {}
    else:
        kw = {}
    return testing_engine(url, options, **kw)