summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/typing.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-11-22 16:07:49 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-11-22 16:07:49 +0000
commitcbc42334b15ad7cef03b3887be76f4bc2c0ff3ee (patch)
tree76b84c1510b93fa121548b461ac308239a37db7f /lib/sqlalchemy/util/typing.py
parent447249e8628ff849758c1a9cdf822ae060b7cb8b (diff)
parent509ffeedefca1ad0ad8e29c6c3410d270fb3d2b9 (diff)
downloadsqlalchemy-cbc42334b15ad7cef03b3887be76f4bc2c0ff3ee.tar.gz
Merge "fix optionalized forms for dict[]" into main
Diffstat (limited to 'lib/sqlalchemy/util/typing.py')
-rw-r--r--lib/sqlalchemy/util/typing.py70
1 files changed, 61 insertions, 9 deletions
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 20ad148f8..dcbc15825 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -70,12 +70,37 @@ NoneFwd = ForwardRef("None")
typing_get_args = get_args
typing_get_origin = get_origin
+
# copied from TypeShed, required in order to implement
# MutableMapping.update()
_AnnotationScanType = Union[Type[Any], str, ForwardRef]
+class ArgsTypeProcotol(Protocol):
+ """protocol for types that have ``__args__``
+
+ there's no public interface for this AFAIK
+
+ """
+
+ __args__: Tuple[_AnnotationScanType, ...]
+
+
+class GenericProtocol(Protocol[_T]):
+ """protocol for generic types.
+
+ this since Python.typing _GenericAlias is private
+
+ """
+
+ __args__: Tuple[_AnnotationScanType, ...]
+ __origin__: Type[_T]
+
+ def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
+ ...
+
+
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]:
...
@@ -93,6 +118,7 @@ def de_stringify_annotation(
annotation: _AnnotationScanType,
originating_module: str,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+ include_generic: bool = False,
) -> Type[Any]:
"""Resolve annotations that may be string based into real objects.
@@ -119,6 +145,20 @@ def de_stringify_annotation(
annotation = str_cleanup_fn(annotation, originating_module)
annotation = eval_expression(annotation, originating_module)
+
+ if include_generic and is_generic(annotation):
+ elements = tuple(
+ de_stringify_annotation(
+ cls,
+ elem,
+ originating_module,
+ str_cleanup_fn=str_cleanup_fn,
+ include_generic=include_generic,
+ )
+ for elem in annotation.__args__
+ )
+
+ return annotation.copy_with(elements)
return annotation # type: ignore
@@ -174,7 +214,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
def de_stringify_union_elements(
cls: Type[Any],
- annotation: _AnnotationScanType,
+ annotation: ArgsTypeProcotol,
originating_module: str,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
) -> Type[Any]:
@@ -183,7 +223,7 @@ def de_stringify_union_elements(
de_stringify_annotation(
cls, anno, originating_module, str_cleanup_fn
)
- for anno in annotation.__args__ # type: ignore
+ for anno in annotation.__args__
]
)
@@ -192,8 +232,19 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
-def is_fwd_ref(type_: _AnnotationScanType) -> bool:
- return isinstance(type_, ForwardRef)
+def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
+ return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
+
+
+def is_fwd_ref(
+ type_: _AnnotationScanType, check_generic: bool = False
+) -> bool:
+ if isinstance(type_, ForwardRef):
+ return True
+ elif check_generic and is_generic(type_):
+ return any(is_fwd_ref(arg, True) for arg in type_.__args__)
+ else:
+ return False
@overload
@@ -220,11 +271,12 @@ def de_optionalize_union_types(
to not include the ``NoneType``.
"""
+
if is_fwd_ref(type_):
return de_optionalize_fwd_ref_union_types(cast(ForwardRef, type_))
elif is_optional(type_):
- typ = set(type_.__args__) # type: ignore
+ typ = set(type_.__args__)
typ.discard(NoneType)
typ.discard(NoneFwd)
@@ -289,14 +341,14 @@ def expand_unions(
typ.discard(NoneType)
if include_union:
- return (type_,) + tuple(typ)
+ return (type_,) + tuple(typ) # type: ignore
else:
- return tuple(typ)
+ return tuple(typ) # type: ignore
else:
return (type_,)
-def is_optional(type_: Any) -> bool:
+def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(
type_,
"Optional",
@@ -309,7 +361,7 @@ def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
-def is_union(type_: Any) -> bool:
+def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(type_, "Union")