summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2021-08-23 20:54:00 +0200
committerCaselIT <cfederico87@gmail.com>2021-08-23 21:26:32 +0200
commit7fd48061bbf893002b2d7a624b313b09ee0a9700 (patch)
tree6620421f730280b84883ad524bbbd98aea5d99a4
parent18136dd42a0820fbecacea8cb0e7f47b002ce68a (diff)
downloadalembic-7fd48061bbf893002b2d7a624b313b09ee0a9700.tar.gz
avoid importing ForwardRef that's py3.9+ only
Change-Id: I76654f10e208d618e21ab0c884cb0abede4d6177
-rw-r--r--alembic/operations/base.py8
-rw-r--r--alembic/util/compat.py3
-rw-r--r--tests/requirements.py20
-rw-r--r--tests/test_stubs.py15
-rw-r--r--tools/write_pyi.py3
5 files changed, 34 insertions, 15 deletions
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index 21f7a85..59bbfc4 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -1,8 +1,8 @@
from contextlib import contextmanager
+import re
import textwrap
from typing import Any
from typing import Callable
-from typing import ForwardRef # noqa
from typing import Iterator
from typing import List # noqa
from typing import Optional
@@ -136,6 +136,12 @@ class Operations(util.ModuleClsProxy):
formatvalue=lambda x: "=" + x,
)
+ args = re.sub(
+ r'[_]?ForwardRef\(([\'"].+?[\'"])\)',
+ lambda m: m.group(1),
+ args,
+ )
+
func_text = textwrap.dedent(
"""\
def %(name)s%(args)s:
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index b87f8a6..dae98f4 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -1,11 +1,14 @@
import io
import os
+import sys
from sqlalchemy.util import inspect_getfullargspec # noqa
from sqlalchemy.util.compat import inspect_formatargspec # noqa
is_posix = os.name == "posix"
+py39 = sys.version_info >= (3, 9)
+
string_types = (str,)
binary_type = bytes
diff --git a/tests/requirements.py b/tests/requirements.py
index 04497f5..011c620 100644
--- a/tests/requirements.py
+++ b/tests/requirements.py
@@ -2,6 +2,7 @@ from sqlalchemy import text
from alembic.testing import exclusions
from alembic.testing.requirements import SuiteRequirements
+from alembic.util import compat
from alembic.util import sqla_compat
@@ -301,3 +302,22 @@ class DefaultRequirements(SuiteRequirements):
return exclusions.only_if(
lambda config: not getattr(config.db, "_is_future", False)
)
+
+ @property
+ def stubs_test(self):
+ def requirements():
+ try:
+ import black # noqa
+ import zimports # noqa
+
+ return False
+ except Exception:
+ return True
+
+ imports = exclusions.skip_if(
+ requirements, "black and zimports are required for this test"
+ )
+ version = exclusions.only_if(
+ lambda _: compat.py39, "python 3.9 is required"
+ )
+ return imports + version
diff --git a/tests/test_stubs.py b/tests/test_stubs.py
index c5186c3..efb1a9d 100644
--- a/tests/test_stubs.py
+++ b/tests/test_stubs.py
@@ -4,22 +4,11 @@ import sys
import alembic
from alembic.testing import eq_
-from alembic.testing import skip_if
from alembic.testing import TestBase
_home = Path(__file__).parent.parent
-def requirements():
- try:
- import black # noqa
- import zimports # noqa
-
- return False
- except Exception:
- return True
-
-
def run_command(file):
res = subprocess.run(
[
@@ -37,14 +26,14 @@ def run_command(file):
class TestStubFiles(TestBase):
- @skip_if(requirements, "black and zimports are required for this test")
+ __requires__ = ("stubs_test",)
+
def test_op_pyi(self):
res = run_command("op")
generated = res.stdout
expected = Path(alembic.__file__).parent / "op.pyi"
eq_(generated, expected.read_text())
- @skip_if(requirements, "black and zimports are required for this test")
def test_context_pyi(self):
res = run_command("context")
generated = res.stdout
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index 6dfed09..234b06f 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -50,7 +50,8 @@ def generate_pyi_for_proxy(
printer = PythonPrinter(buf)
printer.writeline(
- f"# ### this file stubs are generated by {progname} - do not edit ###"
+ f"# ### this file stubs are generated by {progname} "
+ "- do not edit ###"
)
for line in imports:
buf.write(line + "\n")