diff options
author | Federico Caselli <cfederico87@gmail.com> | 2023-03-14 23:17:07 +0100 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2023-03-30 22:18:11 +0200 |
commit | a979b6dc5ebefedfd8c85f5695cc5be8882eaa29 (patch) | |
tree | 8af2f9102fa109b0fa968cada17004e3d2b41e5f /lib/sqlalchemy/util/_py_collections.py | |
parent | 77357be824095b46eb2ed3206bc555a6dacc7f30 (diff) | |
download | sqlalchemy-a979b6dc5ebefedfd8c85f5695cc5be8882eaa29.tar.gz |
Add missing methods to OrderedSet.
Implemented missing method ``copy`` and ``pop`` in OrderedSet class.
Fixes: #9487
Change-Id: I1d2278b64939b44422e9d5857ec7d345fff53997
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 37 |
1 files changed, 28 insertions, 9 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 8810800c4..9962493b5 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -168,8 +168,11 @@ class OrderedSet(Set[_T]): else: self._list = [] - def __reduce__(self): - return (OrderedSet, (self._list,)) + def copy(self) -> OrderedSet[_T]: + cp = self.__class__() + cp._list = self._list.copy() + set.update(cp, cp._list) + return cp def add(self, element: _T) -> None: if element not in self: @@ -180,6 +183,14 @@ class OrderedSet(Set[_T]): super().remove(element) self._list.remove(element) + def pop(self) -> _T: + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + super().remove(value) + return value + def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) @@ -220,9 +231,8 @@ class OrderedSet(Set[_T]): return self # type: ignore def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]: - result: OrderedSet[Union[_T, _S]] = self.__class__(self) # type: ignore # noqa: E501 - for o in other: - result.update(o) + result: OrderedSet[Union[_T, _S]] = self.copy() # type: ignore + result.update(*other) return result def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: @@ -237,9 +247,17 @@ class OrderedSet(Set[_T]): return self.intersection(other) def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]: - other_set = other if isinstance(other, set) else set(other) + collection: Collection[_T] + if isinstance(other, set): + collection = other_set = other + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) result = self.__class__(a for a in self if a not in other_set) - result.update(a for a in other if a not in self) + result.update(a for a in collection if a not in self) return result def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: @@ -263,9 +281,10 @@ class OrderedSet(Set[_T]): return self def symmetric_difference_update(self, other: Iterable[Any]) -> None: - super().symmetric_difference_update(other) + collection = other if isinstance(other, Collection) else list(other) + super().symmetric_difference_update(collection) self._list = [a for a in self._list if a in self] - self._list += [a for a in other if a in self] + self._list += [a for a in collection if a in self] def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: self.symmetric_difference_update(other) |