diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-11-22 16:07:49 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-11-22 16:07:49 +0000 |
commit | cbc42334b15ad7cef03b3887be76f4bc2c0ff3ee (patch) | |
tree | 76b84c1510b93fa121548b461ac308239a37db7f /lib/sqlalchemy/util/typing.py | |
parent | 447249e8628ff849758c1a9cdf822ae060b7cb8b (diff) | |
parent | 509ffeedefca1ad0ad8e29c6c3410d270fb3d2b9 (diff) | |
download | sqlalchemy-cbc42334b15ad7cef03b3887be76f4bc2c0ff3ee.tar.gz |
Merge "fix optionalized forms for dict[]" into main
Diffstat (limited to 'lib/sqlalchemy/util/typing.py')
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 70 |
1 files changed, 61 insertions, 9 deletions
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 20ad148f8..dcbc15825 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -70,12 +70,37 @@ NoneFwd = ForwardRef("None") typing_get_args = get_args typing_get_origin = get_origin + # copied from TypeShed, required in order to implement # MutableMapping.update() _AnnotationScanType = Union[Type[Any], str, ForwardRef] +class ArgsTypeProcotol(Protocol): + """protocol for types that have ``__args__`` + + there's no public interface for this AFAIK + + """ + + __args__: Tuple[_AnnotationScanType, ...] + + +class GenericProtocol(Protocol[_T]): + """protocol for generic types. + + this since Python.typing _GenericAlias is private + + """ + + __args__: Tuple[_AnnotationScanType, ...] + __origin__: Type[_T] + + def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]: + ... + + class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: ... @@ -93,6 +118,7 @@ def de_stringify_annotation( annotation: _AnnotationScanType, originating_module: str, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, ) -> Type[Any]: """Resolve annotations that may be string based into real objects. @@ -119,6 +145,20 @@ def de_stringify_annotation( annotation = str_cleanup_fn(annotation, originating_module) annotation = eval_expression(annotation, originating_module) + + if include_generic and is_generic(annotation): + elements = tuple( + de_stringify_annotation( + cls, + elem, + originating_module, + str_cleanup_fn=str_cleanup_fn, + include_generic=include_generic, + ) + for elem in annotation.__args__ + ) + + return annotation.copy_with(elements) return annotation # type: ignore @@ -174,7 +214,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str: def de_stringify_union_elements( cls: Type[Any], - annotation: _AnnotationScanType, + annotation: ArgsTypeProcotol, originating_module: str, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, ) -> Type[Any]: @@ -183,7 +223,7 @@ def de_stringify_union_elements( de_stringify_annotation( cls, anno, originating_module, str_cleanup_fn ) - for anno in annotation.__args__ # type: ignore + for anno in annotation.__args__ ] ) @@ -192,8 +232,19 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated -def is_fwd_ref(type_: _AnnotationScanType) -> bool: - return isinstance(type_, ForwardRef) +def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: + return hasattr(type_, "__args__") and hasattr(type_, "__origin__") + + +def is_fwd_ref( + type_: _AnnotationScanType, check_generic: bool = False +) -> bool: + if isinstance(type_, ForwardRef): + return True + elif check_generic and is_generic(type_): + return any(is_fwd_ref(arg, True) for arg in type_.__args__) + else: + return False @overload @@ -220,11 +271,12 @@ def de_optionalize_union_types( to not include the ``NoneType``. """ + if is_fwd_ref(type_): return de_optionalize_fwd_ref_union_types(cast(ForwardRef, type_)) elif is_optional(type_): - typ = set(type_.__args__) # type: ignore + typ = set(type_.__args__) typ.discard(NoneType) typ.discard(NoneFwd) @@ -289,14 +341,14 @@ def expand_unions( typ.discard(NoneType) if include_union: - return (type_,) + tuple(typ) + return (type_,) + tuple(typ) # type: ignore else: - return tuple(typ) + return tuple(typ) # type: ignore else: return (type_,) -def is_optional(type_: Any) -> bool: +def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]: return is_origin_of( type_, "Optional", @@ -309,7 +361,7 @@ def is_optional_union(type_: Any) -> bool: return is_optional(type_) and NoneType in typing_get_args(type_) -def is_union(type_: Any) -> bool: +def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]: return is_origin_of(type_, "Union") |