diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2018-12-28 14:16:24 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-12-28 14:16:24 -0800 |
commit | af02bb1d5b8bfe00f6dc1b691404af233399bf5d (patch) | |
tree | 5d9eb4b533984c108ca2d2b0138a1c52aaf579c8 | |
parent | e4764ec1f61eb83b1b80a4360235f08d6d391c1e (diff) | |
parent | ff120df78ccd85d6e2e2938ee02d1eb831676724 (diff) | |
download | redis-py-af02bb1d5b8bfe00f6dc1b691404af233399bf5d.tar.gz |
Merge pull request #1014 from ikalnytskyi/lock-reset
Add `.reacquire()` method to Lock
-rw-r--r-- | redis/lock.py | 36 | ||||
-rw-r--r-- | tests/test_lock.py | 28 |
2 files changed, 64 insertions, 0 deletions
diff --git a/redis/lock.py b/redis/lock.py index 5524d7a..2bb7794 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -16,6 +16,7 @@ class Lock(object): lua_release = None lua_extend = None + lua_reacquire = None # KEYS[1] - lock name # ARGS[1] - token @@ -49,6 +50,19 @@ class Lock(object): return 1 """ + # KEYS[1] - lock name + # ARGS[1] - token + # ARGS[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + def __init__(self, redis, name, timeout=None, sleep=0.1, blocking=True, blocking_timeout=None, thread_local=True): """ @@ -121,6 +135,9 @@ class Lock(object): cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) if cls.lua_extend is None: cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = \ + client.register_script(cls.LUA_REACQUIRE_SCRIPT) def __enter__(self): # force blocking, as otherwise the user would have to check whether @@ -214,3 +231,22 @@ class Lock(object): raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") return True + + def reacquire(self): + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + def do_reacquire(self): + timeout = int(self.timeout * 1000) + if not bool(self.lua_reacquire(keys=[self.name], + args=[self.local.token, timeout], + client=self.redis)): + raise LockNotOwnedError("Cannot reacquire a lock that's" + " no longer owned") + return True diff --git a/tests/test_lock.py b/tests/test_lock.py index ea45379..945fe2f 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -125,6 +125,34 @@ class TestLock(object): with pytest.raises(LockNotOwnedError): lock.extend(10) + def test_reacquire_lock(self, r): + lock = self.get_lock(r, 'foo', timeout=10) + assert lock.acquire(blocking=False) + assert r.pexpire('foo', 5000) + assert r.pttl('foo') <= 5000 + assert lock.reacquire() + assert 8000 < r.pttl('foo') <= 10000 + lock.release() + + def test_reacquiring_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, 'foo', timeout=10) + with pytest.raises(LockError): + lock.reacquire() + + def test_reacquiring_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, 'foo') + assert lock.acquire(blocking=False) + with pytest.raises(LockError): + lock.reacquire() + lock.release() + + def test_reacquiring_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, 'foo', timeout=10) + assert lock.acquire(blocking=False) + r.set('foo', 'a') + with pytest.raises(LockNotOwnedError): + lock.reacquire() + class TestLockClassSelection(object): def test_lock_class_argument(self, r): |