summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util/typing.py')
-rw-r--r--lib/sqlalchemy/util/typing.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 0c8e5a633..b1ef87db1 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -74,7 +74,7 @@ typing_get_origin = get_origin
# copied from TypeShed, required in order to implement
# MutableMapping.update()
-_AnnotationScanType = Union[Type[Any], str, ForwardRef]
+_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"]
class ArgsTypeProcotol(Protocol):
@@ -236,6 +236,15 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
+def flatten_generic(
+ type_: Union[GenericProtocol[Any], Type[Any]]
+) -> Type[Any]:
+ if is_generic(type_):
+ return type_.__origin__
+ else:
+ return cast("Type[Any]", type_)
+
+
def is_fwd_ref(
type_: _AnnotationScanType, check_generic: bool = False
) -> bool: