summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/langhelpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util/langhelpers.py')
-rw-r--r--lib/sqlalchemy/util/langhelpers.py38
1 files changed, 32 insertions, 6 deletions
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index ca64296c1..93caa0ee5 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -20,6 +20,7 @@ import re
import sys
import textwrap
import types
+import typing
from typing import Any
from typing import Callable
from typing import Generic
@@ -31,10 +32,10 @@ import warnings
from . import _collections
from . import compat
+from . import typing as compat_typing
from .. import exc
_T = TypeVar("_T")
-_MP = TypeVar("_MP", bound="memoized_property[Any]")
def md5_hex(x):
@@ -166,7 +167,13 @@ def map_bits(fn, n):
n ^= b
-def decorator(target):
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Args = compat_typing.ParamSpec("_Args")
+
+
+def decorator(
+ target: typing.Callable[compat_typing.Concatenate[_Fn, _Args], typing.Any]
+) -> _Fn:
"""A signature-matching decorator factory."""
def decorate(fn):
@@ -198,7 +205,7 @@ def %(name)s%(grouped_args)s:
decorated.__wrapped__ = fn
return update_wrapper(decorated, fn)
- return update_wrapper(decorate, target)
+ return typing.cast(_Fn, update_wrapper(decorate, target))
def _update_argspec_defaults_into_env(spec, env):
@@ -227,7 +234,16 @@ def _exec_code_in_env(code, env, fn_name):
return env[fn_name]
-def public_factory(target, location, class_location=None):
+_TE = TypeVar("_TE")
+
+_P = compat_typing.ParamSpec("_P")
+
+
+def public_factory(
+ target: typing.Callable[_P, _TE],
+ location: str,
+ class_location: Optional[str] = None,
+) -> typing.Callable[_P, _TE]:
"""Produce a wrapping function for the given cls or classmethod.
Rationale here is so that the __init__ method of the
@@ -273,6 +289,7 @@ def %(name)s%(grouped_args)s:
"__name__": callable_.__module__,
}
exec(code, env)
+
decorated = env[location_name]
if hasattr(fn, "_linked_to"):
@@ -1077,6 +1094,11 @@ def as_interface(obj, cls=None, methods=None, required=None):
)
+Selfmemoized_property = TypeVar(
+ "Selfmemoized_property", bound="memoized_property[Any]"
+)
+
+
class memoized_property(Generic[_T]):
"""A read-only @property that is only evaluated once."""
@@ -1090,14 +1112,18 @@ class memoized_property(Generic[_T]):
self.__name__ = fget.__name__
@overload
- def __get__(self: _MP, obj: None, cls: Any) -> _MP:
+ def __get__(
+ self: Selfmemoized_property, obj: None, cls: Any
+ ) -> Selfmemoized_property:
...
@overload
def __get__(self, obj: Any, cls: Any) -> _T:
...
- def __get__(self: _MP, obj: Any, cls: Any) -> Union[_MP, _T]:
+ def __get__(
+ self: Selfmemoized_property, obj: Any, cls: Any
+ ) -> Union[Selfmemoized_property, _T]:
if obj is None:
return self
obj.__dict__[self.__name__] = result = self.fget(obj)