diff options
Diffstat (limited to 'lib/sqlalchemy/util/typing.py')
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 0c8e5a633..b1ef87db1 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -74,7 +74,7 @@ typing_get_origin = get_origin # copied from TypeShed, required in order to implement # MutableMapping.update() -_AnnotationScanType = Union[Type[Any], str, ForwardRef] +_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"] class ArgsTypeProcotol(Protocol): @@ -236,6 +236,15 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: return hasattr(type_, "__args__") and hasattr(type_, "__origin__") +def flatten_generic( + type_: Union[GenericProtocol[Any], Type[Any]] +) -> Type[Any]: + if is_generic(type_): + return type_.__origin__ + else: + return cast("Type[Any]", type_) + + def is_fwd_ref( type_: _AnnotationScanType, check_generic: bool = False ) -> bool: |