diff options
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 168 |
1 files changed, 119 insertions, 49 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index ff61f6ca9..a4e4b8b5d 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -1,17 +1,52 @@ from itertools import filterfalse +from typing import Any +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import TypeVar + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) class ImmutableContainer: - def _immutable(self, *arg, **kw): + def _immutable(self, *arg: Any, **kw: Any) -> NoReturn: raise TypeError("%s object is immutable" % self.__class__.__name__) - __delitem__ = __setitem__ = __setattr__ = _immutable + def __delitem__(self, key: Any) -> NoReturn: + self._immutable() + def __setitem__(self, key: Any, value: Any) -> NoReturn: + self._immutable() -class immutabledict(ImmutableContainer, dict): + def __setattr__(self, key: str, value: Any) -> NoReturn: + self._immutable() - clear = pop = popitem = setdefault = update = ImmutableContainer._immutable +class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]): + def clear(self) -> NoReturn: + self._immutable() + + def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn: + self._immutable() + + def popitem(self) -> NoReturn: + self._immutable() + + def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn: + self._immutable() + + def update(self, *arg: Any, **kw: Any) -> NoReturn: + self._immutable() + + +class immutabledict(ImmutableDictBase[_KT, _VT]): def __new__(cls, *args): new = dict.__new__(cls) dict.__init__(new, *args) @@ -41,7 +76,7 @@ class immutabledict(ImmutableContainer, dict): dict.__init__(new, self) if __d: dict.update(new, __d) - dict.update(new, kw) + dict.update(new, kw) # type: ignore return new def merge_with(self, *dicts): @@ -61,110 +96,145 @@ class immutabledict(ImmutableContainer, dict): return "immutabledict(%s)" % dict.__repr__(self) -class OrderedSet(set): +class OrderedSet(Generic[_T]): + __slots__ = ("_list", "_set", "__weakref__") + + _list: List[_T] + _set: Set[_T] + def __init__(self, d=None): - set.__init__(self) if d is not None: self._list = unique_list(d) - set.update(self, self._list) + self._set = set(self._list) else: self._list = [] + self._set = set() + + def __reduce__(self): + return (OrderedSet, (self._list,)) - def add(self, element): + def add(self, element: _T) -> None: if element not in self: self._list.append(element) - set.add(self, element) + self._set.add(element) - def remove(self, element): - set.remove(self, element) + def remove(self, element: _T) -> None: + self._set.remove(element) self._list.remove(element) - def insert(self, pos, element): + def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) - set.add(self, element) + self._set.add(element) - def discard(self, element): + def discard(self, element: _T) -> None: if element in self: self._list.remove(element) - set.remove(self, element) + self._set.remove(element) - def clear(self): - set.clear(self) + def clear(self) -> None: + self._set.clear() self._list = [] - def __getitem__(self, key): + def __len__(self) -> int: + return len(self._set) + + def __eq__(self, other): + if not isinstance(other, OrderedSet): + return self._set == other + else: + return self._set == other._set + + def __ne__(self, other): + if not isinstance(other, OrderedSet): + return self._set != other + else: + return self._set != other._set + + def __contains__(self, element: Any) -> bool: + return element in self._set + + def __getitem__(self, key: int) -> _T: return self._list[key] - def __iter__(self): + def __iter__(self) -> Iterator[_T]: return iter(self._list) - def __add__(self, other): + def __add__(self, other: Iterator[_T]) -> "OrderedSet[_T]": return self.union(other) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (self.__class__.__name__, self._list) __str__ = __repr__ - def update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self + def update(self, *iterables: Iterable[_T]) -> None: + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + self._set.add(e) - __ior__ = update + def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.update(other) + return self - def union(self, other): + def union(self, other: Iterable[_T]) -> "OrderedSet[_T]": result = self.__class__(self) result.update(other) return result - __or__ = union + def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.union(other) - def intersection(self, other): + def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]": other = other if isinstance(other, set) else set(other) return self.__class__(a for a in self if a in other) - __and__ = intersection + def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.intersection(other) - def symmetric_difference(self, other): + def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": other_set = other if isinstance(other, set) else set(other) result = self.__class__(a for a in self if a not in other_set) result.update(a for a in other if a not in self) return result - __xor__ = symmetric_difference + def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.symmetric_difference(other) - def difference(self, other): + def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": other = other if isinstance(other, set) else set(other) return self.__class__(a for a in self if a not in other) - __sub__ = difference + def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + return self.difference(other) - def intersection_update(self, other): + def intersection_update(self, other: Iterable[_T]) -> None: other = other if isinstance(other, set) else set(other) - set.intersection_update(self, other) + self._set.intersection_update(other) self._list = [a for a in self._list if a in other] - return self - __iand__ = intersection_update + def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.intersection_update(other) + return self - def symmetric_difference_update(self, other): - set.symmetric_difference_update(self, other) + def symmetric_difference_update(self, other: Iterable[_T]) -> None: + self._set.symmetric_difference_update(other) self._list = [a for a in self._list if a in self] self._list += [a for a in other if a in self] - return self - __ixor__ = symmetric_difference_update + def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.symmetric_difference_update(other) + return self - def difference_update(self, other): - set.difference_update(self, other) + def difference_update(self, other: Iterable[_T]) -> None: + self._set.difference_update(other) self._list = [a for a in self._list if a in self] - return self - __isub__ = difference_update + def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + self.difference_update(other) + return self class IdentitySet: |