diff options
-rw-r--r-- | kazoo/recipe/lock.py | 142 | ||||
-rw-r--r-- | kazoo/tests/test_lock.py | 21 |
2 files changed, 93 insertions, 70 deletions
diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 982c12e..7722a97 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -14,6 +14,7 @@ changes and re-act appropriately. In the event that a and/or the lease has been lost. """ +import re import sys try: @@ -83,9 +84,7 @@ class Lock(object): # sequence number. Involved in read/write locks. _EXCLUDE_NAMES = ["__lock__"] - def __init__( - self, client, path, identifier=None, additional_lock_patterns=() - ): + def __init__(self, client, path, identifier=None, extra_lock_patterns=()): """Create a Kazoo lock. :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -93,25 +92,30 @@ class Lock(object): :param identifier: Name to use for this lock contender. This can be useful for querying to see who the current lock contenders are. - :param additional_lock_patterns: Strings that will be used to - identify other znode in the path - that should be considered contenders - for this lock. - Use this for cross-implementation - compatibility. + :param extra_lock_patterns: Strings that will be used to + identify other znode in the path + that should be considered contenders + for this lock. + Use this for cross-implementation + compatibility. .. versionadded:: 2.7.1 - The additional_lock_patterns option. + The extra_lock_patterns option. """ self.client = client self.path = path self._exclude_names = set( - self._EXCLUDE_NAMES + list(additional_lock_patterns) + self._EXCLUDE_NAMES + list(extra_lock_patterns) + ) + self._contenders_re = re.compile( + r"(?:{patterns})(-?\d{{10}})$".format( + patterns="|".join(self._exclude_names) + ) ) # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock - self.data = str(identifier or "").encode('utf-8') + self.data = str(identifier or "").encode("utf-8") self.node = None self.wake_event = client.handler.event_object() @@ -186,6 +190,7 @@ class Lock(object): return False if not locked: # Lock acquire doesn't take a timeout, so simulate it... + # XXX: This is not true in Py3 >= 3.2 try: locked = retry(_acquire_lock) except RetryFailedError: @@ -255,18 +260,8 @@ class Lock(object): if self.cancelled: raise CancelledError() - children = self._get_sorted_children() - - try: - our_index = children.index(node) - except ValueError: # pragma: nocover - # somehow we aren't in the children -- probably we are - # recovering from a session failure and our ephemeral - # node was removed - raise ForceRetryError() - - predecessor = self.predecessor(children, our_index) - if not predecessor: + predecessor = self._get_predecessor(node) + if predecessor is None: return True if not blocking: @@ -289,36 +284,44 @@ class Lock(object): finally: self.client.remove_listener(self._watch_session) - def predecessor(self, children, index): - for c in reversed(children[:index]): - if any(n in c for n in self._exclude_names): - return c - return None - def _watch_predecessor(self, event): self.wake_event.set() - def _get_sorted_children(self): + def _get_predecessor(self, node): + """returns `node`'s predecessor or None + + Note: This handle the case where the current lock is not a contender + (e.g. rlock), this and also edge cases where the lock's ephemeral node + is gone. + """ children = self.client.get_children(self.path) + found_self = False + # Filter out the contenders using the computed regex + contender_matches = [] + for child in children: + match = self._contenders_re.search(child) + if match is not None: + contender_matches.append(match) + if child == node: + # Remember the node's match object so we can short circuit + # below. + found_self = match + + if found_self is False: # pragma: nocover + # somehow we aren't in the childrens -- probably we are + # recovering from a session failure and our ephemeral + # node was removed. + raise ForceRetryError() + + predecessor = None + # Sort the contenders using the sequence number extracted by the regex, + # then extract the original string. + for match in sorted(contender_matches, key=lambda m: m.groups()): + if match is found_self: + break + predecessor = match.string - # Node names are prefixed by a type: strip the prefix first, which may - # be one of multiple values in case of a read-write lock, and return - # only the sequence number (as a string since it is padded and will - # sort correctly anyway). - # - # In some cases, the lock path may contain nodes with other prefixes - # (eg. in case of a lease), just sort them last ('~' sorts after all - # ASCII digits). - def _seq(c): - for name in self._exclude_names: - idx = c.find(name) - if idx != -1: - return c[idx + len(name):] - # Sort unknown node names eg. "lease_holder" last. - return '~' - - children.sort(key=_seq) - return children + return predecessor def _find_node(self): children = self.client.get_children(self.path) @@ -369,16 +372,37 @@ class Lock(object): if not self.assured_path: self._ensure_path() - children = self._get_sorted_children() - - contenders = [] + children = self.client.get_children(self.path) + # We want all contenders, including self (this is especially important + # for r/w locks). This is similar to the logic of `_get_predecessor` + # except we include our own pattern. + all_contenders_re = re.compile( + r"(?:{patterns})(-?\d{{10}})$".format( + patterns="|".join(self._exclude_names | {self._NODE_NAME}) + ) + ) + # Filter out the contenders using the computed regex + contender_matches = [] for child in children: + match = all_contenders_re.search(child) + if match is not None: + contender_matches.append(match) + # Sort the contenders using the sequence number extracted by the regex, + # then extract the original string. + contender_nodes = [ + match.string + for match in sorted(contender_matches, key=lambda m: m.groups()) + ] + # Retrieve all the contender nodes data (preserving order). + contenders = [] + for node in contender_nodes: try: - data, stat = self.client.get(self.path + "/" + child) + data, stat = self.client.get(self.path + "/" + node) if data is not None: - contenders.append(data.decode('utf-8')) + contenders.append(data.decode("utf-8")) except NoNodeError: # pragma: nocover pass + return contenders def __enter__(self): @@ -508,12 +532,12 @@ class Semaphore(object): # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock - self.data = str(identifier or "").encode('utf-8') + self.data = str(identifier or "").encode("utf-8") self.max_leases = max_leases self.wake_event = client.handler.event_object() self.create_path = self.path + "/" + uuid.uuid4().hex - self.lock_path = path + '-' + '__lock__' + self.lock_path = path + "-" + "__lock__" self.is_acquired = False self.assured_path = False self.cancelled = False @@ -526,7 +550,7 @@ class Semaphore(object): # node did already exist data, _ = self.client.get(self.path) try: - leases = int(data.decode('utf-8')) + leases = int(data.decode("utf-8")) except (ValueError, TypeError): # ignore non-numeric data, maybe the node data is used # for other purposes @@ -538,7 +562,7 @@ class Semaphore(object): % (leases, self.max_leases) ) else: - self.client.set(self.path, str(self.max_leases).encode('utf-8')) + self.client.set(self.path, str(self.max_leases).encode("utf-8")) def cancel(self): """Cancel a pending semaphore acquire.""" @@ -702,7 +726,7 @@ class Semaphore(object): for child in children: try: data, stat = self.client.get(self.path + "/" + child) - lease_holders.append(data.decode('utf-8')) + lease_holders.append(data.decode("utf-8")) except NoNodeError: # pragma: nocover pass return lease_holders diff --git a/kazoo/tests/test_lock.py b/kazoo/tests/test_lock.py index 33691e4..0e16949 100644 --- a/kazoo/tests/test_lock.py +++ b/kazoo/tests/test_lock.py @@ -434,7 +434,6 @@ class KazooLockTests(KazooTestCase): # and that it's still not reentrant. gotten = lock.acquire(blocking=False) assert gotten is False - # Test that a second client we can share the same read lock client2 = self._get_client() client2.start() @@ -444,7 +443,6 @@ class KazooLockTests(KazooTestCase): assert lock2.is_acquired is True gotten = lock2.acquire(blocking=False) assert gotten is False - # Test that a writer is unable to share it client3 = self._get_client() client3.start() @@ -741,24 +739,25 @@ class TestSemaphore(KazooTestCase): class TestSequence(unittest.TestCase): - def test_get_sorted_children(self): + def test_get_predecessor(self): + """Validate selection of predecessors. + """ goLock = "_c_8eb60557ba51e0da67eefc47467d3f34-lock-0000000031" pyLock = "514e5a831836450cb1a56c741e990fd8__lock__0000000032" children = ["hello", goLock, "world", pyLock] client = mock.MagicMock() client.get_children.return_value = children lock = Lock(client, "test") - sorted_children = lock._get_sorted_children() - assert len(sorted_children) == 4 - assert sorted_children[0] == pyLock + assert lock._get_predecessor(pyLock) is None - def test_get_sorted_children_go(self): + def test_get_predecessor_go(self): + """Test selection of predecessor when instructed to consider go-zk + locks. + """ goLock = "_c_8eb60557ba51e0da67eefc47467d3f34-lock-0000000031" pyLock = "514e5a831836450cb1a56c741e990fd8__lock__0000000032" children = ["hello", goLock, "world", pyLock] client = mock.MagicMock() client.get_children.return_value = children - lock = Lock(client, "test", additional_lock_patterns=["-lock-"]) - sorted_children = lock._get_sorted_children() - assert len(sorted_children) == 4 - assert sorted_children[0] == goLock + lock = Lock(client, "test", extra_lock_patterns=["-lock-"]) + assert lock._get_predecessor(pyLock) == goLock |