summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util')
-rw-r--r--lib/sqlalchemy/util/__init__.py258
-rw-r--r--lib/sqlalchemy/util/_collections.py88
-rw-r--r--lib/sqlalchemy/util/compat.py1
-rw-r--r--lib/sqlalchemy/util/concurrency.py18
-rw-r--r--lib/sqlalchemy/util/deprecations.py6
-rw-r--r--lib/sqlalchemy/util/langhelpers.py25
-rw-r--r--lib/sqlalchemy/util/typing.py159
7 files changed, 395 insertions, 160 deletions
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 91d15aae0..85bbca20f 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -6,131 +6,135 @@
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-from collections import defaultdict
-from functools import partial
-from functools import update_wrapper
+from collections import defaultdict as defaultdict
+from functools import partial as partial
+from functools import update_wrapper as update_wrapper
-from ._collections import coerce_generator_arg
-from ._collections import coerce_to_immutabledict
-from ._collections import column_dict
-from ._collections import column_set
-from ._collections import EMPTY_DICT
-from ._collections import EMPTY_SET
-from ._collections import FacadeDict
-from ._collections import flatten_iterator
-from ._collections import has_dupes
-from ._collections import has_intersection
-from ._collections import IdentitySet
-from ._collections import ImmutableContainer
-from ._collections import immutabledict
-from ._collections import ImmutableProperties
-from ._collections import LRUCache
-from ._collections import ordered_column_set
-from ._collections import OrderedDict
-from ._collections import OrderedIdentitySet
-from ._collections import OrderedProperties
-from ._collections import OrderedSet
-from ._collections import PopulateDict
-from ._collections import Properties
-from ._collections import ScopedRegistry
-from ._collections import sort_dictionary
-from ._collections import ThreadLocalRegistry
-from ._collections import to_column_set
-from ._collections import to_list
-from ._collections import to_set
-from ._collections import unique_list
-from ._collections import UniqueAppender
-from ._collections import update_copy
-from ._collections import WeakPopulateDict
-from ._collections import WeakSequence
-from ._preloaded import preload_module
-from ._preloaded import preloaded
-from .compat import arm
-from .compat import b
-from .compat import b64decode
-from .compat import b64encode
-from .compat import cmp
-from .compat import cpython
-from .compat import dataclass_fields
-from .compat import decode_backslashreplace
-from .compat import dottedgetter
-from .compat import has_refcount_gc
-from .compat import inspect_getfullargspec
-from .compat import local_dataclass_fields
-from .compat import next
-from .compat import osx
-from .compat import py38
-from .compat import py39
-from .compat import pypy
-from .compat import win32
-from .concurrency import asyncio
-from .concurrency import await_fallback
-from .concurrency import await_only
-from .concurrency import greenlet_spawn
-from .concurrency import is_exit_exception
-from .deprecations import became_legacy_20
-from .deprecations import deprecated
-from .deprecations import deprecated_cls
-from .deprecations import deprecated_params
-from .deprecations import deprecated_property
-from .deprecations import inject_docstring_text
-from .deprecations import moved_20
-from .deprecations import warn_deprecated
-from .langhelpers import add_parameter_text
-from .langhelpers import as_interface
-from .langhelpers import asbool
-from .langhelpers import asint
-from .langhelpers import assert_arg_type
-from .langhelpers import attrsetter
-from .langhelpers import bool_or_str
-from .langhelpers import chop_traceback
-from .langhelpers import class_hierarchy
-from .langhelpers import classproperty
-from .langhelpers import clsname_as_plain_name
-from .langhelpers import coerce_kw_type
-from .langhelpers import constructor_copy
-from .langhelpers import constructor_key
-from .langhelpers import counter
-from .langhelpers import create_proxy_methods
-from .langhelpers import decode_slice
-from .langhelpers import decorator
-from .langhelpers import dictlike_iteritems
-from .langhelpers import duck_type_collection
-from .langhelpers import ellipses_string
-from .langhelpers import EnsureKWArg
-from .langhelpers import format_argspec_init
-from .langhelpers import format_argspec_plus
-from .langhelpers import generic_repr
-from .langhelpers import get_callable_argspec
-from .langhelpers import get_cls_kwargs
-from .langhelpers import get_func_kwargs
-from .langhelpers import getargspec_init
-from .langhelpers import has_compiled_ext
-from .langhelpers import HasMemoized
-from .langhelpers import hybridmethod
-from .langhelpers import hybridproperty
-from .langhelpers import iterate_attributes
-from .langhelpers import map_bits
-from .langhelpers import md5_hex
-from .langhelpers import memoized_instancemethod
-from .langhelpers import memoized_property
-from .langhelpers import MemoizedSlots
-from .langhelpers import method_is_overridden
-from .langhelpers import methods_equivalent
-from .langhelpers import monkeypatch_proxied_specials
-from .langhelpers import NoneType
-from .langhelpers import only_once
-from .langhelpers import PluginLoader
-from .langhelpers import portable_instancemethod
-from .langhelpers import quoted_token_parser
-from .langhelpers import safe_reraise
-from .langhelpers import set_creation_order
-from .langhelpers import string_or_unprintable
-from .langhelpers import symbol
-from .langhelpers import TypingOnly
-from .langhelpers import unbound_method_to_callable
-from .langhelpers import walk_subclasses
-from .langhelpers import warn
-from .langhelpers import warn_exception
-from .langhelpers import warn_limited
-from .langhelpers import wrap_callable
+from ._collections import coerce_generator_arg as coerce_generator_arg
+from ._collections import coerce_to_immutabledict as coerce_to_immutabledict
+from ._collections import column_dict as column_dict
+from ._collections import column_set as column_set
+from ._collections import EMPTY_DICT as EMPTY_DICT
+from ._collections import EMPTY_SET as EMPTY_SET
+from ._collections import FacadeDict as FacadeDict
+from ._collections import flatten_iterator as flatten_iterator
+from ._collections import has_dupes as has_dupes
+from ._collections import has_intersection as has_intersection
+from ._collections import IdentitySet as IdentitySet
+from ._collections import ImmutableContainer as ImmutableContainer
+from ._collections import immutabledict as immutabledict
+from ._collections import ImmutableProperties as ImmutableProperties
+from ._collections import LRUCache as LRUCache
+from ._collections import merge_lists_w_ordering as merge_lists_w_ordering
+from ._collections import ordered_column_set as ordered_column_set
+from ._collections import OrderedDict as OrderedDict
+from ._collections import OrderedIdentitySet as OrderedIdentitySet
+from ._collections import OrderedProperties as OrderedProperties
+from ._collections import OrderedSet as OrderedSet
+from ._collections import PopulateDict as PopulateDict
+from ._collections import Properties as Properties
+from ._collections import ScopedRegistry as ScopedRegistry
+from ._collections import sort_dictionary as sort_dictionary
+from ._collections import ThreadLocalRegistry as ThreadLocalRegistry
+from ._collections import to_column_set as to_column_set
+from ._collections import to_list as to_list
+from ._collections import to_set as to_set
+from ._collections import unique_list as unique_list
+from ._collections import UniqueAppender as UniqueAppender
+from ._collections import update_copy as update_copy
+from ._collections import WeakPopulateDict as WeakPopulateDict
+from ._collections import WeakSequence as WeakSequence
+from ._preloaded import preload_module as preload_module
+from ._preloaded import preloaded as preloaded
+from .compat import arm as arm
+from .compat import b as b
+from .compat import b64decode as b64decode
+from .compat import b64encode as b64encode
+from .compat import cmp as cmp
+from .compat import cpython as cpython
+from .compat import dataclass_fields as dataclass_fields
+from .compat import decode_backslashreplace as decode_backslashreplace
+from .compat import dottedgetter as dottedgetter
+from .compat import has_refcount_gc as has_refcount_gc
+from .compat import inspect_getfullargspec as inspect_getfullargspec
+from .compat import local_dataclass_fields as local_dataclass_fields
+from .compat import osx as osx
+from .compat import py38 as py38
+from .compat import py39 as py39
+from .compat import pypy as pypy
+from .compat import win32 as win32
+from .concurrency import await_fallback as await_fallback
+from .concurrency import await_only as await_only
+from .concurrency import greenlet_spawn as greenlet_spawn
+from .concurrency import is_exit_exception as is_exit_exception
+from .deprecations import became_legacy_20 as became_legacy_20
+from .deprecations import deprecated as deprecated
+from .deprecations import deprecated_cls as deprecated_cls
+from .deprecations import deprecated_params as deprecated_params
+from .deprecations import deprecated_property as deprecated_property
+from .deprecations import moved_20 as moved_20
+from .deprecations import warn_deprecated as warn_deprecated
+from .langhelpers import add_parameter_text as add_parameter_text
+from .langhelpers import as_interface as as_interface
+from .langhelpers import asbool as asbool
+from .langhelpers import asint as asint
+from .langhelpers import assert_arg_type as assert_arg_type
+from .langhelpers import attrsetter as attrsetter
+from .langhelpers import bool_or_str as bool_or_str
+from .langhelpers import chop_traceback as chop_traceback
+from .langhelpers import class_hierarchy as class_hierarchy
+from .langhelpers import classproperty as classproperty
+from .langhelpers import clsname_as_plain_name as clsname_as_plain_name
+from .langhelpers import coerce_kw_type as coerce_kw_type
+from .langhelpers import constructor_copy as constructor_copy
+from .langhelpers import constructor_key as constructor_key
+from .langhelpers import counter as counter
+from .langhelpers import create_proxy_methods as create_proxy_methods
+from .langhelpers import decode_slice as decode_slice
+from .langhelpers import decorator as decorator
+from .langhelpers import dictlike_iteritems as dictlike_iteritems
+from .langhelpers import duck_type_collection as duck_type_collection
+from .langhelpers import ellipses_string as ellipses_string
+from .langhelpers import EnsureKWArg as EnsureKWArg
+from .langhelpers import format_argspec_init as format_argspec_init
+from .langhelpers import format_argspec_plus as format_argspec_plus
+from .langhelpers import generic_repr as generic_repr
+from .langhelpers import get_annotations as get_annotations
+from .langhelpers import get_callable_argspec as get_callable_argspec
+from .langhelpers import get_cls_kwargs as get_cls_kwargs
+from .langhelpers import get_func_kwargs as get_func_kwargs
+from .langhelpers import getargspec_init as getargspec_init
+from .langhelpers import has_compiled_ext as has_compiled_ext
+from .langhelpers import HasMemoized as HasMemoized
+from .langhelpers import hybridmethod as hybridmethod
+from .langhelpers import hybridproperty as hybridproperty
+from .langhelpers import inject_docstring_text as inject_docstring_text
+from .langhelpers import iterate_attributes as iterate_attributes
+from .langhelpers import map_bits as map_bits
+from .langhelpers import md5_hex as md5_hex
+from .langhelpers import memoized_instancemethod as memoized_instancemethod
+from .langhelpers import memoized_property as memoized_property
+from .langhelpers import MemoizedSlots as MemoizedSlots
+from .langhelpers import method_is_overridden as method_is_overridden
+from .langhelpers import methods_equivalent as methods_equivalent
+from .langhelpers import (
+ monkeypatch_proxied_specials as monkeypatch_proxied_specials,
+)
+from .langhelpers import NoneType as NoneType
+from .langhelpers import only_once as only_once
+from .langhelpers import PluginLoader as PluginLoader
+from .langhelpers import portable_instancemethod as portable_instancemethod
+from .langhelpers import quoted_token_parser as quoted_token_parser
+from .langhelpers import safe_reraise as safe_reraise
+from .langhelpers import set_creation_order as set_creation_order
+from .langhelpers import string_or_unprintable as string_or_unprintable
+from .langhelpers import symbol as symbol
+from .langhelpers import TypingOnly as TypingOnly
+from .langhelpers import (
+ unbound_method_to_callable as unbound_method_to_callable,
+)
+from .langhelpers import walk_subclasses as walk_subclasses
+from .langhelpers import warn as warn
+from .langhelpers import warn_exception as warn_exception
+from .langhelpers import warn_limited as warn_limited
+from .langhelpers import wrap_callable as wrap_callable
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 3e4ef1310..850986802 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -34,19 +34,27 @@ from ._has_cy import HAS_CYEXTENSION
from .typing import Literal
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_collections import immutabledict
- from ._py_collections import IdentitySet
- from ._py_collections import ImmutableContainer
- from ._py_collections import ImmutableDictBase
- from ._py_collections import OrderedSet
- from ._py_collections import unique_list # noqa
+ from ._py_collections import immutabledict as immutabledict
+ from ._py_collections import IdentitySet as IdentitySet
+ from ._py_collections import ImmutableContainer as ImmutableContainer
+ from ._py_collections import ImmutableDictBase as ImmutableDictBase
+ from ._py_collections import OrderedSet as OrderedSet
+ from ._py_collections import unique_list as unique_list
else:
- from sqlalchemy.cyextension.immutabledict import ImmutableContainer
- from sqlalchemy.cyextension.immutabledict import ImmutableDictBase
- from sqlalchemy.cyextension.immutabledict import immutabledict
- from sqlalchemy.cyextension.collections import IdentitySet
- from sqlalchemy.cyextension.collections import OrderedSet
- from sqlalchemy.cyextension.collections import unique_list # noqa
+ from sqlalchemy.cyextension.immutabledict import (
+ ImmutableContainer as ImmutableContainer,
+ )
+ from sqlalchemy.cyextension.immutabledict import (
+ ImmutableDictBase as ImmutableDictBase,
+ )
+ from sqlalchemy.cyextension.immutabledict import (
+ immutabledict as immutabledict,
+ )
+ from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet
+ from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet
+ from sqlalchemy.cyextension.collections import ( # noqa
+ unique_list as unique_list,
+ )
_T = TypeVar("_T", bound=Any)
@@ -57,6 +65,62 @@ _VT = TypeVar("_VT", bound=Any)
EMPTY_SET: FrozenSet[Any] = frozenset()
+def merge_lists_w_ordering(a, b):
+ """merge two lists, maintaining ordering as much as possible.
+
+ this is to reconcile vars(cls) with cls.__annotations__.
+
+ Example::
+
+ >>> a = ['__tablename__', 'id', 'x', 'created_at']
+ >>> b = ['id', 'name', 'data', 'y', 'created_at']
+ >>> merge_lists_w_ordering(a, b)
+ ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
+
+ This is not necessarily the ordering that things had on the class,
+ in this case the class is::
+
+ class User(Base):
+ __tablename__ = "users"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ data: Mapped[Optional[str]]
+ x = Column(Integer)
+ y: Mapped[int]
+ created_at: Mapped[datetime.datetime] = mapped_column()
+
+ But things are *mostly* ordered.
+
+ The algorithm could also be done by creating a partial ordering for
+ all items in both lists and then using topological_sort(), but that
+ is too much overhead.
+
+ Background on how I came up with this is at:
+ https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
+
+ """
+ overlap = set(a).intersection(b)
+
+ result = []
+
+ current, other = iter(a), iter(b)
+
+ while True:
+ for element in current:
+ if element in overlap:
+ overlap.discard(element)
+ other, current = current, other
+ break
+
+ result.append(element)
+ else:
+ result.extend(other)
+ break
+
+ return result
+
+
def coerce_to_immutabledict(d):
if not d:
return EMPTY_DICT
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index 0f4befbb1..62cffa556 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -39,7 +39,6 @@ arm = "aarch" in platform.machine().lower()
has_refcount_gc = bool(cpython)
dottedgetter = operator.attrgetter
-next = next # noqa
class FullArgSpec(typing.NamedTuple):
diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py
index 57ef23006..6b94a2294 100644
--- a/lib/sqlalchemy/util/concurrency.py
+++ b/lib/sqlalchemy/util/concurrency.py
@@ -16,15 +16,17 @@ except ImportError as e:
pass
else:
have_greenlet = True
- from ._concurrency_py3k import await_only
- from ._concurrency_py3k import await_fallback
- from ._concurrency_py3k import greenlet_spawn
- from ._concurrency_py3k import is_exit_exception
- from ._concurrency_py3k import AsyncAdaptedLock
- from ._concurrency_py3k import _util_async_run # noqa F401
+ from ._concurrency_py3k import await_only as await_only
+ from ._concurrency_py3k import await_fallback as await_fallback
+ from ._concurrency_py3k import greenlet_spawn as greenlet_spawn
+ from ._concurrency_py3k import is_exit_exception as is_exit_exception
+ from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock
from ._concurrency_py3k import (
- _util_async_run_coroutine_function,
- ) # noqa F401, E501
+ _util_async_run as _util_async_run,
+ ) # noqa F401
+ from ._concurrency_py3k import (
+ _util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa F401, E501
+ )
if not have_greenlet:
diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py
index 565cbafe2..7c2586166 100644
--- a/lib/sqlalchemy/util/deprecations.py
+++ b/lib/sqlalchemy/util/deprecations.py
@@ -13,6 +13,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
+from typing import Tuple
from typing import TypeVar
from . import compat
@@ -209,7 +210,10 @@ def became_legacy_20(api_name, alternative=None, **kw):
return deprecated("2.0", message=message, warning=warning_cls, **kw)
-def deprecated_params(**specs):
+_C = TypeVar("_C", bound=Callable[..., Any])
+
+
+def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]:
"""Decorates a function to warn on use of certain parameters.
e.g. ::
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 9401c249f..ed879894d 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -30,6 +30,7 @@ from typing import FrozenSet
from typing import Generic
from typing import Iterator
from typing import List
+from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
@@ -54,6 +55,30 @@ _HP = TypeVar("_HP", bound="hybridproperty")
_HM = TypeVar("_HM", bound="hybridmethod")
+if compat.py310:
+
+ def get_annotations(obj: Any) -> Mapping[str, Any]:
+ return inspect.get_annotations(obj)
+
+else:
+
+ def get_annotations(obj: Any) -> Mapping[str, Any]:
+ # it's been observed that cls.__annotations__ can be non present.
+ # it's not clear what causes this, running under tox py37/38 it
+ # happens, running straight pytest it doesnt
+
+ # https://docs.python.org/3/howto/annotations.html#annotations-howto
+ if isinstance(obj, type):
+ ann = obj.__dict__.get("__annotations__", None)
+ else:
+ ann = getattr(obj, "__annotations__", None)
+
+ if ann is None:
+ return _collections.EMPTY_DICT
+ else:
+ return cast("Mapping[str, Any]", ann)
+
+
def md5_hex(x: Any) -> str:
x = x.encode("utf-8")
m = hashlib.md5()
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 62a9f6c8a..56ea4d0e0 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -1,6 +1,10 @@
+import sys
import typing
from typing import Any
from typing import Callable # noqa
+from typing import cast
+from typing import Dict
+from typing import ForwardRef
from typing import Generic
from typing import overload
from typing import Type
@@ -13,21 +17,36 @@ from . import compat
_T = TypeVar("_T", bound=Any)
-if typing.TYPE_CHECKING or not compat.py38:
- from typing_extensions import Literal # noqa F401
- from typing_extensions import Protocol # noqa F401
- from typing_extensions import TypedDict # noqa F401
+if compat.py310:
+ # why they took until py310 to put this in stdlib is beyond me,
+ # I've been wanting it since py27
+ from types import NoneType
else:
- from typing import Literal # noqa F401
- from typing import Protocol # noqa F401
- from typing import TypedDict # noqa F401
+ NoneType = type(None) # type: ignore
+
+if typing.TYPE_CHECKING or compat.py310:
+ from typing import Annotated as Annotated
+else:
+ from typing_extensions import Annotated as Annotated # noqa F401
+
+if typing.TYPE_CHECKING or compat.py38:
+ from typing import Literal as Literal
+ from typing import Protocol as Protocol
+ from typing import TypedDict as TypedDict
+else:
+ from typing_extensions import Literal as Literal # noqa F401
+ from typing_extensions import Protocol as Protocol # noqa F401
+ from typing_extensions import TypedDict as TypedDict # noqa F401
+
+# work around https://github.com/microsoft/pyright/issues/3025
+_LiteralStar = Literal["*"]
if typing.TYPE_CHECKING or not compat.py310:
- from typing_extensions import Concatenate # noqa F401
- from typing_extensions import ParamSpec # noqa F401
+ from typing_extensions import Concatenate as Concatenate
+ from typing_extensions import ParamSpec as ParamSpec
else:
- from typing import Concatenate # noqa F401
- from typing import ParamSpec # noqa F401
+ from typing import Concatenate as Concatenate # noqa F401
+ from typing import ParamSpec as ParamSpec # noqa F401
class _TypeToInstance(Generic[_T]):
@@ -76,3 +95,121 @@ class ReadOnlyInstanceDescriptor(Protocol[_T]):
self, instance: object, owner: Any
) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]:
...
+
+
+def de_stringify_annotation(
+ cls: Type[Any], annotation: Union[str, Type[Any]]
+) -> Union[str, Type[Any]]:
+ """Resolve annotations that may be string based into real objects.
+
+ This is particularly important if a module defines "from __future__ import
+ annotations", as everything inside of __annotations__ is a string. We want
+ to at least have generic containers like ``Mapped``, ``Union``, ``List``,
+ etc.
+
+ """
+
+ # looked at typing.get_type_hints(), looked at pydantic. We need much
+ # less here, and we here try to not use any private typing internals
+ # or construct ForwardRef objects which is documented as something
+ # that should be avoided.
+
+ if (
+ is_fwd_ref(annotation)
+ and not cast(ForwardRef, annotation).__forward_evaluated__
+ ):
+ annotation = cast(ForwardRef, annotation).__forward_arg__
+
+ if isinstance(annotation, str):
+ base_globals: "Dict[str, Any]" = getattr(
+ sys.modules.get(cls.__module__, None), "__dict__", {}
+ )
+ try:
+ annotation = eval(annotation, base_globals, None)
+ except NameError:
+ pass
+ return annotation
+
+
+def is_fwd_ref(type_):
+ return isinstance(type_, ForwardRef)
+
+
+def de_optionalize_union_types(type_):
+ """Given a type, filter out ``Union`` types that include ``NoneType``
+ to not include the ``NoneType``.
+
+ """
+ if is_optional(type_):
+ typ = set(type_.__args__)
+
+ typ.discard(NoneType)
+
+ return make_union_type(*typ)
+
+ else:
+ return type_
+
+
+def make_union_type(*types):
+ """Make a Union type.
+
+ This is needed by :func:`.de_optionalize_union_types` which removes
+ ``NoneType`` from a ``Union``.
+
+ """
+ return cast(Any, Union).__getitem__(types)
+
+
+def expand_unions(type_, include_union=False, discard_none=False):
+ """Return a type as as a tuple of individual types, expanding for
+ ``Union`` types."""
+
+ if is_union(type_):
+ typ = set(type_.__args__)
+
+ if discard_none:
+ typ.discard(NoneType)
+
+ if include_union:
+ return (type_,) + tuple(typ)
+ else:
+ return tuple(typ)
+ else:
+ return (type_,)
+
+
+def is_optional(type_):
+ return is_origin_of(
+ type_,
+ "Optional",
+ "Union",
+ )
+
+
+def is_union(type_):
+ return is_origin_of(type_, "Union")
+
+
+def is_origin_of(type_, *names, module=None):
+ """return True if the given type has an __origin__ with the given name
+ and optional module."""
+
+ origin = getattr(type_, "__origin__", None)
+ if origin is None:
+ return False
+
+ return _get_type_name(origin) in names and (
+ module is None or origin.__module__.startswith(module)
+ )
+
+
+def _get_type_name(type_):
+ if compat.py310:
+ return type_.__name__
+ else:
+ typ_name = getattr(type_, "__name__", None)
+ if typ_name is None:
+ typ_name = getattr(type_, "_name", None)
+
+ return typ_name