diff options
Diffstat (limited to 'lib/sqlalchemy/util/typing.py')
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 159 |
1 files changed, 148 insertions, 11 deletions
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 62a9f6c8a..56ea4d0e0 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,6 +1,10 @@ +import sys import typing from typing import Any from typing import Callable # noqa +from typing import cast +from typing import Dict +from typing import ForwardRef from typing import Generic from typing import overload from typing import Type @@ -13,21 +17,36 @@ from . import compat _T = TypeVar("_T", bound=Any) -if typing.TYPE_CHECKING or not compat.py38: - from typing_extensions import Literal # noqa F401 - from typing_extensions import Protocol # noqa F401 - from typing_extensions import TypedDict # noqa F401 +if compat.py310: + # why they took until py310 to put this in stdlib is beyond me, + # I've been wanting it since py27 + from types import NoneType else: - from typing import Literal # noqa F401 - from typing import Protocol # noqa F401 - from typing import TypedDict # noqa F401 + NoneType = type(None) # type: ignore + +if typing.TYPE_CHECKING or compat.py310: + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated # noqa F401 + +if typing.TYPE_CHECKING or compat.py38: + from typing import Literal as Literal + from typing import Protocol as Protocol + from typing import TypedDict as TypedDict +else: + from typing_extensions import Literal as Literal # noqa F401 + from typing_extensions import Protocol as Protocol # noqa F401 + from typing_extensions import TypedDict as TypedDict # noqa F401 + +# work around https://github.com/microsoft/pyright/issues/3025 +_LiteralStar = Literal["*"] if typing.TYPE_CHECKING or not compat.py310: - from typing_extensions import Concatenate # noqa F401 - from typing_extensions import ParamSpec # noqa F401 + from typing_extensions import Concatenate as Concatenate + from typing_extensions import ParamSpec as ParamSpec else: - from typing import Concatenate # noqa F401 - from typing import ParamSpec # noqa F401 + from typing import Concatenate as Concatenate # noqa F401 + from typing import ParamSpec as ParamSpec # noqa F401 class _TypeToInstance(Generic[_T]): @@ -76,3 +95,121 @@ class ReadOnlyInstanceDescriptor(Protocol[_T]): self, instance: object, owner: Any ) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]: ... + + +def de_stringify_annotation( + cls: Type[Any], annotation: Union[str, Type[Any]] +) -> Union[str, Type[Any]]: + """Resolve annotations that may be string based into real objects. + + This is particularly important if a module defines "from __future__ import + annotations", as everything inside of __annotations__ is a string. We want + to at least have generic containers like ``Mapped``, ``Union``, ``List``, + etc. + + """ + + # looked at typing.get_type_hints(), looked at pydantic. We need much + # less here, and we here try to not use any private typing internals + # or construct ForwardRef objects which is documented as something + # that should be avoided. + + if ( + is_fwd_ref(annotation) + and not cast(ForwardRef, annotation).__forward_evaluated__ + ): + annotation = cast(ForwardRef, annotation).__forward_arg__ + + if isinstance(annotation, str): + base_globals: "Dict[str, Any]" = getattr( + sys.modules.get(cls.__module__, None), "__dict__", {} + ) + try: + annotation = eval(annotation, base_globals, None) + except NameError: + pass + return annotation + + +def is_fwd_ref(type_): + return isinstance(type_, ForwardRef) + + +def de_optionalize_union_types(type_): + """Given a type, filter out ``Union`` types that include ``NoneType`` + to not include the ``NoneType``. + + """ + if is_optional(type_): + typ = set(type_.__args__) + + typ.discard(NoneType) + + return make_union_type(*typ) + + else: + return type_ + + +def make_union_type(*types): + """Make a Union type. + + This is needed by :func:`.de_optionalize_union_types` which removes + ``NoneType`` from a ``Union``. + + """ + return cast(Any, Union).__getitem__(types) + + +def expand_unions(type_, include_union=False, discard_none=False): + """Return a type as as a tuple of individual types, expanding for + ``Union`` types.""" + + if is_union(type_): + typ = set(type_.__args__) + + if discard_none: + typ.discard(NoneType) + + if include_union: + return (type_,) + tuple(typ) + else: + return tuple(typ) + else: + return (type_,) + + +def is_optional(type_): + return is_origin_of( + type_, + "Optional", + "Union", + ) + + +def is_union(type_): + return is_origin_of(type_, "Union") + + +def is_origin_of(type_, *names, module=None): + """return True if the given type has an __origin__ with the given name + and optional module.""" + + origin = getattr(type_, "__origin__", None) + if origin is None: + return False + + return _get_type_name(origin) in names and ( + module is None or origin.__module__.startswith(module) + ) + + +def _get_type_name(type_): + if compat.py310: + return type_.__name__ + else: + typ_name = getattr(type_, "__name__", None) + if typ_name is None: + typ_name = getattr(type_, "_name", None) + + return typ_name |