summaryrefslogtreecommitdiff
path: root/alembic/testing/plugin/plugin_base.py
blob: 276bc56c63fe986813b851a967538ea7796c8cad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""vendored plugin_base functions from the most recent SQLAlchemy versions.

Alembic tests need to run on older versions of SQLAlchemy that don't
necessarily have all the latest testing fixtures.

"""
from __future__ import absolute_import

import abc
import sys

from sqlalchemy.testing.plugin.plugin_base import *  # noqa
from sqlalchemy.testing.plugin.plugin_base import post
from sqlalchemy.testing.plugin.plugin_base import post_begin as sqla_post_begin
from sqlalchemy.testing.plugin.plugin_base import stop_test_class as sqla_stc

py3k = sys.version_info >= (3, 0)


if py3k:

    ABC = abc.ABC
else:

    class ABC(object):
        __metaclass__ = abc.ABCMeta


def post_begin():
    sqla_post_begin()

    import warnings

    try:
        import pytest
    except ImportError:
        pass
    else:
        warnings.filterwarnings(
            "once", category=pytest.PytestDeprecationWarning
        )


# override selected SQLAlchemy pytest hooks with vendored functionality
def stop_test_class(cls):
    sqla_stc(cls)
    import os
    from alembic.testing.env import _get_staging_directory

    assert not os.path.exists(_get_staging_directory()), (
        "staging directory %s was not cleaned up" % _get_staging_directory()
    )


def want_class(name, cls):
    from sqlalchemy.testing import config
    from sqlalchemy.testing import fixtures

    if not issubclass(cls, fixtures.TestBase):
        return False
    elif name.startswith("_"):
        return False
    elif (
        config.options.backend_only
        and not getattr(cls, "__backend__", False)
        and not getattr(cls, "__sparse_backend__", False)
    ):
        return False
    else:
        return True


@post
def _init_symbols(options, file_config):
    from sqlalchemy.testing import config
    from alembic.testing import fixture_functions as alembic_config

    config._fixture_functions = (
        alembic_config._fixture_functions
    ) = _fixture_fn_class()


class FixtureFunctions(ABC):
    @abc.abstractmethod
    def skip_test_exception(self, *arg, **kw):
        raise NotImplementedError()

    @abc.abstractmethod
    def combinations(self, *args, **kw):
        raise NotImplementedError()

    @abc.abstractmethod
    def param_ident(self, *args, **kw):
        raise NotImplementedError()

    @abc.abstractmethod
    def fixture(self, *arg, **kw):
        raise NotImplementedError()

    def get_current_test_name(self):
        raise NotImplementedError()


_fixture_fn_class = None


def set_fixture_functions(fixture_fn_class):
    from sqlalchemy.testing.plugin import plugin_base

    global _fixture_fn_class
    _fixture_fn_class = plugin_base._fixture_fn_class = fixture_fn_class