summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2018-12-28 14:16:24 -0800
committerGitHub <noreply@github.com>2018-12-28 14:16:24 -0800
commitaf02bb1d5b8bfe00f6dc1b691404af233399bf5d (patch)
tree5d9eb4b533984c108ca2d2b0138a1c52aaf579c8
parente4764ec1f61eb83b1b80a4360235f08d6d391c1e (diff)
parentff120df78ccd85d6e2e2938ee02d1eb831676724 (diff)
downloadredis-py-af02bb1d5b8bfe00f6dc1b691404af233399bf5d.tar.gz
Merge pull request #1014 from ikalnytskyi/lock-reset
Add `.reacquire()` method to Lock
-rw-r--r--redis/lock.py36
-rw-r--r--tests/test_lock.py28
2 files changed, 64 insertions, 0 deletions
diff --git a/redis/lock.py b/redis/lock.py
index 5524d7a..2bb7794 100644
--- a/redis/lock.py
+++ b/redis/lock.py
@@ -16,6 +16,7 @@ class Lock(object):
lua_release = None
lua_extend = None
+ lua_reacquire = None
# KEYS[1] - lock name
# ARGS[1] - token
@@ -49,6 +50,19 @@ class Lock(object):
return 1
"""
+ # KEYS[1] - lock name
+ # ARGS[1] - token
+ # ARGS[2] - milliseconds
+ # return 1 if the locks time was reacquired, otherwise 0
+ LUA_REACQUIRE_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ redis.call('pexpire', KEYS[1], ARGV[2])
+ return 1
+ """
+
def __init__(self, redis, name, timeout=None, sleep=0.1,
blocking=True, blocking_timeout=None, thread_local=True):
"""
@@ -121,6 +135,9 @@ class Lock(object):
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
if cls.lua_extend is None:
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
+ if cls.lua_reacquire is None:
+ cls.lua_reacquire = \
+ client.register_script(cls.LUA_REACQUIRE_SCRIPT)
def __enter__(self):
# force blocking, as otherwise the user would have to check whether
@@ -214,3 +231,22 @@ class Lock(object):
raise LockNotOwnedError("Cannot extend a lock that's"
" no longer owned")
return True
+
+ def reacquire(self):
+ """
+ Resets a TTL of an already acquired lock back to a timeout value.
+ """
+ if self.local.token is None:
+ raise LockError("Cannot reacquire an unlocked lock")
+ if self.timeout is None:
+ raise LockError("Cannot reacquire a lock with no timeout")
+ return self.do_reacquire()
+
+ def do_reacquire(self):
+ timeout = int(self.timeout * 1000)
+ if not bool(self.lua_reacquire(keys=[self.name],
+ args=[self.local.token, timeout],
+ client=self.redis)):
+ raise LockNotOwnedError("Cannot reacquire a lock that's"
+ " no longer owned")
+ return True
diff --git a/tests/test_lock.py b/tests/test_lock.py
index ea45379..945fe2f 100644
--- a/tests/test_lock.py
+++ b/tests/test_lock.py
@@ -125,6 +125,34 @@ class TestLock(object):
with pytest.raises(LockNotOwnedError):
lock.extend(10)
+ def test_reacquire_lock(self, r):
+ lock = self.get_lock(r, 'foo', timeout=10)
+ assert lock.acquire(blocking=False)
+ assert r.pexpire('foo', 5000)
+ assert r.pttl('foo') <= 5000
+ assert lock.reacquire()
+ assert 8000 < r.pttl('foo') <= 10000
+ lock.release()
+
+ def test_reacquiring_unlocked_lock_raises_error(self, r):
+ lock = self.get_lock(r, 'foo', timeout=10)
+ with pytest.raises(LockError):
+ lock.reacquire()
+
+ def test_reacquiring_lock_with_no_timeout_raises_error(self, r):
+ lock = self.get_lock(r, 'foo')
+ assert lock.acquire(blocking=False)
+ with pytest.raises(LockError):
+ lock.reacquire()
+ lock.release()
+
+ def test_reacquiring_lock_no_longer_owned_raises_error(self, r):
+ lock = self.get_lock(r, 'foo', timeout=10)
+ assert lock.acquire(blocking=False)
+ r.set('foo', 'a')
+ with pytest.raises(LockNotOwnedError):
+ lock.reacquire()
+
class TestLockClassSelection(object):
def test_lock_class_argument(self, r):