diff options
Diffstat (limited to 'redis/commands')
-rw-r--r-- | redis/commands/__init__.py | 6 | ||||
-rw-r--r-- | redis/commands/core.py | 1829 | ||||
-rw-r--r-- | redis/commands/sentinel.py | 24 |
3 files changed, 1256 insertions, 603 deletions
diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index 07fa7f1..b9dd0b7 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,15 +1,17 @@ from .cluster import RedisClusterCommands -from .core import CoreCommands +from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args from .parser import CommandsParser from .redismodules import RedisModuleCommands -from .sentinel import SentinelCommands +from .sentinel import AsyncSentinelCommands, SentinelCommands __all__ = [ "RedisClusterCommands", "CommandsParser", + "AsyncCoreCommands", "CoreCommands", "list_or_args", "RedisModuleCommands", + "AsyncSentinelCommands", "SentinelCommands", ] diff --git a/redis/commands/core.py b/redis/commands/core.py index 1e221c9..80bc55f 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,21 +1,64 @@ +# from __future__ import annotations + import datetime import hashlib import time import warnings -from typing import List, Optional, Union - +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +from redis.compat import Literal from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError +from redis.typing import ( + AbsExpiryT, + AnyKeyT, + BitfieldOffsetT, + ChannelT, + CommandsProtocol, + ConsumerT, + EncodableT, + ExpiryT, + FieldT, + GroupT, + KeysT, + KeyT, + PatternT, + ScriptTextT, + StreamIdT, + TimeoutSecT, + ZScoreBoundT, +) from .helpers import list_or_args +if TYPE_CHECKING: + from redis.asyncio.client import Redis as AsyncRedis + from redis.client import Redis + +ResponseT = Union[Awaitable, Any] -class ACLCommands: + +class ACLCommands(CommandsProtocol): """ Redis Access Control List (ACL) commands. see: https://redis.io/topics/acl """ - def acl_cat(self, category=None, **kwargs): + def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a list of categories or commands within a category. @@ -25,7 +68,7 @@ class ACLCommands: For more information check https://redis.io/commands/acl-cat """ - pieces = [category] if category else [] + pieces: list[EncodableT] = [category] if category else [] return self.execute_command("ACL CAT", *pieces, **kwargs) def acl_dryrun(self, username, *args, **kwargs): @@ -36,7 +79,7 @@ class ACLCommands: """ return self.execute_command("ACL DRYRUN", username, *args, **kwargs) - def acl_deluser(self, *username, **kwargs): + def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ Delete the ACL for the specified ``username``s @@ -44,7 +87,7 @@ class ACLCommands: """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits=None, **kwargs): + def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -62,7 +105,7 @@ class ACLCommands: ) return self.execute_command("ACL GENPASS", *pieces, **kwargs) - def acl_getuser(self, username, **kwargs): + def acl_getuser(self, username: str, **kwargs) -> ResponseT: """ Get the ACL details for the specified ``username``. @@ -72,7 +115,7 @@ class ACLCommands: """ return self.execute_command("ACL GETUSER", username, **kwargs) - def acl_help(self, **kwargs): + def acl_help(self, **kwargs) -> ResponseT: """The ACL HELP command returns helpful text describing the different subcommands. @@ -80,7 +123,7 @@ class ACLCommands: """ return self.execute_command("ACL HELP", **kwargs) - def acl_list(self, **kwargs): + def acl_list(self, **kwargs) -> ResponseT: """ Return a list of all ACLs on the server @@ -88,7 +131,7 @@ class ACLCommands: """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count=None, **kwargs): + def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -104,7 +147,7 @@ class ACLCommands: return self.execute_command("ACL LOG", *args, **kwargs) - def acl_log_reset(self, **kwargs): + def acl_log_reset(self, **kwargs) -> ResponseT: """ Reset ACL logs. :rtype: Boolean. @@ -114,7 +157,7 @@ class ACLCommands: args = [b"RESET"] return self.execute_command("ACL LOG", *args, **kwargs) - def acl_load(self, **kwargs): + def acl_load(self, **kwargs) -> ResponseT: """ Load ACL rules from the configured ``aclfile``. @@ -125,7 +168,7 @@ class ACLCommands: """ return self.execute_command("ACL LOAD", **kwargs) - def acl_save(self, **kwargs): + def acl_save(self, **kwargs) -> ResponseT: """ Save ACL rules to the configured ``aclfile``. @@ -138,19 +181,19 @@ class ACLCommands: def acl_setuser( self, - username, - enabled=False, - nopass=False, - passwords=None, - hashed_passwords=None, - categories=None, - commands=None, - keys=None, - reset=False, - reset_keys=False, - reset_passwords=False, + username: str, + enabled: bool = False, + nopass: bool = False, + passwords: Union[str, Iterable[str], None] = None, + hashed_passwords: Union[str, Iterable[str], None] = None, + categories: Union[Iterable[str], None] = None, + commands: Union[Iterable[str], None] = None, + keys: Union[Iterable[KeyT], None] = None, + reset: bool = False, + reset_keys: bool = False, + reset_passwords: bool = False, **kwargs, - ): + ) -> ResponseT: """ Create or update an ACL user. @@ -213,7 +256,7 @@ class ACLCommands: For more information check https://redis.io/commands/acl-setuser """ encoder = self.get_encoder() - pieces = [username] + pieces: list[str | bytes] = [username] if reset: pieces.append(b"reset") @@ -303,14 +346,14 @@ class ACLCommands: return self.execute_command("ACL SETUSER", *pieces, **kwargs) - def acl_users(self, **kwargs): + def acl_users(self, **kwargs) -> ResponseT: """Returns a list of all registered users on the server. For more information check https://redis.io/commands/acl-users """ return self.execute_command("ACL USERS", **kwargs) - def acl_whoami(self, **kwargs): + def acl_whoami(self, **kwargs) -> ResponseT: """Get the username for the current connection For more information check https://redis.io/commands/acl-whoami @@ -318,7 +361,10 @@ class ACLCommands: return self.execute_command("ACL WHOAMI", **kwargs) -class ManagementCommands: +AsyncACLCommands = ACLCommands + + +class ManagementCommands(CommandsProtocol): """ Redis management commands """ @@ -339,7 +385,7 @@ class ManagementCommands: """ return self.execute_command("BGREWRITEAOF", **kwargs) - def bgsave(self, schedule=True, **kwargs): + def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: """ Tell the Redis server to save its data to disk. Unlike save(), this method is asynchronous and returns immediately. @@ -351,7 +397,7 @@ class ManagementCommands: pieces.append("SCHEDULE") return self.execute_command("BGSAVE", *pieces, **kwargs) - def role(self): + def role(self) -> ResponseT: """ Provide information on the role of a Redis instance in the context of replication, by returning if the instance @@ -361,7 +407,7 @@ class ManagementCommands: """ return self.execute_command("ROLE") - def client_kill(self, address, **kwargs): + def client_kill(self, address: str, **kwargs) -> ResponseT: """Disconnects the client at ``address`` (ip:port) For more information check https://redis.io/commands/client-kill @@ -370,18 +416,18 @@ class ManagementCommands: def client_kill_filter( self, - _id=None, - _type=None, - addr=None, - skipme=None, - laddr=None, - user=None, + _id: Union[str, None] = None, + _type: Union[str, None] = None, + addr: Union[str, None] = None, + skipme: Union[bool, None] = None, + laddr: Union[bool, None] = None, + user: str = None, **kwargs, - ): + ) -> ResponseT: """ Disconnects client(s) using a variety of filter options - :param id: Kills a client by its unique ID field - :param type: Kills a client by type where type is one of 'normal', + :param _id: Kills a client by its unique ID field + :param _type: Kills a client by type where type is one of 'normal', 'master', 'slave' or 'pubsub' :param addr: Kills a client by its 'address:port' :param skipme: If True, then the client calling the command @@ -418,7 +464,7 @@ class ManagementCommands: ) return self.execute_command("CLIENT KILL", *args, **kwargs) - def client_info(self, **kwargs): + def client_info(self, **kwargs) -> ResponseT: """ Returns information and statistics about the current client connection. @@ -427,7 +473,12 @@ class ManagementCommands: """ return self.execute_command("CLIENT INFO", **kwargs) - def client_list(self, _type=None, client_id=[], **kwargs): + def client_list( + self, + _type: Union[str, None] = None, + client_id: List[EncodableT] = [], + **kwargs, + ) -> ResponseT: """ Returns a list of currently connected clients. If type of client specified, only that type will be returned. @@ -446,12 +497,12 @@ class ManagementCommands: args.append(_type) if not isinstance(client_id, list): raise DataError("client_id must be a list") - if client_id != []: + if client_id: args.append(b"ID") args.append(" ".join(client_id)) return self.execute_command("CLIENT LIST", *args, **kwargs) - def client_getname(self, **kwargs): + def client_getname(self, **kwargs) -> ResponseT: """ Returns the current connection name @@ -459,7 +510,7 @@ class ManagementCommands: """ return self.execute_command("CLIENT GETNAME", **kwargs) - def client_getredir(self, **kwargs): + def client_getredir(self, **kwargs) -> ResponseT: """ Returns the ID (an integer) of the client to whom we are redirecting tracking notifications. @@ -468,7 +519,11 @@ class ManagementCommands: """ return self.execute_command("CLIENT GETREDIR", **kwargs) - def client_reply(self, reply, **kwargs): + def client_reply( + self, + reply: Union[Literal["ON"], Literal["OFF"], Literal["SKIP"]], + **kwargs, + ) -> ResponseT: """ Enable and disable redis server replies. ``reply`` Must be ON OFF or SKIP, @@ -489,7 +544,7 @@ class ManagementCommands: raise DataError(f"CLIENT REPLY must be one of {replies!r}") return self.execute_command("CLIENT REPLY", reply, **kwargs) - def client_id(self, **kwargs): + def client_id(self, **kwargs) -> ResponseT: """ Returns the current connection id @@ -499,13 +554,13 @@ class ManagementCommands: def client_tracking_on( self, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, - ): + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: """ Turn on the tracking mode. For more information about the options look at client_tracking func. @@ -518,13 +573,13 @@ class ManagementCommands: def client_tracking_off( self, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, - ): + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: """ Turn off the tracking mode. For more information about the options look at client_tracking func. @@ -537,15 +592,15 @@ class ManagementCommands: def client_tracking( self, - on=True, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, + on: bool = True, + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, **kwargs, - ): + ) -> ResponseT: """ Enables the tracking feature of the Redis server, that is used for server assisted client side caching. @@ -595,7 +650,7 @@ class ManagementCommands: return self.execute_command("CLIENT TRACKING", *pieces) - def client_trackinginfo(self, **kwargs): + def client_trackinginfo(self, **kwargs) -> ResponseT: """ Returns the information about the current client connection's use of the server assisted client side cache. @@ -604,7 +659,7 @@ class ManagementCommands: """ return self.execute_command("CLIENT TRACKINGINFO", **kwargs) - def client_setname(self, name, **kwargs): + def client_setname(self, name: str, **kwargs) -> ResponseT: """ Sets the current connection name @@ -612,7 +667,12 @@ class ManagementCommands: """ return self.execute_command("CLIENT SETNAME", name, **kwargs) - def client_unblock(self, client_id, error=False, **kwargs): + def client_unblock( + self, + client_id: int, + error: bool = False, + **kwargs, + ) -> ResponseT: """ Unblocks a connection by its client id. If ``error`` is True, unblocks the client with a special error message. @@ -626,7 +686,7 @@ class ManagementCommands: args.append(b"ERROR") return self.execute_command(*args, **kwargs) - def client_pause(self, timeout, all=True, **kwargs): + def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: """ Suspend all the Redis clients for the specified amount of time :param timeout: milliseconds to pause clients @@ -649,7 +709,7 @@ class ManagementCommands: args.append("WRITE") return self.execute_command(*args, **kwargs) - def client_unpause(self, **kwargs): + def client_unpause(self, **kwargs) -> ResponseT: """ Unpause all redis clients @@ -673,15 +733,15 @@ class ManagementCommands: """ return self.execute_command("COMMAND", **kwargs) - def command_info(self, **kwargs): + def command_info(self, **kwargs) -> None: raise NotImplementedError( "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self, **kwargs): + def command_count(self, **kwargs) -> ResponseT: return self.execute_command("COMMAND COUNT", **kwargs) - def config_get(self, pattern="*", **kwargs): + def config_get(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a dictionary of configuration based on the ``pattern`` @@ -689,14 +749,14 @@ class ManagementCommands: """ return self.execute_command("CONFIG GET", pattern, **kwargs) - def config_set(self, name, value, **kwargs): + def config_set(self, name: KeyT, value: EncodableT, **kwargs) -> ResponseT: """Set config item ``name`` with ``value`` For more information check https://redis.io/commands/config-set """ return self.execute_command("CONFIG SET", name, value, **kwargs) - def config_resetstat(self, **kwargs): + def config_resetstat(self, **kwargs) -> ResponseT: """ Reset runtime statistics @@ -704,7 +764,7 @@ class ManagementCommands: """ return self.execute_command("CONFIG RESETSTAT", **kwargs) - def config_rewrite(self, **kwargs): + def config_rewrite(self, **kwargs) -> ResponseT: """ Rewrite config file with the minimal change to reflect running config. @@ -712,7 +772,7 @@ class ManagementCommands: """ return self.execute_command("CONFIG REWRITE", **kwargs) - def dbsize(self, **kwargs): + def dbsize(self, **kwargs) -> ResponseT: """ Returns the number of keys in the current database @@ -720,7 +780,7 @@ class ManagementCommands: """ return self.execute_command("DBSIZE", **kwargs) - def debug_object(self, key, **kwargs): + def debug_object(self, key: KeyT, **kwargs) -> ResponseT: """ Returns version specific meta information about a given key @@ -728,7 +788,7 @@ class ManagementCommands: """ return self.execute_command("DEBUG OBJECT", key, **kwargs) - def debug_segfault(self, **kwargs): + def debug_segfault(self, **kwargs) -> None: raise NotImplementedError( """ DEBUG SEGFAULT is intentionally not implemented in the client. @@ -737,7 +797,7 @@ class ManagementCommands: """ ) - def echo(self, value, **kwargs): + def echo(self, value: EncodableT, **kwargs) -> ResponseT: """ Echo the string back from the server @@ -745,7 +805,7 @@ class ManagementCommands: """ return self.execute_command("ECHO", value, **kwargs) - def flushall(self, asynchronous=False, **kwargs): + def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: """ Delete all keys in all databases on the current host. @@ -759,7 +819,7 @@ class ManagementCommands: args.append(b"ASYNC") return self.execute_command("FLUSHALL", *args, **kwargs) - def flushdb(self, asynchronous=False, **kwargs): + def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: """ Delete all keys in the current database. @@ -773,7 +833,7 @@ class ManagementCommands: args.append(b"ASYNC") return self.execute_command("FLUSHDB", *args, **kwargs) - def sync(self): + def sync(self) -> ResponseT: """ Initiates a replication stream from the master. @@ -785,7 +845,7 @@ class ManagementCommands: options[NEVER_DECODE] = [] return self.execute_command("SYNC", **options) - def psync(self, replicationid, offset): + def psync(self, replicationid: str, offset: int): """ Initiates a replication stream from the master. Newer version for `sync`. @@ -798,7 +858,7 @@ class ManagementCommands: options[NEVER_DECODE] = [] return self.execute_command("PSYNC", replicationid, offset, **options) - def swapdb(self, first, second, **kwargs): + def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: """ Swap two databases @@ -806,14 +866,14 @@ class ManagementCommands: """ return self.execute_command("SWAPDB", first, second, **kwargs) - def select(self, index, **kwargs): + def select(self, index: int, **kwargs) -> ResponseT: """Select the Redis logical database at index. See: https://redis.io/commands/select """ return self.execute_command("SELECT", index, **kwargs) - def info(self, section=None, **kwargs): + def info(self, section: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a dictionary containing information about the Redis server @@ -830,7 +890,7 @@ class ManagementCommands: else: return self.execute_command("INFO", section, **kwargs) - def lastsave(self, **kwargs): + def lastsave(self, **kwargs) -> ResponseT: """ Return a Python datetime object representing the last time the Redis database was saved to disk @@ -839,7 +899,7 @@ class ManagementCommands: """ return self.execute_command("LASTSAVE", **kwargs) - def lolwut(self, *version_numbers, **kwargs): + def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: """ Get the Redis version and a piece of generative computer art @@ -850,7 +910,7 @@ class ManagementCommands: else: return self.execute_command("LOLWUT", **kwargs) - def reset(self): + def reset(self) -> ResponseT: """Perform a full reset on the connection's server side contenxt. See: https://redis.io/commands/reset @@ -859,16 +919,16 @@ class ManagementCommands: def migrate( self, - host, - port, - keys, - destination_db, - timeout, - copy=False, - replace=False, - auth=None, + host: str, + port: int, + keys: KeysT, + destination_db: int, + timeout: int, + copy: bool = False, + replace: bool = False, + auth: Union[str, None] = None, **kwargs, - ): + ) -> ResponseT: """ Migrate 1 or more keys from the current Redis server to a different server specified by the ``host``, ``port`` and ``destination_db``. @@ -905,7 +965,7 @@ class ManagementCommands: "MIGRATE", host, port, "", destination_db, timeout, *pieces, **kwargs ) - def object(self, infotype, key, **kwargs): + def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT: """ Return the encoding, idletime, or refcount about the key """ @@ -913,7 +973,7 @@ class ManagementCommands: "OBJECT", infotype, key, infotype=infotype, **kwargs ) - def memory_doctor(self, **kwargs): + def memory_doctor(self, **kwargs) -> None: raise NotImplementedError( """ MEMORY DOCTOR is intentionally not implemented in the client. @@ -922,7 +982,7 @@ class ManagementCommands: """ ) - def memory_help(self, **kwargs): + def memory_help(self, **kwargs) -> None: raise NotImplementedError( """ MEMORY HELP is intentionally not implemented in the client. @@ -931,7 +991,7 @@ class ManagementCommands: """ ) - def memory_stats(self, **kwargs): + def memory_stats(self, **kwargs) -> ResponseT: """ Return a dictionary of memory stats @@ -939,7 +999,7 @@ class ManagementCommands: """ return self.execute_command("MEMORY STATS", **kwargs) - def memory_malloc_stats(self, **kwargs): + def memory_malloc_stats(self, **kwargs) -> ResponseT: """ Return an internal statistics report from the memory allocator. @@ -947,7 +1007,9 @@ class ManagementCommands: """ return self.execute_command("MEMORY MALLOC-STATS", **kwargs) - def memory_usage(self, key, samples=None, **kwargs): + def memory_usage( + self, key: KeyT, samples: Union[int, None] = None, **kwargs + ) -> ResponseT: """ Return the total memory usage for key, its value and associated administrative overheads. @@ -963,7 +1025,7 @@ class ManagementCommands: args.extend([b"SAMPLES", samples]) return self.execute_command("MEMORY USAGE", key, *args, **kwargs) - def memory_purge(self, **kwargs): + def memory_purge(self, **kwargs) -> ResponseT: """ Attempts to purge dirty pages for reclamation by allocator @@ -971,7 +1033,7 @@ class ManagementCommands: """ return self.execute_command("MEMORY PURGE", **kwargs) - def ping(self, **kwargs): + def ping(self, **kwargs) -> ResponseT: """ Ping the Redis server @@ -979,7 +1041,7 @@ class ManagementCommands: """ return self.execute_command("PING", **kwargs) - def quit(self, **kwargs): + def quit(self, **kwargs) -> ResponseT: """ Ask the server to close the connection. @@ -987,7 +1049,7 @@ class ManagementCommands: """ return self.execute_command("QUIT", **kwargs) - def replicaof(self, *args, **kwargs): + def replicaof(self, *args, **kwargs) -> ResponseT: """ Update the replication settings of a redis replica, on the fly. Examples of valid arguments include: @@ -998,7 +1060,7 @@ class ManagementCommands: """ return self.execute_command("REPLICAOF", *args, **kwargs) - def save(self, **kwargs): + def save(self, **kwargs) -> ResponseT: """ Tell the Redis server to save its data to disk, blocking until the save is complete @@ -1007,7 +1069,7 @@ class ManagementCommands: """ return self.execute_command("SAVE", **kwargs) - def shutdown(self, save=False, nosave=False, **kwargs): + def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -1030,7 +1092,9 @@ class ManagementCommands: return raise RedisError("SHUTDOWN seems to have failed.") - def slaveof(self, host=None, port=None, **kwargs): + def slaveof( + self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs + ) -> ResponseT: """ Set the server to be a replicated slave of the instance identified by the ``host`` and ``port``. If called without arguments, the @@ -1042,7 +1106,7 @@ class ManagementCommands: return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num=None, **kwargs): + def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1059,7 +1123,7 @@ class ManagementCommands: kwargs[NEVER_DECODE] = [] return self.execute_command(*args, **kwargs) - def slowlog_len(self, **kwargs): + def slowlog_len(self, **kwargs) -> ResponseT: """ Get the number of items in the slowlog @@ -1067,7 +1131,7 @@ class ManagementCommands: """ return self.execute_command("SLOWLOG LEN", **kwargs) - def slowlog_reset(self, **kwargs): + def slowlog_reset(self, **kwargs) -> ResponseT: """ Remove all items in the slowlog @@ -1075,7 +1139,7 @@ class ManagementCommands: """ return self.execute_command("SLOWLOG RESET", **kwargs) - def time(self, **kwargs): + def time(self, **kwargs) -> ResponseT: """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). @@ -1084,7 +1148,7 @@ class ManagementCommands: """ return self.execute_command("TIME", **kwargs) - def wait(self, num_replicas, timeout, **kwargs): + def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: """ Redis synchronous replication That returns the number of replicas that processed the query when @@ -1114,12 +1178,166 @@ class ManagementCommands: ) -class BasicKeyCommands: +AsyncManagementCommands = ManagementCommands + + +class AsyncManagementCommands(ManagementCommands): + async def command_info(self, **kwargs) -> None: + return super().command_info(**kwargs) + + async def debug_segfault(self, **kwargs) -> None: + return super().debug_segfault(**kwargs) + + async def memory_doctor(self, **kwargs) -> None: + return super().memory_doctor(**kwargs) + + async def memory_help(self, **kwargs) -> None: + return super().memory_help(**kwargs) + + async def shutdown( + self, save: bool = False, nosave: bool = False, **kwargs + ) -> None: + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + + For more information check https://redis.io/commands/shutdown + """ + if save and nosave: + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] + if save: + args.append("SAVE") + if nosave: + args.append("NOSAVE") + try: + await self.execute_command(*args, **kwargs) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + +class BitFieldOperation: + """ + Command builder for BITFIELD commands. + """ + + def __init__( + self, + client: Union["Redis", "AsyncRedis"], + key: str, + default_overflow: Union[str, None] = None, + ): + self.client = client + self.key = key + self._default_overflow = default_overflow + # for typing purposes, run the following in constructor and in reset() + self.operations: list[tuple[EncodableT, ...]] = [] + self._last_overflow = "WRAP" + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = "WRAP" + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow: str): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(("OVERFLOW", overflow)) + return self + + def incrby( + self, + fmt: str, + offset: BitfieldOffsetT, + increment: int, + overflow: Union[str, None] = None, + ): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(("INCRBY", fmt, offset, increment)) + return self + + def get(self, fmt: str, offset: BitfieldOffsetT): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("GET", fmt, offset)) + return self + + def set(self, fmt: str, offset: BitfieldOffsetT, value: int): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("SET", fmt, offset, value)) + return self + + @property + def command(self): + cmd = ["BITFIELD", self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self) -> ResponseT: + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) + + +class BasicKeyCommands(CommandsProtocol): """ Redis basic key-based commands """ - def append(self, key, value): + def append(self, key: KeyT, value: EncodableT) -> ResponseT: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1129,7 +1347,12 @@ class BasicKeyCommands: """ return self.execute_command("APPEND", key, value) - def bitcount(self, key, start=None, end=None): + def bitcount( + self, + key: KeyT, + start: Union[int, None] = None, + end: Union[int, None] = None, + ) -> ResponseT: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1144,7 +1367,11 @@ class BasicKeyCommands: raise DataError("Both start and end must be specified") return self.execute_command("BITCOUNT", *params) - def bitfield(self, key, default_overflow=None): + def bitfield( + self: Union["Redis", "AsyncRedis"], + key: KeyT, + default_overflow: Union[str, None] = None, + ) -> BitFieldOperation: """ Return a BitFieldOperation instance to conveniently construct one or more bitfield operations on ``key``. @@ -1153,7 +1380,12 @@ class BasicKeyCommands: """ return BitFieldOperation(self, key, default_overflow=default_overflow) - def bitop(self, operation, dest, *keys): + def bitop( + self, + operation: str, + dest: KeyT, + *keys: KeyT, + ) -> ResponseT: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1162,7 +1394,13 @@ class BasicKeyCommands: """ return self.execute_command("BITOP", operation, dest, *keys) - def bitpos(self, key, bit, start=None, end=None): + def bitpos( + self, + key: KeyT, + bit: int, + start: Union[int, None] = None, + end: Union[int, None] = None, + ) -> ResponseT: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1183,7 +1421,13 @@ class BasicKeyCommands: raise DataError("start argument is not set, " "when end is specified") return self.execute_command("BITPOS", *params) - def copy(self, source, destination, destination_db=None, replace=False): + def copy( + self, + source: str, + destination: str, + destination_db: Union[str, None] = None, + replace: bool = False, + ) -> ResponseT: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1203,7 +1447,7 @@ class BasicKeyCommands: params.append("REPLACE") return self.execute_command("COPY", *params) - def decrby(self, name, amount=1): + def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: """ Decrements the value of ``key`` by ``amount``. If no key exists, the value will be initialized as 0 - ``amount`` @@ -1214,16 +1458,16 @@ class BasicKeyCommands: decr = decrby - def delete(self, *names): + def delete(self, *names: KeyT) -> ResponseT: """ Delete one or more keys specified by ``names`` """ return self.execute_command("DEL", *names) - def __delitem__(self, name): + def __delitem__(self, name: KeyT): self.delete(name) - def dump(self, name): + def dump(self, name: KeyT) -> ResponseT: """ Return a serialized version of the value stored at the specified key. If key does not exist a nil bulk reply is returned. @@ -1236,7 +1480,7 @@ class BasicKeyCommands: options[NEVER_DECODE] = [] return self.execute_command("DUMP", name, **options) - def exists(self, *names): + def exists(self, *names: KeyT) -> ResponseT: """ Returns the number of ``names`` that exist @@ -1246,7 +1490,7 @@ class BasicKeyCommands: __contains__ = exists - def expire(self, name, time): + def expire(self, name: KeyT, time: ExpiryT) -> ResponseT: """ Set an expire flag on key ``name`` for ``time`` seconds. ``time`` can be represented by an integer or a Python timedelta object. @@ -1257,7 +1501,7 @@ class BasicKeyCommands: time = int(time.total_seconds()) return self.execute_command("EXPIRE", name, time) - def expireat(self, name, when): + def expireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ Set an expire flag on key ``name``. ``when`` can be represented as an integer indicating unix time or a Python datetime object. @@ -1268,7 +1512,7 @@ class BasicKeyCommands: when = int(time.mktime(when.timetuple())) return self.execute_command("EXPIREAT", name, when) - def get(self, name): + def get(self, name: KeyT) -> ResponseT: """ Return the value at key ``name``, or None if the key doesn't exist @@ -1276,7 +1520,7 @@ class BasicKeyCommands: """ return self.execute_command("GET", name) - def getdel(self, name): + def getdel(self, name: KeyT) -> ResponseT: """ Get the value at key ``name`` and delete the key. This command is similar to GET, except for the fact that it also deletes @@ -1287,7 +1531,15 @@ class BasicKeyCommands: """ return self.execute_command("GETDEL", name) - def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): + def getex( + self, + name: KeyT, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, + persist: bool = False, + ) -> ResponseT: """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1316,7 +1568,7 @@ class BasicKeyCommands: "and ``persist`` are mutually exclusive." ) - pieces = [] + pieces: list[EncodableT] = [] # similar to set command if ex is not None: pieces.append("EX") @@ -1346,7 +1598,7 @@ class BasicKeyCommands: return self.execute_command("GETEX", name, *pieces) - def __getitem__(self, name): + def __getitem__(self, name: KeyT): """ Return the value at key ``name``, raises a KeyError if the key doesn't exist. @@ -1356,7 +1608,7 @@ class BasicKeyCommands: return value raise KeyError(name) - def getbit(self, name, offset): + def getbit(self, name: KeyT, offset: int) -> ResponseT: """ Returns a boolean indicating the value of ``offset`` in ``name`` @@ -1364,7 +1616,7 @@ class BasicKeyCommands: """ return self.execute_command("GETBIT", name, offset) - def getrange(self, key, start, end): + def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ Returns the substring of the string value stored at ``key``, determined by the offsets ``start`` and ``end`` (both are inclusive) @@ -1373,7 +1625,7 @@ class BasicKeyCommands: """ return self.execute_command("GETRANGE", key, start, end) - def getset(self, name, value): + def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ Sets the value at key ``name`` to ``value`` and returns the old value at key ``name`` atomically. @@ -1385,7 +1637,7 @@ class BasicKeyCommands: """ return self.execute_command("GETSET", name, value) - def incrby(self, name, amount=1): + def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: """ Increments the value of ``key`` by ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1396,7 +1648,7 @@ class BasicKeyCommands: incr = incrby - def incrbyfloat(self, name, amount=1.0): + def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: """ Increments the value at key ``name`` by floating ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1405,7 +1657,7 @@ class BasicKeyCommands: """ return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern="*", **kwargs): + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Returns a list of keys matching ``pattern`` @@ -1413,7 +1665,13 @@ class BasicKeyCommands: """ return self.execute_command("KEYS", pattern, **kwargs) - def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): + def lmove( + self, + first_list: str, + second_list: str, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: """ Atomically returns and removes the first/last element of a list, pushing it as the first/last element on the destination list. @@ -1424,7 +1682,14 @@ class BasicKeyCommands: params = [first_list, second_list, src, dest] return self.execute_command("LMOVE", *params) - def blmove(self, first_list, second_list, timeout, src="LEFT", dest="RIGHT"): + def blmove( + self, + first_list: str, + second_list: str, + timeout: int, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: """ Blocking version of lmove. @@ -1433,7 +1698,7 @@ class BasicKeyCommands: params = [first_list, second_list, src, dest, timeout] return self.execute_command("BLMOVE", *params) - def mget(self, keys, *args): + def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: """ Returns a list of values ordered identically to ``keys`` @@ -1447,7 +1712,7 @@ class BasicKeyCommands: options[EMPTY_RESPONSE] = [] return self.execute_command("MGET", *args, **options) - def mset(self, mapping): + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -1460,7 +1725,7 @@ class BasicKeyCommands: items.extend(pair) return self.execute_command("MSET", *items) - def msetnx(self, mapping): + def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: """ Sets key/values based on a mapping if none of the keys are already set. Mapping is a dictionary of key/value pairs. Both keys and values @@ -1474,7 +1739,7 @@ class BasicKeyCommands: items.extend(pair) return self.execute_command("MSETNX", *items) - def move(self, name, db): + def move(self, name: KeyT, db: int) -> ResponseT: """ Moves the key ``name`` to a different Redis database ``db`` @@ -1482,7 +1747,7 @@ class BasicKeyCommands: """ return self.execute_command("MOVE", name, db) - def persist(self, name): + def persist(self, name: KeyT) -> ResponseT: """ Removes an expiration on ``name`` @@ -1490,7 +1755,7 @@ class BasicKeyCommands: """ return self.execute_command("PERSIST", name) - def pexpire(self, name, time): + def pexpire(self, name: KeyT, time: ExpiryT) -> ResponseT: """ Set an expire flag on key ``name`` for ``time`` milliseconds. ``time`` can be represented by an integer or a Python timedelta @@ -1502,7 +1767,7 @@ class BasicKeyCommands: time = int(time.total_seconds() * 1000) return self.execute_command("PEXPIRE", name, time) - def pexpireat(self, name, when): + def pexpireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ Set an expire flag on key ``name``. ``when`` can be represented as an integer representing unix time in milliseconds (unix time * 1000) @@ -1515,7 +1780,12 @@ class BasicKeyCommands: when = int(time.mktime(when.timetuple())) * 1000 + ms return self.execute_command("PEXPIREAT", name, when) - def psetex(self, name, time_ms, value): + def psetex( + self, + name: KeyT, + time_ms: ExpiryT, + value: EncodableT, + ): """ Set the value of key ``name`` to ``value`` that expires in ``time_ms`` milliseconds. ``time_ms`` can be represented by an integer or a Python @@ -1527,7 +1797,7 @@ class BasicKeyCommands: time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command("PSETEX", name, time_ms, value) - def pttl(self, name): + def pttl(self, name: KeyT) -> ResponseT: """ Returns the number of milliseconds until the key ``name`` will expire @@ -1535,7 +1805,12 @@ class BasicKeyCommands: """ return self.execute_command("PTTL", name) - def hrandfield(self, key, count=None, withvalues=False): + def hrandfield( + self, + key: str, + count: int = None, + withvalues: bool = False, + ) -> ResponseT: """ Return a random field from the hash value stored at key. @@ -1557,7 +1832,7 @@ class BasicKeyCommands: return self.execute_command("HRANDFIELD", key, *params) - def randomkey(self, **kwargs): + def randomkey(self, **kwargs) -> ResponseT: """ Returns the name of a random key @@ -1565,7 +1840,7 @@ class BasicKeyCommands: """ return self.execute_command("RANDOMKEY", **kwargs) - def rename(self, src, dst): + def rename(self, src: KeyT, dst: KeyT) -> ResponseT: """ Rename key ``src`` to ``dst`` @@ -1573,7 +1848,7 @@ class BasicKeyCommands: """ return self.execute_command("RENAME", src, dst) - def renamenx(self, src, dst): + def renamenx(self, src: KeyT, dst: KeyT): """ Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist @@ -1583,14 +1858,14 @@ class BasicKeyCommands: def restore( self, - name, - ttl, - value, - replace=False, - absttl=False, - idletime=None, - frequency=None, - ): + name: KeyT, + ttl: float, + value: EncodableT, + replace: bool = False, + absttl: bool = False, + idletime: Union[int, None] = None, + frequency: Union[int, None] = None, + ) -> ResponseT: """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -1633,17 +1908,17 @@ class BasicKeyCommands: def set( self, - name, - value, - ex=None, - px=None, - nx=False, - xx=False, - keepttl=False, - get=False, - exat=None, - pxat=None, - ): + name: KeyT, + value: EncodableT, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, + ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -1672,7 +1947,7 @@ class BasicKeyCommands: For more information check https://redis.io/commands/set """ - pieces = [name, value] + pieces: list[EncodableT] = [name, value] options = {} if ex is not None: pieces.append("EX") @@ -1716,10 +1991,10 @@ class BasicKeyCommands: return self.execute_command("SET", *pieces, **options) - def __setitem__(self, name, value): + def __setitem__(self, name: KeyT, value: EncodableT): self.set(name, value) - def setbit(self, name, offset, value): + def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: """ Flag the ``offset`` in ``name`` as ``value``. Returns a boolean indicating the previous value of ``offset``. @@ -1729,7 +2004,7 @@ class BasicKeyCommands: value = value and 1 or 0 return self.execute_command("SETBIT", name, offset, value) - def setex(self, name, time, value): + def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: """ Set the value of key ``name`` to ``value`` that expires in ``time`` seconds. ``time`` can be represented by an integer or a Python @@ -1741,7 +2016,7 @@ class BasicKeyCommands: time = int(time.total_seconds()) return self.execute_command("SETEX", name, time, value) - def setnx(self, name, value): + def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: """ Set the value of key ``name`` to ``value`` if key doesn't exist @@ -1749,7 +2024,12 @@ class BasicKeyCommands: """ return self.execute_command("SETNX", name, value) - def setrange(self, name, offset, value): + def setrange( + self, + name: KeyT, + offset: int, + value: EncodableT, + ) -> ResponseT: """ Overwrite bytes in the value of ``name`` starting at ``offset`` with ``value``. If ``offset`` plus the length of ``value`` exceeds the @@ -1766,16 +2046,16 @@ class BasicKeyCommands: def stralgo( self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Union[int, None] = None, + withmatchlen: bool = False, **kwargs, - ): + ) -> ResponseT: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -1805,7 +2085,7 @@ class BasicKeyCommands: if len and idx: raise DataError("len and idx cannot be provided together.") - pieces = [algo, specific_argument.upper(), value1, value2] + pieces: list[EncodableT] = [algo, specific_argument.upper(), value1, value2] if len: pieces.append(b"LEN") if idx: @@ -1828,7 +2108,7 @@ class BasicKeyCommands: **kwargs, ) - def strlen(self, name): + def strlen(self, name: KeyT) -> ResponseT: """ Return the number of bytes stored in the value of ``name`` @@ -1836,14 +2116,14 @@ class BasicKeyCommands: """ return self.execute_command("STRLEN", name) - def substr(self, name, start, end=-1): + def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ return self.execute_command("SUBSTR", name, start, end) - def touch(self, *args): + def touch(self, *args: KeyT) -> ResponseT: """ Alters the last access time of a key(s) ``*args``. A key is ignored if it does not exist. @@ -1852,7 +2132,7 @@ class BasicKeyCommands: """ return self.execute_command("TOUCH", *args) - def ttl(self, name): + def ttl(self, name: KeyT) -> ResponseT: """ Returns the number of seconds until the key ``name`` will expire @@ -1860,7 +2140,7 @@ class BasicKeyCommands: """ return self.execute_command("TTL", name) - def type(self, name): + def type(self, name: KeyT) -> ResponseT: """ Returns the type of key ``name`` @@ -1868,7 +2148,7 @@ class BasicKeyCommands: """ return self.execute_command("TYPE", name) - def watch(self, *names): + def watch(self, *names: KeyT) -> None: """ Watches the values at keys ``names``, or None if the key doesn't exist @@ -1876,7 +2156,7 @@ class BasicKeyCommands: """ warnings.warn(DeprecationWarning("Call WATCH from a Pipeline object")) - def unwatch(self): + def unwatch(self) -> None: """ Unwatches the value at key ``name``, or None of the key doesn't exist @@ -1884,7 +2164,7 @@ class BasicKeyCommands: """ warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) - def unlink(self, *names): + def unlink(self, *names: KeyT) -> ResponseT: """ Unlink one or more keys specified by ``names`` @@ -1922,7 +2202,27 @@ class BasicKeyCommands: return self.execute_command("LCS", *pieces) -class ListCommands: +class AsyncBasicKeyCommands(BasicKeyCommands): + def __delitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class deletion") + + def __contains__(self, name: KeyT): + raise TypeError("Async Redis client does not support class inclusion") + + def __getitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class retrieval") + + def __setitem__(self, name: KeyT, value: EncodableT): + raise TypeError("Async Redis client does not support class assignment") + + async def watch(self, *names: KeyT) -> None: + return super().watch(*names) + + async def unwatch(self) -> None: + return super().unwatch() + + +class ListCommands(CommandsProtocol): """ Redis commands for List data type. see: https://redis.io/topics/data-types#lists @@ -2204,7 +2504,7 @@ class ListCommands: For more information check https://redis.io/commands/lpos """ - pieces = [name, value] + pieces: list[EncodableT] = [name, value] if rank is not None: pieces.extend(["RANK", rank]) @@ -2256,7 +2556,7 @@ class ListCommands: if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = [name] + pieces: list[EncodableT] = [name] if by is not None: pieces.extend([b"BY", by]) if start is not None and num is not None: @@ -2289,13 +2589,23 @@ class ListCommands: return self.execute_command("SORT", *pieces, **options) -class ScanCommands: +AsyncListCommands = ListCommands + + +class ScanCommands(CommandsProtocol): """ Redis SCAN commands. see: https://redis.io/commands/scan """ - def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): + def scan( + self, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> ResponseT: """ Incrementally return lists of key names. Also return a cursor indicating the scan position. @@ -2312,7 +2622,7 @@ class ScanCommands: For more information check https://redis.io/commands/scan """ - pieces = [cursor] + pieces: list[EncodableT] = [cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -2321,7 +2631,13 @@ class ScanCommands: pieces.extend([b"TYPE", _type]) return self.execute_command("SCAN", *pieces, **kwargs) - def scan_iter(self, match=None, count=None, _type=None, **kwargs): + def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> Iterator: """ Make an iterator using the SCAN command so that the client doesn't need to remember the cursor position. @@ -2343,7 +2659,13 @@ class ScanCommands: ) yield from data - def sscan(self, name, cursor=0, match=None, count=None): + def sscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor indicating the scan position. @@ -2354,14 +2676,19 @@ class ScanCommands: For more information check https://redis.io/commands/sscan """ - pieces = [name, cursor] + pieces: list[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) return self.execute_command("SSCAN", *pieces) - def sscan_iter(self, name, match=None, count=None): + def sscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> Iterator: """ Make an iterator using the SSCAN command so that the client doesn't need to remember the cursor position. @@ -2375,7 +2702,13 @@ class ScanCommands: cursor, data = self.sscan(name, cursor=cursor, match=match, count=count) yield from data - def hscan(self, name, cursor=0, match=None, count=None): + def hscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> ResponseT: """ Incrementally return key/value slices in a hash. Also return a cursor indicating the scan position. @@ -2386,14 +2719,19 @@ class ScanCommands: For more information check https://redis.io/commands/hscan """ - pieces = [name, cursor] + pieces: list[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) return self.execute_command("HSCAN", *pieces) - def hscan_iter(self, name, match=None, count=None): + def hscan_iter( + self, + name: str, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> Iterator: """ Make an iterator using the HSCAN command so that the client doesn't need to remember the cursor position. @@ -2407,7 +2745,14 @@ class ScanCommands: cursor, data = self.hscan(name, cursor=cursor, match=match, count=count) yield from data.items() - def zscan(self, name, cursor=0, match=None, count=None, score_cast_func=float): + def zscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: """ Incrementally return lists of elements in a sorted set. Also return a cursor indicating the scan position. @@ -2428,7 +2773,13 @@ class ScanCommands: options = {"score_cast_func": score_cast_func} return self.execute_command("ZSCAN", *pieces, **options) - def zscan_iter(self, name, match=None, count=None, score_cast_func=float): + def zscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> Iterator: """ Make an iterator using the ZSCAN command so that the client doesn't need to remember the cursor position. @@ -2451,7 +2802,111 @@ class ScanCommands: yield from data -class SetCommands: +class AsyncScanCommands(ScanCommands): + async def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> AsyncIterator: + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.scan( + cursor=cursor, match=match, count=count, _type=_type, **kwargs + ) + for d in data: + yield d + + async def sscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> AsyncIterator: + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.sscan( + name, cursor=cursor, match=match, count=count + ) + for d in data: + yield d + + async def hscan_iter( + self, + name: str, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> AsyncIterator: + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.hscan( + name, cursor=cursor, match=match, count=count + ) + for it in data.items(): + yield it + + async def zscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> AsyncIterator: + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) + for d in data: + yield d + + +class SetCommands(CommandsProtocol): """ Redis commands for Set data type. see: https://redis.io/topics/data-types#sets @@ -2612,13 +3067,21 @@ class SetCommands: return self.execute_command("SUNIONSTORE", dest, *args) -class StreamCommands: +AsyncSetCommands = SetCommands + + +class StreamCommands(CommandsProtocol): """ Redis commands for Stream data type. see: https://redis.io/topics/streams-intro """ - def xack(self, name, groupname, *ids): + def xack( + self, + name: KeyT, + groupname: GroupT, + *ids: StreamIdT, + ) -> ResponseT: """ Acknowledges the successful processing of one or more messages. name: name of the stream. @@ -2631,15 +3094,15 @@ class StreamCommands: def xadd( self, - name, - fields, - id="*", - maxlen=None, - approximate=True, - nomkstream=False, - minid=None, - limit=None, - ): + name: KeyT, + fields: Dict[FieldT, EncodableT], + id: StreamIdT = "*", + maxlen: Union[int, None] = None, + approximate: bool = True, + nomkstream: bool = False, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, + ) -> ResponseT: """ Add to a stream. name: name of the stream @@ -2655,7 +3118,7 @@ class StreamCommands: For more information check https://redis.io/commands/xadd """ - pieces = [] + pieces: list[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError( "Only one of ```maxlen``` or ```minid``` " "may be specified" @@ -2686,14 +3149,14 @@ class StreamCommands: def xautoclaim( self, - name, - groupname, - consumername, - min_idle_time, - start_id=0, - count=None, - justid=False, - ): + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + start_id: int = 0, + count: Union[int, None] = None, + justid: bool = False, + ) -> ResponseT: """ Transfers ownership of pending stream entries that match the specified criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, @@ -2737,17 +3200,17 @@ class StreamCommands: def xclaim( self, - name, - groupname, - consumername, - min_idle_time, - message_ids, - idle=None, - time=None, - retrycount=None, - force=False, - justid=False, - ): + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + message_ids: [List[StreamIdT], Tuple[StreamIdT]], + idle: Union[int, None] = None, + time: Union[int, None] = None, + retrycount: Union[int, None] = None, + force: bool = False, + justid: bool = False, + ) -> ResponseT: """ Changes the ownership of a pending message. name: name of the stream. @@ -2781,7 +3244,7 @@ class StreamCommands: ) kwargs = {} - pieces = [name, groupname, consumername, str(min_idle_time)] + pieces: list[EncodableT] = [name, groupname, consumername, str(min_idle_time)] pieces.extend(list(message_ids)) if idle is not None: @@ -2808,7 +3271,7 @@ class StreamCommands: kwargs["parse_justid"] = True return self.execute_command("XCLAIM", *pieces, **kwargs) - def xdel(self, name, *ids): + def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: """ Deletes one or more messages from a stream. name: name of the stream. @@ -2818,7 +3281,13 @@ class StreamCommands: """ return self.execute_command("XDEL", name, *ids) - def xgroup_create(self, name, groupname, id="$", mkstream=False): + def xgroup_create( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT = "$", + mkstream: bool = False, + ) -> ResponseT: """ Create a new consumer group associated with a stream. name: name of the stream. @@ -2827,12 +3296,17 @@ class StreamCommands: For more information check https://redis.io/commands/xgroup-create """ - pieces = ["XGROUP CREATE", name, groupname, id] + pieces: list[EncodableT] = ["XGROUP CREATE", name, groupname, id] if mkstream: pieces.append(b"MKSTREAM") return self.execute_command(*pieces) - def xgroup_delconsumer(self, name, groupname, consumername): + def xgroup_delconsumer( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + ) -> ResponseT: """ Remove a specific consumer from a consumer group. Returns the number of pending messages that the consumer had before it @@ -2845,7 +3319,7 @@ class StreamCommands: """ return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) - def xgroup_destroy(self, name, groupname): + def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Destroy a consumer group. name: name of the stream. @@ -2855,7 +3329,12 @@ class StreamCommands: """ return self.execute_command("XGROUP DESTROY", name, groupname) - def xgroup_createconsumer(self, name, groupname, consumername): + def xgroup_createconsumer( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + ) -> ResponseT: """ Consumers in a consumer group are auto-created every time a new consumer name is mentioned by some command. @@ -2870,7 +3349,12 @@ class StreamCommands: "XGROUP CREATECONSUMER", name, groupname, consumername ) - def xgroup_setid(self, name, groupname, id): + def xgroup_setid( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT, + ) -> ResponseT: """ Set the consumer group last delivered ID to something else. name: name of the stream. @@ -2881,7 +3365,7 @@ class StreamCommands: """ return self.execute_command("XGROUP SETID", name, groupname, id) - def xinfo_consumers(self, name, groupname): + def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Returns general information about the consumers in the group. name: name of the stream. @@ -2891,7 +3375,7 @@ class StreamCommands: """ return self.execute_command("XINFO CONSUMERS", name, groupname) - def xinfo_groups(self, name): + def xinfo_groups(self, name: KeyT) -> ResponseT: """ Returns general information about the consumer groups of the stream. name: name of the stream. @@ -2900,7 +3384,7 @@ class StreamCommands: """ return self.execute_command("XINFO GROUPS", name) - def xinfo_stream(self, name, full=False): + def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: """ Returns general information about the stream. name: name of the stream. @@ -2915,7 +3399,7 @@ class StreamCommands: options = {"full": full} return self.execute_command("XINFO STREAM", *pieces, **options) - def xlen(self, name): + def xlen(self, name: KeyT) -> ResponseT: """ Returns the number of elements in a given stream. @@ -2923,7 +3407,7 @@ class StreamCommands: """ return self.execute_command("XLEN", name) - def xpending(self, name, groupname): + def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Returns information about pending messages of a group. name: name of the stream. @@ -2935,14 +3419,14 @@ class StreamCommands: def xpending_range( self, - name, - groupname, - idle=None, - min=None, - max=None, - count=None, - consumername=None, - ): + name: KeyT, + groupname: GroupT, + min: StreamIdT, + max: StreamIdT, + count: int, + consumername: Union[ConsumerT, None] = None, + idle: Union[int, None] = None, + ) -> ResponseT: """ Returns information about pending messages, in a range. @@ -2990,7 +3474,13 @@ class StreamCommands: return self.execute_command("XPENDING", *pieces, parse_detail=True) - def xrange(self, name, min="-", max="+", count=None): + def xrange( + self, + name: KeyT, + min: StreamIdT = "-", + max: StreamIdT = "+", + count: Union[int, None] = None, + ) -> ResponseT: """ Read stream values within an interval. name: name of the stream. @@ -3012,7 +3502,12 @@ class StreamCommands: return self.execute_command("XRANGE", name, *pieces) - def xread(self, streams, count=None, block=None): + def xread( + self, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, + ) -> ResponseT: """ Block and monitor multiple streams for new data. streams: a dict of stream names to stream IDs, where @@ -3043,8 +3538,14 @@ class StreamCommands: return self.execute_command("XREAD", *pieces) def xreadgroup( - self, groupname, consumername, streams, count=None, block=None, noack=False - ): + self, + groupname: str, + consumername: str, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, + noack: bool = False, + ) -> ResponseT: """ Read from a stream via a consumer group. groupname: name of the consumer group. @@ -3058,7 +3559,7 @@ class StreamCommands: For more information check https://redis.io/commands/xreadgroup """ - pieces = [b"GROUP", groupname, consumername] + pieces: list[EncodableT] = [b"GROUP", groupname, consumername] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREADGROUP count must be a positive integer") @@ -3078,7 +3579,13 @@ class StreamCommands: pieces.extend(streams.values()) return self.execute_command("XREADGROUP", *pieces) - def xrevrange(self, name, max="+", min="-", count=None): + def xrevrange( + self, + name: KeyT, + max: StreamIdT = "+", + min: StreamIdT = "-", + count: Union[int, None] = None, + ) -> ResponseT: """ Read stream values within an interval, in reverse order. name: name of the stream @@ -3091,7 +3598,7 @@ class StreamCommands: For more information check https://redis.io/commands/xrevrange """ - pieces = [max, min] + pieces: list[EncodableT] = [max, min] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREVRANGE count must be a positive integer") @@ -3100,7 +3607,14 @@ class StreamCommands: return self.execute_command("XREVRANGE", name, *pieces) - def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): + def xtrim( + self, + name: KeyT, + maxlen: int, + approximate: bool = True, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, + ) -> ResponseT: """ Trims old messages from a stream. name: name of the stream. @@ -3113,7 +3627,7 @@ class StreamCommands: For more information check https://redis.io/commands/xtrim """ - pieces = [] + pieces: list[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ``maxlen`` or ``minid`` " "may be specified") @@ -3134,15 +3648,26 @@ class StreamCommands: return self.execute_command("XTRIM", name, *pieces) -class SortedSetCommands: +AsyncStreamCommands = StreamCommands + + +class SortedSetCommands(CommandsProtocol): """ Redis commands for Sorted Sets data type. see: https://redis.io/topics/data-types-intro#redis-sorted-sets """ def zadd( - self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None - ): + self, + name: KeyT, + mapping: Mapping[AnyKeyT, EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = None, + lt: bool = None, + ) -> ResponseT: """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -3188,7 +3713,7 @@ class SortedSetCommands: if nx is True and (gt is not None or lt is not None): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") - pieces = [] + pieces: list[EncodableT] = [] options = {} if nx: pieces.append(b"NX") @@ -3208,7 +3733,7 @@ class SortedSetCommands: pieces.append(pair[0]) return self.execute_command("ZADD", name, *pieces, **options) - def zcard(self, name): + def zcard(self, name: KeyT): """ Return the number of elements in the sorted set ``name`` @@ -3216,7 +3741,7 @@ class SortedSetCommands: """ return self.execute_command("ZCARD", name) - def zcount(self, name, min, max): + def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ Returns the number of elements in the sorted set at key ``name`` with a score between ``min`` and ``max``. @@ -3225,7 +3750,7 @@ class SortedSetCommands: """ return self.execute_command("ZCOUNT", name, min, max) - def zdiff(self, keys, withscores=False): + def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ Returns the difference between the first and all successive input sorted sets provided in ``keys``. @@ -3237,7 +3762,7 @@ class SortedSetCommands: pieces.append("WITHSCORES") return self.execute_command("ZDIFF", *pieces) - def zdiffstore(self, dest, keys): + def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ Computes the difference between the first and all successive input sorted sets provided in ``keys`` and stores the result in ``dest``. @@ -3247,7 +3772,12 @@ class SortedSetCommands: pieces = [len(keys), *keys] return self.execute_command("ZDIFFSTORE", dest, *pieces) - def zincrby(self, name, amount, value): + def zincrby( + self, + name: KeyT, + amount: float, + value: EncodableT, + ) -> ResponseT: """ Increment the score of ``value`` in sorted set ``name`` by ``amount`` @@ -3255,7 +3785,12 @@ class SortedSetCommands: """ return self.execute_command("ZINCRBY", name, amount, value) - def zinter(self, keys, aggregate=None, withscores=False): + def zinter( + self, + keys: KeysT, + aggregate: Union[str, None] = None, + withscores: bool = False, + ) -> ResponseT: """ Return the intersect of multiple sorted sets specified by ``keys``. With the ``aggregate`` option, it is possible to specify how the @@ -3269,7 +3804,12 @@ class SortedSetCommands: """ return self._zaggregate("ZINTER", None, keys, aggregate, withscores=withscores) - def zinterstore(self, dest, keys, aggregate=None): + def zinterstore( + self, + dest: KeyT, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + ) -> ResponseT: """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be aggregated @@ -3305,7 +3845,11 @@ class SortedSetCommands: """ return self.execute_command("ZLEXCOUNT", name, min, max) - def zpopmax(self, name, count=None): + def zpopmax( + self, + name: KeyT, + count: Union[int, None] = None, + ) -> ResponseT: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -3316,7 +3860,11 @@ class SortedSetCommands: options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name, count=None): + def zpopmin( + self, + name: KeyT, + count: Union[int, None] = None, + ) -> ResponseT: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -3327,7 +3875,12 @@ class SortedSetCommands: options = {"withscores": True} return self.execute_command("ZPOPMIN", name, *args, **options) - def zrandmember(self, key, count=None, withscores=False): + def zrandmember( + self, + key: KeyT, + count: int = None, + withscores: bool = False, + ) -> ResponseT: """ Return a random element from the sorted set value stored at key. @@ -3351,7 +3904,7 @@ class SortedSetCommands: return self.execute_command("ZRANDMEMBER", key, *params) - def bzpopmax(self, keys, timeout=0): + def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ ZPOPMAX a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3370,7 +3923,7 @@ class SortedSetCommands: keys.append(timeout) return self.execute_command("BZPOPMAX", *keys) - def bzpopmin(self, keys, timeout=0): + def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ ZPOPMIN a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3385,7 +3938,7 @@ class SortedSetCommands: """ if timeout is None: timeout = 0 - keys = list_or_args(keys, None) + keys: list[EncodableT] = list_or_args(keys, None) keys.append(timeout) return self.execute_command("BZPOPMIN", *keys) @@ -3449,18 +4002,18 @@ class SortedSetCommands: def _zrange( self, command, - dest, - name, - start, - end, - desc=False, - byscore=False, - bylex=False, - withscores=False, - score_cast_func=float, - offset=None, - num=None, - ): + dest: Union[KeyT, None], + name: KeyT, + start: int, + end: int, + desc: bool = False, + byscore: bool = False, + bylex: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable, None] = float, + offset: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: if byscore and bylex: raise DataError( "``byscore`` and ``bylex`` can not be " "specified together." @@ -3490,17 +4043,17 @@ class SortedSetCommands: def zrange( self, - name, - start, - end, - desc=False, - withscores=False, - score_cast_func=float, - byscore=False, - bylex=False, - offset=None, - num=None, - ): + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + byscore: bool = False, + bylex: bool = False, + offset: int = None, + num: int = None, + ) -> ResponseT: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3549,7 +4102,14 @@ class SortedSetCommands: num, ) - def zrevrange(self, name, start, end, withscores=False, score_cast_func=float): + def zrevrange( + self, + name: KeyT, + start: int, + end: int, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in descending order. @@ -3571,16 +4131,16 @@ class SortedSetCommands: def zrangestore( self, - dest, - name, - start, - end, - byscore=False, - bylex=False, - desc=False, - offset=None, - num=None, - ): + dest: KeyT, + name: KeyT, + start: int, + end: int, + byscore: bool = False, + bylex: bool = False, + desc: bool = False, + offset: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: """ Stores in ``dest`` the result of a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3619,7 +4179,14 @@ class SortedSetCommands: num, ) - def zrangebylex(self, name, min, max, start=None, num=None): + def zrangebylex( + self, + name: KeyT, + min: EncodableT, + max: EncodableT, + start: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` between ``min`` and ``max``. @@ -3636,7 +4203,14 @@ class SortedSetCommands: pieces.extend([b"LIMIT", start, num]) return self.execute_command(*pieces) - def zrevrangebylex(self, name, max, min, start=None, num=None): + def zrevrangebylex( + self, + name: KeyT, + max: EncodableT, + min: EncodableT, + start: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set ``name`` between ``max`` and ``min``. @@ -3655,14 +4229,14 @@ class SortedSetCommands: def zrangebyscore( self, - name, - min, - max, - start=None, - num=None, - withscores=False, - score_cast_func=float, - ): + name: KeyT, + min: ZScoreBoundT, + max: ZScoreBoundT, + start: Union[int, None] = None, + num: Union[int, None] = None, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max``. @@ -3689,13 +4263,13 @@ class SortedSetCommands: def zrevrangebyscore( self, - name, - max, - min, - start=None, - num=None, - withscores=False, - score_cast_func=float, + name: KeyT, + max: ZScoreBoundT, + min: ZScoreBoundT, + start: Union[int, None] = None, + num: Union[int, None] = None, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, ): """ Return a range of values from the sorted set ``name`` with scores @@ -3721,7 +4295,7 @@ class SortedSetCommands: options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrank(self, name, value): + def zrank(self, name: KeyT, value: EncodableT) -> ResponseT: """ Returns a 0-based value indicating the rank of ``value`` in sorted set ``name`` @@ -3730,7 +4304,7 @@ class SortedSetCommands: """ return self.execute_command("ZRANK", name, value) - def zrem(self, name, *values): + def zrem(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Remove member ``values`` from sorted set ``name`` @@ -3738,7 +4312,7 @@ class SortedSetCommands: """ return self.execute_command("ZREM", name, *values) - def zremrangebylex(self, name, min, max): + def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> ResponseT: """ Remove all elements in the sorted set ``name`` between the lexicographical range specified by ``min`` and ``max``. @@ -3749,7 +4323,7 @@ class SortedSetCommands: """ return self.execute_command("ZREMRANGEBYLEX", name, min, max) - def zremrangebyrank(self, name, min, max): + def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: """ Remove all elements in the sorted set ``name`` with ranks between ``min`` and ``max``. Values are 0-based, ordered from smallest score @@ -3760,7 +4334,9 @@ class SortedSetCommands: """ return self.execute_command("ZREMRANGEBYRANK", name, min, max) - def zremrangebyscore(self, name, min, max): + def zremrangebyscore( + self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT + ) -> ResponseT: """ Remove all elements in the sorted set ``name`` with scores between ``min`` and ``max``. Returns the number of elements removed. @@ -3769,7 +4345,7 @@ class SortedSetCommands: """ return self.execute_command("ZREMRANGEBYSCORE", name, min, max) - def zrevrank(self, name, value): + def zrevrank(self, name: KeyT, value: EncodableT) -> ResponseT: """ Returns a 0-based value indicating the descending rank of ``value`` in sorted set ``name`` @@ -3778,7 +4354,7 @@ class SortedSetCommands: """ return self.execute_command("ZREVRANK", name, value) - def zscore(self, name, value): + def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ Return the score of element ``value`` in sorted set ``name`` @@ -3786,7 +4362,12 @@ class SortedSetCommands: """ return self.execute_command("ZSCORE", name, value) - def zunion(self, keys, aggregate=None, withscores=False): + def zunion( + self, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + withscores: bool = False, + ) -> ResponseT: """ Return the union of multiple sorted sets specified by ``keys``. ``keys`` can be provided as dictionary of keys and their weights. @@ -3797,7 +4378,12 @@ class SortedSetCommands: """ return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores) - def zunionstore(self, dest, keys, aggregate=None): + def zunionstore( + self, + dest: KeyT, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -3807,7 +4393,11 @@ class SortedSetCommands: """ return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) - def zmscore(self, key, members): + def zmscore( + self, + key: KeyT, + members: List[str], + ) -> ResponseT: """ Returns the scores associated with the specified members in the sorted set stored at key. @@ -3823,8 +4413,15 @@ class SortedSetCommands: pieces = [key] + members return self.execute_command("ZMSCORE", *pieces) - def _zaggregate(self, command, dest, keys, aggregate=None, **options): - pieces = [command] + def _zaggregate( + self, + command: str, + dest: Union[KeyT, None], + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + **options, + ) -> ResponseT: + pieces: list[EncodableT] = [command] if dest is not None: pieces.append(dest) pieces.append(len(keys)) @@ -3847,13 +4444,16 @@ class SortedSetCommands: return self.execute_command(*pieces, **options) -class HyperlogCommands: +AsyncSortedSetCommands = SortedSetCommands + + +class HyperlogCommands(CommandsProtocol): """ Redis commands of HyperLogLogs data type. see: https://redis.io/topics/data-types-intro#hyperloglogs """ - def pfadd(self, name, *values): + def pfadd(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Adds the specified elements to the specified HyperLogLog. @@ -3861,7 +4461,7 @@ class HyperlogCommands: """ return self.execute_command("PFADD", name, *values) - def pfcount(self, *sources): + def pfcount(self, *sources: KeyT) -> ResponseT: """ Return the approximated cardinality of the set observed by the HyperLogLog at key(s). @@ -3870,7 +4470,7 @@ class HyperlogCommands: """ return self.execute_command("PFCOUNT", *sources) - def pfmerge(self, dest, *sources): + def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: """ Merge N different HyperLogLogs into a single one. @@ -3879,7 +4479,10 @@ class HyperlogCommands: return self.execute_command("PFMERGE", dest, *sources) -class HashCommands: +AsyncHyperlogCommands = HyperlogCommands + + +class HashCommands(CommandsProtocol): """ Redis commands for Hash data type. see: https://redis.io/topics/data-types-intro#redis-hashes @@ -4031,13 +4634,106 @@ class HashCommands: return self.execute_command("HSTRLEN", name, key) -class PubSubCommands: +AsyncHashCommands = HashCommands + + +class Script: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "Redis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["Redis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + +class AsyncScript: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + async def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["AsyncRedis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.asyncio.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return await client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = await client.script_load(self.script) + return await client.evalsha(self.sha, len(keys), *args) + + +class PubSubCommands(CommandsProtocol): """ Redis PubSub commands. see https://redis.io/topics/pubsub """ - def publish(self, channel, message, **kwargs): + def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT: """ Publish ``message`` on ``channel``. Returns the number of subscribers the message was delivered to. @@ -4046,7 +4742,7 @@ class PubSubCommands: """ return self.execute_command("PUBLISH", channel, message, **kwargs) - def pubsub_channels(self, pattern="*", **kwargs): + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber @@ -4054,7 +4750,7 @@ class PubSubCommands: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) - def pubsub_numpat(self, **kwargs): + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -4062,7 +4758,7 @@ class PubSubCommands: """ return self.execute_command("PUBSUB NUMPAT", **kwargs) - def pubsub_numsub(self, *args, **kwargs): + def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ Return a list of (channel, number of subscribers) tuples for each channel given in ``*args`` @@ -4072,7 +4768,10 @@ class PubSubCommands: return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) -class ScriptCommands: +AsyncPubSubCommands = PubSubCommands + + +class ScriptCommands(CommandsProtocol): """ Redis Lua script commands. see: https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ @@ -4140,7 +4839,7 @@ class ScriptCommands: """ return self._evalsha("EVALSHA_RO", sha, numkeys, *keys_and_args) - def script_exists(self, *args): + def script_exists(self, *args: str): """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if @@ -4150,12 +4849,14 @@ class ScriptCommands: """ return self.execute_command("SCRIPT EXISTS", *args) - def script_debug(self, *args): + def script_debug(self, *args) -> None: raise NotImplementedError( "SCRIPT DEBUG is intentionally not implemented in the client." ) - def script_flush(self, sync_type=None): + def script_flush( + self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None + ) -> ResponseT: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. @@ -4175,7 +4876,7 @@ class ScriptCommands: pieces = [sync_type] return self.execute_command("SCRIPT FLUSH", *pieces) - def script_kill(self): + def script_kill(self) -> ResponseT: """ Kill the currently executing Lua script @@ -4183,7 +4884,7 @@ class ScriptCommands: """ return self.execute_command("SCRIPT KILL") - def script_load(self, script): + def script_load(self, script: ScriptTextT) -> ResponseT: """ Load a Lua ``script`` into the script cache. Returns the SHA. @@ -4191,7 +4892,7 @@ class ScriptCommands: """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self, script): + def register_script(self: "Redis", script: ScriptTextT) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -4201,13 +4902,34 @@ class ScriptCommands: return Script(self, script) -class GeoCommands: +class AsyncScriptCommands(ScriptCommands): + async def script_debug(self, *args) -> None: + return super().script_debug() + + def register_script(self: "AsyncRedis", script: ScriptTextT) -> AsyncScript: + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return AsyncScript(self, script) + + +class GeoCommands(CommandsProtocol): """ Redis Geospatial commands. see: https://redis.com/redis-best-practices/indexing-patterns/geospatial/ """ - def geoadd(self, name, values, nx=False, xx=False, ch=False): + def geoadd( + self, + name: KeyT, + values: Sequence[EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + ) -> ResponseT: """ Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered @@ -4242,7 +4964,13 @@ class GeoCommands: pieces.extend(values) return self.execute_command("GEOADD", *pieces) - def geodist(self, name, place1, place2, unit=None): + def geodist( + self, + name: KeyT, + place1: FieldT, + place2: FieldT, + unit: Union[str, None] = None, + ) -> ResponseT: """ Return the distance between ``place1`` and ``place2`` members of the ``name`` key. @@ -4251,14 +4979,14 @@ class GeoCommands: For more information check https://redis.io/commands/geodist """ - pieces = [name, place1, place2] + pieces: list[EncodableT] = [name, place1, place2] if unit and unit not in ("m", "km", "mi", "ft"): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) return self.execute_command("GEODIST", *pieces) - def geohash(self, name, *values): + def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ Return the geo hash string for each item of ``values`` members of the specified key identified by the ``name`` argument. @@ -4267,7 +4995,7 @@ class GeoCommands: """ return self.execute_command("GEOHASH", name, *values) - def geopos(self, name, *values): + def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name`` argument. Each position @@ -4279,20 +5007,20 @@ class GeoCommands: def georadius( self, - name, - longitude, - latitude, - radius, - unit=None, - withdist=False, - withcoord=False, - withhash=False, - count=None, - sort=None, - store=None, - store_dist=None, - any=False, - ): + name: KeyT, + longitude: float, + latitude: float, + radius: float, + unit: Union[str, None] = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, + any: bool = False, + ) -> ResponseT: """ Return the members of the specified key identified by the ``name`` argument which are within the borders of the area specified @@ -4342,19 +5070,19 @@ class GeoCommands: def georadiusbymember( self, - name, - member, - radius, - unit=None, - withdist=False, - withcoord=False, - withhash=False, - count=None, - sort=None, - store=None, - store_dist=None, - any=False, - ): + name: KeyT, + member: FieldT, + radius: float, + unit: Union[str, None] = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, + any: bool = False, + ) -> ResponseT: """ This command is exactly like ``georadius`` with the sole difference that instead of taking, as the center of the area to query, a longitude @@ -4379,7 +5107,12 @@ class GeoCommands: any=any, ) - def _georadiusgeneric(self, command, *args, **kwargs): + def _georadiusgeneric( + self, + command: str, + *args: EncodableT, + **kwargs: Union[EncodableT, None], + ) -> ResponseT: pieces = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") @@ -4427,21 +5160,21 @@ class GeoCommands: def geosearch( self, - name, - member=None, - longitude=None, - latitude=None, - unit="m", - radius=None, - width=None, - height=None, - sort=None, - count=None, - any=False, - withcoord=False, - withdist=False, - withhash=False, - ): + name: KeyT, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, + unit: str = "m", + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, + any: bool = False, + withcoord: bool = False, + withdist: bool = False, + withhash: bool = False, + ) -> ResponseT: """ Return the members of specified key identified by the ``name`` argument, which are within the borders of the @@ -4499,20 +5232,20 @@ class GeoCommands: def geosearchstore( self, - dest, - name, - member=None, - longitude=None, - latitude=None, - unit="m", - radius=None, - width=None, - height=None, - sort=None, - count=None, - any=False, - storedist=False, - ): + dest: KeyT, + name: KeyT, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, + unit: str = "m", + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, + any: bool = False, + storedist: bool = False, + ) -> ResponseT: """ This command is like GEOSEARCH, but stores the result in ``dest``. By default, it stores the results in the destination @@ -4544,7 +5277,12 @@ class GeoCommands: store_dist=storedist, ) - def _geosearchgeneric(self, command, *args, **kwargs): + def _geosearchgeneric( + self, + command: str, + *args: EncodableT, + **kwargs: Union[EncodableT, None], + ) -> ResponseT: pieces = list(args) # FROMMEMBER or FROMLONLAT @@ -4609,13 +5347,16 @@ class GeoCommands: return self.execute_command(command, *pieces, **kwargs) -class ModuleCommands: +AsyncGeoCommands = GeoCommands + + +class ModuleCommands(CommandsProtocol): """ Redis Module commands. see: https://redis.io/topics/modules-intro """ - def module_load(self, path, *args): + def module_load(self, path, *args) -> ResponseT: """ Loads the module from ``path``. Passes all ``*args`` to the module, during loading. @@ -4625,7 +5366,7 @@ class ModuleCommands: """ return self.execute_command("MODULE LOAD", path, *args) - def module_unload(self, name): + def module_unload(self, name) -> ResponseT: """ Unloads the module ``name``. Raises ``ModuleError`` if ``name`` is not in loaded modules. @@ -4634,7 +5375,7 @@ class ModuleCommands: """ return self.execute_command("MODULE UNLOAD", name) - def module_list(self): + def module_list(self) -> ResponseT: """ Returns a list of dictionaries containing the name and version of all loaded modules. @@ -4643,166 +5384,35 @@ class ModuleCommands: """ return self.execute_command("MODULE LIST") - def command_info(self): + def command_info(self) -> None: raise NotImplementedError( "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self): + def command_count(self) -> ResponseT: return self.execute_command("COMMAND COUNT") - def command_getkeys(self, *args): + def command_getkeys(self, *args) -> ResponseT: return self.execute_command("COMMAND GETKEYS", *args) - def command(self): + def command(self) -> ResponseT: return self.execute_command("COMMAND") -class Script: - """ - An executable Lua script object returned by ``register_script`` - """ - - def __init__(self, registered_client, script): - self.registered_client = registered_client - self.script = script - # Precalculate and store the SHA1 hex digest of the script. - - if isinstance(script, str): - # We need the encoding from the client in order to generate an - # accurate byte representation of the script - encoder = registered_client.connection_pool.get_encoder() - script = encoder.encode(script) - self.sha = hashlib.sha1(script).hexdigest() - - def __call__(self, keys=[], args=[], client=None): - "Execute the script, passing any required ``args``" - if client is None: - client = self.registered_client - args = tuple(keys) + tuple(args) - # make sure the Redis server knows about the script - from redis.client import Pipeline - - if isinstance(client, Pipeline): - # Make sure the pipeline can register the script before executing. - client.scripts.add(self) - try: - return client.evalsha(self.sha, len(keys), *args) - except NoScriptError: - # Maybe the client is pointed to a different server than the client - # that created this instance? - # Overwrite the sha just in case there was a discrepancy. - self.sha = client.script_load(self.script) - return client.evalsha(self.sha, len(keys), *args) +class AsyncModuleCommands(ModuleCommands): + async def command_info(self) -> None: + return super().command_info() -class BitFieldOperation: - """ - Command builder for BITFIELD commands. - """ - - def __init__(self, client, key, default_overflow=None): - self.client = client - self.key = key - self._default_overflow = default_overflow - self.reset() - - def reset(self): - """ - Reset the state of the instance to when it was constructed - """ - self.operations = [] - self._last_overflow = "WRAP" - self.overflow(self._default_overflow or self._last_overflow) - - def overflow(self, overflow): - """ - Update the overflow algorithm of successive INCRBY operations - :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the - Redis docs for descriptions of these algorithmsself. - :returns: a :py:class:`BitFieldOperation` instance. - """ - overflow = overflow.upper() - if overflow != self._last_overflow: - self._last_overflow = overflow - self.operations.append(("OVERFLOW", overflow)) - return self - - def incrby(self, fmt, offset, increment, overflow=None): - """ - Increment a bitfield by a given amount. - :param fmt: format-string for the bitfield being updated, e.g. 'u8' - for an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :param int increment: value to increment the bitfield by. - :param str overflow: overflow algorithm. Defaults to WRAP, but other - acceptable values are SAT and FAIL. See the Redis docs for - descriptions of these algorithms. - :returns: a :py:class:`BitFieldOperation` instance. - """ - if overflow is not None: - self.overflow(overflow) - - self.operations.append(("INCRBY", fmt, offset, increment)) - return self - - def get(self, fmt, offset): - """ - Get the value of a given bitfield. - :param fmt: format-string for the bitfield being read, e.g. 'u8' for - an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :returns: a :py:class:`BitFieldOperation` instance. - """ - self.operations.append(("GET", fmt, offset)) - return self - - def set(self, fmt, offset, value): - """ - Set the value of a given bitfield. - :param fmt: format-string for the bitfield being read, e.g. 'u8' for - an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :param int value: value to set at the given position. - :returns: a :py:class:`BitFieldOperation` instance. - """ - self.operations.append(("SET", fmt, offset, value)) - return self - - @property - def command(self): - cmd = ["BITFIELD", self.key] - for ops in self.operations: - cmd.extend(ops) - return cmd - - def execute(self): - """ - Execute the operation(s) in a single BITFIELD command. The return value - is a list of values corresponding to each operation. If the client - used to create this instance was a pipeline, the list of values - will be present within the pipeline's execute. - """ - command = self.command - self.reset() - return self.client.execute_command(*command) - - -class ClusterCommands: +class ClusterCommands(CommandsProtocol): """ Class for Redis Cluster commands """ - def cluster(self, cluster_arg, *args, **kwargs): + def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT: return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args, **kwargs) - def readwrite(self, **kwargs): + def readwrite(self, **kwargs) -> ResponseT: """ Disables read queries for a connection to a Redis Cluster slave node. @@ -4810,7 +5420,7 @@ class ClusterCommands: """ return self.execute_command("READWRITE", **kwargs) - def readonly(self, **kwargs): + def readonly(self, **kwargs) -> ResponseT: """ Enables read queries for a connection to a Redis Cluster replica node. @@ -4819,6 +5429,9 @@ class ClusterCommands: return self.execute_command("READONLY", **kwargs) +AsyncClusterCommands = ClusterCommands + + class DataAccessCommands( BasicKeyCommands, HyperlogCommands, @@ -4832,7 +5445,24 @@ class DataAccessCommands( ): """ A class containing all of the implemented data access redis commands. - This class is to be used as a mixin. + This class is to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncDataAccessCommands( + AsyncBasicKeyCommands, + AsyncHyperlogCommands, + AsyncHashCommands, + AsyncGeoCommands, + AsyncListCommands, + AsyncScanCommands, + AsyncSetCommands, + AsyncStreamCommands, + AsyncSortedSetCommands, +): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin for asynchronous Redis clients. """ @@ -4847,5 +5477,20 @@ class CoreCommands( ): """ A class containing all of the implemented redis commands. This class is - to be used as a mixin. + to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncCoreCommands( + AsyncACLCommands, + AsyncClusterCommands, + AsyncDataAccessCommands, + AsyncManagementCommands, + AsyncModuleCommands, + AsyncPubSubCommands, + AsyncScriptCommands, +): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin for asynchronous Redis clients. """ diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py index a9b06c2..e054ec6 100644 --- a/redis/commands/sentinel.py +++ b/redis/commands/sentinel.py @@ -8,39 +8,39 @@ class SentinelCommands: """ def sentinel(self, *args): - "Redis Sentinel's SENTINEL command." + """Redis Sentinel's SENTINEL command.""" warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) def sentinel_get_master_addr_by_name(self, service_name): - "Returns a (host, port) pair for the given ``service_name``" + """Returns a (host, port) pair for the given ``service_name``""" return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) def sentinel_master(self, service_name): - "Returns a dictionary containing the specified masters state." + """Returns a dictionary containing the specified masters state.""" return self.execute_command("SENTINEL MASTER", service_name) def sentinel_masters(self): - "Returns a list of dictionaries containing each master's state." + """Returns a list of dictionaries containing each master's state.""" return self.execute_command("SENTINEL MASTERS") def sentinel_monitor(self, name, ip, port, quorum): - "Add a new master to Sentinel to be monitored" + """Add a new master to Sentinel to be monitored""" return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) def sentinel_remove(self, name): - "Remove a master from Sentinel's monitoring" + """Remove a master from Sentinel's monitoring""" return self.execute_command("SENTINEL REMOVE", name) def sentinel_sentinels(self, service_name): - "Returns a list of sentinels for ``service_name``" + """Returns a list of sentinels for ``service_name``""" return self.execute_command("SENTINEL SENTINELS", service_name) def sentinel_set(self, name, option, value): - "Set Sentinel monitoring parameters for a given master" + """Set Sentinel monitoring parameters for a given master""" return self.execute_command("SENTINEL SET", name, option, value) def sentinel_slaves(self, service_name): - "Returns a list of slaves for ``service_name``" + """Returns a list of slaves for ``service_name``""" return self.execute_command("SENTINEL SLAVES", service_name) def sentinel_reset(self, pattern): @@ -91,3 +91,9 @@ class SentinelCommands: completely missing. """ return self.execute_command("SENTINEL FLUSHCONFIG") + + +class AsyncSentinelCommands(SentinelCommands): + async def sentinel(self, *args) -> None: + """Redis Sentinel's SENTINEL command.""" + super().sentinel(*args) |