diff options
-rw-r--r-- | alembic/operations/base.py | 8 | ||||
-rw-r--r-- | alembic/util/compat.py | 3 | ||||
-rw-r--r-- | tests/requirements.py | 20 | ||||
-rw-r--r-- | tests/test_stubs.py | 15 | ||||
-rw-r--r-- | tools/write_pyi.py | 3 |
5 files changed, 34 insertions, 15 deletions
diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 21f7a85..59bbfc4 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -1,8 +1,8 @@ from contextlib import contextmanager +import re import textwrap from typing import Any from typing import Callable -from typing import ForwardRef # noqa from typing import Iterator from typing import List # noqa from typing import Optional @@ -136,6 +136,12 @@ class Operations(util.ModuleClsProxy): formatvalue=lambda x: "=" + x, ) + args = re.sub( + r'[_]?ForwardRef\(([\'"].+?[\'"])\)', + lambda m: m.group(1), + args, + ) + func_text = textwrap.dedent( """\ def %(name)s%(args)s: diff --git a/alembic/util/compat.py b/alembic/util/compat.py index b87f8a6..dae98f4 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -1,11 +1,14 @@ import io import os +import sys from sqlalchemy.util import inspect_getfullargspec # noqa from sqlalchemy.util.compat import inspect_formatargspec # noqa is_posix = os.name == "posix" +py39 = sys.version_info >= (3, 9) + string_types = (str,) binary_type = bytes diff --git a/tests/requirements.py b/tests/requirements.py index 04497f5..011c620 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -2,6 +2,7 @@ from sqlalchemy import text from alembic.testing import exclusions from alembic.testing.requirements import SuiteRequirements +from alembic.util import compat from alembic.util import sqla_compat @@ -301,3 +302,22 @@ class DefaultRequirements(SuiteRequirements): return exclusions.only_if( lambda config: not getattr(config.db, "_is_future", False) ) + + @property + def stubs_test(self): + def requirements(): + try: + import black # noqa + import zimports # noqa + + return False + except Exception: + return True + + imports = exclusions.skip_if( + requirements, "black and zimports are required for this test" + ) + version = exclusions.only_if( + lambda _: compat.py39, "python 3.9 is required" + ) + return imports + version diff --git a/tests/test_stubs.py b/tests/test_stubs.py index c5186c3..efb1a9d 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -4,22 +4,11 @@ import sys import alembic from alembic.testing import eq_ -from alembic.testing import skip_if from alembic.testing import TestBase _home = Path(__file__).parent.parent -def requirements(): - try: - import black # noqa - import zimports # noqa - - return False - except Exception: - return True - - def run_command(file): res = subprocess.run( [ @@ -37,14 +26,14 @@ def run_command(file): class TestStubFiles(TestBase): - @skip_if(requirements, "black and zimports are required for this test") + __requires__ = ("stubs_test",) + def test_op_pyi(self): res = run_command("op") generated = res.stdout expected = Path(alembic.__file__).parent / "op.pyi" eq_(generated, expected.read_text()) - @skip_if(requirements, "black and zimports are required for this test") def test_context_pyi(self): res = run_command("context") generated = res.stdout diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 6dfed09..234b06f 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -50,7 +50,8 @@ def generate_pyi_for_proxy( printer = PythonPrinter(buf) printer.writeline( - f"# ### this file stubs are generated by {progname} - do not edit ###" + f"# ### this file stubs are generated by {progname} " + "- do not edit ###" ) for line in imports: buf.write(line + "\n") |