summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util')
-rw-r--r--lib/sqlalchemy/util/__init__.py1
-rw-r--r--lib/sqlalchemy/util/_collections.py35
-rw-r--r--lib/sqlalchemy/util/_py_collections.py76
-rw-r--r--lib/sqlalchemy/util/langhelpers.py173
-rw-r--r--lib/sqlalchemy/util/preloaded.py2
-rw-r--r--lib/sqlalchemy/util/typing.py10
6 files changed, 156 insertions, 141 deletions
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index c0c2e7dfb..6d41231d9 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -96,6 +96,7 @@ from .langhelpers import dictlike_iteritems as dictlike_iteritems
from .langhelpers import duck_type_collection as duck_type_collection
from .langhelpers import ellipses_string as ellipses_string
from .langhelpers import EnsureKWArg as EnsureKWArg
+from .langhelpers import FastIntFlag as FastIntFlag
from .langhelpers import format_argspec_init as format_argspec_init
from .langhelpers import format_argspec_plus as format_argspec_plus
from .langhelpers import generic_fn_descriptor as generic_fn_descriptor
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index eb5b16b65..eea76f60b 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -22,6 +22,7 @@ from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
+from typing import Mapping
from typing import Optional
from typing import overload
from typing import Set
@@ -123,7 +124,7 @@ def merge_lists_w_ordering(a, b):
return result
-def coerce_to_immutabledict(d):
+def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
if not d:
return EMPTY_DICT
elif isinstance(d, immutabledict):
@@ -161,6 +162,8 @@ class FacadeDict(ImmutableDictBase[_KT, _VT]):
_DT = TypeVar("_DT", bound=Any)
+_F = TypeVar("_F", bound=Any)
+
class Properties(Generic[_T]):
"""Provide a __getattr__/__setattr__ interface over a dict."""
@@ -169,7 +172,7 @@ class Properties(Generic[_T]):
_data: Dict[str, _T]
- def __init__(self, data):
+ def __init__(self, data: Dict[str, _T]):
object.__setattr__(self, "_data", data)
def __len__(self) -> int:
@@ -178,30 +181,30 @@ class Properties(Generic[_T]):
def __iter__(self) -> Iterator[_T]:
return iter(list(self._data.values()))
- def __dir__(self):
+ def __dir__(self) -> List[str]:
return dir(super(Properties, self)) + [
str(k) for k in self._data.keys()
]
- def __add__(self, other):
- return list(self) + list(other)
+ def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
+ return list(self) + list(other) # type: ignore
- def __setitem__(self, key, obj):
+ def __setitem__(self, key: str, obj: _T) -> None:
self._data[key] = obj
def __getitem__(self, key: str) -> _T:
return self._data[key]
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> None:
del self._data[key]
- def __setattr__(self, key, obj):
+ def __setattr__(self, key: str, obj: _T) -> None:
self._data[key] = obj
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {"_data": self._data}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_data", state["_data"])
def __getattr__(self, key: str) -> _T:
@@ -213,12 +216,12 @@ class Properties(Generic[_T]):
def __contains__(self, key: str) -> bool:
return key in self._data
- def as_readonly(self) -> "ReadOnlyProperties[_T]":
+ def as_readonly(self) -> ReadOnlyProperties[_T]:
"""Return an immutable proxy for this :class:`.Properties`."""
return ReadOnlyProperties(self._data)
- def update(self, value):
+ def update(self, value: Dict[str, _T]) -> None:
self._data.update(value)
@overload
@@ -249,7 +252,7 @@ class Properties(Generic[_T]):
def has_key(self, key: str) -> bool:
return key in self._data
- def clear(self):
+ def clear(self) -> None:
self._data.clear()
@@ -318,7 +321,7 @@ class WeakSequence:
class OrderedIdentitySet(IdentitySet):
- def __init__(self, iterable=None):
+ def __init__(self, iterable: Optional[Iterable[Any]] = None):
IdentitySet.__init__(self)
self._members = OrderedDict()
if iterable:
@@ -615,7 +618,9 @@ class ScopedRegistry(Generic[_T]):
scopefunc: _ScopeFuncType
registry: Any
- def __init__(self, createfunc, scopefunc):
+ def __init__(
+ self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
+ ):
"""Construct a new :class:`.ScopedRegistry`.
:param createfunc: A creation function that will generate
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index d649a0bea..725f6930e 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -263,52 +263,54 @@ class IdentitySet:
"""
- def __init__(self, iterable=None):
+ _members: Dict[int, Any]
+
+ def __init__(self, iterable: Optional[Iterable[Any]] = None):
self._members = dict()
if iterable:
self.update(iterable)
- def add(self, value):
+ def add(self, value: Any) -> None:
self._members[id(value)] = value
- def __contains__(self, value):
+ def __contains__(self, value: Any) -> bool:
return id(value) in self._members
- def remove(self, value):
+ def remove(self, value: Any) -> None:
del self._members[id(value)]
- def discard(self, value):
+ def discard(self, value: Any) -> None:
try:
self.remove(value)
except KeyError:
pass
- def pop(self):
+ def pop(self) -> Any:
try:
pair = self._members.popitem()
return pair[1]
except KeyError:
raise KeyError("pop from an empty set")
- def clear(self):
+ def clear(self) -> None:
self._members.clear()
- def __cmp__(self, other):
+ def __cmp__(self, other: Any) -> NoReturn:
raise TypeError("cannot compare sets using cmp()")
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, IdentitySet):
return self._members == other._members
else:
return False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
if isinstance(other, IdentitySet):
return self._members != other._members
else:
return True
- def issubset(self, iterable):
+ def issubset(self, iterable: Iterable[Any]) -> bool:
if isinstance(iterable, self.__class__):
other = iterable
else:
@@ -322,17 +324,17 @@ class IdentitySet:
return False
return True
- def __le__(self, other):
+ def __le__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issubset(other)
- def __lt__(self, other):
+ def __lt__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) < len(other) and self.issubset(other)
- def issuperset(self, iterable):
+ def issuperset(self, iterable: Iterable[Any]) -> bool:
if isinstance(iterable, self.__class__):
other = iterable
else:
@@ -347,38 +349,38 @@ class IdentitySet:
return False
return True
- def __ge__(self, other):
+ def __ge__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issuperset(other)
- def __gt__(self, other):
+ def __gt__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) > len(other) and self.issuperset(other)
- def union(self, iterable):
+ def union(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__class__()
members = self._members
result._members.update(members)
result._members.update((id(obj), obj) for obj in iterable)
return result
- def __or__(self, other):
+ def __or__(self, other: Any) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.union(other)
- def update(self, iterable):
+ def update(self, iterable: Iterable[Any]) -> None:
self._members.update((id(obj), obj) for obj in iterable)
- def __ior__(self, other):
+ def __ior__(self, other: Any) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.update(other)
return self
- def difference(self, iterable):
+ def difference(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
other: Collection[Any]
@@ -391,21 +393,21 @@ class IdentitySet:
}
return result
- def __sub__(self, other):
+ def __sub__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.difference(other)
- def difference_update(self, iterable):
+ def difference_update(self, iterable: Iterable[Any]) -> None:
self._members = self.difference(iterable)._members
- def __isub__(self, other):
+ def __isub__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.difference_update(other)
return self
- def intersection(self, iterable):
+ def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
other: Collection[Any]
@@ -419,21 +421,21 @@ class IdentitySet:
}
return result
- def __and__(self, other):
+ def __and__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.intersection(other)
- def intersection_update(self, iterable):
+ def intersection_update(self, iterable: Iterable[Any]) -> None:
self._members = self.intersection(iterable)._members
- def __iand__(self, other):
+ def __iand__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.intersection_update(other)
return self
- def symmetric_difference(self, iterable):
+ def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
if isinstance(iterable, self.__class__):
other = iterable._members
@@ -447,37 +449,37 @@ class IdentitySet:
)
return result
- def __xor__(self, other):
+ def __xor__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.symmetric_difference(other)
- def symmetric_difference_update(self, iterable):
+ def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
self._members = self.symmetric_difference(iterable)._members
- def __ixor__(self, other):
+ def __ixor__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.symmetric_difference(other)
return self
- def copy(self):
+ def copy(self) -> IdentitySet:
result = self.__new__(self.__class__)
result._members = self._members.copy()
return result
__copy__ = copy
- def __len__(self):
+ def __len__(self) -> int:
return len(self._members)
- def __iter__(self):
+ def __iter__(self) -> Iterator[Any]:
return iter(self._members.values())
- def __hash__(self):
+ def __hash__(self) -> NoReturn:
raise TypeError("set objects are unhashable")
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 3e89c72bb..2cb9c45d6 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -12,6 +12,7 @@ modules, classes, hierarchies, attributes, functions, and methods.
from __future__ import annotations
import collections
+import enum
from functools import update_wrapper
import hashlib
import inspect
@@ -671,13 +672,13 @@ def format_argspec_init(method, grouped=True):
def create_proxy_methods(
- target_cls,
- target_cls_sphinx_name,
- proxy_cls_sphinx_name,
- classmethods=(),
- methods=(),
- attributes=(),
-):
+ target_cls: Type[Any],
+ target_cls_sphinx_name: str,
+ proxy_cls_sphinx_name: str,
+ classmethods: Sequence[str] = (),
+ methods: Sequence[str] = (),
+ attributes: Sequence[str] = (),
+) -> Callable[[_T], _T]:
"""A class decorator indicating attributes should refer to a proxy
class.
@@ -1539,24 +1540,50 @@ class hybridmethod(Generic[_T]):
return self
-class _symbol(int):
+class symbol(int):
+ """A constant symbol.
+
+ >>> symbol('foo') is symbol('foo')
+ True
+ >>> symbol('foo')
+ <symbol 'foo>
+
+ A slight refinement of the MAGICCOOKIE=object() pattern. The primary
+ advantage of symbol() is its repr(). They are also singletons.
+
+ Repeated calls of symbol('name') will all return the same instance.
+
+ In SQLAlchemy 2.0, symbol() is used for the implementation of
+ ``_FastIntFlag``, but otherwise should be mostly replaced by
+ ``enum.Enum`` and variants.
+
+
+ """
+
name: str
+ symbols: Dict[str, symbol] = {}
+ _lock = threading.Lock()
+
def __new__(
cls,
name: str,
doc: Optional[str] = None,
canonical: Optional[int] = None,
- ) -> "_symbol":
- """Construct a new named symbol."""
- assert isinstance(name, str)
- if canonical is None:
- canonical = hash(name)
- v = int.__new__(_symbol, canonical)
- v.name = name
- if doc:
- v.__doc__ = doc
- return v
+ ) -> symbol:
+ with cls._lock:
+ sym = cls.symbols.get(name)
+ if sym is None:
+ assert isinstance(name, str)
+ if canonical is None:
+ canonical = hash(name)
+ sym = int.__new__(symbol, canonical)
+ sym.name = name
+ if doc:
+ sym.__doc__ = doc
+
+ cls.symbols[name] = sym
+ return sym
def __reduce__(self):
return symbol, (self.name, "x", int(self))
@@ -1565,90 +1592,60 @@ class _symbol(int):
return repr(self)
def __repr__(self):
- return "symbol(%r)" % self.name
+ return f"symbol({self.name!r})"
-_symbol.__name__ = "symbol"
+class _IntFlagMeta(type):
+ def __init__(
+ cls,
+ classname: str,
+ bases: Tuple[Type[Any], ...],
+ dict_: Dict[str, Any],
+ **kw: Any,
+ ) -> None:
+ items: List[symbol]
+ cls._items = items = []
+ for k, v in dict_.items():
+ if isinstance(v, int):
+ sym = symbol(k, canonical=v)
+ elif not k.startswith("_"):
+ raise TypeError("Expected integer values for IntFlag")
+ else:
+ continue
+ setattr(cls, k, sym)
+ items.append(sym)
+ def __iter__(self) -> Iterator[symbol]:
+ return iter(self._items)
-class symbol:
- """A constant symbol.
- >>> symbol('foo') is symbol('foo')
- True
- >>> symbol('foo')
- <symbol 'foo>
+class _FastIntFlag(metaclass=_IntFlagMeta):
+ """An 'IntFlag' copycat that isn't slow when performing bitwise
+ operations.
- A slight refinement of the MAGICCOOKIE=object() pattern. The primary
- advantage of symbol() is its repr(). They are also singletons.
+ the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
+ and ``_FastIntFlag`` otherwise.
- Repeated calls of symbol('name') will all return the same instance.
+ """
- The optional ``doc`` argument assigns to ``__doc__``. This
- is strictly so that Sphinx autoattr picks up the docstring we want
- (it doesn't appear to pick up the in-module docstring if the datamember
- is in a different module - autoattribute also blows up completely).
- If Sphinx fixes/improves this then we would no longer need
- ``doc`` here.
- """
+if TYPE_CHECKING:
+ from enum import IntFlag
- symbols: Dict[str, "_symbol"] = {}
- _lock = threading.Lock()
+ FastIntFlag = IntFlag
+else:
+ FastIntFlag = _FastIntFlag
- def __new__( # type: ignore[misc]
- cls,
- name: str,
- doc: Optional[str] = None,
- canonical: Optional[int] = None,
- ) -> _symbol:
- with cls._lock:
- sym = cls.symbols.get(name)
- if sym is None:
- cls.symbols[name] = sym = _symbol(name, doc, canonical)
- return sym
- @classmethod
- def parse_user_argument(
- cls, arg, choices, name, resolve_symbol_names=False
- ):
- """Given a user parameter, parse the parameter into a chosen symbol.
-
- The user argument can be a string name that matches the name of a
- symbol, or the symbol object itself, or any number of alternate choices
- such as True/False/ None etc.
-
- :param arg: the user argument.
- :param choices: dictionary of symbol object to list of possible
- entries.
- :param name: name of the argument. Used in an :class:`.ArgumentError`
- that is raised if the parameter doesn't match any available argument.
- :param resolve_symbol_names: include the name of each symbol as a valid
- entry.
-
- """
- # note using hash lookup is tricky here because symbol's `__hash__`
- # is its int value which we don't want included in the lookup
- # explicitly, so we iterate and compare each.
- for sym, choice in choices.items():
- if arg is sym:
- return sym
- elif resolve_symbol_names and arg == sym.name:
- return sym
- elif arg in choice:
- return sym
-
- if arg is None:
- return None
-
- raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+_E = TypeVar("_E", bound=enum.Enum)
def parse_user_argument_for_enum(
arg: Any,
- choices: Dict[_T, List[Any]],
+ choices: Dict[_E, List[Any]],
name: str,
-) -> Optional[_T]:
+ resolve_symbol_names: bool = False,
+) -> Optional[_E]:
"""Given a user parameter, parse the parameter into a chosen value
from a list of choice objects, typically Enum values.
@@ -1663,18 +1660,18 @@ def parse_user_argument_for_enum(
that is raised if the parameter doesn't match any available argument.
"""
- # TODO: use whatever built in thing Enum provides for this,
- # if applicable
for enum_value, choice in choices.items():
if arg is enum_value:
return enum_value
+ elif resolve_symbol_names and arg == enum_value.name:
+ return enum_value
elif arg in choice:
return enum_value
if arg is None:
return None
- raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+ raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
_creation_order = 1
diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py
index c861c83b3..907c51064 100644
--- a/lib/sqlalchemy/util/preloaded.py
+++ b/lib/sqlalchemy/util/preloaded.py
@@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any])
if TYPE_CHECKING:
from sqlalchemy.engine import default as engine_default
+ from sqlalchemy.orm import session as orm_session
+ from sqlalchemy.orm import util as orm_util
from sqlalchemy.sql import dml as sql_dml
from sqlalchemy.sql import util as sql_util
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index df54017da..dd574f3b0 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -3,10 +3,12 @@ from __future__ import annotations
import sys
import typing
from typing import Any
+from typing import Callable
from typing import cast
from typing import Dict
from typing import ForwardRef
from typing import Iterable
+from typing import Optional
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -82,7 +84,9 @@ else:
def de_stringify_annotation(
- cls: Type[Any], annotation: Union[str, Type[Any]]
+ cls: Type[Any],
+ annotation: Union[str, Type[Any]],
+ str_cleanup_fn: Optional[Callable[[str], str]] = None,
) -> Union[str, Type[Any]]:
"""Resolve annotations that may be string based into real objects.
@@ -105,9 +109,13 @@ def de_stringify_annotation(
annotation = cast(ForwardRef, annotation).__forward_arg__
if isinstance(annotation, str):
+ if str_cleanup_fn:
+ annotation = str_cleanup_fn(annotation)
+
base_globals: "Dict[str, Any]" = getattr(
sys.modules.get(cls.__module__, None), "__dict__", {}
)
+
try:
annotation = eval(annotation, base_globals, None)
except NameError: