summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/_py_collections.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r--lib/sqlalchemy/util/_py_collections.py168
1 files changed, 119 insertions, 49 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index ff61f6ca9..a4e4b8b5d 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -1,17 +1,52 @@
from itertools import filterfalse
+from typing import Any
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import TypeVar
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
class ImmutableContainer:
- def _immutable(self, *arg, **kw):
+ def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
raise TypeError("%s object is immutable" % self.__class__.__name__)
- __delitem__ = __setitem__ = __setattr__ = _immutable
+ def __delitem__(self, key: Any) -> NoReturn:
+ self._immutable()
+ def __setitem__(self, key: Any, value: Any) -> NoReturn:
+ self._immutable()
-class immutabledict(ImmutableContainer, dict):
+ def __setattr__(self, key: str, value: Any) -> NoReturn:
+ self._immutable()
- clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
+class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]):
+ def clear(self) -> NoReturn:
+ self._immutable()
+
+ def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+ self._immutable()
+
+ def popitem(self) -> NoReturn:
+ self._immutable()
+
+ def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+ self._immutable()
+
+ def update(self, *arg: Any, **kw: Any) -> NoReturn:
+ self._immutable()
+
+
+class immutabledict(ImmutableDictBase[_KT, _VT]):
def __new__(cls, *args):
new = dict.__new__(cls)
dict.__init__(new, *args)
@@ -41,7 +76,7 @@ class immutabledict(ImmutableContainer, dict):
dict.__init__(new, self)
if __d:
dict.update(new, __d)
- dict.update(new, kw)
+ dict.update(new, kw) # type: ignore
return new
def merge_with(self, *dicts):
@@ -61,110 +96,145 @@ class immutabledict(ImmutableContainer, dict):
return "immutabledict(%s)" % dict.__repr__(self)
-class OrderedSet(set):
+class OrderedSet(Generic[_T]):
+ __slots__ = ("_list", "_set", "__weakref__")
+
+ _list: List[_T]
+ _set: Set[_T]
+
def __init__(self, d=None):
- set.__init__(self)
if d is not None:
self._list = unique_list(d)
- set.update(self, self._list)
+ self._set = set(self._list)
else:
self._list = []
+ self._set = set()
+
+ def __reduce__(self):
+ return (OrderedSet, (self._list,))
- def add(self, element):
+ def add(self, element: _T) -> None:
if element not in self:
self._list.append(element)
- set.add(self, element)
+ self._set.add(element)
- def remove(self, element):
- set.remove(self, element)
+ def remove(self, element: _T) -> None:
+ self._set.remove(element)
self._list.remove(element)
- def insert(self, pos, element):
+ def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
- set.add(self, element)
+ self._set.add(element)
- def discard(self, element):
+ def discard(self, element: _T) -> None:
if element in self:
self._list.remove(element)
- set.remove(self, element)
+ self._set.remove(element)
- def clear(self):
- set.clear(self)
+ def clear(self) -> None:
+ self._set.clear()
self._list = []
- def __getitem__(self, key):
+ def __len__(self) -> int:
+ return len(self._set)
+
+ def __eq__(self, other):
+ if not isinstance(other, OrderedSet):
+ return self._set == other
+ else:
+ return self._set == other._set
+
+ def __ne__(self, other):
+ if not isinstance(other, OrderedSet):
+ return self._set != other
+ else:
+ return self._set != other._set
+
+ def __contains__(self, element: Any) -> bool:
+ return element in self._set
+
+ def __getitem__(self, key: int) -> _T:
return self._list[key]
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
return iter(self._list)
- def __add__(self, other):
+ def __add__(self, other: Iterator[_T]) -> "OrderedSet[_T]":
return self.union(other)
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
- def update(self, iterable):
- for e in iterable:
- if e not in self:
- self._list.append(e)
- set.add(self, e)
- return self
+ def update(self, *iterables: Iterable[_T]) -> None:
+ for iterable in iterables:
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ self._set.add(e)
- __ior__ = update
+ def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.update(other)
+ return self
- def union(self, other):
+ def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
result = self.__class__(self)
result.update(other)
return result
- __or__ = union
+ def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.union(other)
- def intersection(self, other):
+ def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other = other if isinstance(other, set) else set(other)
return self.__class__(a for a in self if a in other)
- __and__ = intersection
+ def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.intersection(other)
- def symmetric_difference(self, other):
+ def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other_set = other if isinstance(other, set) else set(other)
result = self.__class__(a for a in self if a not in other_set)
result.update(a for a in other if a not in self)
return result
- __xor__ = symmetric_difference
+ def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.symmetric_difference(other)
- def difference(self, other):
+ def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other = other if isinstance(other, set) else set(other)
return self.__class__(a for a in self if a not in other)
- __sub__ = difference
+ def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.difference(other)
- def intersection_update(self, other):
+ def intersection_update(self, other: Iterable[_T]) -> None:
other = other if isinstance(other, set) else set(other)
- set.intersection_update(self, other)
+ self._set.intersection_update(other)
self._list = [a for a in self._list if a in other]
- return self
- __iand__ = intersection_update
+ def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.intersection_update(other)
+ return self
- def symmetric_difference_update(self, other):
- set.symmetric_difference_update(self, other)
+ def symmetric_difference_update(self, other: Iterable[_T]) -> None:
+ self._set.symmetric_difference_update(other)
self._list = [a for a in self._list if a in self]
self._list += [a for a in other if a in self]
- return self
- __ixor__ = symmetric_difference_update
+ def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.symmetric_difference_update(other)
+ return self
- def difference_update(self, other):
- set.difference_update(self, other)
+ def difference_update(self, other: Iterable[_T]) -> None:
+ self._set.difference_update(other)
self._list = [a for a in self._list if a in self]
- return self
- __isub__ = difference_update
+ def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.difference_update(other)
+ return self
class IdentitySet: