summaryrefslogtreecommitdiff
path: root/redis/asyncio/lock.py
diff options
context:
space:
mode:
authorMilhan <kimmilhan@gmail.com>2022-11-09 21:22:05 +0900
committerGitHub <noreply@github.com>2022-11-09 14:22:05 +0200
commit772079fabd7453edf3788d0c31b9caf21ff5deca (patch)
tree32c6be41e88676ae2f24c4689ad651e99ac16665 /redis/asyncio/lock.py
parent1cdba6302c694e2d0aaca1e336185217fdc1c136 (diff)
downloadredis-py-772079fabd7453edf3788d0c31b9caf21ff5deca.tar.gz
Enable AsyncIO cluster mode lock (#2446)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
Diffstat (limited to 'redis/asyncio/lock.py')
-rw-r--r--redis/asyncio/lock.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py
index 8ede59b..7f45c8b 100644
--- a/redis/asyncio/lock.py
+++ b/redis/asyncio/lock.py
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Awaitable, Optional, Union
from redis.exceptions import LockError, LockNotOwnedError
if TYPE_CHECKING:
- from redis.asyncio import Redis
+ from redis.asyncio import Redis, RedisCluster
class Lock:
@@ -77,7 +77,7 @@ class Lock:
def __init__(
self,
- redis: "Redis",
+ redis: Union["Redis", "RedisCluster"],
name: Union[str, bytes, memoryview],
timeout: Optional[float] = None,
sleep: float = 0.1,
@@ -189,7 +189,11 @@ class Lock:
if token is None:
token = uuid.uuid1().hex.encode()
else:
- encoder = self.redis.connection_pool.get_encoder()
+ try:
+ encoder = self.redis.connection_pool.get_encoder()
+ except AttributeError:
+ # Cluster
+ encoder = self.redis.get_encoder()
token = encoder.encode(token)
if blocking is None:
blocking = self.blocking
@@ -233,7 +237,11 @@ class Lock:
# need to always compare bytes to bytes
# TODO: this can be simplified when the context manager is finished
if stored_token and not isinstance(stored_token, bytes):
- encoder = self.redis.connection_pool.get_encoder()
+ try:
+ encoder = self.redis.connection_pool.get_encoder()
+ except AttributeError:
+ # Cluster
+ encoder = self.redis.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token