summaryrefslogtreecommitdiff
path: root/alembic/testing/util.py
blob: e65597d9ac7579be3bdf16de63d7697c98510ad0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# testing/util.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import annotations

import re
import types
from typing import Union

from sqlalchemy.util import inspect_getfullargspec


def flag_combinations(*combinations):
    """A facade around @testing.combinations() oriented towards boolean
    keyword-based arguments.

    Basically generates a nice looking identifier based on the keywords
    and also sets up the argument names.

    E.g.::

        @testing.flag_combinations(
            dict(lazy=False, passive=False),
            dict(lazy=True, passive=False),
            dict(lazy=False, passive=True),
            dict(lazy=False, passive=True, raiseload=True),
        )


    would result in::

        @testing.combinations(
            ('', False, False, False),
            ('lazy', True, False, False),
            ('lazy_passive', True, True, False),
            ('lazy_passive', True, True, True),
            id_='iaaa',
            argnames='lazy,passive,raiseload'
        )

    """
    from sqlalchemy.testing import config

    keys = set()

    for d in combinations:
        keys.update(d)

    keys = sorted(keys)

    return config.combinations(
        *[
            ("_".join(k for k in keys if d.get(k, False)),)
            + tuple(d.get(k, False) for k in keys)
            for d in combinations
        ],
        id_="i" + ("a" * len(keys)),
        argnames=",".join(keys),
    )


def resolve_lambda(__fn, **kw):
    """Given a no-arg lambda and a namespace, return a new lambda that
    has all the values filled in.

    This is used so that we can have module-level fixtures that
    refer to instance-level variables using lambdas.

    """

    pos_args = inspect_getfullargspec(__fn)[0]
    pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
    glb = dict(__fn.__globals__)
    glb.update(kw)
    new_fn = types.FunctionType(__fn.__code__, glb)
    return new_fn(**pass_pos_args)


def metadata_fixture(ddl="function"):
    """Provide MetaData for a pytest fixture."""

    from sqlalchemy.testing import config
    from . import fixture_functions

    def decorate(fn):
        def run_ddl(self):
            from sqlalchemy import schema

            metadata = self.metadata = schema.MetaData()
            try:
                result = fn(self, metadata)
                metadata.create_all(config.db)
                # TODO:
                # somehow get a per-function dml erase fixture here
                yield result
            finally:
                metadata.drop_all(config.db)

        return fixture_functions.fixture(scope=ddl)(run_ddl)

    return decorate


def _safe_int(value: str) -> Union[int, str]:
    try:
        return int(value)
    except:
        return value


def testing_engine(url=None, options=None, future=False):
    from sqlalchemy.testing import config
    from sqlalchemy.testing.engines import testing_engine
    from sqlalchemy import __version__

    _vers = tuple(
        [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
    )
    sqla_1x = _vers < (2,)

    if not future:
        future = getattr(config._current.options, "future_engine", False)

    if sqla_1x:
        kw = {"future": future} if future else {}
    else:
        kw = {}
    return testing_engine(url, options, **kw)