summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/_py_collections.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r--lib/sqlalchemy/util/_py_collections.py106
1 files changed, 47 insertions, 59 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index a4e4b8b5d..7914507cd 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -1,7 +1,8 @@
from itertools import filterfalse
+from typing import AbstractSet
from typing import Any
+from typing import cast
from typing import Dict
-from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
@@ -9,6 +10,7 @@ from typing import NoReturn
from typing import Optional
from typing import Set
from typing import TypeVar
+from typing import Union
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT", bound=Any)
@@ -96,19 +98,20 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
return "immutabledict(%s)" % dict.__repr__(self)
-class OrderedSet(Generic[_T]):
- __slots__ = ("_list", "_set", "__weakref__")
+_S = TypeVar("_S", bound=Any)
+
+
+class OrderedSet(Set[_T]):
+ __slots__ = ("_list",)
_list: List[_T]
- _set: Set[_T]
def __init__(self, d=None):
if d is not None:
self._list = unique_list(d)
- self._set = set(self._list)
+ super().update(self._list)
else:
self._list = []
- self._set = set()
def __reduce__(self):
return (OrderedSet, (self._list,))
@@ -116,44 +119,26 @@ class OrderedSet(Generic[_T]):
def add(self, element: _T) -> None:
if element not in self:
self._list.append(element)
- self._set.add(element)
+ super().add(element)
def remove(self, element: _T) -> None:
- self._set.remove(element)
+ super().remove(element)
self._list.remove(element)
def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
- self._set.add(element)
+ super().add(element)
def discard(self, element: _T) -> None:
if element in self:
self._list.remove(element)
- self._set.remove(element)
+ super().remove(element)
def clear(self) -> None:
- self._set.clear()
+ super().clear()
self._list = []
- def __len__(self) -> int:
- return len(self._set)
-
- def __eq__(self, other):
- if not isinstance(other, OrderedSet):
- return self._set == other
- else:
- return self._set == other._set
-
- def __ne__(self, other):
- if not isinstance(other, OrderedSet):
- return self._set != other
- else:
- return self._set != other._set
-
- def __contains__(self, element: Any) -> bool:
- return element in self._set
-
def __getitem__(self, key: int) -> _T:
return self._list[key]
@@ -173,25 +158,27 @@ class OrderedSet(Generic[_T]):
for e in iterable:
if e not in self:
self._list.append(e)
- self._set.add(e)
+ super().add(e)
- def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- self.update(other)
- return self
+ def __ior__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
+ self.update(other) # type: ignore
+ return self # type: ignore
- def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- result = self.__class__(self)
- result.update(other)
+ def union(self, *other: Iterable[_S]) -> "OrderedSet[Union[_T, _S]]":
+ result: "OrderedSet[Union[_T, _S]]" = self.__class__(self) # type: ignore # noqa E501
+ for o in other:
+ result.update(o)
return result
- def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __or__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
return self.union(other)
- def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- other = other if isinstance(other, set) else set(other)
- return self.__class__(a for a in self if a in other)
+ def intersection(self, *other: Iterable[Any]) -> "OrderedSet[_T]":
+ other_set: Set[Any] = set()
+ other_set.update(*other)
+ return self.__class__(a for a in self if a in other_set)
- def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __and__(self, other: AbstractSet[object]) -> "OrderedSet[_T]":
return self.intersection(other)
def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
@@ -200,39 +187,40 @@ class OrderedSet(Generic[_T]):
result.update(a for a in other if a not in self)
return result
- def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- return self.symmetric_difference(other)
+ def __xor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
+ return cast("OrderedSet[Union[_T, _S]]", self).symmetric_difference(
+ other
+ )
- def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- other = other if isinstance(other, set) else set(other)
- return self.__class__(a for a in self if a not in other)
+ def difference(self, *other: Iterable[Any]) -> "OrderedSet[_T]":
+ other_set = super().difference(*other)
+ return self.__class__(a for a in self._list if a in other_set)
- def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __sub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]":
return self.difference(other)
- def intersection_update(self, other: Iterable[_T]) -> None:
- other = other if isinstance(other, set) else set(other)
- self._set.intersection_update(other)
- self._list = [a for a in self._list if a in other]
+ def intersection_update(self, *other: Iterable[Any]) -> None:
+ super().intersection_update(*other)
+ self._list = [a for a in self._list if a in self]
- def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __iand__(self, other: AbstractSet[object]) -> "OrderedSet[_T]":
self.intersection_update(other)
return self
- def symmetric_difference_update(self, other: Iterable[_T]) -> None:
- self._set.symmetric_difference_update(other)
+ def symmetric_difference_update(self, other: Iterable[Any]) -> None:
+ super().symmetric_difference_update(other)
self._list = [a for a in self._list if a in self]
self._list += [a for a in other if a in self]
- def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __ixor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
self.symmetric_difference_update(other)
- return self
+ return cast("OrderedSet[Union[_T, _S]]", self)
- def difference_update(self, other: Iterable[_T]) -> None:
- self._set.difference_update(other)
+ def difference_update(self, *other: Iterable[Any]) -> None:
+ super().difference_update(*other)
self._list = [a for a in self._list if a in self]
- def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __isub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]":
self.difference_update(other)
return self