summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/_collections.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-05-03 20:27:24 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-05-03 20:33:17 -0400
commit35552e88ca798b809c7391bae11890c1557a3dd2 (patch)
treef1a689be9de573d6ba885a0921c5b123fbaa8ff5 /lib/sqlalchemy/util/_collections.py
parent00d40775b73cf94fa5d1b765dac1e600e93e172f (diff)
downloadsqlalchemy-35552e88ca798b809c7391bae11890c1557a3dd2.tar.gz
Don't apply sets or similar to objects in IdentitySet
Modified the internal "identity set" implementation, which is a set that hashes objects on their id() rather than their hash values, to not actually call the ``__hash__()`` method of the objects, which are typically user-mapped objects. Some methods were calling this method as a side effect of the implementation. Fixes: #5304 Change-Id: I0ed8762f47622215a54dcad9f210377b1becf8e8
Diffstat (limited to 'lib/sqlalchemy/util/_collections.py')
-rw-r--r--lib/sqlalchemy/util/_collections.py71
1 files changed, 27 insertions, 44 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index b21eb44cf..10d80fc98 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -363,13 +363,10 @@ class IdentitySet(object):
"""
- _working_set = set
-
def __init__(self, iterable=None):
self._members = dict()
if iterable:
- for o in iterable:
- self.add(o)
+ self.update(iterable)
def add(self, value):
self._members[id(value)] = value
@@ -412,7 +409,7 @@ class IdentitySet(object):
return True
def issubset(self, iterable):
- other = type(self)(iterable)
+ other = self.__class__(iterable)
if len(self) > len(other):
return False
@@ -433,7 +430,7 @@ class IdentitySet(object):
return len(self) < len(other) and self.issubset(other)
def issuperset(self, iterable):
- other = type(self)(iterable)
+ other = self.__class__(iterable)
if len(self) < len(other):
return False
@@ -455,11 +452,10 @@ class IdentitySet(object):
return len(self) > len(other) and self.issuperset(other)
def union(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
- result._members.update(self._working_set(members).union(other))
+ result = self.__class__()
+ members = self._members
+ result._members.update(members)
+ result._members.update((id(obj), obj) for obj in iterable)
return result
def __or__(self, other):
@@ -468,7 +464,7 @@ class IdentitySet(object):
return self.union(other)
def update(self, iterable):
- self._members = self.union(iterable)._members
+ self._members.update((id(obj), obj) for obj in iterable)
def __ior__(self, other):
if not isinstance(other, IdentitySet):
@@ -477,11 +473,12 @@ class IdentitySet(object):
return self
def difference(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
- result._members.update(self._working_set(members).difference(other))
+ result = self.__class__()
+ members = self._members
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ ((k, v) for k, v in members.items() if k not in other)
+ )
return result
def __sub__(self, other):
@@ -499,11 +496,12 @@ class IdentitySet(object):
return self
def intersection(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
- result._members.update(self._working_set(members).intersection(other))
+ result = self.__class__()
+ members = self._members
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ (k, v) for k, v in members.items() if k in other
+ )
return result
def __and__(self, other):
@@ -521,18 +519,17 @@ class IdentitySet(object):
return self
def symmetric_difference(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
+ result = self.__class__()
+ members = self._members
+ other = {id(obj): obj for obj in iterable}
result._members.update(
- self._working_set(members).symmetric_difference(other)
+ ((k, v) for k, v in members.items() if k not in other)
+ )
+ result._members.update(
+ ((k, v) for k, v in other.items() if k not in members)
)
return result
- def _member_id_tuples(self):
- return ((id(v), v) for v in self._members.values())
-
def __xor__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
@@ -600,13 +597,6 @@ class WeakSequence(object):
class OrderedIdentitySet(IdentitySet):
- class _working_set(OrderedSet):
- # a testing pragma: exempt the OIDS working set from the test suite's
- # "never call the user's __hash__" assertions. this is a big hammer,
- # but it's safe here: IDS operates on (id, instance) tuples in the
- # working set.
- __sa_hash_exempt__ = True
-
def __init__(self, iterable=None):
IdentitySet.__init__(self)
self._members = OrderedDict()
@@ -942,13 +932,6 @@ class ThreadLocalRegistry(ScopedRegistry):
pass
-def _iter_id(iterable):
- """Generator: ((id(o), o) for o in iterable)."""
-
- for item in iterable:
- yield id(item), item
-
-
def has_dupes(sequence, target):
"""Given a sequence and search object, return True if there's more
than one, False if zero or one of them.