summaryrefslogtreecommitdiff
path: root/redis/asyncio/lock.py
diff options
context:
space:
mode:
authorAndrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com>2022-02-22 05:29:55 -0500
committerGitHub <noreply@github.com>2022-02-22 12:29:55 +0200
commitd56baeb683fc1935cfa343fa2eeb0fa9bd955283 (patch)
tree47357a74bf1d1428cfbcf0d8b2c781f1f971cf77 /redis/asyncio/lock.py
parente3c989d93e914e6502bd5a72f15ded49a135c5be (diff)
downloadredis-py-d56baeb683fc1935cfa343fa2eeb0fa9bd955283.tar.gz
Add Async Support (#1899)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'redis/asyncio/lock.py')
-rw-r--r--redis/asyncio/lock.py311
1 files changed, 311 insertions, 0 deletions
diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py
new file mode 100644
index 0000000..d486132
--- /dev/null
+++ b/redis/asyncio/lock.py
@@ -0,0 +1,311 @@
+import asyncio
+import sys
+import threading
+import uuid
+from types import SimpleNamespace
+from typing import TYPE_CHECKING, Awaitable, NoReturn, Optional, Union
+
+from redis.exceptions import LockError, LockNotOwnedError
+
+if TYPE_CHECKING:
+ from redis.asyncio import Redis
+
+
+class Lock:
+ """
+ A shared, distributed Lock. Using Redis for locking allows the Lock
+ to be shared across processes and/or machines.
+
+ It's left to the user to resolve deadlock issues and make sure
+ multiple clients play nicely together.
+ """
+
+ lua_release = None
+ lua_extend = None
+ lua_reacquire = None
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # return 1 if the lock was released, otherwise 0
+ LUA_RELEASE_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ redis.call('del', KEYS[1])
+ return 1
+ """
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # ARGV[2] - additional milliseconds
+ # ARGV[3] - "0" if the additional time should be added to the lock's
+ # existing ttl or "1" if the existing ttl should be replaced
+ # return 1 if the locks time was extended, otherwise 0
+ LUA_EXTEND_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ local expiration = redis.call('pttl', KEYS[1])
+ if not expiration then
+ expiration = 0
+ end
+ if expiration < 0 then
+ return 0
+ end
+
+ local newttl = ARGV[2]
+ if ARGV[3] == "0" then
+ newttl = ARGV[2] + expiration
+ end
+ redis.call('pexpire', KEYS[1], newttl)
+ return 1
+ """
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # ARGV[2] - milliseconds
+ # return 1 if the locks time was reacquired, otherwise 0
+ LUA_REACQUIRE_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ redis.call('pexpire', KEYS[1], ARGV[2])
+ return 1
+ """
+
+ def __init__(
+ self,
+ redis: "Redis",
+ name: Union[str, bytes, memoryview],
+ timeout: Optional[float] = None,
+ sleep: float = 0.1,
+ blocking: bool = True,
+ blocking_timeout: Optional[float] = None,
+ thread_local: bool = True,
+ ):
+ """
+ Create a new Lock instance named ``name`` using the Redis client
+ supplied by ``redis``.
+
+ ``timeout`` indicates a maximum life for the lock in seconds.
+ By default, it will remain locked until release() is called.
+ ``timeout`` can be specified as a float or integer, both representing
+ the number of seconds to wait.
+
+ ``sleep`` indicates the amount of time to sleep in seconds per loop
+ iteration when the lock is in blocking mode and another client is
+ currently holding the lock.
+
+ ``blocking`` indicates whether calling ``acquire`` should block until
+ the lock has been acquired or to fail immediately, causing ``acquire``
+ to return False and the lock not being acquired. Defaults to True.
+ Note this value can be overridden by passing a ``blocking``
+ argument to ``acquire``.
+
+ ``blocking_timeout`` indicates the maximum amount of time in seconds to
+ spend trying to acquire the lock. A value of ``None`` indicates
+ continue trying forever. ``blocking_timeout`` can be specified as a
+ float or integer, both representing the number of seconds to wait.
+
+ ``thread_local`` indicates whether the lock token is placed in
+ thread-local storage. By default, the token is placed in thread local
+ storage so that a thread only sees its token, not a token set by
+ another thread. Consider the following timeline:
+
+ time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
+ thread-1 sets the token to "abc"
+ time: 1, thread-2 blocks trying to acquire `my-lock` using the
+ Lock instance.
+ time: 5, thread-1 has not yet completed. redis expires the lock
+ key.
+ time: 5, thread-2 acquired `my-lock` now that it's available.
+ thread-2 sets the token to "xyz"
+ time: 6, thread-1 finishes its work and calls release(). if the
+ token is *not* stored in thread local storage, then
+ thread-1 would see the token value as "xyz" and would be
+ able to successfully release the thread-2's lock.
+
+ In some use cases it's necessary to disable thread local storage. For
+ example, if you have code where one thread acquires a lock and passes
+ that lock instance to a worker thread to release later. If thread
+ local storage isn't disabled in this case, the worker thread won't see
+ the token set by the thread that acquired the lock. Our assumption
+ is that these cases aren't common and as such default to using
+ thread local storage.
+ """
+ self.redis = redis
+ self.name = name
+ self.timeout = timeout
+ self.sleep = sleep
+ self.blocking = blocking
+ self.blocking_timeout = blocking_timeout
+ self.thread_local = bool(thread_local)
+ self.local = threading.local() if self.thread_local else SimpleNamespace()
+ self.local.token = None
+ self.register_scripts()
+
+ def register_scripts(self):
+ cls = self.__class__
+ client = self.redis
+ if cls.lua_release is None:
+ cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
+ if cls.lua_extend is None:
+ cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
+ if cls.lua_reacquire is None:
+ cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
+
+ async def __aenter__(self):
+ if await self.acquire():
+ return self
+ raise LockError("Unable to acquire lock within the time specified")
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.release()
+
+ async def acquire(
+ self,
+ blocking: Optional[bool] = None,
+ blocking_timeout: Optional[float] = None,
+ token: Optional[Union[str, bytes]] = None,
+ ):
+ """
+ Use Redis to hold a shared, distributed lock named ``name``.
+ Returns True once the lock is acquired.
+
+ If ``blocking`` is False, always return immediately. If the lock
+ was acquired, return True, otherwise return False.
+
+ ``blocking_timeout`` specifies the maximum number of seconds to
+ wait trying to acquire the lock.
+
+ ``token`` specifies the token value to be used. If provided, token
+ must be a bytes object or a string that can be encoded to a bytes
+ object with the default encoding. If a token isn't specified, a UUID
+ will be generated.
+ """
+ if sys.version_info[0:2] != (3, 6):
+ loop = asyncio.get_running_loop()
+ else:
+ loop = asyncio.get_event_loop()
+
+ sleep = self.sleep
+ if token is None:
+ token = uuid.uuid1().hex.encode()
+ else:
+ encoder = self.redis.connection_pool.get_encoder()
+ token = encoder.encode(token)
+ if blocking is None:
+ blocking = self.blocking
+ if blocking_timeout is None:
+ blocking_timeout = self.blocking_timeout
+ stop_trying_at = None
+ if blocking_timeout is not None:
+ stop_trying_at = loop.time() + blocking_timeout
+ while True:
+ if await self.do_acquire(token):
+ self.local.token = token
+ return True
+ if not blocking:
+ return False
+ next_try_at = loop.time() + sleep
+ if stop_trying_at is not None and next_try_at > stop_trying_at:
+ return False
+ await asyncio.sleep(sleep)
+
+ async def do_acquire(self, token: Union[str, bytes]) -> bool:
+ if self.timeout:
+ # convert to milliseconds
+ timeout = int(self.timeout * 1000)
+ else:
+ timeout = None
+ if await self.redis.set(self.name, token, nx=True, px=timeout):
+ return True
+ return False
+
+ async def locked(self) -> bool:
+ """
+ Returns True if this key is locked by any process, otherwise False.
+ """
+ return await self.redis.get(self.name) is not None
+
+ async def owned(self) -> bool:
+ """
+ Returns True if this key is locked by this lock, otherwise False.
+ """
+ stored_token = await self.redis.get(self.name)
+ # need to always compare bytes to bytes
+ # TODO: this can be simplified when the context manager is finished
+ if stored_token and not isinstance(stored_token, bytes):
+ encoder = self.redis.connection_pool.get_encoder()
+ stored_token = encoder.encode(stored_token)
+ return self.local.token is not None and stored_token == self.local.token
+
+ def release(self) -> Awaitable[NoReturn]:
+ """Releases the already acquired lock"""
+ expected_token = self.local.token
+ if expected_token is None:
+ raise LockError("Cannot release an unlocked lock")
+ self.local.token = None
+ return self.do_release(expected_token)
+
+ async def do_release(self, expected_token: bytes):
+ if not bool(
+ await self.lua_release(
+ keys=[self.name], args=[expected_token], client=self.redis
+ )
+ ):
+ raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")
+
+ def extend(
+ self, additional_time: float, replace_ttl: bool = False
+ ) -> Awaitable[bool]:
+ """
+ Adds more time to an already acquired lock.
+
+ ``additional_time`` can be specified as an integer or a float, both
+ representing the number of seconds to add.
+
+ ``replace_ttl`` if False (the default), add `additional_time` to
+ the lock's existing ttl. If True, replace the lock's ttl with
+ `additional_time`.
+ """
+ if self.local.token is None:
+ raise LockError("Cannot extend an unlocked lock")
+ if self.timeout is None:
+ raise LockError("Cannot extend a lock with no timeout")
+ return self.do_extend(additional_time, replace_ttl)
+
+ async def do_extend(self, additional_time, replace_ttl) -> bool:
+ additional_time = int(additional_time * 1000)
+ if not bool(
+ await self.lua_extend(
+ keys=[self.name],
+ args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
+ client=self.redis,
+ )
+ ):
+ raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
+ return True
+
+ def reacquire(self) -> Awaitable[bool]:
+ """
+ Resets a TTL of an already acquired lock back to a timeout value.
+ """
+ if self.local.token is None:
+ raise LockError("Cannot reacquire an unlocked lock")
+ if self.timeout is None:
+ raise LockError("Cannot reacquire a lock with no timeout")
+ return self.do_reacquire()
+
+ async def do_reacquire(self) -> bool:
+ timeout = int(self.timeout * 1000)
+ if not bool(
+ await self.lua_reacquire(
+ keys=[self.name], args=[self.local.token, timeout], client=self.redis
+ )
+ ):
+ raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
+ return True