diff options
Diffstat (limited to 'lib/sqlalchemy/util')
-rw-r--r-- | lib/sqlalchemy/util/_collections.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/util/compat.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/util/preloaded.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/util/topological.py | 35 | ||||
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 77 |
6 files changed, 132 insertions, 29 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 7150dedcf..54be2e4e5 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -71,7 +71,7 @@ _T_co = TypeVar("_T_co", covariant=True) EMPTY_SET: FrozenSet[Any] = frozenset() -def merge_lists_w_ordering(a, b): +def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: """merge two lists, maintaining ordering as much as possible. this is to reconcile vars(cls) with cls.__annotations__. @@ -450,7 +450,7 @@ def to_set(x): return x -def to_column_set(x): +def to_column_set(x: Any) -> Set[Any]: if x is None: return column_set() if not isinstance(x, column_set): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 24fa0f3e3..adbbf143f 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -20,11 +20,14 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import List from typing import Mapping from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type py311 = sys.version_info >= (3, 11) @@ -225,7 +228,7 @@ def inspect_formatargspec( return result -def dataclass_fields(cls): +def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with a class.""" @@ -235,12 +238,12 @@ def dataclass_fields(cls): return [] -def local_dataclass_fields(cls): +def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with a class, excluding those that originate from a superclass.""" if dataclasses.is_dataclass(cls): - super_fields = set() + super_fields: Set[dataclasses.Field[Any]] = set() for sup in cls.__bases__: super_fields.update(dataclass_fields(sup)) return [f for f in dataclasses.fields(cls) if f not in super_fields] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 24c66bfa4..e54f33475 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -266,13 +266,31 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]: metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name) metadata.update(format_argspec_plus(spec, grouped=False)) metadata["name"] = fn.__name__ - code = ( - """\ + + # look for __ positional arguments. This is a convention in + # SQLAlchemy that arguments should be passed positionally + # rather than as keyword + # arguments. note that apply_pos doesn't currently work in all cases + # such as when a kw-only indicator "*" is present, which is why + # we limit the use of this to just that case we can detect. As we add + # more kinds of methods that use @decorator, things may have to + # be further improved in this area + if "__" in repr(spec[0]): + code = ( + """\ +def %(name)s%(grouped_args)s: + return %(target)s(%(fn)s, %(apply_pos)s) +""" + % metadata + ) + else: + code = ( + """\ def %(name)s%(grouped_args)s: return %(target)s(%(fn)s, %(apply_kw)s) """ - % metadata - ) + % metadata + ) env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__}) decorated = cast( @@ -1235,10 +1253,10 @@ class HasMemoized: return result @classmethod - def memoized_instancemethod(cls, fn: Any) -> Any: + def memoized_instancemethod(cls, fn: _F) -> _F: """Decorate a method memoize its return value.""" - def oneshot(self, *args, **kw): + def oneshot(self: Any, *args: Any, **kw: Any) -> Any: result = fn(self, *args, **kw) def memo(*a, **kw): @@ -1250,7 +1268,7 @@ class HasMemoized: self._memoized_keys |= {fn.__name__} return result - return update_wrapper(oneshot, fn) + return update_wrapper(oneshot, fn) # type: ignore if TYPE_CHECKING: diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index fce3cd3b0..67394c9a3 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -25,8 +25,12 @@ _FN = TypeVar("_FN", bound=Callable[..., Any]) if TYPE_CHECKING: from sqlalchemy.engine import default as engine_default # noqa + from sqlalchemy.orm import clsregistry as orm_clsregistry # noqa + from sqlalchemy.orm import decl_api as orm_decl_api # noqa + from sqlalchemy.orm import properties as orm_properties # noqa from sqlalchemy.orm import relationships as orm_relationships # noqa from sqlalchemy.orm import session as orm_session # noqa + from sqlalchemy.orm import state as orm_state # noqa from sqlalchemy.orm import util as orm_util # noqa from sqlalchemy.sql import dml as sql_dml # noqa from sqlalchemy.sql import functions as sql_functions # noqa diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 37297103e..24e478b57 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -4,21 +4,33 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls """Topological sorting algorithms.""" from __future__ import annotations +from typing import Any +from typing import DefaultDict +from typing import Iterable +from typing import Iterator +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TypeVar + from .. import util from ..exc import CircularDependencyError +_T = TypeVar("_T", bound=Any) + __all__ = ["sort", "sort_as_subsets", "find_cycles"] -def sort_as_subsets(tuples, allitems): +def sort_as_subsets( + tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T] +) -> Iterator[Sequence[_T]]: - edges = util.defaultdict(set) + edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set) for parent, child in tuples: edges[child].add(parent) @@ -43,7 +55,11 @@ def sort_as_subsets(tuples, allitems): yield output -def sort(tuples, allitems, deterministic_order=True): +def sort( + tuples: Iterable[Tuple[_T, _T]], + allitems: Iterable[_T], + deterministic_order: bool = True, +) -> Iterator[_T]: """sort the given list of items by dependency. 'tuples' is a list of tuples representing a partial ordering. @@ -59,11 +75,14 @@ def sort(tuples, allitems, deterministic_order=True): yield s -def find_cycles(tuples, allitems): +def find_cycles( + tuples: Iterable[Tuple[_T, _T]], + allitems: Iterable[_T], +) -> Set[_T]: # adapted from: # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html - edges = util.defaultdict(set) + edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set) for parent, child in tuples: edges[parent].add(child) nodes_to_test = set(edges) @@ -99,5 +118,5 @@ def find_cycles(tuples, allitems): return output -def _gen_edges(edges): - return set([(right, left) for left in edges for right in edges[left]]) +def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]: + return {(right, left) for left in edges for right in edges[left]} diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index ebcae28a7..44e26f609 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -11,7 +11,9 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import NoReturn from typing import Optional +from typing import overload from typing import Tuple from typing import Type from typing import TypeVar @@ -33,7 +35,7 @@ Self = TypeVar("Self", bound=Any) if compat.py310: # why they took until py310 to put this in stdlib is beyond me, # I've been wanting it since py27 - from types import NoneType + from types import NoneType as NoneType else: NoneType = type(None) # type: ignore @@ -68,6 +70,8 @@ else: # copied from TypeShed, required in order to implement # MutableMapping.update() +_AnnotationScanType = Union[Type[Any], str] + class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: @@ -90,9 +94,9 @@ else: def de_stringify_annotation( cls: Type[Any], - annotation: Union[str, Type[Any]], + annotation: _AnnotationScanType, str_cleanup_fn: Optional[Callable[[str], str]] = None, -) -> Union[str, Type[Any]]: +) -> Type[Any]: """Resolve annotations that may be string based into real objects. This is particularly important if a module defines "from __future__ import @@ -125,20 +129,32 @@ def de_stringify_annotation( annotation = eval(annotation, base_globals, None) except NameError: pass - return annotation + return annotation # type: ignore -def is_fwd_ref(type_): +def is_fwd_ref(type_: _AnnotationScanType) -> bool: return isinstance(type_, ForwardRef) -def de_optionalize_union_types(type_): +@overload +def de_optionalize_union_types(type_: str) -> str: + ... + + +@overload +def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: + ... + + +def de_optionalize_union_types( + type_: _AnnotationScanType, +) -> _AnnotationScanType: """Given a type, filter out ``Union`` types that include ``NoneType`` to not include the ``NoneType``. """ if is_optional(type_): - typ = set(type_.__args__) + typ = set(type_.__args__) # type: ignore typ.discard(NoneType) @@ -148,14 +164,14 @@ def de_optionalize_union_types(type_): return type_ -def make_union_type(*types): +def make_union_type(*types: _AnnotationScanType) -> Type[Any]: """Make a Union type. This is needed by :func:`.de_optionalize_union_types` which removes ``NoneType`` from a ``Union``. """ - return cast(Any, Union).__getitem__(types) + return cast(Any, Union).__getitem__(types) # type: ignore def expand_unions( @@ -251,4 +267,47 @@ class DescriptorReference(Generic[_DESC]): ... +_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True) + + +class RODescriptorReference(Generic[_DESC_co]): + """a descriptor that refers to a descriptor. + + same as :class:`.DescriptorReference` but is read-only, so that subclasses + can define a subtype as the generically contained element + + """ + + def __get__(self, instance: object, owner: Any) -> _DESC_co: + ... + + def __set__(self, instance: Any, value: Any) -> NoReturn: + ... + + def __delete__(self, instance: Any) -> NoReturn: + ... + + +_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]]) + + +class CallableReference(Generic[_FN]): + """a descriptor that refers to a callable. + + works around mypy's limitation of not allowing callables assigned + as instance variables + + + """ + + def __get__(self, instance: object, owner: Any) -> _FN: + ... + + def __set__(self, instance: Any, value: _FN) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + # $def ro_descriptor_reference(fn: Callable[]) |