diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-01-24 18:13:05 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-01-25 11:18:55 -0500 |
commit | 5aee5fe12afdeb4569e588344f00aa56c9250215 (patch) | |
tree | 2319b83f15a0139bb8f9c60697c40a440c386b5f /lib/sqlalchemy/util/_py_collections.py | |
parent | ff1ab665cb1694b85085680d1a02c7c11fa2a6d4 (diff) | |
download | sqlalchemy-5aee5fe12afdeb4569e588344f00aa56c9250215.tar.gz |
restore set-as-superclass for OrderedSet
OrderedSet again subclasses set, spent some time
with the stubs at
https://github.com/python/typeshed/blob/master/stdlib/builtins.pyi#L887
to more deeply understand what they are doing here
so that we can type check fully.
Change-Id: Iec9b5ab43befd30e1f2c5cc40e59ab852dd28e75
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 106 |
1 files changed, 47 insertions, 59 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index a4e4b8b5d..7914507cd 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -1,7 +1,8 @@ from itertools import filterfalse +from typing import AbstractSet from typing import Any +from typing import cast from typing import Dict -from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -9,6 +10,7 @@ from typing import NoReturn from typing import Optional from typing import Set from typing import TypeVar +from typing import Union _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT", bound=Any) @@ -96,19 +98,20 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): return "immutabledict(%s)" % dict.__repr__(self) -class OrderedSet(Generic[_T]): - __slots__ = ("_list", "_set", "__weakref__") +_S = TypeVar("_S", bound=Any) + + +class OrderedSet(Set[_T]): + __slots__ = ("_list",) _list: List[_T] - _set: Set[_T] def __init__(self, d=None): if d is not None: self._list = unique_list(d) - self._set = set(self._list) + super().update(self._list) else: self._list = [] - self._set = set() def __reduce__(self): return (OrderedSet, (self._list,)) @@ -116,44 +119,26 @@ class OrderedSet(Generic[_T]): def add(self, element: _T) -> None: if element not in self: self._list.append(element) - self._set.add(element) + super().add(element) def remove(self, element: _T) -> None: - self._set.remove(element) + super().remove(element) self._list.remove(element) def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) - self._set.add(element) + super().add(element) def discard(self, element: _T) -> None: if element in self: self._list.remove(element) - self._set.remove(element) + super().remove(element) def clear(self) -> None: - self._set.clear() + super().clear() self._list = [] - def __len__(self) -> int: - return len(self._set) - - def __eq__(self, other): - if not isinstance(other, OrderedSet): - return self._set == other - else: - return self._set == other._set - - def __ne__(self, other): - if not isinstance(other, OrderedSet): - return self._set != other - else: - return self._set != other._set - - def __contains__(self, element: Any) -> bool: - return element in self._set - def __getitem__(self, key: int) -> _T: return self._list[key] @@ -173,25 +158,27 @@ class OrderedSet(Generic[_T]): for e in iterable: if e not in self: self._list.append(e) - self._set.add(e) + super().add(e) - def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]": - self.update(other) - return self + def __ior__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": + self.update(other) # type: ignore + return self # type: ignore - def union(self, other: Iterable[_T]) -> "OrderedSet[_T]": - result = self.__class__(self) - result.update(other) + def union(self, *other: Iterable[_S]) -> "OrderedSet[Union[_T, _S]]": + result: "OrderedSet[Union[_T, _S]]" = self.__class__(self) # type: ignore # noqa E501 + for o in other: + result.update(o) return result - def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __or__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": return self.union(other) - def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]": - other = other if isinstance(other, set) else set(other) - return self.__class__(a for a in self if a in other) + def intersection(self, *other: Iterable[Any]) -> "OrderedSet[_T]": + other_set: Set[Any] = set() + other_set.update(*other) + return self.__class__(a for a in self if a in other_set) - def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __and__(self, other: AbstractSet[object]) -> "OrderedSet[_T]": return self.intersection(other) def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": @@ -200,39 +187,40 @@ class OrderedSet(Generic[_T]): result.update(a for a in other if a not in self) return result - def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": - return self.symmetric_difference(other) + def __xor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": + return cast("OrderedSet[Union[_T, _S]]", self).symmetric_difference( + other + ) - def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": - other = other if isinstance(other, set) else set(other) - return self.__class__(a for a in self if a not in other) + def difference(self, *other: Iterable[Any]) -> "OrderedSet[_T]": + other_set = super().difference(*other) + return self.__class__(a for a in self._list if a in other_set) - def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __sub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]": return self.difference(other) - def intersection_update(self, other: Iterable[_T]) -> None: - other = other if isinstance(other, set) else set(other) - self._set.intersection_update(other) - self._list = [a for a in self._list if a in other] + def intersection_update(self, *other: Iterable[Any]) -> None: + super().intersection_update(*other) + self._list = [a for a in self._list if a in self] - def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __iand__(self, other: AbstractSet[object]) -> "OrderedSet[_T]": self.intersection_update(other) return self - def symmetric_difference_update(self, other: Iterable[_T]) -> None: - self._set.symmetric_difference_update(other) + def symmetric_difference_update(self, other: Iterable[Any]) -> None: + super().symmetric_difference_update(other) self._list = [a for a in self._list if a in self] self._list += [a for a in other if a in self] - def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __ixor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": self.symmetric_difference_update(other) - return self + return cast("OrderedSet[Union[_T, _S]]", self) - def difference_update(self, other: Iterable[_T]) -> None: - self._set.difference_update(other) + def difference_update(self, *other: Iterable[Any]) -> None: + super().difference_update(*other) self._list = [a for a in self._list if a in self] - def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __isub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]": self.difference_update(other) return self |