From a3cfded93afa2a65908f05ac251b18d77fa84dd2 Mon Sep 17 00:00:00 2001 From: Andy McCurdy Date: Wed, 2 Jan 2019 14:13:44 -0800 Subject: Lock objects now support specifying token values and ownership checking Lock.acquire() can now be provided a token. If provided, this value will be used as the value stored in Redis to hold the lock. Lock.owned() returns a boolean indicating whether the lock is owned by the current instance. --- CHANGES | 6 ++++++ redis/lock.py | 26 ++++++++++++++++++++++++-- tests/test_lock.py | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/CHANGES b/CHANGES index 968b21e..25af3a3 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,10 @@ * 3.1.0 (in development) + * Added an owned method to Lock objects. owned returns a boolean + indicating whether the current lock instance still owns the lock. + Thanks Dave Johansen. #1112 + * Allow lock.acquire() to accept an optional token argument. If + provided, the token argument is used as the unique value used to claim + the lock. Thankd Dave Johansen. #1112 * Added a reacquire method to Lock objects. reaquire attempts to renew the lock such that the timeout is extended to the same value that the lock was initially acquired with. Thanks Ihor Kalnytskyi. #1014 diff --git a/redis/lock.py b/redis/lock.py index 2bb7794..8d481d7 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -149,7 +149,7 @@ class Lock(object): def __exit__(self, exc_type, exc_value, traceback): self.release() - def acquire(self, blocking=None, blocking_timeout=None): + def acquire(self, blocking=None, blocking_timeout=None, token=None): """ Use Redis to hold a shared, distributed lock named ``name``. Returns True once the lock is acquired. @@ -159,9 +159,18 @@ class Lock(object): ``blocking_timeout`` specifies the maximum number of seconds to wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. """ sleep = self.sleep - token = uuid.uuid1().hex.encode() + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.connection_pool.get_encoder() + token = encoder.encode(token) if blocking is None: blocking = self.blocking if blocking_timeout is None: @@ -195,6 +204,19 @@ class Lock(object): """ return self.redis.get(self.name) is not None + def owned(self): + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.connection_pool.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and \ + stored_token == self.local.token + def release(self): "Releases the already acquired lock" expected_token = self.local.token diff --git a/tests/test_lock.py b/tests/test_lock.py index 945fe2f..f92afec 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -2,10 +2,16 @@ import pytest import time from redis.exceptions import LockError, LockNotOwnedError +from redis.client import Redis from redis.lock import Lock +from .conftest import _get_client class TestLock(object): + @pytest.fixture() + def r_decoded(self, request): + return _get_client(Redis, request=request, decode_responses=True) + def get_lock(self, redis, *args, **kwargs): kwargs['lock_class'] = Lock return redis.lock(*args, **kwargs) @@ -18,6 +24,16 @@ class TestLock(object): lock.release() assert r.get('foo') is None + def test_lock_token(self, r): + lock = self.get_lock(r, 'foo') + assert lock.acquire(blocking=False, token='test') + assert r.get('foo') == b'test' + assert lock.local.token == b'test' + assert r.ttl('foo') == -1 + lock.release() + assert r.get('foo') is None + assert lock.local.token is None + def test_locked(self, r): lock = self.get_lock(r, 'foo') assert lock.locked() is False @@ -26,6 +42,30 @@ class TestLock(object): lock.release() assert lock.locked() is False + def _test_owned(self, client): + lock = self.get_lock(client, 'foo') + assert lock.owned() is False + lock.acquire(blocking=False) + assert lock.owned() is True + lock.release() + assert lock.owned() is False + + lock2 = self.get_lock(client, 'foo') + assert lock.owned() is False + assert lock2.owned() is False + lock2.acquire(blocking=False) + assert lock.owned() is False + assert lock2.owned() is True + lock2.release() + assert lock.owned() is False + assert lock2.owned() is False + + def test_owned(self, r): + self._test_owned(r) + + def test_owned_with_decoded_responses(self, r_decoded): + self._test_owned(r_decoded) + def test_competing_locks(self, r): lock1 = self.get_lock(r, 'foo') lock2 = self.get_lock(r, 'foo') -- cgit v1.2.1