summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/_py_collections.py
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2023-03-14 23:17:07 +0100
committerFederico Caselli <cfederico87@gmail.com>2023-03-30 22:18:11 +0200
commita979b6dc5ebefedfd8c85f5695cc5be8882eaa29 (patch)
tree8af2f9102fa109b0fa968cada17004e3d2b41e5f /lib/sqlalchemy/util/_py_collections.py
parent77357be824095b46eb2ed3206bc555a6dacc7f30 (diff)
downloadsqlalchemy-a979b6dc5ebefedfd8c85f5695cc5be8882eaa29.tar.gz
Add missing methods to OrderedSet.
Implemented missing method ``copy`` and ``pop`` in OrderedSet class. Fixes: #9487 Change-Id: I1d2278b64939b44422e9d5857ec7d345fff53997
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r--lib/sqlalchemy/util/_py_collections.py37
1 files changed, 28 insertions, 9 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index 8810800c4..9962493b5 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -168,8 +168,11 @@ class OrderedSet(Set[_T]):
else:
self._list = []
- def __reduce__(self):
- return (OrderedSet, (self._list,))
+ def copy(self) -> OrderedSet[_T]:
+ cp = self.__class__()
+ cp._list = self._list.copy()
+ set.update(cp, cp._list)
+ return cp
def add(self, element: _T) -> None:
if element not in self:
@@ -180,6 +183,14 @@ class OrderedSet(Set[_T]):
super().remove(element)
self._list.remove(element)
+ def pop(self) -> _T:
+ try:
+ value = self._list.pop()
+ except IndexError:
+ raise KeyError("pop from an empty set") from None
+ super().remove(value)
+ return value
+
def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
@@ -220,9 +231,8 @@ class OrderedSet(Set[_T]):
return self # type: ignore
def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
- result: OrderedSet[Union[_T, _S]] = self.__class__(self) # type: ignore # noqa: E501
- for o in other:
- result.update(o)
+ result: OrderedSet[Union[_T, _S]] = self.copy() # type: ignore
+ result.update(*other)
return result
def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
@@ -237,9 +247,17 @@ class OrderedSet(Set[_T]):
return self.intersection(other)
def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
- other_set = other if isinstance(other, set) else set(other)
+ collection: Collection[_T]
+ if isinstance(other, set):
+ collection = other_set = other
+ elif isinstance(other, Collection):
+ collection = other
+ other_set = set(other)
+ else:
+ collection = list(other)
+ other_set = set(collection)
result = self.__class__(a for a in self if a not in other_set)
- result.update(a for a in other if a not in self)
+ result.update(a for a in collection if a not in self)
return result
def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
@@ -263,9 +281,10 @@ class OrderedSet(Set[_T]):
return self
def symmetric_difference_update(self, other: Iterable[Any]) -> None:
- super().symmetric_difference_update(other)
+ collection = other if isinstance(other, Collection) else list(other)
+ super().symmetric_difference_update(collection)
self._list = [a for a in self._list if a in self]
- self._list += [a for a in other if a in self]
+ self._list += [a for a in collection if a in self]
def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
self.symmetric_difference_update(other)