summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util')
-rw-r--r--lib/sqlalchemy/util/_collections.py4
-rw-r--r--lib/sqlalchemy/util/compat.py9
-rw-r--r--lib/sqlalchemy/util/langhelpers.py32
-rw-r--r--lib/sqlalchemy/util/preloaded.py4
-rw-r--r--lib/sqlalchemy/util/topological.py35
-rw-r--r--lib/sqlalchemy/util/typing.py77
6 files changed, 132 insertions, 29 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 7150dedcf..54be2e4e5 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -71,7 +71,7 @@ _T_co = TypeVar("_T_co", covariant=True)
EMPTY_SET: FrozenSet[Any] = frozenset()
-def merge_lists_w_ordering(a, b):
+def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
"""merge two lists, maintaining ordering as much as possible.
this is to reconcile vars(cls) with cls.__annotations__.
@@ -450,7 +450,7 @@ def to_set(x):
return x
-def to_column_set(x):
+def to_column_set(x: Any) -> Set[Any]:
if x is None:
return column_set()
if not isinstance(x, column_set):
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index 24fa0f3e3..adbbf143f 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -20,11 +20,14 @@ import typing
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
py311 = sys.version_info >= (3, 11)
@@ -225,7 +228,7 @@ def inspect_formatargspec(
return result
-def dataclass_fields(cls):
+def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
with a class."""
@@ -235,12 +238,12 @@ def dataclass_fields(cls):
return []
-def local_dataclass_fields(cls):
+def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
a class, excluding those that originate from a superclass."""
if dataclasses.is_dataclass(cls):
- super_fields = set()
+ super_fields: Set[dataclasses.Field[Any]] = set()
for sup in cls.__bases__:
super_fields.update(dataclass_fields(sup))
return [f for f in dataclasses.fields(cls) if f not in super_fields]
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 24c66bfa4..e54f33475 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -266,13 +266,31 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
metadata.update(format_argspec_plus(spec, grouped=False))
metadata["name"] = fn.__name__
- code = (
- """\
+
+ # look for __ positional arguments. This is a convention in
+ # SQLAlchemy that arguments should be passed positionally
+ # rather than as keyword
+ # arguments. note that apply_pos doesn't currently work in all cases
+ # such as when a kw-only indicator "*" is present, which is why
+ # we limit the use of this to just that case we can detect. As we add
+ # more kinds of methods that use @decorator, things may have to
+ # be further improved in this area
+ if "__" in repr(spec[0]):
+ code = (
+ """\
+def %(name)s%(grouped_args)s:
+ return %(target)s(%(fn)s, %(apply_pos)s)
+"""
+ % metadata
+ )
+ else:
+ code = (
+ """\
def %(name)s%(grouped_args)s:
return %(target)s(%(fn)s, %(apply_kw)s)
"""
- % metadata
- )
+ % metadata
+ )
env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
decorated = cast(
@@ -1235,10 +1253,10 @@ class HasMemoized:
return result
@classmethod
- def memoized_instancemethod(cls, fn: Any) -> Any:
+ def memoized_instancemethod(cls, fn: _F) -> _F:
"""Decorate a method memoize its return value."""
- def oneshot(self, *args, **kw):
+ def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
result = fn(self, *args, **kw)
def memo(*a, **kw):
@@ -1250,7 +1268,7 @@ class HasMemoized:
self._memoized_keys |= {fn.__name__}
return result
- return update_wrapper(oneshot, fn)
+ return update_wrapper(oneshot, fn) # type: ignore
if TYPE_CHECKING:
diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py
index fce3cd3b0..67394c9a3 100644
--- a/lib/sqlalchemy/util/preloaded.py
+++ b/lib/sqlalchemy/util/preloaded.py
@@ -25,8 +25,12 @@ _FN = TypeVar("_FN", bound=Callable[..., Any])
if TYPE_CHECKING:
from sqlalchemy.engine import default as engine_default # noqa
+ from sqlalchemy.orm import clsregistry as orm_clsregistry # noqa
+ from sqlalchemy.orm import decl_api as orm_decl_api # noqa
+ from sqlalchemy.orm import properties as orm_properties # noqa
from sqlalchemy.orm import relationships as orm_relationships # noqa
from sqlalchemy.orm import session as orm_session # noqa
+ from sqlalchemy.orm import state as orm_state # noqa
from sqlalchemy.orm import util as orm_util # noqa
from sqlalchemy.sql import dml as sql_dml # noqa
from sqlalchemy.sql import functions as sql_functions # noqa
diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py
index 37297103e..24e478b57 100644
--- a/lib/sqlalchemy/util/topological.py
+++ b/lib/sqlalchemy/util/topological.py
@@ -4,21 +4,33 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
"""Topological sorting algorithms."""
from __future__ import annotations
+from typing import Any
+from typing import DefaultDict
+from typing import Iterable
+from typing import Iterator
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+
from .. import util
from ..exc import CircularDependencyError
+_T = TypeVar("_T", bound=Any)
+
__all__ = ["sort", "sort_as_subsets", "find_cycles"]
-def sort_as_subsets(tuples, allitems):
+def sort_as_subsets(
+ tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
+) -> Iterator[Sequence[_T]]:
- edges = util.defaultdict(set)
+ edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
@@ -43,7 +55,11 @@ def sort_as_subsets(tuples, allitems):
yield output
-def sort(tuples, allitems, deterministic_order=True):
+def sort(
+ tuples: Iterable[Tuple[_T, _T]],
+ allitems: Iterable[_T],
+ deterministic_order: bool = True,
+) -> Iterator[_T]:
"""sort the given list of items by dependency.
'tuples' is a list of tuples representing a partial ordering.
@@ -59,11 +75,14 @@ def sort(tuples, allitems, deterministic_order=True):
yield s
-def find_cycles(tuples, allitems):
+def find_cycles(
+ tuples: Iterable[Tuple[_T, _T]],
+ allitems: Iterable[_T],
+) -> Set[_T]:
# adapted from:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
- edges = util.defaultdict(set)
+ edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[parent].add(child)
nodes_to_test = set(edges)
@@ -99,5 +118,5 @@ def find_cycles(tuples, allitems):
return output
-def _gen_edges(edges):
- return set([(right, left) for left in edges for right in edges[left]])
+def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
+ return {(right, left) for left in edges for right in edges[left]}
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index ebcae28a7..44e26f609 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -11,7 +11,9 @@ from typing import Dict
from typing import ForwardRef
from typing import Generic
from typing import Iterable
+from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -33,7 +35,7 @@ Self = TypeVar("Self", bound=Any)
if compat.py310:
# why they took until py310 to put this in stdlib is beyond me,
# I've been wanting it since py27
- from types import NoneType
+ from types import NoneType as NoneType
else:
NoneType = type(None) # type: ignore
@@ -68,6 +70,8 @@ else:
# copied from TypeShed, required in order to implement
# MutableMapping.update()
+_AnnotationScanType = Union[Type[Any], str]
+
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]:
@@ -90,9 +94,9 @@ else:
def de_stringify_annotation(
cls: Type[Any],
- annotation: Union[str, Type[Any]],
+ annotation: _AnnotationScanType,
str_cleanup_fn: Optional[Callable[[str], str]] = None,
-) -> Union[str, Type[Any]]:
+) -> Type[Any]:
"""Resolve annotations that may be string based into real objects.
This is particularly important if a module defines "from __future__ import
@@ -125,20 +129,32 @@ def de_stringify_annotation(
annotation = eval(annotation, base_globals, None)
except NameError:
pass
- return annotation
+ return annotation # type: ignore
-def is_fwd_ref(type_):
+def is_fwd_ref(type_: _AnnotationScanType) -> bool:
return isinstance(type_, ForwardRef)
-def de_optionalize_union_types(type_):
+@overload
+def de_optionalize_union_types(type_: str) -> str:
+ ...
+
+
+@overload
+def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]:
+ ...
+
+
+def de_optionalize_union_types(
+ type_: _AnnotationScanType,
+) -> _AnnotationScanType:
"""Given a type, filter out ``Union`` types that include ``NoneType``
to not include the ``NoneType``.
"""
if is_optional(type_):
- typ = set(type_.__args__)
+ typ = set(type_.__args__) # type: ignore
typ.discard(NoneType)
@@ -148,14 +164,14 @@ def de_optionalize_union_types(type_):
return type_
-def make_union_type(*types):
+def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
"""Make a Union type.
This is needed by :func:`.de_optionalize_union_types` which removes
``NoneType`` from a ``Union``.
"""
- return cast(Any, Union).__getitem__(types)
+ return cast(Any, Union).__getitem__(types) # type: ignore
def expand_unions(
@@ -251,4 +267,47 @@ class DescriptorReference(Generic[_DESC]):
...
+_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
+
+
+class RODescriptorReference(Generic[_DESC_co]):
+ """a descriptor that refers to a descriptor.
+
+ same as :class:`.DescriptorReference` but is read-only, so that subclasses
+ can define a subtype as the generically contained element
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _DESC_co:
+ ...
+
+ def __set__(self, instance: Any, value: Any) -> NoReturn:
+ ...
+
+ def __delete__(self, instance: Any) -> NoReturn:
+ ...
+
+
+_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
+
+
+class CallableReference(Generic[_FN]):
+ """a descriptor that refers to a callable.
+
+ works around mypy's limitation of not allowing callables assigned
+ as instance variables
+
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _FN:
+ ...
+
+ def __set__(self, instance: Any, value: _FN) -> None:
+ ...
+
+ def __delete__(self, instance: Any) -> None:
+ ...
+
+
# $def ro_descriptor_reference(fn: Callable[])