diff options
Diffstat (limited to 'lib/sqlalchemy/util')
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 76 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 173 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/preloaded.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 10 |
6 files changed, 156 insertions, 141 deletions
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c0c2e7dfb..6d41231d9 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -96,6 +96,7 @@ from .langhelpers import dictlike_iteritems as dictlike_iteritems from .langhelpers import duck_type_collection as duck_type_collection from .langhelpers import ellipses_string as ellipses_string from .langhelpers import EnsureKWArg as EnsureKWArg +from .langhelpers import FastIntFlag as FastIntFlag from .langhelpers import format_argspec_init as format_argspec_init from .langhelpers import format_argspec_plus as format_argspec_plus from .langhelpers import generic_fn_descriptor as generic_fn_descriptor diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index eb5b16b65..eea76f60b 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -22,6 +22,7 @@ from typing import Generic from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Set @@ -123,7 +124,7 @@ def merge_lists_w_ordering(a, b): return result -def coerce_to_immutabledict(d): +def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: if not d: return EMPTY_DICT elif isinstance(d, immutabledict): @@ -161,6 +162,8 @@ class FacadeDict(ImmutableDictBase[_KT, _VT]): _DT = TypeVar("_DT", bound=Any) +_F = TypeVar("_F", bound=Any) + class Properties(Generic[_T]): """Provide a __getattr__/__setattr__ interface over a dict.""" @@ -169,7 +172,7 @@ class Properties(Generic[_T]): _data: Dict[str, _T] - def __init__(self, data): + def __init__(self, data: Dict[str, _T]): object.__setattr__(self, "_data", data) def __len__(self) -> int: @@ -178,30 +181,30 @@ class Properties(Generic[_T]): def __iter__(self) -> Iterator[_T]: return iter(list(self._data.values())) - def __dir__(self): + def __dir__(self) -> List[str]: return dir(super(Properties, self)) + [ str(k) for k in self._data.keys() ] - def __add__(self, other): - return list(self) + list(other) + def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: + return list(self) + list(other) # type: ignore - def __setitem__(self, key, obj): + def __setitem__(self, key: str, obj: _T) -> None: self._data[key] = obj def __getitem__(self, key: str) -> _T: return self._data[key] - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._data[key] - def __setattr__(self, key, obj): + def __setattr__(self, key: str, obj: _T) -> None: self._data[key] = obj - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"_data": self._data} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "_data", state["_data"]) def __getattr__(self, key: str) -> _T: @@ -213,12 +216,12 @@ class Properties(Generic[_T]): def __contains__(self, key: str) -> bool: return key in self._data - def as_readonly(self) -> "ReadOnlyProperties[_T]": + def as_readonly(self) -> ReadOnlyProperties[_T]: """Return an immutable proxy for this :class:`.Properties`.""" return ReadOnlyProperties(self._data) - def update(self, value): + def update(self, value: Dict[str, _T]) -> None: self._data.update(value) @overload @@ -249,7 +252,7 @@ class Properties(Generic[_T]): def has_key(self, key: str) -> bool: return key in self._data - def clear(self): + def clear(self) -> None: self._data.clear() @@ -318,7 +321,7 @@ class WeakSequence: class OrderedIdentitySet(IdentitySet): - def __init__(self, iterable=None): + def __init__(self, iterable: Optional[Iterable[Any]] = None): IdentitySet.__init__(self) self._members = OrderedDict() if iterable: @@ -615,7 +618,9 @@ class ScopedRegistry(Generic[_T]): scopefunc: _ScopeFuncType registry: Any - def __init__(self, createfunc, scopefunc): + def __init__( + self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] + ): """Construct a new :class:`.ScopedRegistry`. :param createfunc: A creation function that will generate diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index d649a0bea..725f6930e 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -263,52 +263,54 @@ class IdentitySet: """ - def __init__(self, iterable=None): + _members: Dict[int, Any] + + def __init__(self, iterable: Optional[Iterable[Any]] = None): self._members = dict() if iterable: self.update(iterable) - def add(self, value): + def add(self, value: Any) -> None: self._members[id(value)] = value - def __contains__(self, value): + def __contains__(self, value: Any) -> bool: return id(value) in self._members - def remove(self, value): + def remove(self, value: Any) -> None: del self._members[id(value)] - def discard(self, value): + def discard(self, value: Any) -> None: try: self.remove(value) except KeyError: pass - def pop(self): + def pop(self) -> Any: try: pair = self._members.popitem() return pair[1] except KeyError: raise KeyError("pop from an empty set") - def clear(self): + def clear(self) -> None: self._members.clear() - def __cmp__(self, other): + def __cmp__(self, other: Any) -> NoReturn: raise TypeError("cannot compare sets using cmp()") - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, IdentitySet): return self._members == other._members else: return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, IdentitySet): return self._members != other._members else: return True - def issubset(self, iterable): + def issubset(self, iterable: Iterable[Any]) -> bool: if isinstance(iterable, self.__class__): other = iterable else: @@ -322,17 +324,17 @@ class IdentitySet: return False return True - def __le__(self, other): + def __le__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return self.issubset(other) - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return len(self) < len(other) and self.issubset(other) - def issuperset(self, iterable): + def issuperset(self, iterable: Iterable[Any]) -> bool: if isinstance(iterable, self.__class__): other = iterable else: @@ -347,38 +349,38 @@ class IdentitySet: return False return True - def __ge__(self, other): + def __ge__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return self.issuperset(other) - def __gt__(self, other): + def __gt__(self, other: Any) -> bool: if not isinstance(other, IdentitySet): return NotImplemented return len(self) > len(other) and self.issuperset(other) - def union(self, iterable): + def union(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__class__() members = self._members result._members.update(members) result._members.update((id(obj), obj) for obj in iterable) return result - def __or__(self, other): + def __or__(self, other: Any) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.union(other) - def update(self, iterable): + def update(self, iterable: Iterable[Any]) -> None: self._members.update((id(obj), obj) for obj in iterable) - def __ior__(self, other): + def __ior__(self, other: Any) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.update(other) return self - def difference(self, iterable): + def difference(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__new__(self.__class__) other: Collection[Any] @@ -391,21 +393,21 @@ class IdentitySet: } return result - def __sub__(self, other): + def __sub__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.difference(other) - def difference_update(self, iterable): + def difference_update(self, iterable: Iterable[Any]) -> None: self._members = self.difference(iterable)._members - def __isub__(self, other): + def __isub__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.difference_update(other) return self - def intersection(self, iterable): + def intersection(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__new__(self.__class__) other: Collection[Any] @@ -419,21 +421,21 @@ class IdentitySet: } return result - def __and__(self, other): + def __and__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.intersection(other) - def intersection_update(self, iterable): + def intersection_update(self, iterable: Iterable[Any]) -> None: self._members = self.intersection(iterable)._members - def __iand__(self, other): + def __iand__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.intersection_update(other) return self - def symmetric_difference(self, iterable): + def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet: result = self.__new__(self.__class__) if isinstance(iterable, self.__class__): other = iterable._members @@ -447,37 +449,37 @@ class IdentitySet: ) return result - def __xor__(self, other): + def __xor__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented return self.symmetric_difference(other) - def symmetric_difference_update(self, iterable): + def symmetric_difference_update(self, iterable: Iterable[Any]) -> None: self._members = self.symmetric_difference(iterable)._members - def __ixor__(self, other): + def __ixor__(self, other: IdentitySet) -> IdentitySet: if not isinstance(other, IdentitySet): return NotImplemented self.symmetric_difference(other) return self - def copy(self): + def copy(self) -> IdentitySet: result = self.__new__(self.__class__) result._members = self._members.copy() return result __copy__ = copy - def __len__(self): + def __len__(self) -> int: return len(self._members) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self._members.values()) - def __hash__(self): + def __hash__(self) -> NoReturn: raise TypeError("set objects are unhashable") - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 3e89c72bb..2cb9c45d6 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -12,6 +12,7 @@ modules, classes, hierarchies, attributes, functions, and methods. from __future__ import annotations import collections +import enum from functools import update_wrapper import hashlib import inspect @@ -671,13 +672,13 @@ def format_argspec_init(method, grouped=True): def create_proxy_methods( - target_cls, - target_cls_sphinx_name, - proxy_cls_sphinx_name, - classmethods=(), - methods=(), - attributes=(), -): + target_cls: Type[Any], + target_cls_sphinx_name: str, + proxy_cls_sphinx_name: str, + classmethods: Sequence[str] = (), + methods: Sequence[str] = (), + attributes: Sequence[str] = (), +) -> Callable[[_T], _T]: """A class decorator indicating attributes should refer to a proxy class. @@ -1539,24 +1540,50 @@ class hybridmethod(Generic[_T]): return self -class _symbol(int): +class symbol(int): + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + <symbol 'foo> + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + In SQLAlchemy 2.0, symbol() is used for the implementation of + ``_FastIntFlag``, but otherwise should be mostly replaced by + ``enum.Enum`` and variants. + + + """ + name: str + symbols: Dict[str, symbol] = {} + _lock = threading.Lock() + def __new__( cls, name: str, doc: Optional[str] = None, canonical: Optional[int] = None, - ) -> "_symbol": - """Construct a new named symbol.""" - assert isinstance(name, str) - if canonical is None: - canonical = hash(name) - v = int.__new__(_symbol, canonical) - v.name = name - if doc: - v.__doc__ = doc - return v + ) -> symbol: + with cls._lock: + sym = cls.symbols.get(name) + if sym is None: + assert isinstance(name, str) + if canonical is None: + canonical = hash(name) + sym = int.__new__(symbol, canonical) + sym.name = name + if doc: + sym.__doc__ = doc + + cls.symbols[name] = sym + return sym def __reduce__(self): return symbol, (self.name, "x", int(self)) @@ -1565,90 +1592,60 @@ class _symbol(int): return repr(self) def __repr__(self): - return "symbol(%r)" % self.name + return f"symbol({self.name!r})" -_symbol.__name__ = "symbol" +class _IntFlagMeta(type): + def __init__( + cls, + classname: str, + bases: Tuple[Type[Any], ...], + dict_: Dict[str, Any], + **kw: Any, + ) -> None: + items: List[symbol] + cls._items = items = [] + for k, v in dict_.items(): + if isinstance(v, int): + sym = symbol(k, canonical=v) + elif not k.startswith("_"): + raise TypeError("Expected integer values for IntFlag") + else: + continue + setattr(cls, k, sym) + items.append(sym) + def __iter__(self) -> Iterator[symbol]: + return iter(self._items) -class symbol: - """A constant symbol. - >>> symbol('foo') is symbol('foo') - True - >>> symbol('foo') - <symbol 'foo> +class _FastIntFlag(metaclass=_IntFlagMeta): + """An 'IntFlag' copycat that isn't slow when performing bitwise + operations. - A slight refinement of the MAGICCOOKIE=object() pattern. The primary - advantage of symbol() is its repr(). They are also singletons. + the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING + and ``_FastIntFlag`` otherwise. - Repeated calls of symbol('name') will all return the same instance. + """ - The optional ``doc`` argument assigns to ``__doc__``. This - is strictly so that Sphinx autoattr picks up the docstring we want - (it doesn't appear to pick up the in-module docstring if the datamember - is in a different module - autoattribute also blows up completely). - If Sphinx fixes/improves this then we would no longer need - ``doc`` here. - """ +if TYPE_CHECKING: + from enum import IntFlag - symbols: Dict[str, "_symbol"] = {} - _lock = threading.Lock() + FastIntFlag = IntFlag +else: + FastIntFlag = _FastIntFlag - def __new__( # type: ignore[misc] - cls, - name: str, - doc: Optional[str] = None, - canonical: Optional[int] = None, - ) -> _symbol: - with cls._lock: - sym = cls.symbols.get(name) - if sym is None: - cls.symbols[name] = sym = _symbol(name, doc, canonical) - return sym - @classmethod - def parse_user_argument( - cls, arg, choices, name, resolve_symbol_names=False - ): - """Given a user parameter, parse the parameter into a chosen symbol. - - The user argument can be a string name that matches the name of a - symbol, or the symbol object itself, or any number of alternate choices - such as True/False/ None etc. - - :param arg: the user argument. - :param choices: dictionary of symbol object to list of possible - entries. - :param name: name of the argument. Used in an :class:`.ArgumentError` - that is raised if the parameter doesn't match any available argument. - :param resolve_symbol_names: include the name of each symbol as a valid - entry. - - """ - # note using hash lookup is tricky here because symbol's `__hash__` - # is its int value which we don't want included in the lookup - # explicitly, so we iterate and compare each. - for sym, choice in choices.items(): - if arg is sym: - return sym - elif resolve_symbol_names and arg == sym.name: - return sym - elif arg in choice: - return sym - - if arg is None: - return None - - raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) +_E = TypeVar("_E", bound=enum.Enum) def parse_user_argument_for_enum( arg: Any, - choices: Dict[_T, List[Any]], + choices: Dict[_E, List[Any]], name: str, -) -> Optional[_T]: + resolve_symbol_names: bool = False, +) -> Optional[_E]: """Given a user parameter, parse the parameter into a chosen value from a list of choice objects, typically Enum values. @@ -1663,18 +1660,18 @@ def parse_user_argument_for_enum( that is raised if the parameter doesn't match any available argument. """ - # TODO: use whatever built in thing Enum provides for this, - # if applicable for enum_value, choice in choices.items(): if arg is enum_value: return enum_value + elif resolve_symbol_names and arg == enum_value.name: + return enum_value elif arg in choice: return enum_value if arg is None: return None - raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) + raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}") _creation_order = 1 diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index c861c83b3..907c51064 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any]) if TYPE_CHECKING: from sqlalchemy.engine import default as engine_default + from sqlalchemy.orm import session as orm_session + from sqlalchemy.orm import util as orm_util from sqlalchemy.sql import dml as sql_dml from sqlalchemy.sql import util as sql_util diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index df54017da..dd574f3b0 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -3,10 +3,12 @@ from __future__ import annotations import sys import typing from typing import Any +from typing import Callable from typing import cast from typing import Dict from typing import ForwardRef from typing import Iterable +from typing import Optional from typing import Tuple from typing import Type from typing import TypeVar @@ -82,7 +84,9 @@ else: def de_stringify_annotation( - cls: Type[Any], annotation: Union[str, Type[Any]] + cls: Type[Any], + annotation: Union[str, Type[Any]], + str_cleanup_fn: Optional[Callable[[str], str]] = None, ) -> Union[str, Type[Any]]: """Resolve annotations that may be string based into real objects. @@ -105,9 +109,13 @@ def de_stringify_annotation( annotation = cast(ForwardRef, annotation).__forward_arg__ if isinstance(annotation, str): + if str_cleanup_fn: + annotation = str_cleanup_fn(annotation) + base_globals: "Dict[str, Any]" = getattr( sys.modules.get(cls.__module__, None), "__dict__", {} ) + try: annotation = eval(annotation, base_globals, None) except NameError: |
