From 7424e6557a3c3ab88c5446e04f39c3748b97cad6 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Mon, 21 Nov 2022 14:37:44 +0200 Subject: Cherry-pick for 4.3.5 (#2441) * fix is_connected (#2278) * fix: workaround asyncio bug on connection reset by peer (#2259) Fixes #2237 * Fix crash: key expire while search (#2270) * fix expire while search * sleep * async_cluster: fix concurrent pipeline (#2280) - each pipeline should create separate stacks for each node * async_cluster: fix max_connections/ssl & improve args (#2217) * async_cluster: fix max_connections/ssl & improve args - set proper connection_class if ssl = True - pass max_connections/connection_class to ClusterNode - recreate startup_nodes to properly initialize - pass parser_class to Connection instead of changing it in on_connect - only pass redis_connect_func if read_from_replicas = True - add connection_error_retry_attempts parameter - skip is_connected check in acquire_connection as it is already checked in send_packed_command BREAKING: - RedisCluster args except host & port are kw-only now - RedisCluster will no longer accept unknown arguments - RedisCluster will no longer accept url as an argument. Use RedisCluster.from_url - RedisCluster.require_full_coverage defaults to True - ClusterNode args except host, port, & server_type are kw-only now * async_cluster: remove kw-only requirement from client Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * More cherry picks to support latest module release * Fix KeyError in async cluster - initialize before execute multi key commands (#2439) * Fix KeyError in async cluster * link to issue * typo * fixing lint * fix json test Co-authored-by: Mehdi ABAAKOUK Co-authored-by: Utkarsh Gupta Co-authored-by: Chayim I. Kirshen --- redis/asyncio/cluster.py | 395 ++++++++++---------- redis/asyncio/connection.py | 13 +- redis/cluster.py | 4 +- redis/commands/bf/__init__.py | 9 +- redis/commands/bf/commands.py | 103 ++++-- redis/commands/bf/info.py | 22 +- redis/commands/cluster.py | 19 + redis/commands/graph/__init__.py | 129 ++++++- redis/commands/graph/commands.py | 146 ++++++-- redis/commands/graph/query_result.py | 395 +++++++++++++++----- redis/commands/json/commands.py | 11 +- redis/commands/search/__init__.py | 2 +- redis/commands/search/aggregation.py | 6 +- redis/commands/search/commands.py | 12 +- redis/commands/search/field.py | 12 +- redis/commands/search/querystring.py | 3 + redis/commands/search/result.py | 2 +- redis/commands/timeseries/commands.py | 668 ++++++++++++++++++++-------------- redis/commands/timeseries/info.py | 18 +- redis/exceptions.py | 4 + redis/utils.py | 28 ++ tests/test_asyncio/test_bloom.py | 140 +++++-- tests/test_asyncio/test_cluster.py | 227 ++++++++++-- tests/test_asyncio/test_graph.py | 503 +++++++++++++++++++++++++ tests/test_asyncio/test_json.py | 4 +- tests/test_asyncio/test_search.py | 283 ++++++-------- tests/test_asyncio/test_timeseries.py | 5 +- tests/test_bloom.py | 143 ++++++-- tests/test_graph.py | 51 ++- tests/test_json.py | 5 +- tests/test_search.py | 349 +++++++----------- tests/test_timeseries.py | 243 ++++++++++++- 32 files changed, 2821 insertions(+), 1133 deletions(-) create mode 100644 tests/test_asyncio/test_graph.py diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 2894004..df0c17d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -17,7 +17,13 @@ from typing import ( ) from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url +from redis.asyncio.connection import ( + Connection, + DefaultParser, + Encoder, + SSLConnection, + parse_url, +) from redis.asyncio.parser import CommandsParser from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( @@ -42,6 +48,7 @@ from redis.exceptions import ( ConnectionError, DataError, MasterDownError, + MaxConnectionsError, MovedError, RedisClusterException, ResponseError, @@ -56,44 +63,17 @@ TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) -CONNECTION_ALLOWED_KEYS = ( - "client_name", - "db", - "decode_responses", - "encoder_class", - "encoding", - "encoding_errors", - "health_check_interval", - "parser_class", - "password", - "redis_connect_func", - "retry", - "retry_on_timeout", - "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_read_size", - "socket_timeout", - "socket_type", - "username", -) - - -def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]: - """Remove unsupported or disabled keys from kwargs.""" - return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS} - class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( DefaultParser.EXCEPTION_CLASSES, { "ASK": AskError, - "TRYAGAIN": TryAgainError, - "MOVED": MovedError, "CLUSTERDOWN": ClusterDownError, "CROSSSLOT": ClusterCrossSlotError, "MASTERDOWN": MasterDownError, + "MOVED": MovedError, + "TRYAGAIN": TryAgainError, }, ) @@ -104,7 +84,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand Pass one of parameters: - - `url` - `host` & `port` - `startup_nodes` @@ -128,9 +107,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | Port used if **host** is provided :param startup_nodes: | :class:`~.ClusterNode` to used as a startup node - :param cluster_error_retry_attempts: - | Retry command execution attempts when encountering :class:`~.ClusterDownError` - or :class:`~.ConnectionError` :param require_full_coverage: | When set to ``False``: the client will not require a full coverage of the slots. However, if not all slots are covered, and at least one node has @@ -141,6 +117,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand thrown. | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters + :param read_from_replicas: + | Enable read from replicas in READONLY mode. You can read possibly stale data. + When set to true, read commands will be assigned between the primary and + its replications in a Round-Robin manner. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -149,23 +129,27 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. - :param read_from_replicas: - | Enable read from replicas in READONLY mode. You can read possibly stale data. - When set to true, read commands will be assigned between the primary and - its replications in a Round-Robin manner. - :param url: - | See :meth:`.from_url` - :param kwargs: - | Extra arguments that will be passed to the - :class:`~redis.asyncio.connection.Connection` instances when created + :param cluster_error_retry_attempts: + | Number of times to retry before raising an error when :class:`~.TimeoutError` + or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered + :param connection_error_retry_attempts: + | Number of times to retry before reinitializing when :class:`~.TimeoutError` + or :class:`~.ConnectionError` are encountered + :param max_connections: + | Maximum number of connections per node. If there are no free connections & the + maximum number of connections are already created, a + :class:`~.MaxConnectionsError` is raised. This error may be retried as defined + by :attr:`connection_error_retry_attempts` + + | Rest of the arguments will be passed to the + :class:`~redis.asyncio.connection.Connection` instances when created :raises RedisClusterException: - if any arguments are invalid. Eg: + if any arguments are invalid or unknown. Eg: - - db kwarg - - db != 0 in url - - unix socket connection - - none of host & url & startup_nodes were provided + - `db` != 0 or None + - `path` argument for unix socket connection + - none of the `host`/`port` & `startup_nodes` were provided """ @@ -178,7 +162,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand redis://[[username]:[password]]@localhost:6379/0 rediss://[[username]:[password]]@localhost:6379/0 - unix://[[username]:[password]]@/path/to/socket.sock?db=0 Three URL schemes are supported: @@ -186,32 +169,22 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand - `rediss://` creates a SSL wrapped TCP socket connection. See more at: - - ``unix://``: creates a Unix Domain Socket connection. - - The username, password, hostname, path and all querystring values - are passed through urllib.parse.unquote in order to replace any - percent-encoded values with their corresponding characters. - - There are several ways to specify a database number. The first value - found will be used: - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument - of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. - - If none of these options are specified, the default db=0 is used. - - All querystring options are cast to their appropriate Python types. - Boolean arguments can be specified with string values "True"/"False" - or "Yes"/"No". Values that cannot be properly cast cause a - ``ValueError`` to be raised. Once parsed, the querystring arguments and - keyword arguments are passed to :class:`~redis.asyncio.connection.Connection` - when created. In the case of conflicting arguments, querystring - arguments always win. + The username, password, hostname, path and all querystring values are passed + through ``urllib.parse.unquote`` in order to replace any percent-encoded values + with their corresponding characters. + All querystring options are cast to their appropriate Python types. Boolean + arguments can be specified with string values "True"/"False" or "Yes"/"No". + Values that cannot be properly cast cause a ``ValueError`` to be raised. Once + parsed, the querystring arguments and keyword arguments are passed to + :class:`~redis.asyncio.connection.Connection` when created. + In the case of conflicting arguments, querystring arguments are used. """ - return cls(url=url, **kwargs) + kwargs.update(parse_url(url)) + if kwargs.pop("connection_class", None) is SSLConnection: + kwargs["ssl"] = True + return cls(**kwargs) __slots__ = ( "_initialize", @@ -219,6 +192,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "cluster_error_retry_attempts", "command_flags", "commands_parser", + "connection_error_retry_attempts", "connection_kwargs", "encoder", "node_flags", @@ -233,87 +207,131 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand def __init__( self, host: Optional[str] = None, - port: int = 6379, + port: Union[str, int] = 6379, + # Cluster related kwargs startup_nodes: Optional[List["ClusterNode"]] = None, - require_full_coverage: bool = False, + require_full_coverage: bool = True, read_from_replicas: bool = False, - cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 10, - url: Optional[str] = None, - **kwargs: Any, + cluster_error_retry_attempts: int = 3, + connection_error_retry_attempts: int = 5, + max_connections: int = 2**31, + # Client related kwargs + db: Union[str, int] = 0, + path: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_name: Optional[str] = None, + # Encoding related kwargs + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + # Connection related kwargs + health_check_interval: float = 0, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_timeout: Optional[float] = None, + # SSL related kwargs + ssl: bool = False, + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_certfile: Optional[str] = None, + ssl_check_hostname: bool = False, + ssl_keyfile: Optional[str] = None, ) -> None: - if not startup_nodes: - startup_nodes = [] + if db: + raise RedisClusterException( + "Argument 'db' must be 0 or None in cluster mode" + ) - if "db" in kwargs: - # Argument 'db' is not possible to use in cluster mode + if path: raise RedisClusterException( - "Argument 'db' is not possible to use in cluster mode" + "Unix domain socket is not supported in cluster mode" ) - # Get the startup node(s) - if url: - url_options = parse_url(url) - if "path" in url_options: - raise RedisClusterException( - "RedisCluster does not currently support Unix Domain " - "Socket connections" - ) - if "db" in url_options and url_options["db"] != 0: - # Argument 'db' is not possible to use in cluster mode - raise RedisClusterException( - "A ``db`` querystring option can only be 0 in cluster mode" - ) - kwargs.update(url_options) - host = kwargs.get("host") - port = kwargs.get("port", port) - elif (not host or not port) and not startup_nodes: - # No startup node was provided + if (not host or not port) and not startup_nodes: raise RedisClusterException( - "RedisCluster requires at least one node to discover the " - "cluster. Please provide one of the followings:\n" - "1. host and port, for example:\n" - " RedisCluster(host='localhost', port=6379)\n" - "2. list of startup nodes, for example:\n" - " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," - " ClusterNode('localhost', 6378)])" + "RedisCluster requires at least one node to discover the cluster.\n" + "Please provide one of the following or use RedisCluster.from_url:\n" + ' - host and port: RedisCluster(host="localhost", port=6379)\n' + " - startup_nodes: RedisCluster(startup_nodes=[" + 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' + ) + + kwargs: Dict[str, Any] = { + "max_connections": max_connections, + "connection_class": Connection, + "parser_class": ClusterParser, + # Client related kwargs + "username": username, + "password": password, + "client_name": client_name, + # Encoding related kwargs + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + # Connection related kwargs + "health_check_interval": health_check_interval, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + "socket_timeout": socket_timeout, + } + + if ssl: + # SSL related kwargs + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_ca_certs": ssl_ca_certs, + "ssl_ca_data": ssl_ca_data, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_certfile": ssl_certfile, + "ssl_check_hostname": ssl_check_hostname, + "ssl_keyfile": ssl_keyfile, + } ) - # Update the connection arguments - # Whenever a new connection is established, RedisCluster's on_connect - # method should be run - kwargs["redis_connect_func"] = self.on_connect - self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) - self.response_callbacks = kwargs[ - "response_callbacks" - ] = self.__class__.RESPONSE_CALLBACKS.copy() + if read_from_replicas: + # Call our on_connect function to configure READONLY mode + kwargs["redis_connect_func"] = self.on_connect + + kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + self.connection_kwargs = kwargs + + if startup_nodes: + passed_nodes = [] + for node in startup_nodes: + passed_nodes.append( + ClusterNode(node.host, node.port, **self.connection_kwargs) + ) + startup_nodes = passed_nodes + else: + startup_nodes = [] if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager( - startup_nodes=startup_nodes, - require_full_coverage=require_full_coverage, - **self.connection_kwargs, - ) - self.encoder = Encoder( - kwargs.get("encoding", "utf-8"), - kwargs.get("encoding_errors", "strict"), - kwargs.get("decode_responses", False), - ) - self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 self.commands_parser = CommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.response_callbacks = kwargs["response_callbacks"] self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() self.result_callbacks[ "CLUSTER SLOTS" ] = lambda cmd, res, **kwargs: parse_cluster_slots( list(res.values())[0], **kwargs ) + self._initialize = True self._lock = asyncio.Lock() @@ -365,18 +383,16 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ... async def on_connect(self, connection: Connection) -> None: - connection.set_parser(ClusterParser) await connection.on_connect() - if self.read_from_replicas: - # Sending READONLY command to server to configure connection as - # readonly. Since each cluster node may change its server type due - # to a failover, we should establish a READONLY connection - # regardless of the server type. If this is a primary connection, - # READONLY would not affect executing write commands. - await connection.send_command("READONLY") - if str_if_bytes(await connection.read_response_without_lock()) != "OK": - raise ConnectionError("READONLY command failed") + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + await connection.send_command("READONLY") + if str_if_bytes(await connection.read_response_without_lock()) != "OK": + raise ConnectionError("READONLY command failed") def get_nodes(self) -> List["ClusterNode"]: """Get all nodes of the cluster.""" @@ -436,12 +452,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand slot_cache = self.nodes_manager.slots_cache.get(slot) if not slot_cache: raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') - if replica and len(self.nodes_manager.slots_cache[slot]) < 2: - return None - elif replica: + + if replica: + if len(self.nodes_manager.slots_cache[slot]) < 2: + return None node_idx = 1 else: - # primary node_idx = 0 return slot_cache[node_idx] @@ -638,14 +654,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand command, dict(zip(keys, values)), **kwargs ) return dict(zip(keys, values)) - except BaseException as e: + except Exception as e: if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e else: # All other errors should be raised. - raise e + raise # If it fails the configured number of times then raise exception back # to caller of this method @@ -678,19 +694,30 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand return await target_node.execute_command(*args, **kwargs) except BusyLoadingError: raise - except (ConnectionError, TimeoutError): - # Give the node 0.25 seconds to get back up and retry again - # with same node and configuration. After 5 attempts then try - # to reinitialize the cluster and see if the nodes - # configuration has changed or not + except (ConnectionError, TimeoutError) as e: + # Give the node 0.25 seconds to get back up and retry again with the + # same node and configuration. After the defined number of attempts, try + # to reinitialize the cluster and try again. connection_error_retry_counter += 1 - if connection_error_retry_counter < 5: + if ( + connection_error_retry_counter + < self.connection_error_retry_attempts + ): await asyncio.sleep(0.25) else: + if isinstance(e, MaxConnectionsError): + raise # Hard force of reinitialize of the node/slots setup # and try again with the new setup await self.close() raise + except ClusterDownError: + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + await self.close() + await asyncio.sleep(0.25) + raise except MovedError as e: # First, we will try to patch the slots/nodes cache with the # redirected node output and try again. If MovedError exceeds @@ -711,19 +738,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand else: self.nodes_manager._moved_exception = e moved = True - except TryAgainError: - if ttl < self.RedisClusterRequestTTL / 2: - await asyncio.sleep(0.05) except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except ClusterDownError: - # ClusterDownError can occur during a failover and to get - # self-healed, we will try to reinitialize the cluster layout - # and retry executing the command - await asyncio.sleep(0.25) - await self.close() - raise + except TryAgainError: + if ttl < self.RedisClusterRequestTTL / 2: + await asyncio.sleep(0.05) raise ClusterError("TTL exhausted.") @@ -755,7 +775,6 @@ class ClusterNode: """ __slots__ = ( - "_command_stack", "_connections", "_free", "connection_class", @@ -771,8 +790,9 @@ class ClusterNode: def __init__( self, host: str, - port: int, + port: Union[str, int], server_type: Optional[str] = None, + *, max_connections: int = 2**31, connection_class: Type[Connection] = Connection, **connection_kwargs: Any, @@ -790,13 +810,10 @@ class ClusterNode: self.max_connections = max_connections self.connection_class = connection_class self.connection_kwargs = connection_kwargs - self.response_callbacks = connection_kwargs.pop( - "response_callbacks", RedisCluster.RESPONSE_CALLBACKS - ) + self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) self._connections: List[Connection] = [] self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) - self._command_stack: List["PipelineCommand"] = [] def __repr__(self) -> str: return ( @@ -836,21 +853,15 @@ class ClusterNode: raise exc def acquire_connection(self) -> Connection: - if self._free: - for _ in range(len(self._free)): - connection = self._free.popleft() - if connection.is_connected: - return connection - self._free.append(connection) - + try: return self._free.popleft() + except IndexError: + if len(self._connections) < self.max_connections: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection - if len(self._connections) < self.max_connections: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection - - raise ConnectionError("Too many connections") + raise MaxConnectionsError() async def parse_response( self, connection: Connection, command: str, **kwargs: Any @@ -887,18 +898,18 @@ class ClusterNode: # Release connection self._free.append(connection) - async def execute_pipeline(self) -> bool: + async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection connection = self.acquire_connection() # Execute command await connection.send_packed_command( - connection.pack_commands(cmd.args for cmd in self._command_stack), False + connection.pack_commands(cmd.args for cmd in commands), False ) # Read responses ret = False - for cmd in self._command_stack: + for cmd in commands: try: cmd.result = await self.parse_response( connection, cmd.args[0], **cmd.kwargs @@ -928,12 +939,12 @@ class NodesManager: def __init__( self, startup_nodes: List["ClusterNode"], - require_full_coverage: bool = False, - **kwargs: Any, + require_full_coverage: bool, + connection_kwargs: Dict[str, Any], ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage - self.connection_kwargs = kwargs + self.connection_kwargs = connection_kwargs self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1052,6 +1063,7 @@ class NodesManager: disagreements = [] startup_nodes_reachable = False fully_covered = False + exception = None for startup_node in self.startup_nodes.values(): try: # Make sure cluster mode is enabled on this node @@ -1063,7 +1075,8 @@ class NodesManager: ) cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + exception = e continue except ResponseError as e: # Isn't a cluster connection, so it won't parse these @@ -1164,7 +1177,7 @@ class NodesManager: raise RedisClusterException( "Redis Cluster cannot be connected. Please provide at least " "one reachable node. " - ) + ) from exception # Check if the slots are not fully covered if not fully_covered and self.require_full_coverage: @@ -1329,7 +1342,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm await asyncio.sleep(0.25) else: # All other errors should be raised. - raise e + raise # If it fails the configured number of times then raise an exception raise exception @@ -1365,12 +1378,14 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm node = target_nodes[0] if node.name not in nodes: - nodes[node.name] = node - node._command_stack = [] - node._command_stack.append(cmd) + nodes[node.name] = (node, []) + nodes[node.name][1].append(cmd) errors = await asyncio.gather( - *(asyncio.ensure_future(node.execute_pipeline()) for node in nodes.values()) + *( + asyncio.ensure_future(node[0].execute_pipeline(node[1])) + for node in nodes.values() + ) ) if any(errors): diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 35536fc..7f9a0c7 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -685,7 +685,7 @@ class Connection: @property def is_connected(self): - return self._reader and self._writer + return self._reader is not None and self._writer is not None def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -767,7 +767,16 @@ class Connection: def _error_message(self, exception): # args for socket.error can either be (errno, "message") # or just "message" - if len(exception.args) == 1: + if not exception.args: + # asyncio has a bug where on Connection reset by peer, the + # exception is not instanciated, so args is empty. This is the + # workaround. + # See: https://github.com/redis/redis-py/issues/2237 + # See: https://github.com/python/cpython/issues/94061 + return ( + f"Error connecting to {self.host}:{self.port}. Connection reset by peer" + ) + elif len(exception.args) == 1: return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." else: return ( diff --git a/redis/cluster.py b/redis/cluster.py index 6034e96..5f3e303 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,7 +6,7 @@ import sys import threading import time from collections import OrderedDict -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Tuple, Union from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands @@ -42,7 +42,7 @@ from redis.utils import ( log = logging.getLogger(__name__) -def get_node_name(host: str, port: int) -> str: +def get_node_name(host: str, port: Union[str, int]) -> str: return f"{host}:{port}" diff --git a/redis/commands/bf/__init__.py b/redis/commands/bf/__init__.py index d62d8a0..a448def 100644 --- a/redis/commands/bf/__init__.py +++ b/redis/commands/bf/__init__.py @@ -165,11 +165,16 @@ class TDigestBloom(TDigestCommands, AbstractBloom): # TDIGEST_RESET: bool_ok, # TDIGEST_ADD: spaceHolder, # TDIGEST_MERGE: spaceHolder, - TDIGEST_CDF: float, - TDIGEST_QUANTILE: float, + TDIGEST_CDF: parse_to_list, + TDIGEST_QUANTILE: parse_to_list, TDIGEST_MIN: float, TDIGEST_MAX: float, + TDIGEST_TRIMMED_MEAN: float, TDIGEST_INFO: TDigestInfo, + TDIGEST_RANK: parse_to_list, + TDIGEST_REVRANK: parse_to_list, + TDIGEST_BYRANK: parse_to_list, + TDIGEST_BYREVRANK: parse_to_list, } self.client = client diff --git a/redis/commands/bf/commands.py b/redis/commands/bf/commands.py index baf0130..f6bf281 100644 --- a/redis/commands/bf/commands.py +++ b/redis/commands/bf/commands.py @@ -1,6 +1,6 @@ from redis.client import NEVER_DECODE from redis.exceptions import ModuleError -from redis.utils import HIREDIS_AVAILABLE +from redis.utils import HIREDIS_AVAILABLE, deprecated_function BF_RESERVE = "BF.RESERVE" BF_ADD = "BF.ADD" @@ -49,6 +49,11 @@ TDIGEST_QUANTILE = "TDIGEST.QUANTILE" TDIGEST_MIN = "TDIGEST.MIN" TDIGEST_MAX = "TDIGEST.MAX" TDIGEST_INFO = "TDIGEST.INFO" +TDIGEST_TRIMMED_MEAN = "TDIGEST.TRIMMED_MEAN" +TDIGEST_RANK = "TDIGEST.RANK" +TDIGEST_REVRANK = "TDIGEST.REVRANK" +TDIGEST_BYRANK = "TDIGEST.BYRANK" +TDIGEST_BYREVRANK = "TDIGEST.BYREVRANK" class BFCommands: @@ -67,6 +72,8 @@ class BFCommands: self.append_no_scale(params, noScale) return self.execute_command(BF_RESERVE, *params) + reserve = create + def add(self, key, item): """ Add to a Bloom Filter `key` an `item`. @@ -176,6 +183,8 @@ class CFCommands: self.append_max_iterations(params, max_iterations) return self.execute_command(CF_RESERVE, *params) + reserve = create + def add(self, key, item): """ Add an `item` to a Cuckoo Filter `key`. @@ -316,6 +325,7 @@ class TOPKCommands: """ # noqa return self.execute_command(TOPK_QUERY, key, *items) + @deprecated_function(version="4.4.0", reason="deprecated since redisbloom 2.4.0") def count(self, key, *items): """ Return count for one `item` or more from `key`. @@ -344,12 +354,12 @@ class TOPKCommands: class TDigestCommands: - def create(self, key, compression): + def create(self, key, compression=100): """ Allocate the memory and initialize the t-digest. For more information see `TDIGEST.CREATE `_. """ # noqa - return self.execute_command(TDIGEST_CREATE, key, compression) + return self.execute_command(TDIGEST_CREATE, key, "COMPRESSION", compression) def reset(self, key): """ @@ -358,26 +368,30 @@ class TDigestCommands: """ # noqa return self.execute_command(TDIGEST_RESET, key) - def add(self, key, values, weights): + def add(self, key, values): """ - Add one or more samples (value with weight) to a sketch `key`. - Both `values` and `weights` are lists. - For more information see `TDIGEST.ADD `_. + Adds one or more observations to a t-digest sketch `key`. - Example: - - >>> tdigestadd('A', [1500.0], [1.0]) + For more information see `TDIGEST.ADD `_. """ # noqa - params = [key] - self.append_values_and_weights(params, values, weights) - return self.execute_command(TDIGEST_ADD, *params) + return self.execute_command(TDIGEST_ADD, key, *values) - def merge(self, toKey, fromKey): + def merge(self, destination_key, num_keys, *keys, compression=None, override=False): """ - Merge all of the values from 'fromKey' to 'toKey' sketch. + Merges all of the values from `keys` to 'destination-key' sketch. + It is mandatory to provide the `num_keys` before passing the input keys and + the other (optional) arguments. + If `destination_key` already exists its values are merged with the input keys. + If you wish to override the destination key contents use the `OVERRIDE` parameter. + For more information see `TDIGEST.MERGE `_. """ # noqa - return self.execute_command(TDIGEST_MERGE, toKey, fromKey) + params = [destination_key, num_keys, *keys] + if compression is not None: + params.extend(["COMPRESSION", compression]) + if override: + params.append("OVERRIDE") + return self.execute_command(TDIGEST_MERGE, *params) def min(self, key): """ @@ -393,20 +407,21 @@ class TDigestCommands: """ # noqa return self.execute_command(TDIGEST_MAX, key) - def quantile(self, key, quantile): + def quantile(self, key, quantile, *quantiles): """ - Return double value estimate of the cutoff such that a specified fraction of the data - added to this TDigest would be less than or equal to the cutoff. + Returns estimates of one or more cutoffs such that a specified fraction of the + observations added to this t-digest would be less than or equal to each of the + specified cutoffs. (Multiple quantiles can be returned with one call) For more information see `TDIGEST.QUANTILE `_. """ # noqa - return self.execute_command(TDIGEST_QUANTILE, key, quantile) + return self.execute_command(TDIGEST_QUANTILE, key, quantile, *quantiles) - def cdf(self, key, value): + def cdf(self, key, value, *values): """ Return double fraction of all points added which are <= value. For more information see `TDIGEST.CDF `_. """ # noqa - return self.execute_command(TDIGEST_CDF, key, value) + return self.execute_command(TDIGEST_CDF, key, value, *values) def info(self, key): """ @@ -416,6 +431,50 @@ class TDigestCommands: """ # noqa return self.execute_command(TDIGEST_INFO, key) + def trimmed_mean(self, key, low_cut_quantile, high_cut_quantile): + """ + Return mean value from the sketch, excluding observation values outside + the low and high cutoff quantiles. + For more information see `TDIGEST.TRIMMED_MEAN `_. + """ # noqa + return self.execute_command( + TDIGEST_TRIMMED_MEAN, key, low_cut_quantile, high_cut_quantile + ) + + def rank(self, key, value, *values): + """ + Retrieve the estimated rank of value (the number of observations in the sketch + that are smaller than value + half the number of observations that are equal to value). + + For more information see `TDIGEST.RANK `_. + """ # noqa + return self.execute_command(TDIGEST_RANK, key, value, *values) + + def revrank(self, key, value, *values): + """ + Retrieve the estimated rank of value (the number of observations in the sketch + that are larger than value + half the number of observations that are equal to value). + + For more information see `TDIGEST.REVRANK `_. + """ # noqa + return self.execute_command(TDIGEST_REVRANK, key, value, *values) + + def byrank(self, key, rank, *ranks): + """ + Retrieve an estimation of the value with the given rank. + + For more information see `TDIGEST.BY_RANK `_. + """ # noqa + return self.execute_command(TDIGEST_BYRANK, key, rank, *ranks) + + def byrevrank(self, key, rank, *ranks): + """ + Retrieve an estimation of the value with the given reverse rank. + + For more information see `TDIGEST.BY_REVRANK `_. + """ # noqa + return self.execute_command(TDIGEST_BYREVRANK, key, rank, *ranks) + class CMSCommands: """Count-Min Sketch Commands""" diff --git a/redis/commands/bf/info.py b/redis/commands/bf/info.py index 24c5419..c526e6c 100644 --- a/redis/commands/bf/info.py +++ b/redis/commands/bf/info.py @@ -68,18 +68,20 @@ class TopKInfo(object): class TDigestInfo(object): compression = None capacity = None - mergedNodes = None - unmergedNodes = None - mergedWeight = None - unmergedWeight = None - totalCompressions = None + merged_nodes = None + unmerged_nodes = None + merged_weight = None + unmerged_weight = None + total_compressions = None + memory_usage = None def __init__(self, args): response = dict(zip(map(nativestr, args[::2]), args[1::2])) self.compression = response["Compression"] self.capacity = response["Capacity"] - self.mergedNodes = response["Merged nodes"] - self.unmergedNodes = response["Unmerged nodes"] - self.mergedWeight = response["Merged weight"] - self.unmergedWeight = response["Unmerged weight"] - self.totalCompressions = response["Total compressions"] + self.merged_nodes = response["Merged nodes"] + self.unmerged_nodes = response["Unmerged nodes"] + self.merged_weight = response["Merged weight"] + self.unmerged_weight = response["Unmerged weight"] + self.total_compressions = response["Total compressions"] + self.memory_usage = response["Memory usage"] diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index a1060d2..68b66fa 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -316,6 +316,25 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): # Sum up the reply from each command return sum(await self._execute_pipeline_by_slot(command, slots_to_keys)) + async def _execute_pipeline_by_slot( + self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]] + ) -> List[Any]: + if self._initialize: + await self.initialize() + read_from_replicas = self.read_from_replicas and command in READ_COMMANDS + pipe = self.pipeline() + [ + pipe.execute_command( + command, + *slot_args, + target_nodes=[ + self.nodes_manager.get_node_from_slot(slot, read_from_replicas) + ], + ) + for slot, slot_args in slots_to_args.items() + ] + return await pipe.execute() + class ClusterManagementCommands(ManagementCommands): """ diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py index 3736195..a882dd5 100644 --- a/redis/commands/graph/__init__.py +++ b/redis/commands/graph/__init__.py @@ -1,9 +1,13 @@ from ..helpers import quote_string, random_string, stringify_param_value -from .commands import GraphCommands +from .commands import AsyncGraphCommands, GraphCommands from .edge import Edge # noqa from .node import Node # noqa from .path import Path # noqa +DB_LABELS = "DB.LABELS" +DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES" +DB_PROPERTYKEYS = "DB.PROPERTYKEYS" + class Graph(GraphCommands): """ @@ -44,25 +48,19 @@ class Graph(GraphCommands): lbls = self.labels() # Unpack data. - self._labels = [None] * len(lbls) - for i, l in enumerate(lbls): - self._labels[i] = l[0] + self._labels = [l[0] for _, l in enumerate(lbls)] def _refresh_relations(self): rels = self.relationship_types() # Unpack data. - self._relationship_types = [None] * len(rels) - for i, r in enumerate(rels): - self._relationship_types[i] = r[0] + self._relationship_types = [r[0] for _, r in enumerate(rels)] def _refresh_attributes(self): props = self.property_keys() # Unpack data. - self._properties = [None] * len(props) - for i, p in enumerate(props): - self._properties[i] = p[0] + self._properties = [p[0] for _, p in enumerate(props)] def get_label(self, idx): """ @@ -108,12 +106,12 @@ class Graph(GraphCommands): The index of the property """ try: - propertie = self._properties[idx] + p = self._properties[idx] except IndexError: # Refresh properties. self._refresh_attributes() - propertie = self._properties[idx] - return propertie + p = self._properties[idx] + return p def add_node(self, node): """ @@ -133,6 +131,8 @@ class Graph(GraphCommands): self.edges.append(edge) def _build_params_header(self, params): + if params is None: + return "" if not isinstance(params, dict): raise TypeError("'params' must be a dict") # Header starts with "CYPHER" @@ -147,16 +147,109 @@ class Graph(GraphCommands): q = f"CALL {procedure}({','.join(args)})" y = kwagrs.get("y", None) - if y: - q += f" YIELD {','.join(y)}" + if y is not None: + q += f"YIELD {','.join(y)}" return self.query(q, read_only=read_only) def labels(self): - return self.call_procedure("db.labels", read_only=True).result_set + return self.call_procedure(DB_LABELS, read_only=True).result_set def relationship_types(self): - return self.call_procedure("db.relationshipTypes", read_only=True).result_set + return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set def property_keys(self): - return self.call_procedure("db.propertyKeys", read_only=True).result_set + return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set + + +class AsyncGraph(Graph, AsyncGraphCommands): + """Async version for Graph""" + + async def _refresh_labels(self): + lbls = await self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + async def _refresh_attributes(self): + props = await self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + async def _refresh_relations(self): + rels = await self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + async def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + await self._refresh_labels() + label = self._labels[idx] + return label + + async def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + await self._refresh_attributes() + p = self._properties[idx] + return p + + async def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + await self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + async def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + f"YIELD {','.join(y)}" + return await self.query(q, read_only=read_only) + + async def labels(self): + return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set + + async def property_keys(self): + return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set + + async def relationship_types(self): + return ( + await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True) + ).result_set diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py index fe4224b..762ab42 100644 --- a/redis/commands/graph/commands.py +++ b/redis/commands/graph/commands.py @@ -3,7 +3,16 @@ from redis.exceptions import ResponseError from .exceptions import VersionMismatchException from .execution_plan import ExecutionPlan -from .query_result import QueryResult +from .query_result import AsyncQueryResult, QueryResult + +PROFILE_CMD = "GRAPH.PROFILE" +RO_QUERY_CMD = "GRAPH.RO_QUERY" +QUERY_CMD = "GRAPH.QUERY" +DELETE_CMD = "GRAPH.DELETE" +SLOWLOG_CMD = "GRAPH.SLOWLOG" +CONFIG_CMD = "GRAPH.CONFIG" +LIST_CMD = "GRAPH.LIST" +EXPLAIN_CMD = "GRAPH.EXPLAIN" class GraphCommands: @@ -52,33 +61,28 @@ class GraphCommands: query = q # handle query parameters - if params is not None: - query = self._build_params_header(params) + query + query = self._build_params_header(params) + query # construct query command # ask for compact result-set format # specify known graph version if profile: - cmd = "GRAPH.PROFILE" + cmd = PROFILE_CMD else: - cmd = "GRAPH.RO_QUERY" if read_only else "GRAPH.QUERY" + cmd = RO_QUERY_CMD if read_only else QUERY_CMD command = [cmd, self.name, query, "--compact"] # include timeout is specified - if timeout: - if not isinstance(timeout, int): - raise Exception("Timeout argument must be a positive integer") - command += ["timeout", timeout] + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") # issue query try: response = self.execute_command(*command) return QueryResult(self, response, profile) except ResponseError as e: - if "wrong number of arguments" in str(e): - print( - "Note: RedisGraph Python requires server version 2.2.8 or above" - ) # noqa if "unknown command" in str(e) and read_only: # `GRAPH.RO_QUERY` is unavailable in older versions. return self.query(q, params, timeout, read_only=False) @@ -106,7 +110,7 @@ class GraphCommands: For more information see `DELETE `_. # noqa """ self._clear_schema() - return self.execute_command("GRAPH.DELETE", self.name) + return self.execute_command(DELETE_CMD, self.name) # declared here, to override the built in redis.db.flush() def flush(self): @@ -146,7 +150,7 @@ class GraphCommands: 3. The issued query. 4. The amount of time needed for its execution, in milliseconds. """ - return self.execute_command("GRAPH.SLOWLOG", self.name) + return self.execute_command(SLOWLOG_CMD, self.name) def config(self, name, value=None, set=False): """ @@ -170,14 +174,14 @@ class GraphCommands: raise DataError( "``value`` can be provided only when ``set`` is True" ) # noqa - return self.execute_command("GRAPH.CONFIG", *params) + return self.execute_command(CONFIG_CMD, *params) def list_keys(self): """ Lists all graph keys in the keyspace. For more information see `GRAPH.LIST `_. # noqa """ - return self.execute_command("GRAPH.LIST") + return self.execute_command(LIST_CMD) def execution_plan(self, query, params=None): """ @@ -188,10 +192,9 @@ class GraphCommands: query: the query that will be executed params: query parameters """ - if params is not None: - query = self._build_params_header(params) + query + query = self._build_params_header(params) + query - plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) + plan = self.execute_command(EXPLAIN_CMD, self.name, query) if isinstance(plan[0], bytes): plan = [b.decode() for b in plan] return "\n".join(plan) @@ -206,8 +209,105 @@ class GraphCommands: query: the query that will be executed params: query parameters """ - if params is not None: - query = self._build_params_header(params) + query + query = self._build_params_header(params) + query + + plan = self.execute_command(EXPLAIN_CMD, self.name, query) + return ExecutionPlan(plan) + + +class AsyncGraphCommands(GraphCommands): + async def query(self, q, params=None, timeout=None, read_only=False, profile=False): + """ + Executes a query against the graph. + For more information see `GRAPH.QUERY `_. # noqa + + Args: + + q : str + The query. + params : dict + Query parameters. + timeout : int + Maximum runtime for read queries in milliseconds. + read_only : bool + Executes a readonly query if set to True. + profile : bool + Return details on results produced by and time + spent in each operation. + """ + + # maintain original 'q' + query = q + + # handle query parameters + query = self._build_params_header(params) + query + + # construct query command + # ask for compact result-set format + # specify known graph version + if profile: + cmd = PROFILE_CMD + else: + cmd = RO_QUERY_CMD if read_only else QUERY_CMD + command = [cmd, self.name, query, "--compact"] + + # include timeout is specified + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") + + # issue query + try: + response = await self.execute_command(*command) + return await AsyncQueryResult().initialize(self, response, profile) + except ResponseError as e: + if "unknown command" in str(e) and read_only: + # `GRAPH.RO_QUERY` is unavailable in older versions. + return await self.query(q, params, timeout, read_only=False) + raise e + except VersionMismatchException as e: + # client view over the graph schema is out of sync + # set client version and refresh local schema + self.version = e.version + self._refresh_schema() + # re-issue query + return await self.query(q, params, timeout, read_only) + + async def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query - plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + return "\n".join(plan) + + async def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) return ExecutionPlan(plan) + + async def flush(self): + """ + Commit the graph and reset the edges and the nodes to zero length. + """ + await self.commit() + self.nodes = {} + self.edges = [] diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py index 644ac5a..7c7f58b 100644 --- a/redis/commands/graph/query_result.py +++ b/redis/commands/graph/query_result.py @@ -1,4 +1,6 @@ +import sys from collections import OrderedDict +from distutils.util import strtobool # from prettytable import PrettyTable from redis import ResponseError @@ -9,10 +11,12 @@ from .node import Node from .path import Path LABELS_ADDED = "Labels added" +LABELS_REMOVED = "Labels removed" NODES_CREATED = "Nodes created" NODES_DELETED = "Nodes deleted" RELATIONSHIPS_DELETED = "Relationships deleted" PROPERTIES_SET = "Properties set" +PROPERTIES_REMOVED = "Properties removed" RELATIONSHIPS_CREATED = "Relationships created" INDICES_CREATED = "Indices created" INDICES_DELETED = "Indices deleted" @@ -21,8 +25,10 @@ INTERNAL_EXECUTION_TIME = "internal execution time" STATS = [ LABELS_ADDED, + LABELS_REMOVED, NODES_CREATED, PROPERTIES_SET, + PROPERTIES_REMOVED, RELATIONSHIPS_CREATED, NODES_DELETED, RELATIONSHIPS_DELETED, @@ -86,6 +92,9 @@ class QueryResult: self.parse_results(response) def _check_for_errors(self, response): + """ + Check if the response contains an error. + """ if isinstance(response[0], ResponseError): error = response[0] if str(error) == "version mismatch": @@ -99,6 +108,9 @@ class QueryResult: raise response[-1] def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ self.header = self.parse_header(raw_result_set) # Empty header. @@ -108,6 +120,9 @@ class QueryResult: self.result_set = self.parse_records(raw_result_set) def parse_statistics(self, raw_statistics): + """ + Parse the statistics returned in the response. + """ self.statistics = {} # decode statistics @@ -121,31 +136,31 @@ class QueryResult: self.statistics[s] = v def parse_header(self, raw_result_set): + """ + Parse the header of the result. + """ # An array of column name/column type pairs. header = raw_result_set[0] return header def parse_records(self, raw_result_set): - records = [] - result_set = raw_result_set[1] - for row in result_set: - record = [] - for idx, cell in enumerate(row): - if self.header[idx][0] == ResultSetColumnTypes.COLUMN_SCALAR: # noqa - record.append(self.parse_scalar(cell)) - elif self.header[idx][0] == ResultSetColumnTypes.COLUMN_NODE: # noqa - record.append(self.parse_node(cell)) - elif ( - self.header[idx][0] == ResultSetColumnTypes.COLUMN_RELATION - ): # noqa - record.append(self.parse_edge(cell)) - else: - print("Unknown column type.\n") - records.append(record) + """ + Parses the result set and returns a list of records. + """ + records = [ + [ + self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + for row in raw_result_set[1] + ] return records def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ # [[name, value type, value] X N] properties = {} for prop in props: @@ -156,6 +171,9 @@ class QueryResult: return properties def parse_string(self, cell): + """ + Parse the cell as a string. + """ if isinstance(cell, bytes): return cell.decode() elif not isinstance(cell, str): @@ -164,6 +182,9 @@ class QueryResult: return cell def parse_node(self, cell): + """ + Parse the cell to a node. + """ # Node ID (integer), # [label string offset (integer)], # [[name, value type, value] X N] @@ -178,6 +199,9 @@ class QueryResult: return Node(node_id=node_id, label=labels, properties=properties) def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ # Edge ID (integer), # reltype string offset (integer), # src node ID offset (integer), @@ -194,11 +218,17 @@ class QueryResult: ) def parse_path(self, cell): + """ + Parse the cell to a path. + """ nodes = self.parse_scalar(cell[0]) edges = self.parse_scalar(cell[1]) return Path(nodes, edges) def parse_map(self, cell): + """ + Parse the cell as a map. + """ m = OrderedDict() n_entries = len(cell) @@ -212,6 +242,9 @@ class QueryResult: return m def parse_point(self, cell): + """ + Parse the cell to point. + """ p = {} # A point is received an array of the form: [latitude, longitude] # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa @@ -219,94 +252,63 @@ class QueryResult: p["longitude"] = float(cell[1]) return p - def parse_scalar(self, cell): - scalar_type = int(cell[0]) - value = cell[1] - scalar = None - - if scalar_type == ResultSetScalarTypes.VALUE_NULL: - scalar = None - - elif scalar_type == ResultSetScalarTypes.VALUE_STRING: - scalar = self.parse_string(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER: - scalar = int(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_BOOLEAN: - value = value.decode() if isinstance(value, bytes) else value - if value == "true": - scalar = True - elif value == "false": - scalar = False - else: - print("Unknown boolean type\n") - - elif scalar_type == ResultSetScalarTypes.VALUE_DOUBLE: - scalar = float(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_ARRAY: - # array variable is introduced only for readability - scalar = array = value - for i in range(len(array)): - scalar[i] = self.parse_scalar(array[i]) + def parse_null(self, cell): + """ + Parse a null value. + """ + return None - elif scalar_type == ResultSetScalarTypes.VALUE_NODE: - scalar = self.parse_node(value) + def parse_integer(self, cell): + """ + Parse the integer value from the cell. + """ + return int(cell) - elif scalar_type == ResultSetScalarTypes.VALUE_EDGE: - scalar = self.parse_edge(value) + def parse_boolean(self, value): + """ + Parse the cell value as a boolean. + """ + value = value.decode() if isinstance(value, bytes) else value + try: + scalar = True if strtobool(value) else False + except ValueError: + sys.stderr.write("unknown boolean type\n") + scalar = None + return scalar - elif scalar_type == ResultSetScalarTypes.VALUE_PATH: - scalar = self.parse_path(value) + def parse_double(self, cell): + """ + Parse the cell as a double. + """ + return float(cell) - elif scalar_type == ResultSetScalarTypes.VALUE_MAP: - scalar = self.parse_map(value) + def parse_array(self, value): + """ + Parse an array of values. + """ + scalar = [self.parse_scalar(value[i]) for i in range(len(value))] + return scalar - elif scalar_type == ResultSetScalarTypes.VALUE_POINT: - scalar = self.parse_point(value) + def parse_unknown(self, cell): + """ + Parse a cell of unknown type. + """ + sys.stderr.write("Unknown type\n") + return None - elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN: - print("Unknown scalar type\n") + def parse_scalar(self, cell): + """ + Parse a scalar value from a cell in the result set. + """ + scalar_type = int(cell[0]) + value = cell[1] + scalar = self.parse_scalar_types[scalar_type](value) return scalar def parse_profile(self, response): self.result_set = [x[0 : x.index(",")].strip() for x in response] - # """Prints the data from the query response: - # 1. First row result_set contains the columns names. - # Thus the first row in PrettyTable will contain the - # columns. - # 2. The row after that will contain the data returned, - # or 'No Data returned' if there is none. - # 3. Prints the statistics of the query. - # """ - - # def pretty_print(self): - # if not self.is_empty(): - # header = [col[1] for col in self.header] - # tbl = PrettyTable(header) - - # for row in self.result_set: - # record = [] - # for idx, cell in enumerate(row): - # if type(cell) is Node: - # record.append(cell.to_string()) - # elif type(cell) is Edge: - # record.append(cell.to_string()) - # else: - # record.append(cell) - # tbl.add_row(record) - - # if len(self.result_set) == 0: - # tbl.add_row(['No data returned.']) - - # print(str(tbl) + '\n') - - # for stat in self.statistics: - # print("%s %s" % (stat, self.statistics[stat])) - def is_empty(self): return len(self.result_set) == 0 @@ -323,40 +325,249 @@ class QueryResult: @property def labels_added(self): + """Returns the number of labels added in the query""" return self._get_stat(LABELS_ADDED) + @property + def labels_removed(self): + """Returns the number of labels removed in the query""" + return self._get_stat(LABELS_REMOVED) + @property def nodes_created(self): + """Returns the number of nodes created in the query""" return self._get_stat(NODES_CREATED) @property def nodes_deleted(self): + """Returns the number of nodes deleted in the query""" return self._get_stat(NODES_DELETED) @property def properties_set(self): + """Returns the number of properties set in the query""" return self._get_stat(PROPERTIES_SET) + @property + def properties_removed(self): + """Returns the number of properties removed in the query""" + return self._get_stat(PROPERTIES_REMOVED) + @property def relationships_created(self): + """Returns the number of relationships created in the query""" return self._get_stat(RELATIONSHIPS_CREATED) @property def relationships_deleted(self): + """Returns the number of relationships deleted in the query""" return self._get_stat(RELATIONSHIPS_DELETED) @property def indices_created(self): + """Returns the number of indices created in the query""" return self._get_stat(INDICES_CREATED) @property def indices_deleted(self): + """Returns the number of indices deleted in the query""" return self._get_stat(INDICES_DELETED) @property def cached_execution(self): + """Returns whether or not the query execution plan was cached""" return self._get_stat(CACHED_EXECUTION) == 1 @property def run_time_ms(self): + """Returns the server execution time of the query""" return self._get_stat(INTERNAL_EXECUTION_TIME) + + @property + def parse_scalar_types(self): + return { + ResultSetScalarTypes.VALUE_NULL: self.parse_null, + ResultSetScalarTypes.VALUE_STRING: self.parse_string, + ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer, + ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean, + ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double, + ResultSetScalarTypes.VALUE_ARRAY: self.parse_array, + ResultSetScalarTypes.VALUE_NODE: self.parse_node, + ResultSetScalarTypes.VALUE_EDGE: self.parse_edge, + ResultSetScalarTypes.VALUE_PATH: self.parse_path, + ResultSetScalarTypes.VALUE_MAP: self.parse_map, + ResultSetScalarTypes.VALUE_POINT: self.parse_point, + ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown, + } + + @property + def parse_record_types(self): + return { + ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar, + ResultSetColumnTypes.COLUMN_NODE: self.parse_node, + ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge, + ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown, + } + + +class AsyncQueryResult(QueryResult): + """ + Async version for the QueryResult class - a class that + represents a result of the query operation. + """ + + def __init__(self): + """ + To init the class you must call self.initialize() + """ + pass + + async def initialize(self, graph, response, profile=False): + """ + Initializes the class. + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + await self.parse_results(response) + + return self + + async def parse_node(self, cell): + """ + Parses a node from the cell. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(await self.graph.get_label(inner_label)) + properties = await self.parse_entity_properties(cell[2]) + node_id = int(cell[0]) + return Node(node_id=node_id, label=labels, properties=properties) + + async def parse_scalar(self, cell): + """ + Parses a scalar value from the server response. + """ + scalar_type = int(cell[0]) + value = cell[1] + try: + scalar = await self.parse_scalar_types[scalar_type](value) + except TypeError: + # Not all of the functions are async + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + async def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [] + for row in raw_result_set[1]: + record = [ + await self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + records.append(record) + + return records + + async def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = await self.parse_records(raw_result_set) + + async def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = await self.graph.get_property(prop[0]) + prop_value = await self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + async def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = await self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = await self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + async def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = await self.parse_scalar(cell[0]) + edges = await self.parse_scalar(cell[1]) + return Path(nodes, edges) + + async def parse_map(self, cell): + """ + Parse the cell to a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = await self.parse_scalar(cell[i + 1]) + + return m + + async def parse_array(self, value): + """ + Parse array value. + """ + scalar = [await self.parse_scalar(value[i]) for i in range(len(value))] + return scalar diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 9391c2a..7fd4039 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -2,9 +2,8 @@ import os from json import JSONDecodeError, loads from typing import Dict, List, Optional, Union -from deprecated import deprecated - from redis.exceptions import DataError +from redis.utils import deprecated_function from ._util import JsonType from .decoders import decode_dict_keys @@ -137,7 +136,7 @@ class JSONCommands: "JSON.NUMINCRBY", name, str(path), self._encode(number) ) - @deprecated(version="4.0.0", reason="deprecated since redisjson 1.0.0") + @deprecated_function(version="4.0.0", reason="deprecated since redisjson 1.0.0") def nummultby(self, name: str, path: str, number: int) -> str: """Multiply the numeric (integer or floating point) JSON value under ``path`` at key ``name`` with the provided ``number``. @@ -368,19 +367,19 @@ class JSONCommands: pieces.append(str(path)) return self.execute_command("JSON.DEBUG", *pieces) - @deprecated( + @deprecated_function( version="4.0.0", reason="redisjson-py supported this, call get directly." ) def jsonget(self, *args, **kwargs): return self.get(*args, **kwargs) - @deprecated( + @deprecated_function( version="4.0.0", reason="redisjson-py supported this, call get directly." ) def jsonmget(self, *args, **kwargs): return self.mget(*args, **kwargs) - @deprecated( + @deprecated_function( version="4.0.0", reason="redisjson-py supported this, call get directly." ) def jsonset(self, *args, **kwargs): diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 923711b..70e9c27 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -167,5 +167,5 @@ class Pipeline(SearchCommands, redis.client.Pipeline): """Pipeline for the module.""" -class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline): +class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline): """AsyncPipeline for the module.""" diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 061e69c..a171fa1 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -104,7 +104,6 @@ class AggregateRequest: self._aggregateplan = [] self._loadfields = [] self._loadall = False - self._limit = Limit() self._max = 0 self._with_schema = False self._verbatim = False @@ -211,7 +210,8 @@ class AggregateRequest: `sort_by()` instead. """ - self._limit = Limit(offset, num) + _limit = Limit(offset, num) + self._aggregateplan.extend(_limit.build_args()) return self def sort_by(self, *fields, **kwargs): @@ -323,8 +323,6 @@ class AggregateRequest: ret.extend(self._aggregateplan) - ret += self._limit.build_args() - return ret diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 0121436..f02805e 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -3,6 +3,7 @@ import time from typing import Dict, Optional, Union from redis.client import Pipeline +from redis.utils import deprecated_function from ..helpers import parse_to_dict from ._util import to_string @@ -20,6 +21,7 @@ SEARCH_CMD = "FT.SEARCH" ADD_CMD = "FT.ADD" ADDHASH_CMD = "FT.ADDHASH" DROP_CMD = "FT.DROP" +DROPINDEX_CMD = "FT.DROPINDEX" EXPLAIN_CMD = "FT.EXPLAIN" EXPLAINCLI_CMD = "FT.EXPLAINCLI" DEL_CMD = "FT.DEL" @@ -170,8 +172,8 @@ class SearchCommands: For more information see `FT.DROPINDEX `_. """ # noqa - keep_str = "" if delete_documents else "KEEPDOCS" - return self.execute_command(DROP_CMD, self.index_name, keep_str) + delete_str = "DD" if delete_documents else "" + return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str) def _add_document( self, @@ -235,6 +237,9 @@ class SearchCommands: return self.execute_command(*args) + @deprecated_function( + version="2.0.0", reason="deprecated since redisearch 2.0, call hset instead" + ) def add_document( self, doc_id, @@ -288,6 +293,9 @@ class SearchCommands: **fields, ) + @deprecated_function( + version="2.0.0", reason="deprecated since redisearch 2.0, call hset instead" + ) def add_document_hash(self, doc_id, score=1.0, language=None, replace=False): """ Add a hash document to the index. diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 89ed973..6f31ce1 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -64,6 +64,7 @@ class TextField(Field): weight: float = 1.0, no_stem: bool = False, phonetic_matcher: str = None, + withsuffixtrie: bool = False, **kwargs, ): Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs) @@ -78,6 +79,8 @@ class TextField(Field): ]: Field.append_arg(self, self.PHONETIC) Field.append_arg(self, phonetic_matcher) + if withsuffixtrie: + Field.append_arg(self, "WITHSUFFIXTRIE") class NumericField(Field): @@ -108,11 +111,18 @@ class TagField(Field): CASESENSITIVE = "CASESENSITIVE" def __init__( - self, name: str, separator: str = ",", case_sensitive: bool = False, **kwargs + self, + name: str, + separator: str = ",", + case_sensitive: bool = False, + withsuffixtrie: bool = False, + **kwargs, ): args = [Field.TAG, self.SEPARATOR, separator] if case_sensitive: args.append(self.CASESENSITIVE) + if withsuffixtrie: + args.append("WITHSUFFIXTRIE") Field.__init__(self, name, args=args, **kwargs) diff --git a/redis/commands/search/querystring.py b/redis/commands/search/querystring.py index 1da0387..3ff1320 100644 --- a/redis/commands/search/querystring.py +++ b/redis/commands/search/querystring.py @@ -132,6 +132,9 @@ class GeoValue(Value): self.radius = radius self.unit = unit + def to_string(self): + return f"[{self.lon} {self.lat} {self.radius} {self.unit}]" + class Node: def __init__(self, *children, **kwparams): diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 5f4aca6..451bf89 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -38,7 +38,7 @@ class Result: score = float(res[i + 1]) if with_scores else None fields = {} - if hascontent: + if hascontent and res[i + fields_offset] is not None: fields = ( dict( dict( diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py index 3a30c24..13e3cdf 100644 --- a/redis/commands/timeseries/commands.py +++ b/redis/commands/timeseries/commands.py @@ -1,4 +1,7 @@ +from typing import Dict, List, Optional, Tuple, Union + from redis.exceptions import DataError +from redis.typing import KeyT, Number ADD_CMD = "TS.ADD" ALTER_CMD = "TS.ALTER" @@ -22,7 +25,15 @@ REVRANGE_CMD = "TS.REVRANGE" class TimeSeriesCommands: """RedisTimeSeries Commands.""" - def create(self, key, **kwargs): + def create( + self, + key: KeyT, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + duplicate_policy: Optional[str] = None, + ): """ Create a new time-series. @@ -31,40 +42,26 @@ class TimeSeriesCommands: key: time-series key retention_msecs: - Maximum age for samples compared to last event time (in milliseconds). + Maximum age for samples compared to highest reported timestamp (in milliseconds). If None or 0 is passed then the series is not trimmed at all. uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are - compressed by default. - Adding this flag will keep data in an uncompressed form. - Compression not only saves - memory but usually improve performance due to lower number - of memory accesses. + Changes data storage from compressed (by default) to uncompressed labels: Set of label-value pairs that represent metadata labels of the key. chunk_size: - Each time-serie uses chunks of memory of fixed size for - time series samples. - You can alter the default TSDB chunk size by passing the - chunk_size argument (in Bytes). + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. duplicate_policy: - Since RedisTimeSeries v1.4 you can specify the duplicate sample policy - ( Configure what to do on duplicate sample. ) + Policy for handling multiple samples with identical timestamps. Can be one of: - 'block': an error will occur for any out of order sample. - 'first': ignore the new value. - 'last': override with latest value. - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. - When this is not set, the server-wide default will be used. - For more information: https://oss.redis.com/redistimeseries/commands/#tscreate + For more information: https://redis.io/commands/ts.create/ """ # noqa - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) - duplicate_policy = kwargs.get("duplicate_policy", None) params = [key] self._append_retention(params, retention_msecs) self._append_uncompressed(params, uncompressed) @@ -74,29 +71,62 @@ class TimeSeriesCommands: return self.execute_command(CREATE_CMD, *params) - def alter(self, key, **kwargs): + def alter( + self, + key: KeyT, + retention_msecs: Optional[int] = None, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + duplicate_policy: Optional[str] = None, + ): """ - Update the retention, labels of an existing key. - For more information see + Update the retention, chunk size, duplicate policy, and labels of an existing + time series. + + Args: - The parameters are the same as TS.CREATE. + key: + time-series key + retention_msecs: + Maximum retention period, compared to maximal existing timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. - For more information: https://oss.redis.com/redistimeseries/commands/#tsalter + For more information: https://redis.io/commands/ts.alter/ """ # noqa - retention_msecs = kwargs.get("retention_msecs", None) - labels = kwargs.get("labels", {}) - duplicate_policy = kwargs.get("duplicate_policy", None) params = [key] self._append_retention(params, retention_msecs) + self._append_chunk_size(params, chunk_size) self._append_duplicate_policy(params, ALTER_CMD, duplicate_policy) self._append_labels(params, labels) return self.execute_command(ALTER_CMD, *params) - def add(self, key, timestamp, value, **kwargs): + def add( + self, + key: KeyT, + timestamp: Union[int, str], + value: Number, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + duplicate_policy: Optional[str] = None, + ): """ - Append (or create and append) a new sample to the series. - For more information see + Append (or create and append) a new sample to a time series. Args: @@ -107,35 +137,26 @@ class TimeSeriesCommands: value: Numeric data value of the sample retention_msecs: - Maximum age for samples compared to last event time (in milliseconds). + Maximum retention period, compared to maximal existing timestamp (in milliseconds). If None or 0 is passed then the series is not trimmed at all. uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are compressed by default. - Adding this flag will keep data in an uncompressed form. Compression not only saves - memory but usually improve performance due to lower number of memory accesses. + Changes data storage from compressed (by default) to uncompressed labels: Set of label-value pairs that represent metadata labels of the key. chunk_size: - Each time-serie uses chunks of memory of fixed size for time series samples. - You can alter the default TSDB chunk size by passing the chunk_size argument (in Bytes). + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. duplicate_policy: - Since RedisTimeSeries v1.4 you can specify the duplicate sample policy - (Configure what to do on duplicate sample). + Policy for handling multiple samples with identical timestamps. Can be one of: - 'block': an error will occur for any out of order sample. - 'first': ignore the new value. - 'last': override with latest value. - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. - When this is not set, the server-wide default will be used. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsadd + For more information: https://redis.io/commands/ts.add/ """ # noqa - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) - duplicate_policy = kwargs.get("duplicate_policy", None) params = [key, timestamp, value] self._append_retention(params, retention_msecs) self._append_uncompressed(params, uncompressed) @@ -145,28 +166,34 @@ class TimeSeriesCommands: return self.execute_command(ADD_CMD, *params) - def madd(self, ktv_tuples): + def madd(self, ktv_tuples: List[Tuple[KeyT, Union[int, str], Number]]): """ Append (or create and append) a new `value` to series `key` with `timestamp`. Expects a list of `tuples` as (`key`,`timestamp`, `value`). Return value is an array with timestamps of insertions. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmadd + For more information: https://redis.io/commands/ts.madd/ """ # noqa params = [] for ktv in ktv_tuples: - for item in ktv: - params.append(item) + params.extend(ktv) return self.execute_command(MADD_CMD, *params) - def incrby(self, key, value, **kwargs): + def incrby( + self, + key: KeyT, + value: Number, + timestamp: Optional[Union[int, str]] = None, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + ): """ - Increment (or create an time-series and increment) the latest - sample's of a series. - This command can be used as a counter or gauge that automatically gets - history as a time series. + Increment (or create an time-series and increment) the latest sample's of a series. + This command can be used as a counter or gauge that automatically gets history as a time series. Args: @@ -175,27 +202,19 @@ class TimeSeriesCommands: value: Numeric data value of the sample timestamp: - Timestamp of the sample. None can be used for automatic timestamp (using the system clock). + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). retention_msecs: Maximum age for samples compared to last event time (in milliseconds). If None or 0 is passed then the series is not trimmed at all. uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are compressed by default. - Adding this flag will keep data in an uncompressed form. Compression not only saves - memory but usually improve performance due to lower number of memory accesses. + Changes data storage from compressed (by default) to uncompressed labels: Set of label-value pairs that represent metadata labels of the key. chunk_size: - Each time-series uses chunks of memory of fixed size for time series samples. - You can alter the default TSDB chunk size by passing the chunk_size argument (in Bytes). + Memory size, in bytes, allocated for each data chunk. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsincrbytsdecrby + For more information: https://redis.io/commands/ts.incrby/ """ # noqa - timestamp = kwargs.get("timestamp", None) - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) params = [key, value] self._append_timestamp(params, timestamp) self._append_retention(params, retention_msecs) @@ -205,12 +224,19 @@ class TimeSeriesCommands: return self.execute_command(INCRBY_CMD, *params) - def decrby(self, key, value, **kwargs): + def decrby( + self, + key: KeyT, + value: Number, + timestamp: Optional[Union[int, str]] = None, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + ): """ - Decrement (or create an time-series and decrement) the - latest sample's of a series. - This command can be used as a counter or gauge that - automatically gets history as a time series. + Decrement (or create an time-series and decrement) the latest sample's of a series. + This command can be used as a counter or gauge that automatically gets history as a time series. Args: @@ -219,31 +245,19 @@ class TimeSeriesCommands: value: Numeric data value of the sample timestamp: - Timestamp of the sample. None can be used for automatic - timestamp (using the system clock). + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). retention_msecs: Maximum age for samples compared to last event time (in milliseconds). If None or 0 is passed then the series is not trimmed at all. uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are - compressed by default. - Adding this flag will keep data in an uncompressed form. - Compression not only saves - memory but usually improve performance due to lower number - of memory accesses. + Changes data storage from compressed (by default) to uncompressed labels: Set of label-value pairs that represent metadata labels of the key. chunk_size: - Each time-series uses chunks of memory of fixed size for time series samples. - You can alter the default TSDB chunk size by passing the chunk_size argument (in Bytes). + Memory size, in bytes, allocated for each data chunk. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsincrbytsdecrby + For more information: https://redis.io/commands/ts.decrby/ """ # noqa - timestamp = kwargs.get("timestamp", None) - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) params = [key, value] self._append_timestamp(params, timestamp) self._append_retention(params, retention_msecs) @@ -253,14 +267,9 @@ class TimeSeriesCommands: return self.execute_command(DECRBY_CMD, *params) - def delete(self, key, from_time, to_time): + def delete(self, key: KeyT, from_time: int, to_time: int): """ - Delete data points for a given timeseries and interval range - in the form of start and end delete timestamps. - The given timestamp interval is closed (inclusive), meaning start - and end data points will also be deleted. - Return the count for deleted items. - For more information see + Delete all samples between two timestamps for a given time series. Args: @@ -271,68 +280,98 @@ class TimeSeriesCommands: to_time: End timestamp for the range deletion. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsdel + For more information: https://redis.io/commands/ts.del/ """ # noqa return self.execute_command(DEL_CMD, key, from_time, to_time) - def createrule(self, source_key, dest_key, aggregation_type, bucket_size_msec): + def createrule( + self, + source_key: KeyT, + dest_key: KeyT, + aggregation_type: str, + bucket_size_msec: int, + align_timestamp: Optional[int] = None, + ): """ Create a compaction rule from values added to `source_key` into `dest_key`. - Aggregating for `bucket_size_msec` where an `aggregation_type` can be - [`avg`, `sum`, `min`, `max`, `range`, `count`, `first`, `last`, - `std.p`, `std.s`, `var.p`, `var.s`] - For more information: https://oss.redis.com/redistimeseries/master/commands/#tscreaterule + Args: + + source_key: + Key name for source time series + dest_key: + Key name for destination (compacted) time series + aggregation_type: + Aggregation type: One of the following: + [`avg`, `sum`, `min`, `max`, `range`, `count`, `first`, `last`, `std.p`, + `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Duration of each bucket, in milliseconds + align_timestamp: + Assure that there is a bucket that starts at exactly align_timestamp and + align all other buckets accordingly. + + For more information: https://redis.io/commands/ts.createrule/ """ # noqa params = [source_key, dest_key] self._append_aggregation(params, aggregation_type, bucket_size_msec) + if align_timestamp is not None: + params.append(align_timestamp) return self.execute_command(CREATERULE_CMD, *params) - def deleterule(self, source_key, dest_key): + def deleterule(self, source_key: KeyT, dest_key: KeyT): """ - Delete a compaction rule. - For more information see + Delete a compaction rule from `source_key` to `dest_key`.. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsdeleterule + For more information: https://redis.io/commands/ts.deleterule/ """ # noqa return self.execute_command(DELETERULE_CMD, source_key, dest_key) def __range_params( self, - key, - from_time, - to_time, - count, - aggregation_type, - bucket_size_msec, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - align, + key: KeyT, + from_time: Union[int, str], + to_time: Union[int, str], + count: Optional[int], + aggregation_type: Optional[str], + bucket_size_msec: Optional[int], + filter_by_ts: Optional[List[int]], + filter_by_min_value: Optional[int], + filter_by_max_value: Optional[int], + align: Optional[Union[int, str]], + latest: Optional[bool], + bucket_timestamp: Optional[str], + empty: Optional[bool], ): """Create TS.RANGE and TS.REVRANGE arguments.""" params = [key, from_time, to_time] + self._append_latest(params, latest) self._append_filer_by_ts(params, filter_by_ts) self._append_filer_by_value(params, filter_by_min_value, filter_by_max_value) self._append_count(params, count) self._append_align(params, align) self._append_aggregation(params, aggregation_type, bucket_size_msec) + self._append_bucket_timestamp(params, bucket_timestamp) + self._append_empty(params, empty) return params def range( self, - key, - from_time, - to_time, - count=None, - aggregation_type=None, - bucket_size_msec=0, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - align=None, + key: KeyT, + from_time: Union[int, str], + to_time: Union[int, str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, ): """ Query a range in forward direction for a specific time-serie. @@ -342,31 +381,34 @@ class TimeSeriesCommands: key: Key name for timeseries. from_time: - Start timestamp for the range query. - can be used to express - the minimum possible timestamp (0). + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). to_time: - End timestamp for range query, + can be used to express the - maximum possible timestamp. + End timestamp for range query, + can be used to express the maximum possible timestamp. count: - Optional maximum number of returned results. + Limits the number of returned samples. aggregation_type: - Optional aggregation type. Can be one of - [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] bucket_size_msec: Time bucket for aggregation in milliseconds. filter_by_ts: List of timestamps to filter the result by specific timestamps. filter_by_min_value: - Filter result by minimum value (must mention also filter - by_max_value). + Filter result by minimum value (must mention also filter by_max_value). filter_by_max_value: - Filter result by maximum value (must mention also filter - by_min_value). + Filter result by maximum value (must mention also filter by_min_value). align: Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsrangetsrevrange + latest: + Used when a time series is a compaction, reports the compacted value of the + latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.range/ """ # noqa params = self.__range_params( key, @@ -379,21 +421,27 @@ class TimeSeriesCommands: filter_by_min_value, filter_by_max_value, align, + latest, + bucket_timestamp, + empty, ) return self.execute_command(RANGE_CMD, *params) def revrange( self, - key, - from_time, - to_time, - count=None, - aggregation_type=None, - bucket_size_msec=0, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - align=None, + key: KeyT, + from_time: Union[int, str], + to_time: Union[int, str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, ): """ Query a range in reverse direction for a specific time-series. @@ -409,10 +457,10 @@ class TimeSeriesCommands: to_time: End timestamp for range query, + can be used to express the maximum possible timestamp. count: - Optional maximum number of returned results. + Limits the number of returned samples. aggregation_type: - Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] bucket_size_msec: Time bucket for aggregation in milliseconds. filter_by_ts: @@ -423,8 +471,16 @@ class TimeSeriesCommands: Filter result by maximum value (must mention also filter_by_min_value). align: Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsrangetsrevrange + latest: + Used when a time series is a compaction, reports the compacted value of the + latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.revrange/ """ # noqa params = self.__range_params( key, @@ -437,34 +493,43 @@ class TimeSeriesCommands: filter_by_min_value, filter_by_max_value, align, + latest, + bucket_timestamp, + empty, ) return self.execute_command(REVRANGE_CMD, *params) def __mrange_params( self, - aggregation_type, - bucket_size_msec, - count, - filters, - from_time, - to_time, - with_labels, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - groupby, - reduce, - select_labels, - align, + aggregation_type: Optional[str], + bucket_size_msec: Optional[int], + count: Optional[int], + filters: List[str], + from_time: Union[int, str], + to_time: Union[int, str], + with_labels: Optional[bool], + filter_by_ts: Optional[List[int]], + filter_by_min_value: Optional[int], + filter_by_max_value: Optional[int], + groupby: Optional[str], + reduce: Optional[str], + select_labels: Optional[List[str]], + align: Optional[Union[int, str]], + latest: Optional[bool], + bucket_timestamp: Optional[str], + empty: Optional[bool], ): """Create TS.MRANGE and TS.MREVRANGE arguments.""" params = [from_time, to_time] + self._append_latest(params, latest) self._append_filer_by_ts(params, filter_by_ts) self._append_filer_by_value(params, filter_by_min_value, filter_by_max_value) + self._append_with_labels(params, with_labels, select_labels) self._append_count(params, count) self._append_align(params, align) self._append_aggregation(params, aggregation_type, bucket_size_msec) - self._append_with_labels(params, with_labels, select_labels) + self._append_bucket_timestamp(params, bucket_timestamp) + self._append_empty(params, empty) params.extend(["FILTER"]) params += filters self._append_groupby_reduce(params, groupby, reduce) @@ -472,20 +537,23 @@ class TimeSeriesCommands: def mrange( self, - from_time, - to_time, - filters, - count=None, - aggregation_type=None, - bucket_size_msec=0, - with_labels=False, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - groupby=None, - reduce=None, - select_labels=None, - align=None, + from_time: Union[int, str], + to_time: Union[int, str], + filters: List[str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + with_labels: Optional[bool] = False, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + groupby: Optional[str] = None, + reduce: Optional[str] = None, + select_labels: Optional[List[str]] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, ): """ Query a range across multiple time-series by filters in forward direction. @@ -493,46 +561,45 @@ class TimeSeriesCommands: Args: from_time: - Start timestamp for the range query. `-` can be used to - express the minimum possible timestamp (0). + Start timestamp for the range query. `-` can be used to express the minimum possible timestamp (0). to_time: - End timestamp for range query, `+` can be used to express - the maximum possible timestamp. + End timestamp for range query, `+` can be used to express the maximum possible timestamp. filters: filter to match the time-series labels. count: - Optional maximum number of returned results. + Limits the number of returned samples. aggregation_type: - Optional aggregation type. Can be one of - [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] bucket_size_msec: Time bucket for aggregation in milliseconds. with_labels: - Include in the reply the label-value pairs that represent metadata - labels of the time-series. - If this argument is not set, by default, an empty Array will be - replied on the labels array position. + Include in the reply all label-value pairs representing metadata labels of the time series. filter_by_ts: List of timestamps to filter the result by specific timestamps. filter_by_min_value: - Filter result by minimum value (must mention also - filter_by_max_value). + Filter result by minimum value (must mention also filter_by_max_value). filter_by_max_value: - Filter result by maximum value (must mention also - filter_by_min_value). + Filter result by maximum value (must mention also filter_by_min_value). groupby: Grouping by fields the results (must mention also reduce). reduce: - Applying reducer functions on each group. Can be one - of [`sum`, `min`, `max`]. + Applying reducer functions on each group. Can be one of [`avg` `sum`, `min`, + `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. select_labels: - Include in the reply only a subset of the key-value - pair labels of a series. + Include in the reply only a subset of the key-value pair labels of a series. align: Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmrangetsmrevrange + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.mrange/ """ # noqa params = self.__mrange_params( aggregation_type, @@ -549,26 +616,32 @@ class TimeSeriesCommands: reduce, select_labels, align, + latest, + bucket_timestamp, + empty, ) return self.execute_command(MRANGE_CMD, *params) def mrevrange( self, - from_time, - to_time, - filters, - count=None, - aggregation_type=None, - bucket_size_msec=0, - with_labels=False, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - groupby=None, - reduce=None, - select_labels=None, - align=None, + from_time: Union[int, str], + to_time: Union[int, str], + filters: List[str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + with_labels: Optional[bool] = False, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + groupby: Optional[str] = None, + reduce: Optional[str] = None, + select_labels: Optional[List[str]] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, ): """ Query a range across multiple time-series by filters in reverse direction. @@ -576,48 +649,45 @@ class TimeSeriesCommands: Args: from_time: - Start timestamp for the range query. - can be used to express - the minimum possible timestamp (0). + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). to_time: - End timestamp for range query, + can be used to express - the maximum possible timestamp. + End timestamp for range query, + can be used to express the maximum possible timestamp. filters: Filter to match the time-series labels. count: - Optional maximum number of returned results. + Limits the number of returned samples. aggregation_type: - Optional aggregation type. Can be one of - [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] bucket_size_msec: Time bucket for aggregation in milliseconds. with_labels: - Include in the reply the label-value pairs that represent - metadata labels - of the time-series. - If this argument is not set, by default, an empty Array - will be replied - on the labels array position. + Include in the reply all label-value pairs representing metadata labels of the time series. filter_by_ts: List of timestamps to filter the result by specific timestamps. filter_by_min_value: - Filter result by minimum value (must mention also filter - by_max_value). + Filter result by minimum value (must mention also filter_by_max_value). filter_by_max_value: - Filter result by maximum value (must mention also filter - by_min_value). + Filter result by maximum value (must mention also filter_by_min_value). groupby: Grouping by fields the results (must mention also reduce). reduce: - Applying reducer functions on each group. Can be one - of [`sum`, `min`, `max`]. + Applying reducer functions on each group. Can be one of [`avg` `sum`, `min`, + `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. select_labels: - Include in the reply only a subset of the key-value pair - labels of a series. + Include in the reply only a subset of the key-value pair labels of a series. align: Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmrangetsmrevrange + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.mrevrange/ """ # noqa params = self.__mrange_params( aggregation_type, @@ -634,54 +704,85 @@ class TimeSeriesCommands: reduce, select_labels, align, + latest, + bucket_timestamp, + empty, ) return self.execute_command(MREVRANGE_CMD, *params) - def get(self, key): + def get(self, key: KeyT, latest: Optional[bool] = False): """# noqa Get the last sample of `key`. + `latest` used when a time series is a compaction, reports the compacted + value of the latest (possibly partial) bucket - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsget + For more information: https://redis.io/commands/ts.get/ """ # noqa - return self.execute_command(GET_CMD, key) + params = [key] + self._append_latest(params, latest) + return self.execute_command(GET_CMD, *params) - def mget(self, filters, with_labels=False): + def mget( + self, + filters: List[str], + with_labels: Optional[bool] = False, + select_labels: Optional[List[str]] = None, + latest: Optional[bool] = False, + ): """# noqa Get the last samples matching the specific `filter`. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmget + Args: + + filters: + Filter to match the time-series labels. + with_labels: + Include in the reply all label-value pairs representing metadata + labels of the time series. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + + For more information: https://redis.io/commands/ts.mget/ """ # noqa params = [] - self._append_with_labels(params, with_labels) + self._append_latest(params, latest) + self._append_with_labels(params, with_labels, select_labels) params.extend(["FILTER"]) params += filters return self.execute_command(MGET_CMD, *params) - def info(self, key): + def info(self, key: KeyT): """# noqa Get information of `key`. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsinfo + For more information: https://redis.io/commands/ts.info/ """ # noqa return self.execute_command(INFO_CMD, key) - def queryindex(self, filters): + def queryindex(self, filters: List[str]): """# noqa - Get all the keys matching the `filter` list. + Get all time series keys matching the `filter` list. - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsqueryindex + For more information: https://redis.io/commands/ts.queryindex/ """ # noq return self.execute_command(QUERYINDEX_CMD, *filters) @staticmethod - def _append_uncompressed(params, uncompressed): + def _append_uncompressed(params: List[str], uncompressed: Optional[bool]): """Append UNCOMPRESSED tag to params.""" if uncompressed: params.extend(["UNCOMPRESSED"]) @staticmethod - def _append_with_labels(params, with_labels, select_labels=None): + def _append_with_labels( + params: List[str], + with_labels: Optional[bool], + select_labels: Optional[List[str]], + ): """Append labels behavior to params.""" if with_labels and select_labels: raise DataError( @@ -694,19 +795,21 @@ class TimeSeriesCommands: params.extend(["SELECTED_LABELS", *select_labels]) @staticmethod - def _append_groupby_reduce(params, groupby, reduce): + def _append_groupby_reduce( + params: List[str], groupby: Optional[str], reduce: Optional[str] + ): """Append GROUPBY REDUCE property to params.""" if groupby is not None and reduce is not None: params.extend(["GROUPBY", groupby, "REDUCE", reduce.upper()]) @staticmethod - def _append_retention(params, retention): + def _append_retention(params: List[str], retention: Optional[int]): """Append RETENTION property to params.""" if retention is not None: params.extend(["RETENTION", retention]) @staticmethod - def _append_labels(params, labels): + def _append_labels(params: List[str], labels: Optional[List[str]]): """Append LABELS property to params.""" if labels: params.append("LABELS") @@ -714,38 +817,43 @@ class TimeSeriesCommands: params.extend([k, v]) @staticmethod - def _append_count(params, count): + def _append_count(params: List[str], count: Optional[int]): """Append COUNT property to params.""" if count is not None: params.extend(["COUNT", count]) @staticmethod - def _append_timestamp(params, timestamp): + def _append_timestamp(params: List[str], timestamp: Optional[int]): """Append TIMESTAMP property to params.""" if timestamp is not None: params.extend(["TIMESTAMP", timestamp]) @staticmethod - def _append_align(params, align): + def _append_align(params: List[str], align: Optional[Union[int, str]]): """Append ALIGN property to params.""" if align is not None: params.extend(["ALIGN", align]) @staticmethod - def _append_aggregation(params, aggregation_type, bucket_size_msec): + def _append_aggregation( + params: List[str], + aggregation_type: Optional[str], + bucket_size_msec: Optional[int], + ): """Append AGGREGATION property to params.""" if aggregation_type is not None: - params.append("AGGREGATION") - params.extend([aggregation_type, bucket_size_msec]) + params.extend(["AGGREGATION", aggregation_type, bucket_size_msec]) @staticmethod - def _append_chunk_size(params, chunk_size): + def _append_chunk_size(params: List[str], chunk_size: Optional[int]): """Append CHUNK_SIZE property to params.""" if chunk_size is not None: params.extend(["CHUNK_SIZE", chunk_size]) @staticmethod - def _append_duplicate_policy(params, command, duplicate_policy): + def _append_duplicate_policy( + params: List[str], command: Optional[str], duplicate_policy: Optional[str] + ): """Append DUPLICATE_POLICY property to params on CREATE and ON_DUPLICATE on ADD. """ @@ -756,13 +864,33 @@ class TimeSeriesCommands: params.extend(["DUPLICATE_POLICY", duplicate_policy]) @staticmethod - def _append_filer_by_ts(params, ts_list): + def _append_filer_by_ts(params: List[str], ts_list: Optional[List[int]]): """Append FILTER_BY_TS property to params.""" if ts_list is not None: params.extend(["FILTER_BY_TS", *ts_list]) @staticmethod - def _append_filer_by_value(params, min_value, max_value): + def _append_filer_by_value( + params: List[str], min_value: Optional[int], max_value: Optional[int] + ): """Append FILTER_BY_VALUE property to params.""" if min_value is not None and max_value is not None: params.extend(["FILTER_BY_VALUE", min_value, max_value]) + + @staticmethod + def _append_latest(params: List[str], latest: Optional[bool]): + """Append LATEST property to params.""" + if latest: + params.append("LATEST") + + @staticmethod + def _append_bucket_timestamp(params: List[str], bucket_timestamp: Optional[str]): + """Append BUCKET_TIMESTAMP property to params.""" + if bucket_timestamp is not None: + params.extend(["BUCKETTIMESTAMP", bucket_timestamp]) + + @staticmethod + def _append_empty(params: List[str], empty: Optional[bool]): + """Append EMPTY property to params.""" + if empty: + params.append("EMPTY") diff --git a/redis/commands/timeseries/info.py b/redis/commands/timeseries/info.py index fba7f09..65f3baa 100644 --- a/redis/commands/timeseries/info.py +++ b/redis/commands/timeseries/info.py @@ -60,15 +60,15 @@ class TSInfo: https://oss.redis.com/redistimeseries/configuration/#duplicate_policy """ response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.rules = response["rules"] - self.source_key = response["sourceKey"] - self.chunk_count = response["chunkCount"] - self.memory_usage = response["memoryUsage"] - self.total_samples = response["totalSamples"] - self.labels = list_to_dict(response["labels"]) - self.retention_msecs = response["retentionTime"] - self.lastTimeStamp = response["lastTimestamp"] - self.first_time_stamp = response["firstTimestamp"] + self.rules = response.get("rules") + self.source_key = response.get("sourceKey") + self.chunk_count = response.get("chunkCount") + self.memory_usage = response.get("memoryUsage") + self.total_samples = response.get("totalSamples") + self.labels = list_to_dict(response.get("labels")) + self.retention_msecs = response.get("retentionTime") + self.last_timestamp = response.get("lastTimestamp") + self.first_timestamp = response.get("firstTimestamp") if "maxSamplesPerChunk" in response: self.max_samples_per_chunk = response["maxSamplesPerChunk"] self.chunk_size = ( diff --git a/redis/exceptions.py b/redis/exceptions.py index d18b354..8a8bf42 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -199,3 +199,7 @@ class SlotNotCoveredError(RedisClusterException): """ pass + + +class MaxConnectionsError(ConnectionError): + ... diff --git a/redis/utils.py b/redis/utils.py index 0c34e1e..56165ea 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from functools import wraps from typing import Any, Dict, Mapping, Union try: @@ -79,3 +80,30 @@ def merge_result(command, res): result.add(value) return list(result) + + +def warn_deprecated(name, reason="", version="", stacklevel=2): + import warnings + + msg = f"Call to deprecated {name}." + if reason: + msg += f" ({reason})" + if version: + msg += f" -- Deprecated since version {version}." + warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) + + +def deprecated_function(reason="", version="", name=None): + """ + Decorator to mark a function as deprecated. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + warn_deprecated(name or func.__name__, reason, version, stacklevel=3) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index feb98cc..cac989f 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -1,10 +1,11 @@ +from math import inf + import pytest import redis.asyncio as redis from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE - -pytestmark = pytest.mark.asyncio +from tests.conftest import skip_ifmodversion_lt def intlist(obj): @@ -91,7 +92,7 @@ async def test_bf_scandump_and_loadchunk(modclient: redis.Redis): res += rv == x assert res < 5 - do_verify() + await do_verify() cmds = [] if HIREDIS_AVAILABLE: with pytest.raises(ModuleError): @@ -120,7 +121,7 @@ async def test_bf_scandump_and_loadchunk(modclient: redis.Redis): cur_info = await modclient.bf().execute_command("bf.debug", "myBloom") assert prev_info == cur_info - do_verify() + await do_verify() await modclient.bf().client.delete("myBloom") await modclient.bf().create("myBloom", "0.0001", "10000000") @@ -264,9 +265,10 @@ async def test_topk(modclient: redis.Redis): assert [1, 1, 0, 0, 1, 0, 0] == await modclient.topk().query( "topk", "A", "B", "C", "D", "E", "F", "G" ) - assert [4, 3, 2, 3, 3, 0, 1] == await modclient.topk().count( - "topk", "A", "B", "C", "D", "E", "F", "G" - ) + with pytest.deprecated_call(): + assert [4, 3, 2, 3, 3, 0, 1] == await modclient.topk().count( + "topk", "A", "B", "C", "D", "E", "F", "G" + ) # test full list assert await modclient.topk().reserve("topklist", 3, 50, 3, 0.9) @@ -308,9 +310,10 @@ async def test_topk_incrby(modclient: redis.Redis): ) res = await modclient.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) assert [None, "bar"] == res - assert [3, 6, 10, 4, 0] == await modclient.topk().count( - "topk", "bar", "baz", "42", "xyzzy", 4 - ) + with pytest.deprecated_call(): + assert [3, 6, 10, 4, 0] == await modclient.topk().count( + "topk", "bar", "baz", "42", "xyzzy", 4 + ) # region Test T-Digest @@ -321,11 +324,11 @@ async def test_tdigest_reset(modclient: redis.Redis): # reset on empty histogram assert await modclient.tdigest().reset("tDigest") # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(10)), [1.0] * 10) + assert await modclient.tdigest().add("tDigest", list(range(10))) assert await modclient.tdigest().reset("tDigest") # assert we have 0 unmerged nodes - assert 0 == (await modclient.tdigest().info("tDigest")).unmergedNodes + assert 0 == (await modclient.tdigest().info("tDigest")).unmerged_nodes @pytest.mark.redismod @@ -334,14 +337,24 @@ async def test_tdigest_merge(modclient: redis.Redis): assert await modclient.tdigest().create("to-tDigest", 10) assert await modclient.tdigest().create("from-tDigest", 10) # insert data-points into sketch - assert await modclient.tdigest().add("from-tDigest", [1.0] * 10, [1.0] * 10) - assert await modclient.tdigest().add("to-tDigest", [2.0] * 10, [10.0] * 10) + assert await modclient.tdigest().add("from-tDigest", [1.0] * 10) + assert await modclient.tdigest().add("to-tDigest", [2.0] * 10) # merge from-tdigest into to-tdigest - assert await modclient.tdigest().merge("to-tDigest", "from-tDigest") + assert await modclient.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram info = await modclient.tdigest().info("to-tDigest") - total_weight_to = float(info.mergedWeight) + float(info.unmergedWeight) - assert 110 == total_weight_to + total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) + assert 20.0 == total_weight_to + # test override + assert await modclient.tdigest().create("from-override", 10) + assert await modclient.tdigest().create("from-override-2", 10) + assert await modclient.tdigest().add("from-override", [3.0] * 10) + assert await modclient.tdigest().add("from-override-2", [4.0] * 10) + assert await modclient.tdigest().merge( + "to-tDigest", 2, "from-override", "from-override-2", override=True + ) + assert 3.0 == await modclient.tdigest().min("to-tDigest") + assert 4.0 == await modclient.tdigest().max("to-tDigest") @pytest.mark.redismod @@ -349,7 +362,7 @@ async def test_tdigest_merge(modclient: redis.Redis): async def test_tdigest_min_and_max(modclient: redis.Redis): assert await modclient.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", [1, 2, 3], [1.0] * 3) + assert await modclient.tdigest().add("tDigest", [1, 2, 3]) # min/max assert 3 == await modclient.tdigest().max("tDigest") assert 1 == await modclient.tdigest().min("tDigest") @@ -357,22 +370,31 @@ async def test_tdigest_min_and_max(modclient: redis.Redis): @pytest.mark.redismod @pytest.mark.experimental +@skip_ifmodversion_lt("2.4.0", "bf") async def test_tdigest_quantile(modclient: redis.Redis): assert await modclient.tdigest().create("tDigest", 500) # insert data-points into sketch assert await modclient.tdigest().add( - "tDigest", list([x * 0.01 for x in range(1, 10000)]), [1.0] * 10000 + "tDigest", list([x * 0.01 for x in range(1, 10000)]) ) # assert min min/max have same result as quantile 0 and 1 - assert await modclient.tdigest().max( - "tDigest" - ) == await modclient.tdigest().quantile("tDigest", 1.0) - assert await modclient.tdigest().min( - "tDigest" - ) == await modclient.tdigest().quantile("tDigest", 0.0) + assert ( + await modclient.tdigest().max("tDigest") + == (await modclient.tdigest().quantile("tDigest", 1))[0] + ) + assert ( + await modclient.tdigest().min("tDigest") + == (await modclient.tdigest().quantile("tDigest", 0.0))[0] + ) + + assert 1.0 == round((await modclient.tdigest().quantile("tDigest", 0.01))[0], 2) + assert 99.0 == round((await modclient.tdigest().quantile("tDigest", 0.99))[0], 2) - assert 1.0 == round(await modclient.tdigest().quantile("tDigest", 0.01), 2) - assert 99.0 == round(await modclient.tdigest().quantile("tDigest", 0.99), 2) + # test multiple quantiles + assert await modclient.tdigest().create("t-digest", 100) + assert await modclient.tdigest().add("t-digest", [1, 2, 3, 4, 5]) + res = await modclient.tdigest().quantile("t-digest", 0.5, 0.8) + assert [3.0, 5.0] == res @pytest.mark.redismod @@ -380,9 +402,67 @@ async def test_tdigest_quantile(modclient: redis.Redis): async def test_tdigest_cdf(modclient: redis.Redis): assert await modclient.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(1, 10)), [1.0] * 10) - assert 0.1 == round(await modclient.tdigest().cdf("tDigest", 1.0), 1) - assert 0.9 == round(await modclient.tdigest().cdf("tDigest", 9.0), 1) + assert await modclient.tdigest().add("tDigest", list(range(1, 10))) + assert 0.1 == round((await modclient.tdigest().cdf("tDigest", 1.0))[0], 1) + assert 0.9 == round((await modclient.tdigest().cdf("tDigest", 9.0))[0], 1) + res = await modclient.tdigest().cdf("tDigest", 1.0, 9.0) + assert [0.1, 0.9] == [round(x, 1) for x in res] + + +@pytest.mark.redismod +@pytest.mark.experimental +@skip_ifmodversion_lt("2.4.0", "bf") +async def test_tdigest_trimmed_mean(modclient: redis.Redis): + assert await modclient.tdigest().create("tDigest", 100) + # insert data-points into sketch + assert await modclient.tdigest().add("tDigest", list(range(1, 10))) + assert 5 == await modclient.tdigest().trimmed_mean("tDigest", 0.1, 0.9) + assert 4.5 == await modclient.tdigest().trimmed_mean("tDigest", 0.4, 0.5) + + +@pytest.mark.redismod +@pytest.mark.experimental +async def test_tdigest_rank(modclient: redis.Redis): + assert await modclient.tdigest().create("t-digest", 500) + assert await modclient.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == (await modclient.tdigest().rank("t-digest", -1))[0] + assert 0 == (await modclient.tdigest().rank("t-digest", 0))[0] + assert 10 == (await modclient.tdigest().rank("t-digest", 10))[0] + assert [-1, 20, 9] == await modclient.tdigest().rank("t-digest", -20, 20, 9) + + +@pytest.mark.redismod +@pytest.mark.experimental +async def test_tdigest_revrank(modclient: redis.Redis): + assert await modclient.tdigest().create("t-digest", 500) + assert await modclient.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == (await modclient.tdigest().revrank("t-digest", 20))[0] + assert 19 == (await modclient.tdigest().revrank("t-digest", 0))[0] + assert [-1, 19, 9] == await modclient.tdigest().revrank("t-digest", 21, 0, 10) + + +@pytest.mark.redismod +@pytest.mark.experimental +async def test_tdigest_byrank(modclient: redis.Redis): + assert await modclient.tdigest().create("t-digest", 500) + assert await modclient.tdigest().add("t-digest", list(range(1, 11))) + assert 1 == (await modclient.tdigest().byrank("t-digest", 0))[0] + assert 10 == (await modclient.tdigest().byrank("t-digest", 9))[0] + assert (await modclient.tdigest().byrank("t-digest", 100))[0] == inf + with pytest.raises(redis.ResponseError): + (await modclient.tdigest().byrank("t-digest", -1))[0] + + +@pytest.mark.redismod +@pytest.mark.experimental +async def test_tdigest_byrevrank(modclient: redis.Redis): + assert await modclient.tdigest().create("t-digest", 500) + assert await modclient.tdigest().add("t-digest", list(range(1, 11))) + assert 10 == (await modclient.tdigest().byrevrank("t-digest", 0))[0] + assert 1 == (await modclient.tdigest().byrevrank("t-digest", 9))[0] + assert (await modclient.tdigest().byrevrank("t-digest", 100))[0] == -inf + with pytest.raises(redis.ResponseError): + (await modclient.tdigest().byrevrank("t-digest", -1))[0] # @pytest.mark.redismod diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index f4ea5cd..7e1049e 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,9 +1,11 @@ import asyncio import binascii import datetime +import os import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union +from urllib.parse import urlparse import pytest @@ -14,10 +16,10 @@ if sys.version_info[0:2] == (3, 6): else: import pytest_asyncio -from _pytest.fixtures import FixtureRequest, SubRequest +from _pytest.fixtures import FixtureRequest -from redis.asyncio import Connection, RedisCluster -from redis.asyncio.cluster import ClusterNode, NodesManager +from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster +from redis.asyncio.connection import Connection, SSLConnection from redis.asyncio.parser import CommandsParser from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot @@ -26,6 +28,7 @@ from redis.exceptions import ( ClusterDownError, ConnectionError, DataError, + MaxConnectionsError, MovedError, NoPermissionError, RedisClusterException, @@ -39,7 +42,8 @@ from tests.conftest import ( skip_unless_arch_bits, ) -pytestmark = pytest.mark.asyncio +pytestmark = pytest.mark.onlycluster + default_host = "127.0.0.1" default_port = 7000 @@ -50,7 +54,7 @@ default_cluster_slots = [ @pytest_asyncio.fixture() -async def slowlog(request: SubRequest, r: RedisCluster) -> None: +async def slowlog(r: RedisCluster) -> None: """ Set the slowlog threshold to 0, and the max length to 128. This will force every @@ -146,7 +150,7 @@ def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster: async def moved_redirection_helper( - request: FixtureRequest, create_redis: Callable, failover: bool = False + create_redis: Callable[..., RedisCluster], failover: bool = False ) -> None: """ Test that the client handles MOVED response after a failover. @@ -202,7 +206,6 @@ async def moved_redirection_helper( assert prev_primary.server_type == REPLICA -@pytest.mark.onlycluster class TestRedisClusterObj: """ Tests for the RedisCluster class @@ -237,10 +240,18 @@ class TestRedisClusterObj: await cluster.close() - startup_nodes = [ClusterNode("127.0.0.1", 16379)] - async with RedisCluster(startup_nodes=startup_nodes) as rc: + startup_node = ClusterNode("127.0.0.1", 16379) + async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc: assert await rc.set("A", 1) assert await rc.get("A") == b"1" + assert all( + [ + name == "test" + for name in ( + await rc.client_getname(target_nodes=rc.ALL_NODES) + ).values() + ] + ) async def test_empty_startup_nodes(self) -> None: """ @@ -253,18 +264,43 @@ class TestRedisClusterObj: "RedisCluster requires at least one node to discover the " "cluster" ), str_if_bytes(ex.value) - async def test_from_url(self, r: RedisCluster) -> None: - redis_url = f"redis://{default_host}:{default_port}/0" - with mock.patch.object(RedisCluster, "from_url") as from_url: + async def test_from_url(self, request: FixtureRequest) -> None: + url = request.config.getoption("--redis-url") - async def from_url_mocked(_url, **_kwargs): - return await get_mocked_redis_client(url=_url, **_kwargs) + async with RedisCluster.from_url(url) as rc: + await rc.set("a", 1) + await rc.get("a") == 1 - from_url.side_effect = from_url_mocked - cluster = await RedisCluster.from_url(redis_url) - assert cluster.get_node(host=default_host, port=default_port) is not None + rc = RedisCluster.from_url("rediss://localhost:16379") + assert rc.connection_kwargs["connection_class"] is SSLConnection - await cluster.close() + async def test_max_connections( + self, create_redis: Callable[..., RedisCluster] + ) -> None: + rc = await create_redis(cls=RedisCluster, max_connections=10) + for node in rc.get_nodes(): + assert node.max_connections == 10 + + with mock.patch.object( + Connection, "read_response_without_lock" + ) as read_response_without_lock: + + async def read_response_without_lock_mocked( + *args: Any, **kwargs: Any + ) -> None: + await asyncio.sleep(10) + + read_response_without_lock.side_effect = read_response_without_lock_mocked + + with pytest.raises(MaxConnectionsError): + await asyncio.gather( + *( + rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) + for _ in range(11) + ) + ) + + await rc.close() async def test_execute_command_errors(self, r: RedisCluster) -> None: """ @@ -373,23 +409,23 @@ class TestRedisClusterObj: assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" async def test_moved_redirection( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test that the client handles MOVED response. """ - await moved_redirection_helper(request, create_redis, failover=False) + await moved_redirection_helper(create_redis, failover=False) async def test_moved_redirection_after_failover( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test that the client handles MOVED response after a failover. """ - await moved_redirection_helper(request, create_redis, failover=True) + await moved_redirection_helper(create_redis, failover=True) async def test_refresh_using_specific_nodes( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test making calls on specific nodes when the cluster has failed over to @@ -691,7 +727,6 @@ class TestRedisClusterObj: await rc.close() -@pytest.mark.onlycluster class TestClusterRedisCommands: """ Tests for RedisCluster unique commands @@ -777,6 +812,15 @@ class TestClusterRedisCommands: await asyncio.sleep(0.1) assert await r.unlink(*d.keys()) == 0 + async def test_initialize_before_execute_multi_key_command( + self, request: FixtureRequest + ) -> None: + # Test for issue https://github.com/redis/redis-py/issues/2437 + url = request.config.getoption("--redis-url") + r = RedisCluster.from_url(url) + assert 0 == await r.exists("a", "b", "c") + await r.close() + @skip_if_redis_enterprise() async def test_cluster_myid(self, r: RedisCluster) -> None: node = r.get_random_node() @@ -1918,7 +1962,7 @@ class TestClusterRedisCommands: @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() async def test_acl_log( - self, r: RedisCluster, request: FixtureRequest, create_redis: Callable + self, r: RedisCluster, create_redis: Callable[..., RedisCluster] ) -> None: key = "{cache}:" node = r.get_node_from_key(key) @@ -1963,7 +2007,6 @@ class TestClusterRedisCommands: await user_client.close() -@pytest.mark.onlycluster class TestNodesManager: """ Tests for the NodesManager class @@ -2095,7 +2138,7 @@ class TestNodesManager: specified """ with pytest.raises(RedisClusterException): - await NodesManager([]).initialize() + await NodesManager([], False, {}).initialize() async def test_wrong_startup_nodes_type(self) -> None: """ @@ -2103,11 +2146,9 @@ class TestNodesManager: fail """ with pytest.raises(RedisClusterException): - await NodesManager({}).initialize() + await NodesManager({}, False, {}).initialize() - async def test_init_slots_cache_slots_collision( - self, request: FixtureRequest - ) -> None: + async def test_init_slots_cache_slots_collision(self) -> None: """ Test that if 2 nodes do not agree on the same slots setup it should raise an error. In this test both nodes will say that the first @@ -2236,7 +2277,6 @@ class TestNodesManager: assert rc.get_node(host=default_host, port=7002) is not None -@pytest.mark.onlycluster class TestClusterPipeline: """Tests for the ClusterPipeline class.""" @@ -2476,3 +2516,124 @@ class TestClusterPipeline: executed_on_replica = True break assert executed_on_replica + + async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: + """Test that the pipeline can be used concurrently.""" + await asyncio.gather( + *(self.test_redis_cluster_pipeline(r) for i in range(100)), + *(self.test_multi_key_operation_with_a_single_slot(r) for i in range(100)), + *(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)), + ) + + +@pytest.mark.ssl +class TestSSL: + """ + Tests for SSL connections. + + This relies on the --redis-ssl-url for building the client and connecting to the + appropriate port. + """ + + ROOT = os.path.join(os.path.dirname(__file__), "../..") + CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) + if not os.path.isdir(CERT_DIR): # github actions package validation case + CERT_DIR = os.path.abspath( + os.path.join(ROOT, "..", "docker", "stunnel", "keys") + ) + if not os.path.isdir(CERT_DIR): + raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") + + SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") + SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + + @pytest_asyncio.fixture() + def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: + ssl_url = request.config.option.redis_ssl_url + ssl_host, ssl_port = urlparse(ssl_url)[1].split(":") + + async def _create_client(mocked: bool = True, **kwargs: Any) -> RedisCluster: + if mocked: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command_mock: + + async def execute_command(self, *args, **kwargs): + if args[0] == "INFO": + return {"cluster_enabled": True} + if args[0] == "CLUSTER SLOTS": + return [[0, 16383, [ssl_host, ssl_port, "ssl_node"]]] + if args[0] == "COMMAND": + return { + "ping": { + "name": "ping", + "arity": -1, + "flags": ["stale", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + } + } + raise NotImplementedError() + + execute_command_mock.side_effect = execute_command + + rc = await RedisCluster(host=ssl_host, port=ssl_port, **kwargs) + + assert len(rc.get_nodes()) == 1 + node = rc.get_default_node() + assert node.port == int(ssl_port) + return rc + + return await RedisCluster(host=ssl_host, port=ssl_port, **kwargs) + + return _create_client + + async def test_ssl_connection_without_ssl( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with pytest.raises(RedisClusterException) as e: + await create_client(mocked=False, ssl=False) + e = e.value.__cause__ + assert "Connection closed by server" in str(e) + + async def test_ssl_with_invalid_cert( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with pytest.raises(RedisClusterException) as e: + await create_client(mocked=False, ssl=True) + e = e.value.__cause__.__context__ + assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + + async def test_ssl_connection( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client(ssl=True, ssl_cert_reqs="none") as rc: + assert await rc.ping() + + async def test_validating_self_signed_certificate( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client( + ssl=True, + ssl_ca_certs=self.SERVER_CERT, + ssl_cert_reqs="required", + ssl_certfile=self.SERVER_CERT, + ssl_keyfile=self.SERVER_KEY, + ) as rc: + assert await rc.ping() + + async def test_validating_self_signed_string_certificate( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with open(self.SERVER_CERT) as f: + cert_data = f.read() + + async with await create_client( + ssl=True, + ssl_ca_data=cert_data, + ssl_cert_reqs="required", + ssl_certfile=self.SERVER_CERT, + ssl_keyfile=self.SERVER_KEY, + ) as rc: + assert await rc.ping() diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py new file mode 100644 index 0000000..0c011bd --- /dev/null +++ b/tests/test_asyncio/test_graph.py @@ -0,0 +1,503 @@ +import pytest + +import redis.asyncio as redis +from redis.commands.graph import Edge, Node, Path +from redis.commands.graph.execution_plan import Operation +from redis.exceptions import ResponseError +from tests.conftest import skip_if_redis_enterprise + + +@pytest.mark.redismod +async def test_bulk(modclient): + with pytest.raises(NotImplementedError): + await modclient.graph().bulk() + await modclient.graph().bulk(foo="bar!") + + +@pytest.mark.redismod +async def test_graph_creation(modclient: redis.Redis): + graph = modclient.graph() + + john = Node( + label="person", + properties={ + "name": "John Doe", + "age": 33, + "gender": "male", + "status": "single", + }, + ) + graph.add_node(john) + japan = Node(label="country", properties={"name": "Japan"}) + + graph.add_node(japan) + edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) + graph.add_edge(edge) + + await graph.commit() + + query = ( + 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) ' + "RETURN p, v, c" + ) + + result = await graph.query(query) + + person = result.result_set[0][0] + visit = result.result_set[0][1] + country = result.result_set[0][2] + + assert person == john + assert visit.properties == edge.properties + assert country == japan + + query = """RETURN [1, 2.3, "4", true, false, null]""" + result = await graph.query(query) + assert [1, 2.3, "4", True, False, None] == result.result_set[0][0] + + # All done, remove graph. + await graph.delete() + + +@pytest.mark.redismod +async def test_array_functions(modclient: redis.Redis): + graph = modclient.graph() + + query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" + await graph.query(query) + + query = """WITH [0,1,2] as x return x""" + result = await graph.query(query) + assert [0, 1, 2] == result.result_set[0][0] + + query = """MATCH(n) return collect(n)""" + result = await graph.query(query) + + a = Node( + node_id=0, + label="person", + properties={"name": "a", "age": 32, "array": [0, 1, 2]}, + ) + + assert [a] == result.result_set[0][0] + + +@pytest.mark.redismod +async def test_path(modclient: redis.Redis): + node0 = Node(node_id=0, label="L1") + node1 = Node(node_id=1, label="L1") + edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) + + graph = modclient.graph() + graph.add_node(node0) + graph.add_node(node1) + graph.add_edge(edge01) + await graph.flush() + + path01 = Path.new_empty_path().add_node(node0).add_edge(edge01).add_node(node1) + expected_results = [[path01]] + + query = "MATCH p=(:L1)-[:R1]->(:L1) RETURN p ORDER BY p" + result = await graph.query(query) + assert expected_results == result.result_set + + +@pytest.mark.redismod +async def test_param(modclient: redis.Redis): + params = [1, 2.3, "str", True, False, None, [0, 1, 2]] + query = "RETURN $param" + for param in params: + result = await modclient.graph().query(query, {"param": param}) + expected_results = [[param]] + assert expected_results == result.result_set + + +@pytest.mark.redismod +async def test_map(modclient: redis.Redis): + query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" + + actual = (await modclient.graph().query(query)).result_set[0][0] + expected = { + "a": 1, + "b": "str", + "c": None, + "d": [1, 2, 3], + "e": True, + "f": {"x": 1, "y": 2}, + } + + assert actual == expected + + +@pytest.mark.redismod +async def test_point(modclient: redis.Redis): + query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" + expected_lat = 32.070794860 + expected_lon = 34.820751118 + actual = (await modclient.graph().query(query)).result_set[0][0] + assert abs(actual["latitude"] - expected_lat) < 0.001 + assert abs(actual["longitude"] - expected_lon) < 0.001 + + query = "RETURN point({latitude: 32, longitude: 34.0})" + expected_lat = 32 + expected_lon = 34 + actual = (await modclient.graph().query(query)).result_set[0][0] + assert abs(actual["latitude"] - expected_lat) < 0.001 + assert abs(actual["longitude"] - expected_lon) < 0.001 + + +@pytest.mark.redismod +async def test_index_response(modclient: redis.Redis): + result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") + assert 1 == result_set.indices_created + + result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") + assert 0 == result_set.indices_created + + result_set = await modclient.graph().query("DROP INDEX ON :person(age)") + assert 1 == result_set.indices_deleted + + with pytest.raises(ResponseError): + await modclient.graph().query("DROP INDEX ON :person(age)") + + +@pytest.mark.redismod +async def test_stringify_query_result(modclient: redis.Redis): + graph = modclient.graph() + + john = Node( + alias="a", + label="person", + properties={ + "name": "John Doe", + "age": 33, + "gender": "male", + "status": "single", + }, + ) + graph.add_node(john) + + japan = Node(alias="b", label="country", properties={"name": "Japan"}) + graph.add_node(japan) + + edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) + graph.add_edge(edge) + + assert ( + str(john) + == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa + ) + assert ( + str(edge) + == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa + + """-[:visited{purpose:"pleasure"}]->""" + + """(b:country{name:"Japan"})""" + ) + assert str(japan) == """(b:country{name:"Japan"})""" + + await graph.commit() + + query = """MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) + RETURN p, v, c""" + + result = await graph.query(query) + person = result.result_set[0][0] + visit = result.result_set[0][1] + country = result.result_set[0][2] + + assert ( + str(person) + == """(:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa + ) + assert str(visit) == """()-[:visited{purpose:"pleasure"}]->()""" + assert str(country) == """(:country{name:"Japan"})""" + + await graph.delete() + + +@pytest.mark.redismod +async def test_optional_match(modclient: redis.Redis): + # Build a graph of form (a)-[R]->(b) + node0 = Node(node_id=0, label="L1", properties={"value": "a"}) + node1 = Node(node_id=1, label="L1", properties={"value": "b"}) + + edge01 = Edge(node0, "R", node1, edge_id=0) + + graph = modclient.graph() + graph.add_node(node0) + graph.add_node(node1) + graph.add_edge(edge01) + await graph.flush() + + # Issue a query that collects all outgoing edges from both nodes + # (the second has none) + query = """MATCH (a) OPTIONAL MATCH (a)-[e]->(b) RETURN a, e, b ORDER BY a.value""" # noqa + expected_results = [[node0, edge01, node1], [node1, None, None]] + + result = await graph.query(query) + assert expected_results == result.result_set + + await graph.delete() + + +@pytest.mark.redismod +async def test_cached_execution(modclient: redis.Redis): + await modclient.graph().query("CREATE ()") + + uncached_result = await modclient.graph().query( + "MATCH (n) RETURN n, $param", {"param": [0]} + ) + assert uncached_result.cached_execution is False + + # loop to make sure the query is cached on each thread on server + for x in range(0, 64): + cached_result = await modclient.graph().query( + "MATCH (n) RETURN n, $param", {"param": [0]} + ) + assert uncached_result.result_set == cached_result.result_set + + # should be cached on all threads by now + assert cached_result.cached_execution + + +@pytest.mark.redismod +async def test_slowlog(modclient: redis.Redis): + create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), + (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), + (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + await modclient.graph().query(create_query) + + results = await modclient.graph().slowlog() + assert results[0][1] == "GRAPH.QUERY" + assert results[0][2] == create_query + + +@pytest.mark.redismod +async def test_query_timeout(modclient: redis.Redis): + # Build a sample graph with 1000 nodes. + await modclient.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") + # Issue a long-running query with a 1-millisecond timeout. + with pytest.raises(ResponseError): + await modclient.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) + assert False is False + + with pytest.raises(Exception): + await modclient.graph().query("RETURN 1", timeout="str") + assert False is False + + +@pytest.mark.redismod +async def test_read_only_query(modclient: redis.Redis): + with pytest.raises(Exception): + # Issue a write query, specifying read-only true, + # this call should fail. + await modclient.graph().query("CREATE (p:person {name:'a'})", read_only=True) + assert False is False + + +@pytest.mark.redismod +async def test_profile(modclient: redis.Redis): + q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" + profile = (await modclient.graph().profile(q)).result_set + assert "Create | Records produced: 3" in profile + assert "Unwind | Records produced: 3" in profile + + q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" + profile = (await modclient.graph().profile(q)).result_set + assert "Results | Records produced: 2" in profile + assert "Project | Records produced: 2" in profile + assert "Filter | Records produced: 2" in profile + assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile + + +@pytest.mark.redismod +@skip_if_redis_enterprise() +async def test_config(modclient: redis.Redis): + config_name = "RESULTSET_SIZE" + config_value = 3 + + # Set configuration + response = await modclient.graph().config(config_name, config_value, set=True) + assert response == "OK" + + # Make sure config been updated. + response = await modclient.graph().config(config_name, set=False) + expected_response = [config_name, config_value] + assert response == expected_response + + config_name = "QUERY_MEM_CAPACITY" + config_value = 1 << 20 # 1MB + + # Set configuration + response = await modclient.graph().config(config_name, config_value, set=True) + assert response == "OK" + + # Make sure config been updated. + response = await modclient.graph().config(config_name, set=False) + expected_response = [config_name, config_value] + assert response == expected_response + + # reset to default + await modclient.graph().config("QUERY_MEM_CAPACITY", 0, set=True) + await modclient.graph().config("RESULTSET_SIZE", -100, set=True) + + +@pytest.mark.redismod +@pytest.mark.onlynoncluster +async def test_list_keys(modclient: redis.Redis): + result = await modclient.graph().list_keys() + assert result == [] + + await modclient.graph("G").query("CREATE (n)") + result = await modclient.graph().list_keys() + assert result == ["G"] + + await modclient.graph("X").query("CREATE (m)") + result = await modclient.graph().list_keys() + assert result == ["G", "X"] + + await modclient.delete("G") + await modclient.rename("X", "Z") + result = await modclient.graph().list_keys() + assert result == ["Z"] + + await modclient.delete("Z") + result = await modclient.graph().list_keys() + assert result == [] + + +@pytest.mark.redismod +async def test_multi_label(modclient: redis.Redis): + redis_graph = modclient.graph("g") + + node = Node(label=["l", "ll"]) + redis_graph.add_node(node) + await redis_graph.commit() + + query = "MATCH (n) RETURN n" + result = await redis_graph.query(query) + result_node = result.result_set[0][0] + assert result_node == node + + try: + Node(label=1) + assert False + except AssertionError: + assert True + + try: + Node(label=["l", 1]) + assert False + except AssertionError: + assert True + + +@pytest.mark.redismod +async def test_execution_plan(modclient: redis.Redis): + redis_graph = modclient.graph("execution_plan") + create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), + (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), + (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + await redis_graph.query(create_query) + + result = await redis_graph.execution_plan( + "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa + {"name": "Yehuda"}, + ) + expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa + assert result == expected + + await redis_graph.delete() + + +@pytest.mark.redismod +async def test_explain(modclient: redis.Redis): + redis_graph = modclient.graph("execution_plan") + # graph creation / population + create_query = """CREATE +(:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), +(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), +(:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + await redis_graph.query(create_query) + + result = await redis_graph.explain( + """MATCH (r:Rider)-[:rides]->(t:Team) +WHERE t.name = $name +RETURN r.name, t.name +UNION +MATCH (r:Rider)-[:rides]->(t:Team) +WHERE t.name = $name +RETURN r.name, t.name""", + {"name": "Yamaha"}, + ) + expected = """\ +Results +Distinct + Join + Project + Conditional Traverse | (t:Team)->(r:Rider) + Filter + Node By Label Scan | (t:Team) + Project + Conditional Traverse | (t:Team)->(r:Rider) + Filter + Node By Label Scan | (t:Team)""" + assert str(result).replace(" ", "").replace("\n", "") == expected.replace( + " ", "" + ).replace("\n", "") + + expected = Operation("Results").append_child( + Operation("Distinct").append_child( + Operation("Join") + .append_child( + Operation("Project").append_child( + Operation( + "Conditional Traverse", "(t:Team)->(r:Rider)" + ).append_child( + Operation("Filter").append_child( + Operation("Node By Label Scan", "(t:Team)") + ) + ) + ) + ) + .append_child( + Operation("Project").append_child( + Operation( + "Conditional Traverse", "(t:Team)->(r:Rider)" + ).append_child( + Operation("Filter").append_child( + Operation("Node By Label Scan", "(t:Team)") + ) + ) + ) + ) + ) + ) + + assert result.structured_plan == expected + + result = await redis_graph.explain( + """MATCH (r:Rider), (t:Team) + RETURN r.name, t.name""" + ) + expected = """\ +Results +Project + Cartesian Product + Node By Label Scan | (r:Rider) + Node By Label Scan | (t:Team)""" + assert str(result).replace(" ", "").replace("\n", "") == expected.replace( + " ", "" + ).replace("\n", "") + + expected = Operation("Results").append_child( + Operation("Project").append_child( + Operation("Cartesian Product") + .append_child(Operation("Node By Label Scan")) + .append_child(Operation("Node By Label Scan")) + ) + ) + + assert result.structured_plan == expected + + await redis_graph.delete() diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index a045dd7..b8854d2 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -5,8 +5,6 @@ from redis import exceptions from redis.commands.json.path import Path from tests.conftest import skip_ifmodversion_lt -pytestmark = pytest.mark.asyncio - @pytest.mark.redismod async def test_json_setbinarykey(modclient: redis.Redis): @@ -819,7 +817,7 @@ async def test_objlen_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().objlen("doc1", "$..a") == [2, None, 1] + assert await modclient.json().objlen("doc1", "$..a") == [None, 2, 1] # Test single assert await modclient.json().objlen("doc1", "$.nested1.a") == [2] diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 5aaa56f..88c80d3 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -16,10 +16,7 @@ from redis.commands.search.indexDefinition import IndexDefinition from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from tests.conftest import skip_ifmodversion_lt - -pytestmark = pytest.mark.asyncio - +from tests.conftest import skip_if_redis_enterprise, skip_ifmodversion_lt WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -88,7 +85,7 @@ async def createIndex(modclient, num_docs=100, definition=None): assert 50 == indexer.chunk_size for key, doc in chapters.items(): - await indexer.add_document(key, **doc) + await indexer.client.client.hset(key, mapping=doc) await indexer.commit() @@ -192,7 +189,7 @@ async def test_client(modclient: redis.Redis): assert 167 == (await modclient.ft().search(Query("henry king").slop(100))).total # test delete document - await modclient.ft().add_document("doc-5ghs2", play="Death of a Salesman") + await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) res = await modclient.ft().search(Query("death of a salesman")) assert 1 == res.total @@ -201,36 +198,19 @@ async def test_client(modclient: redis.Redis): assert 0 == res.total assert 0 == await modclient.ft().delete_document("doc-5ghs2") - await modclient.ft().add_document("doc-5ghs2", play="Death of a Salesman") + await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) res = await modclient.ft().search(Query("death of a salesman")) assert 1 == res.total await modclient.ft().delete_document("doc-5ghs2") -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -async def test_payloads(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"),)) - - await modclient.ft().add_document("doc1", payload="foo baz", txt="foo bar") - await modclient.ft().add_document("doc2", txt="foo bar") - - q = Query("foo bar").with_payloads() - res = await modclient.ft().search(q) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "foo baz" == res.docs[0].payload - assert res.docs[1].payload is None - - @pytest.mark.redismod @pytest.mark.onlynoncluster async def test_scores(modclient: redis.Redis): await modclient.ft().create_index((TextField("txt"),)) - await modclient.ft().add_document("doc1", txt="foo baz") - await modclient.ft().add_document("doc2", txt="foo bar") + await modclient.hset("doc1", mapping={"txt": "foo baz"}) + await modclient.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo ~bar").with_scores() res = await modclient.ft().search(q) @@ -242,35 +222,12 @@ async def test_scores(modclient: redis.Redis): # self.assertEqual(0.2, res.docs[1].score) -@pytest.mark.redismod -async def test_replace(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"),)) - - await modclient.ft().add_document("doc1", txt="foo bar") - await modclient.ft().add_document("doc2", txt="foo bar") - await waitForIndex(modclient, "idx") - - res = await modclient.ft().search("foo bar") - assert 2 == res.total - await ( - modclient.ft().add_document("doc1", replace=True, txt="this is a replaced doc") - ) - - res = await modclient.ft().search("foo bar") - assert 1 == res.total - assert "doc2" == res.docs[0].id - - res = await modclient.ft().search("replaced doc") - assert 1 == res.total - assert "doc1" == res.docs[0].id - - @pytest.mark.redismod async def test_stopwords(modclient: redis.Redis): stopwords = ["foo", "bar", "baz"] await modclient.ft().create_index((TextField("txt"),), stopwords=stopwords) - await modclient.ft().add_document("doc1", txt="foo bar") - await modclient.ft().add_document("doc2", txt="hello world") + await modclient.hset("doc1", mapping={"txt": "foo bar"}) + await modclient.hset("doc2", mapping={"txt": "hello world"}) await waitForIndex(modclient, "idx") q1 = Query("foo bar").no_content() @@ -288,11 +245,13 @@ async def test_filters(modclient: redis.Redis): ) ) await ( - modclient.ft().add_document( - "doc1", txt="foo bar", num=3.141, loc="-0.441,51.458" + modclient.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} ) ) - await modclient.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") + await ( + modclient.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + ) await waitForIndex(modclient, "idx") # Test numerical filter @@ -324,17 +283,6 @@ async def test_filters(modclient: redis.Redis): assert ["doc1", "doc2"] == res -@pytest.mark.redismod -async def test_payloads_with_no_content(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"),)) - await modclient.ft().add_document("doc1", payload="foo baz", txt="foo bar") - await modclient.ft().add_document("doc2", payload="foo baz2", txt="foo bar") - - q = Query("foo bar").with_payloads().no_content() - res = await modclient.ft().search(q) - assert 2 == len(res.docs) - - @pytest.mark.redismod async def test_sort_by(modclient: redis.Redis): await ( @@ -342,9 +290,9 @@ async def test_sort_by(modclient: redis.Redis): (TextField("txt"), NumericField("num", sortable=True)) ) ) - await modclient.ft().add_document("doc1", txt="foo bar", num=1) - await modclient.ft().add_document("doc2", txt="foo baz", num=2) - await modclient.ft().add_document("doc3", txt="foo qux", num=3) + await modclient.hset("doc1", mapping={"txt": "foo bar", "num": 1}) + await modclient.hset("doc2", mapping={"txt": "foo baz", "num": 2}) + await modclient.hset("doc3", mapping={"txt": "foo qux", "num": 3}) # Test sort q1 = Query("foo").sort_by("num", asc=True).no_content() @@ -388,10 +336,12 @@ async def test_example(modclient: redis.Redis): ) # Indexing a document - await modclient.ft().add_document( + await modclient.hset( "doc1", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + }, ) # Searching with complex parameters: @@ -464,11 +414,13 @@ async def test_no_index(modclient: redis.Redis): ) ) - await modclient.ft().add_document( - "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" + await modclient.hset( + "doc1", + mapping={"field": "aaa", "text": "1", "numeric": "1", "geo": "1,1", "tag": "1"}, ) - await modclient.ft().add_document( - "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" + await modclient.hset( + "doc2", + mapping={"field": "aab", "text": "2", "numeric": "2", "geo": "2,2", "tag": "2"}, ) await waitForIndex(modclient, "idx") @@ -505,53 +457,6 @@ async def test_no_index(modclient: redis.Redis): TagField("name", no_index=True, sortable=False) -@pytest.mark.redismod -async def test_partial(modclient: redis.Redis): - await ( - modclient.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - ) - await modclient.ft().add_document("doc1", f1="f1_val", f2="f2_val") - await modclient.ft().add_document("doc2", f1="f1_val", f2="f2_val") - await modclient.ft().add_document("doc1", f3="f3_val", partial=True) - await modclient.ft().add_document("doc2", f3="f3_val", replace=True) - await waitForIndex(modclient, "idx") - - # Search for f3 value. All documents should have it - res = await modclient.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = await modclient.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - -@pytest.mark.redismod -async def test_no_create(modclient: redis.Redis): - await ( - modclient.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - ) - await modclient.ft().add_document("doc1", f1="f1_val", f2="f2_val") - await modclient.ft().add_document("doc2", f1="f1_val", f2="f2_val") - await modclient.ft().add_document("doc1", f3="f3_val", no_create=True) - await modclient.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True) - await waitForIndex(modclient, "idx") - - # Search for f3 value. All documents should have it - res = await modclient.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = await modclient.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - with pytest.raises(redis.ResponseError): - await ( - modclient.ft().add_document( - "doc3", f2="f2_val", f3="f3_val", no_create=True - ) - ) - - @pytest.mark.redismod async def test_explain(modclient: redis.Redis): await ( @@ -643,11 +548,11 @@ async def test_alias_basic(modclient: redis.Redis): index1 = getClient(modclient).ft("testAlias") await index1.create_index((TextField("txt"),)) - await index1.add_document("doc1", txt="text goes here") + await index1.client.hset("doc1", mapping={"txt": "text goes here"}) index2 = getClient(modclient).ft("testAlias2") await index2.create_index((TextField("txt"),)) - await index2.add_document("doc2", txt="text goes here") + await index2.client.hset("doc2", mapping={"txt": "text goes here"}) # add the actual alias and check await index1.aliasadd("myalias") @@ -677,8 +582,8 @@ async def test_tags(modclient: redis.Redis): tags = "foo,foo bar,hello;world" tags2 = "soba,ramen" - await modclient.ft().add_document("doc1", txt="fooz barz", tags=tags) - await modclient.ft().add_document("doc2", txt="noodles", tags=tags2) + await modclient.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) + await modclient.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) await waitForIndex(modclient, "idx") q = Query("@tags:{foo}") @@ -721,8 +626,8 @@ async def test_alter_schema_add(modclient: redis.Redis): await modclient.ft().alter_schema_add(TextField("body")) # Indexing a document - await modclient.ft().add_document( - "doc1", title="MyTitle", body="Some content only in the body" + await modclient.hset( + "doc1", mapping={"title": "MyTitle", "body": "Some content only in the body"} ) # Searching with parameter only in the body (the added field) @@ -738,11 +643,11 @@ async def test_spell_check(modclient: redis.Redis): await modclient.ft().create_index((TextField("f1"), TextField("f2"))) await ( - modclient.ft().add_document( - "doc1", f1="some valid content", f2="this is sample text" + modclient.hset( + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) ) - await modclient.ft().add_document("doc2", f1="very important", f2="lorem ipsum") + await modclient.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) await waitForIndex(modclient, "idx") # test spellcheck @@ -796,8 +701,8 @@ async def test_dict_operations(modclient: redis.Redis): @pytest.mark.redismod async def test_phonetic_matcher(modclient: redis.Redis): await modclient.ft().create_index((TextField("name"),)) - await modclient.ft().add_document("doc1", name="Jon") - await modclient.ft().add_document("doc2", name="John") + await modclient.hset("doc1", mapping={"name": "Jon"}) + await modclient.hset("doc2", mapping={"name": "John"}) res = await modclient.ft().search(Query("Jon")) assert 1 == len(res.docs) @@ -807,8 +712,8 @@ async def test_phonetic_matcher(modclient: redis.Redis): await modclient.flushdb() await modclient.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - await modclient.ft().add_document("doc1", name="Jon") - await modclient.ft().add_document("doc2", name="John") + await modclient.hset("doc1", mapping={"name": "Jon"}) + await modclient.hset("doc2", mapping={"name": "John"}) res = await modclient.ft().search(Query("Jon")) assert 2 == len(res.docs) @@ -820,12 +725,14 @@ async def test_phonetic_matcher(modclient: redis.Redis): async def test_scorer(modclient: redis.Redis): await modclient.ft().create_index((TextField("description"),)) - await modclient.ft().add_document( - "doc1", description="The quick brown fox jumps over the lazy dog" + await modclient.hset( + "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} ) - await modclient.ft().add_document( + await modclient.hset( "doc2", - description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, ) # default scorer is TFIDF @@ -854,19 +761,19 @@ async def test_get(modclient: redis.Redis): assert [None] == await modclient.ft().get("doc1") assert [None, None] == await modclient.ft().get("doc2", "doc1") - await modclient.ft().add_document( - "doc1", f1="some valid content dd1", f2="this is sample text ff1" + await modclient.hset( + "doc1", mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"} ) - await modclient.ft().add_document( - "doc2", f1="some valid content dd2", f2="this is sample text ff2" + await modclient.hset( + "doc2", mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"} ) assert [ - ["f1", "some valid content dd2", "f2", "this is sample text ff2"] + ["f1", "some valid content dd2", "f2", "this is sample text f2"] ] == await modclient.ft().get("doc2") assert [ - ["f1", "some valid content dd1", "f2", "this is sample text ff1"], - ["f1", "some valid content dd2", "f2", "this is sample text ff2"], + ["f1", "some valid content dd1", "f2", "this is sample text f1"], + ["f1", "some valid content dd2", "f2", "this is sample text f2"], ] == await modclient.ft().get("doc1", "doc2") @@ -897,26 +804,32 @@ async def test_aggregations_groupby(modclient: redis.Redis): ) # Indexing a document - await modclient.ft().add_document( + await modclient.hset( "search", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - parent="redis", - random_num=10, - ) - await modclient.ft().add_document( + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + await modclient.hset( "ai", - title="RedisAI", - body="RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa - parent="redis", - random_num=3, - ) - await modclient.ft().add_document( + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + await modclient.hset( "json", - title="RedisJson", - body="RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa - parent="redis", - random_num=8, + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, ) req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) @@ -995,7 +908,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): res = (await modclient.ft().aggregate(req)).rows[0] assert res[1] == "redis" - assert res[3] == ["RediSearch", "RedisAI", "RedisJson"] + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} req = aggregations.AggregateRequest("redis").group_by( "@parent", reducers.first_value("@title").alias("first") @@ -1046,3 +959,45 @@ async def test_aggregations_sort_by_and_limit(modclient: redis.Redis): res = await modclient.ft().aggregate(req) assert len(res.rows) == 1 assert res.rows[0] == ["t1", "b"] + + +@pytest.mark.redismod +@pytest.mark.experimental +async def test_withsuffixtrie(modclient: redis.Redis): + # create index + assert await modclient.ft().create_index((TextField("txt"),)) + await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert await modclient.ft().dropindex("idx") + + # create withsuffixtrie index (text field) + assert await modclient.ft().create_index((TextField("t", withsuffixtrie=True))) + await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await modclient.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await modclient.ft().create_index((TagField("t", withsuffixtrie=True))) + await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + + +@pytest.mark.redismod +@skip_if_redis_enterprise() +async def test_search_commands_in_pipeline(modclient: redis.Redis): + p = await modclient.ft().pipeline() + p.create_index((TextField("txt"),)) + p.hset("doc1", mapping={"txt": "foo bar"}) + p.hset("doc2", mapping={"txt": "foo bar"}) + q = Query("foo bar").with_payloads() + await p.search(q) + res = await p.execute() + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index ac2807f..a710993 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -6,8 +6,6 @@ import pytest import redis.asyncio as redis from tests.conftest import skip_ifmodversion_lt -pytestmark = pytest.mark.asyncio - @pytest.mark.redismod async def test_create(modclient: redis.Redis): @@ -242,6 +240,9 @@ async def test_range_advanced(modclient: redis.Redis): assert [(0, 5.0), (5, 6.0)] == await modclient.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) + assert [(0, 2.55), (10, 3.0)] == await modclient.ts().range( + 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + ) @pytest.mark.redismod diff --git a/tests/test_bloom.py b/tests/test_bloom.py index 1f8201c..13921bf 100644 --- a/tests/test_bloom.py +++ b/tests/test_bloom.py @@ -1,9 +1,13 @@ +from math import inf + import pytest import redis.commands.bf from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE +from .conftest import skip_ifmodversion_lt + def intlist(obj): return [int(v) for v in obj] @@ -36,6 +40,21 @@ def test_create(client): assert client.topk().reserve("topk", 5, 100, 5, 0.9) +@pytest.mark.redismod +def test_bf_reserve(client): + """Testing BF.RESERVE""" + assert client.bf().reserve("bloom", 0.01, 1000) + assert client.bf().reserve("bloom_e", 0.01, 1000, expansion=1) + assert client.bf().reserve("bloom_ns", 0.01, 1000, noScale=True) + assert client.cf().reserve("cuckoo", 1000) + assert client.cf().reserve("cuckoo_e", 1000, expansion=1) + assert client.cf().reserve("cuckoo_bs", 1000, bucket_size=4) + assert client.cf().reserve("cuckoo_mi", 1000, max_iterations=10) + assert client.cms().initbydim("cmsDim", 100, 5) + assert client.cms().initbyprob("cmsProb", 0.01, 0.01) + assert client.topk().reserve("topk", 5, 100, 5, 0.9) + + @pytest.mark.redismod @pytest.mark.experimental def test_tdigest_create(client): @@ -263,9 +282,10 @@ def test_topk(client): assert [1, 1, 0, 0, 1, 0, 0] == client.topk().query( "topk", "A", "B", "C", "D", "E", "F", "G" ) - assert [4, 3, 2, 3, 3, 0, 1] == client.topk().count( - "topk", "A", "B", "C", "D", "E", "F", "G" - ) + with pytest.deprecated_call(): + assert [4, 3, 2, 3, 3, 0, 1] == client.topk().count( + "topk", "A", "B", "C", "D", "E", "F", "G" + ) # test full list assert client.topk().reserve("topklist", 3, 50, 3, 0.9) @@ -305,9 +325,10 @@ def test_topk_incrby(client): "topk", ["bar", "baz", "42"], [3, 6, 2] ) assert [None, "bar"] == client.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) - assert [3, 6, 10, 4, 0] == client.topk().count( - "topk", "bar", "baz", "42", "xyzzy", 4 - ) + with pytest.deprecated_call(): + assert [3, 6, 10, 4, 0] == client.topk().count( + "topk", "bar", "baz", "42", "xyzzy", 4 + ) # region Test T-Digest @@ -318,11 +339,11 @@ def test_tdigest_reset(client): # reset on empty histogram assert client.tdigest().reset("tDigest") # insert data-points into sketch - assert client.tdigest().add("tDigest", list(range(10)), [1.0] * 10) + assert client.tdigest().add("tDigest", list(range(10))) assert client.tdigest().reset("tDigest") # assert we have 0 unmerged nodes - assert 0 == client.tdigest().info("tDigest").unmergedNodes + assert 0 == client.tdigest().info("tDigest").unmerged_nodes @pytest.mark.redismod @@ -331,14 +352,24 @@ def test_tdigest_merge(client): assert client.tdigest().create("to-tDigest", 10) assert client.tdigest().create("from-tDigest", 10) # insert data-points into sketch - assert client.tdigest().add("from-tDigest", [1.0] * 10, [1.0] * 10) - assert client.tdigest().add("to-tDigest", [2.0] * 10, [10.0] * 10) + assert client.tdigest().add("from-tDigest", [1.0] * 10) + assert client.tdigest().add("to-tDigest", [2.0] * 10) # merge from-tdigest into to-tdigest - assert client.tdigest().merge("to-tDigest", "from-tDigest") + assert client.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram info = client.tdigest().info("to-tDigest") - total_weight_to = float(info.mergedWeight) + float(info.unmergedWeight) - assert 110 == total_weight_to + total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) + assert 20 == total_weight_to + # test override + assert client.tdigest().create("from-override", 10) + assert client.tdigest().create("from-override-2", 10) + assert client.tdigest().add("from-override", [3.0] * 10) + assert client.tdigest().add("from-override-2", [4.0] * 10) + assert client.tdigest().merge( + "to-tDigest", 2, "from-override", "from-override-2", override=True + ) + assert 3.0 == client.tdigest().min("to-tDigest") + assert 4.0 == client.tdigest().max("to-tDigest") @pytest.mark.redismod @@ -346,7 +377,7 @@ def test_tdigest_merge(client): def test_tdigest_min_and_max(client): assert client.tdigest().create("tDigest", 100) # insert data-points into sketch - assert client.tdigest().add("tDigest", [1, 2, 3], [1.0] * 3) + assert client.tdigest().add("tDigest", [1, 2, 3]) # min/max assert 3 == client.tdigest().max("tDigest") assert 1 == client.tdigest().min("tDigest") @@ -354,18 +385,24 @@ def test_tdigest_min_and_max(client): @pytest.mark.redismod @pytest.mark.experimental +@skip_ifmodversion_lt("2.4.0", "bf") def test_tdigest_quantile(client): assert client.tdigest().create("tDigest", 500) # insert data-points into sketch - assert client.tdigest().add( - "tDigest", list([x * 0.01 for x in range(1, 10000)]), [1.0] * 10000 - ) + assert client.tdigest().add("tDigest", list([x * 0.01 for x in range(1, 10000)])) # assert min min/max have same result as quantile 0 and 1 - assert client.tdigest().max("tDigest") == client.tdigest().quantile("tDigest", 1.0) - assert client.tdigest().min("tDigest") == client.tdigest().quantile("tDigest", 0.0) + res = client.tdigest().quantile("tDigest", 1.0) + assert client.tdigest().max("tDigest") == res[0] + res = client.tdigest().quantile("tDigest", 0.0) + assert client.tdigest().min("tDigest") == res[0] - assert 1.0 == round(client.tdigest().quantile("tDigest", 0.01), 2) - assert 99.0 == round(client.tdigest().quantile("tDigest", 0.99), 2) + assert 1.0 == round(client.tdigest().quantile("tDigest", 0.01)[0], 2) + assert 99.0 == round(client.tdigest().quantile("tDigest", 0.99)[0], 2) + + # test multiple quantiles + assert client.tdigest().create("t-digest", 100) + assert client.tdigest().add("t-digest", [1, 2, 3, 4, 5]) + assert [3.0, 5.0] == client.tdigest().quantile("t-digest", 0.5, 0.8) @pytest.mark.redismod @@ -373,9 +410,67 @@ def test_tdigest_quantile(client): def test_tdigest_cdf(client): assert client.tdigest().create("tDigest", 100) # insert data-points into sketch - assert client.tdigest().add("tDigest", list(range(1, 10)), [1.0] * 10) - assert 0.1 == round(client.tdigest().cdf("tDigest", 1.0), 1) - assert 0.9 == round(client.tdigest().cdf("tDigest", 9.0), 1) + assert client.tdigest().add("tDigest", list(range(1, 10))) + assert 0.1 == round(client.tdigest().cdf("tDigest", 1.0)[0], 1) + assert 0.9 == round(client.tdigest().cdf("tDigest", 9.0)[0], 1) + res = client.tdigest().cdf("tDigest", 1.0, 9.0) + assert [0.1, 0.9] == [round(x, 1) for x in res] + + +@pytest.mark.redismod +@pytest.mark.experimental +@skip_ifmodversion_lt("2.4.0", "bf") +def test_tdigest_trimmed_mean(client): + assert client.tdigest().create("tDigest", 100) + # insert data-points into sketch + assert client.tdigest().add("tDigest", list(range(1, 10))) + assert 5 == client.tdigest().trimmed_mean("tDigest", 0.1, 0.9) + assert 4.5 == client.tdigest().trimmed_mean("tDigest", 0.4, 0.5) + + +@pytest.mark.redismod +@pytest.mark.experimental +def test_tdigest_rank(client): + assert client.tdigest().create("t-digest", 500) + assert client.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == client.tdigest().rank("t-digest", -1)[0] + assert 0 == client.tdigest().rank("t-digest", 0)[0] + assert 10 == client.tdigest().rank("t-digest", 10)[0] + assert [-1, 20, 9] == client.tdigest().rank("t-digest", -20, 20, 9) + + +@pytest.mark.redismod +@pytest.mark.experimental +def test_tdigest_revrank(client): + assert client.tdigest().create("t-digest", 500) + assert client.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == client.tdigest().revrank("t-digest", 20)[0] + assert 19 == client.tdigest().revrank("t-digest", 0)[0] + assert [-1, 19, 9] == client.tdigest().revrank("t-digest", 21, 0, 10) + + +@pytest.mark.redismod +@pytest.mark.experimental +def test_tdigest_byrank(client): + assert client.tdigest().create("t-digest", 500) + assert client.tdigest().add("t-digest", list(range(1, 11))) + assert 1 == client.tdigest().byrank("t-digest", 0)[0] + assert 10 == client.tdigest().byrank("t-digest", 9)[0] + assert client.tdigest().byrank("t-digest", 100)[0] == inf + with pytest.raises(redis.ResponseError): + client.tdigest().byrank("t-digest", -1)[0] + + +@pytest.mark.redismod +@pytest.mark.experimental +def test_tdigest_byrevrank(client): + assert client.tdigest().create("t-digest", 500) + assert client.tdigest().add("t-digest", list(range(1, 11))) + assert 10 == client.tdigest().byrevrank("t-digest", 0)[0] + assert 1 == client.tdigest().byrevrank("t-digest", 9)[0] + assert client.tdigest().byrevrank("t-digest", 100)[0] == -inf + with pytest.raises(redis.ResponseError): + client.tdigest().byrevrank("t-digest", -1)[0] # @pytest.mark.redismod diff --git a/tests/test_graph.py b/tests/test_graph.py index 76f8794..a46c336 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,24 @@ +from unittest.mock import patch + import pytest from redis.commands.graph import Edge, Node, Path from redis.commands.graph.execution_plan import Operation +from redis.commands.graph.query_result import ( + CACHED_EXECUTION, + INDICES_CREATED, + INDICES_DELETED, + INTERNAL_EXECUTION_TIME, + LABELS_ADDED, + LABELS_REMOVED, + NODES_CREATED, + NODES_DELETED, + PROPERTIES_REMOVED, + PROPERTIES_SET, + RELATIONSHIPS_CREATED, + RELATIONSHIPS_DELETED, + QueryResult, +) from redis.exceptions import ResponseError from tests.conftest import skip_if_redis_enterprise @@ -349,11 +366,11 @@ def test_list_keys(client): result = client.graph().list_keys() assert result == [] - client.execute_command("GRAPH.EXPLAIN", "G", "RETURN 1") + client.graph("G").query("CREATE (n)") result = client.graph().list_keys() assert result == ["G"] - client.execute_command("GRAPH.EXPLAIN", "X", "RETURN 1") + client.graph("X").query("CREATE (m)") result = client.graph().list_keys() assert result == ["G", "X"] @@ -575,3 +592,33 @@ Project assert result.structured_plan == expected redis_graph.delete() + + +@pytest.mark.redismod +def test_resultset_statistics(client): + with patch.object(target=QueryResult, attribute="_get_stat") as mock_get_stats: + result = client.graph().query("RETURN 1") + result.labels_added + mock_get_stats.assert_called_with(LABELS_ADDED) + result.labels_removed + mock_get_stats.assert_called_with(LABELS_REMOVED) + result.nodes_created + mock_get_stats.assert_called_with(NODES_CREATED) + result.nodes_deleted + mock_get_stats.assert_called_with(NODES_DELETED) + result.properties_set + mock_get_stats.assert_called_with(PROPERTIES_SET) + result.properties_removed + mock_get_stats.assert_called_with(PROPERTIES_REMOVED) + result.relationships_created + mock_get_stats.assert_called_with(RELATIONSHIPS_CREATED) + result.relationships_deleted + mock_get_stats.assert_called_with(RELATIONSHIPS_DELETED) + result.indices_created + mock_get_stats.assert_called_with(INDICES_CREATED) + result.indices_deleted + mock_get_stats.assert_called_with(INDICES_DELETED) + result.cached_execution + mock_get_stats.assert_called_with(CACHED_EXECUTION) + result.run_time_ms + mock_get_stats.assert_called_with(INTERNAL_EXECUTION_TIME) diff --git a/tests/test_json.py b/tests/test_json.py index 1cc448c..676683d 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -824,7 +824,7 @@ def test_objlen_dollar(client): }, ) # Test multi - assert client.json().objlen("doc1", "$..a") == [2, None, 1] + assert client.json().objlen("doc1", "$..a") == [None, 2, 1] # Test single assert client.json().objlen("doc1", "$.nested1.a") == [2] @@ -1411,7 +1411,8 @@ def test_set_path(client): with open(jsonfile, "w+") as fp: fp.write(json.dumps({"hello": "world"})) - open(nojsonfile, "a+").write("hello") + with open(nojsonfile, "a+") as fp: + fp.write("hello") result = {jsonfile: True, nojsonfile: False} assert client.json().set_path(Path.root_path(), root) == result diff --git a/tests/test_search.py b/tests/test_search.py index f0a1190..08c76f7 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -93,7 +93,7 @@ def createIndex(client, num_docs=100, definition=None): assert 50 == indexer.chunk_size for key, doc in chapters.items(): - indexer.add_document(key, **doc) + indexer.client.client.hset(key, mapping=doc) indexer.commit() @@ -196,7 +196,7 @@ def test_client(client): assert 167 == client.ft().search(Query("henry king").slop(100)).total # test delete document - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) res = client.ft().search(Query("death of a salesman")) assert 1 == res.total @@ -205,36 +205,19 @@ def test_client(client): assert 0 == res.total assert 0 == client.ft().delete_document("doc-5ghs2") - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) res = client.ft().search(Query("death of a salesman")) assert 1 == res.total client.ft().delete_document("doc-5ghs2") -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_payloads(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo bar").with_payloads() - res = client.ft().search(q) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "foo baz" == res.docs[0].payload - assert res.docs[1].payload is None - - @pytest.mark.redismod @pytest.mark.onlynoncluster def test_scores(client): client.ft().create_index((TextField("txt"),)) - client.ft().add_document("doc1", txt="foo baz") - client.ft().add_document("doc2", txt="foo bar") + client.hset("doc1", mapping={"txt": "foo baz"}) + client.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo ~bar").with_scores() res = client.ft().search(q) @@ -246,32 +229,11 @@ def test_scores(client): # self.assertEqual(0.2, res.docs[1].score) -@pytest.mark.redismod -def test_replace(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - - res = client.ft().search("foo bar") - assert 2 == res.total - client.ft().add_document("doc1", replace=True, txt="this is a replaced doc") - - res = client.ft().search("foo bar") - assert 1 == res.total - assert "doc2" == res.docs[0].id - - res = client.ft().search("replaced doc") - assert 1 == res.total - assert "doc1" == res.docs[0].id - - @pytest.mark.redismod def test_stopwords(client): client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="hello world") + client.hset("doc1", mapping={"txt": "foo bar"}) + client.hset("doc2", mapping={"txt": "hello world"}) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) q1 = Query("foo bar").no_content() @@ -284,8 +246,10 @@ def test_stopwords(client): @pytest.mark.redismod def test_filters(client): client.ft().create_index((TextField("txt"), NumericField("num"), GeoField("loc"))) - client.ft().add_document("doc1", txt="foo bar", num=3.141, loc="-0.441,51.458") - client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") + client.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} + ) + client.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # Test numerical filter @@ -317,23 +281,12 @@ def test_filters(client): assert ["doc1", "doc2"] == res -@pytest.mark.redismod -def test_payloads_with_no_content(client): - client.ft().create_index((TextField("txt"),)) - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", payload="foo baz2", txt="foo bar") - - q = Query("foo bar").with_payloads().no_content() - res = client.ft().search(q) - assert 2 == len(res.docs) - - @pytest.mark.redismod def test_sort_by(client): client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) - client.ft().add_document("doc1", txt="foo bar", num=1) - client.ft().add_document("doc2", txt="foo baz", num=2) - client.ft().add_document("doc3", txt="foo qux", num=3) + client.hset("doc1", mapping={"txt": "foo bar", "num": 1}) + client.hset("doc2", mapping={"txt": "foo baz", "num": 2}) + client.hset("doc3", mapping={"txt": "foo qux", "num": 3}) # Test sort q1 = Query("foo").sort_by("num", asc=True).no_content() @@ -375,10 +328,12 @@ def test_example(client): client.ft().create_index((TextField("title", weight=5.0), TextField("body"))) # Indexing a document - client.ft().add_document( + client.hset( "doc1", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + }, ) # Searching with complex parameters: @@ -450,11 +405,13 @@ def test_no_index(client): ) ) - client.ft().add_document( - "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" + client.hset( + "doc1", + mapping={"field": "aaa", "text": "1", "numeric": "1", "geo": "1,1", "tag": "1"}, ) - client.ft().add_document( - "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" + client.hset( + "doc2", + mapping={"field": "aab", "text": "2", "numeric": "2", "geo": "2,2", "tag": "2"}, ) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) @@ -491,45 +448,6 @@ def test_no_index(client): TagField("name", no_index=True, sortable=False) -@pytest.mark.redismod -def test_partial(client): - client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", partial=True) - client.ft().add_document("doc2", f3="f3_val", replace=True) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - -@pytest.mark.redismod -def test_no_create(client): - client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", no_create=True) - client.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - with pytest.raises(redis.ResponseError): - client.ft().add_document("doc3", f2="f2_val", f3="f3_val", no_create=True) - - @pytest.mark.redismod def test_explain(client): client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) @@ -618,11 +536,11 @@ def test_alias_basic(client): index1 = getClient(client).ft("testAlias") index1.create_index((TextField("txt"),)) - index1.add_document("doc1", txt="text goes here") + index1.client.hset("doc1", mapping={"txt": "text goes here"}) index2 = getClient(client).ft("testAlias2") index2.create_index((TextField("txt"),)) - index2.add_document("doc2", txt="text goes here") + index2.client.hset("doc2", mapping={"txt": "text goes here"}) # add the actual alias and check index1.aliasadd("myalias") @@ -646,36 +564,6 @@ def test_alias_basic(client): _ = alias_client2.search("*").docs[0] -@pytest.mark.redismod -def test_tags(client): - client.ft().create_index((TextField("txt"), TagField("tags"))) - tags = "foo,foo bar,hello;world" - tags2 = "soba,ramen" - - client.ft().add_document("doc1", txt="fooz barz", tags=tags) - client.ft().add_document("doc2", txt="noodles", tags=tags2) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - - q = Query("@tags:{foo}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo\\ bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{hello\\;world}") - res = client.ft().search(q) - assert 1 == res.total - - q2 = client.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() - - @pytest.mark.redismod def test_textfield_sortable_nostem(client): # Creating the index definition with sortable and no_stem @@ -696,8 +584,8 @@ def test_alter_schema_add(client): client.ft().alter_schema_add(TextField("body")) # Indexing a document - client.ft().add_document( - "doc1", title="MyTitle", body="Some content only in the body" + client.hset( + "doc1", mapping={"title": "MyTitle", "body": "Some content only in the body"} ) # Searching with parameter only in the body (the added field) @@ -712,8 +600,10 @@ def test_alter_schema_add(client): def test_spell_check(client): client.ft().create_index((TextField("f1"), TextField("f2"))) - client.ft().add_document("doc1", f1="some valid content", f2="this is sample text") - client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") + client.hset( + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} + ) + client.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # test spellcheck @@ -767,8 +657,8 @@ def test_dict_operations(client): @pytest.mark.redismod def test_phonetic_matcher(client): client.ft().create_index((TextField("name"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") + client.hset("doc1", mapping={"name": "Jon"}) + client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) assert 1 == len(res.docs) @@ -778,8 +668,8 @@ def test_phonetic_matcher(client): client.flushdb() client.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") + client.hset("doc1", mapping={"name": "Jon"}) + client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) assert 2 == len(res.docs) @@ -791,12 +681,14 @@ def test_phonetic_matcher(client): def test_scorer(client): client.ft().create_index((TextField("description"),)) - client.ft().add_document( - "doc1", description="The quick brown fox jumps over the lazy dog" + client.hset( + "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} ) - client.ft().add_document( + client.hset( "doc2", - description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, ) # default scorer is TFIDF @@ -823,19 +715,19 @@ def test_get(client): assert [None] == client.ft().get("doc1") assert [None, None] == client.ft().get("doc2", "doc1") - client.ft().add_document( - "doc1", f1="some valid content dd1", f2="this is sample text ff1" + client.hset( + "doc1", mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"} ) - client.ft().add_document( - "doc2", f1="some valid content dd2", f2="this is sample text ff2" + client.hset( + "doc2", mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"} ) assert [ - ["f1", "some valid content dd2", "f2", "this is sample text ff2"] + ["f1", "some valid content dd2", "f2", "this is sample text f2"] ] == client.ft().get("doc2") assert [ - ["f1", "some valid content dd1", "f2", "this is sample text ff1"], - ["f1", "some valid content dd2", "f2", "this is sample text ff2"], + ["f1", "some valid content dd1", "f2", "this is sample text f1"], + ["f1", "some valid content dd2", "f2", "this is sample text f2"], ] == client.ft().get("doc1", "doc2") @@ -866,26 +758,32 @@ def test_aggregations_groupby(client): ) # Indexing a document - client.ft().add_document( + client.hset( "search", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - parent="redis", - random_num=10, + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, ) - client.ft().add_document( + client.hset( "ai", - title="RedisAI", - body="RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa - parent="redis", - random_num=3, + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, ) - client.ft().add_document( + client.hset( "json", - title="RedisJson", - body="RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa - parent="redis", - random_num=8, + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, ) req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) @@ -965,7 +863,7 @@ def test_aggregations_groupby(client): res = client.ft().aggregate(req).rows[0] assert res[1] == "redis" - assert res[3] == ["RediSearch", "RedisAI", "RedisJson"] + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} req = aggregations.AggregateRequest("redis").group_by( "@parent", reducers.first_value("@title").alias("first") @@ -1134,7 +1032,7 @@ def test_index_definition(client): @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_if_redis_enterprise() -def testExpire(client): +def test_expire(client): client.ft().create_index((TextField("txt", sortable=True),), temporary=4) ttl = client.execute_command("ft.debug", "TTL", "idx") assert ttl > 2 @@ -1143,20 +1041,9 @@ def testExpire(client): ttl = client.execute_command("ft.debug", "TTL", "idx") time.sleep(0.01) - # add document - should reset the ttl - client.ft().add_document("doc", txt="foo bar", text="this is a simple test") - ttl = client.execute_command("ft.debug", "TTL", "idx") - assert ttl > 2 - try: - while True: - ttl = client.execute_command("ft.debug", "TTL", "idx") - time.sleep(0.5) - except redis.exceptions.ResponseError: - assert ttl == 0 - @pytest.mark.redismod -def testSkipInitialScan(client): +def test_skip_initial_scan(client): client.hset("doc1", "foo", "bar") q = Query("@foo:bar") @@ -1165,23 +1052,23 @@ def testSkipInitialScan(client): @pytest.mark.redismod -def testSummarizeDisabled_nooffset(client): +def test_summarize_disabled_nooffset(client): client.ft().create_index((TextField("txt"),), no_term_offsets=True) - client.ft().add_document("doc1", txt="foo bar") + client.hset("doc1", mapping={"txt": "foo bar"}) with pytest.raises(Exception): client.ft().search(Query("foo").summarize(fields=["txt"])) @pytest.mark.redismod -def testSummarizeDisabled_nohl(client): +def test_summarize_disabled_nohl(client): client.ft().create_index((TextField("txt"),), no_highlight=True) - client.ft().add_document("doc1", txt="foo bar") + client.hset("doc1", mapping={"txt": "foo bar"}) with pytest.raises(Exception): client.ft().search(Query("foo").summarize(fields=["txt"])) @pytest.mark.redismod -def testMaxTextFields(client): +def test_max_text_fields(client): # Creating the index definition client.ft().create_index((TextField("f0"),)) for x in range(1, 32): @@ -1337,10 +1224,10 @@ def test_synupdate(client): ) client.ft().synupdate("id1", True, "boy", "child", "offspring") - client.ft().add_document("doc1", title="he is a baby", body="this is a test") + client.hset("doc1", mapping={"title": "he is a baby", "body": "this is a test"}) client.ft().synupdate("id1", True, "baby") - client.ft().add_document("doc2", title="he is another baby", body="another test") + client.hset("doc2", mapping={"title": "he is another baby", "body": "another test"}) res = client.ft().search(Query("child").expander("SYNONYM")) assert res.docs[0].id == "doc2" @@ -1448,9 +1335,9 @@ def test_json_with_jsonpath(client): assert res.docs[0].id == "doc:1" assert res.docs[0].json == '{"prod:name":"RediSearch"}' - # query for an unsupported field fails + # query for an unsupported field res = client.ft().search("@name_unsupported:RediSearch") - assert res.total == 0 + assert res.total == 1 # return of a supported field succeeds res = client.ft().search(Query("@name:RediSearch").return_field("name")) @@ -1458,13 +1345,6 @@ def test_json_with_jsonpath(client): assert res.docs[0].id == "doc:1" assert res.docs[0].name == "RediSearch" - # return of an unsupported field fails - res = client.ft().search(Query("@name:RediSearch").return_field("name_unsupported")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - with pytest.raises(Exception): - res.docs[0].name_unsupported - @pytest.mark.redismod @pytest.mark.onlynoncluster @@ -1586,9 +1466,9 @@ def test_text_params(modclient): modclient.flushdb() modclient.ft().create_index((TextField("name"),)) - modclient.ft().add_document("doc1", name="Alice") - modclient.ft().add_document("doc2", name="Bob") - modclient.ft().add_document("doc3", name="Carol") + modclient.hset("doc1", mapping={"name": "Alice"}) + modclient.hset("doc2", mapping={"name": "Bob"}) + modclient.hset("doc3", mapping={"name": "Carol"}) params_dict = {"name1": "Alice", "name2": "Bob"} q = Query("@name:($name1 | $name2 )").dialect(2) @@ -1604,9 +1484,9 @@ def test_numeric_params(modclient): modclient.flushdb() modclient.ft().create_index((NumericField("numval"),)) - modclient.ft().add_document("doc1", numval=101) - modclient.ft().add_document("doc2", numval=102) - modclient.ft().add_document("doc3", numval=103) + modclient.hset("doc1", mapping={"numval": 101}) + modclient.hset("doc2", mapping={"numval": 102}) + modclient.hset("doc3", mapping={"numval": 103}) params_dict = {"min": 101, "max": 102} q = Query("@numval:[$min $max]").dialect(2) @@ -1623,9 +1503,9 @@ def test_geo_params(modclient): modclient.flushdb() modclient.ft().create_index((GeoField("g"))) - modclient.ft().add_document("doc1", g="29.69465, 34.95126") - modclient.ft().add_document("doc2", g="29.69350, 34.94737") - modclient.ft().add_document("doc3", g="29.68746, 34.94882") + modclient.hset("doc1", mapping={"g": "29.69465, 34.95126"}) + modclient.hset("doc2", mapping={"g": "29.69350, 34.94737"}) + modclient.hset("doc3", mapping={"g": "29.68746, 34.94882"}) params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} q = Query("@g:[$lon $lat $radius $units]").dialect(2) @@ -1641,16 +1521,15 @@ def test_geo_params(modclient): def test_search_commands_in_pipeline(client): p = client.ft().pipeline() p.create_index((TextField("txt"),)) - p.add_document("doc1", payload="foo baz", txt="foo bar") - p.add_document("doc2", txt="foo bar") + p.hset("doc1", mapping={"txt": "foo bar"}) + p.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo bar").with_payloads() p.search(q) res = p.execute() - assert res[:3] == ["OK", "OK", "OK"] + assert res[:3] == ["OK", True, True] assert 2 == res[3][0] assert "doc1" == res[3][1] assert "doc2" == res[3][4] - assert "foo baz" == res[3][2] assert res[3][5] is None assert res[3][3] == res[3][6] == ["txt", "foo bar"] @@ -1698,3 +1577,41 @@ def test_dialect(modclient: redis.Redis): with pytest.raises(redis.ResponseError) as err: modclient.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) assert "Syntax error" in str(err) + + +@pytest.mark.redismod +def test_expire_while_search(modclient: redis.Redis): + modclient.ft().create_index((TextField("txt"),)) + modclient.hset("hset:1", "txt", "a") + modclient.hset("hset:2", "txt", "b") + modclient.hset("hset:3", "txt", "c") + assert 3 == modclient.ft().search(Query("*")).total + modclient.pexpire("hset:2", 300) + for _ in range(500): + modclient.ft().search(Query("*")).docs[1] + time.sleep(1) + assert 2 == modclient.ft().search(Query("*")).total + + +@pytest.mark.redismod +@pytest.mark.experimental +def test_withsuffixtrie(modclient: redis.Redis): + # create index + assert modclient.ft().create_index((TextField("txt"),)) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert modclient.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert modclient.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert modclient.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert modclient.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 7d42147..6ced535 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -1,8 +1,11 @@ +import math import time from time import sleep import pytest +import redis + from .conftest import skip_ifmodversion_lt @@ -230,6 +233,84 @@ def test_range_advanced(client): assert [(0, 5.0), (5, 6.0)] == client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) + assert [(0, 2.55), (10, 3.0)] == client.ts().range( + 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_range_latest(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.create("t2") + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + timeseries.add("t1", 1, 1) + timeseries.add("t1", 2, 3) + timeseries.add("t1", 11, 7) + timeseries.add("t1", 13, 1) + res = timeseries.range("t1", 0, 20) + assert res == [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)] + res = timeseries.range("t2", 0, 10) + assert res == [(0, 4.0)] + res = timeseries.range("t2", 0, 10, latest=True) + assert res == [(0, 4.0), (10, 8.0)] + res = timeseries.range("t2", 0, 9, latest=True) + assert res == [(0, 4.0)] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_range_bucket_timestamp(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.add("t1", 15, 1) + timeseries.add("t1", 17, 4) + timeseries.add("t1", 51, 3) + timeseries.add("t1", 73, 5) + timeseries.add("t1", 75, 3) + assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ) + assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_range_empty(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.add("t1", 15, 1) + timeseries.add("t1", 17, 4) + timeseries.add("t1", 51, 3) + timeseries.add("t1", 73, 5) + timeseries.add("t1", 75, 3) + assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ) + res = timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True + ) + for i in range(len(res)): + if math.isnan(res[i][1]): + res[i] = (res[i][0], None) + assert [ + (10, 4.0), + (20, None), + (30, None), + (40, None), + (50, 3.0), + (60, None), + (70, 5.0), + ] == res @pytest.mark.redismod @@ -262,11 +343,87 @@ def test_rev_range(client): assert [(1, 10.0), (0, 1.0)] == client.ts().revrange( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 ) + assert [(10, 3.0), (0, 2.55)] == client.ts().revrange( + 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_revrange_latest(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.create("t2") + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + timeseries.add("t1", 1, 1) + timeseries.add("t1", 2, 3) + timeseries.add("t1", 11, 7) + timeseries.add("t1", 13, 1) + res = timeseries.revrange("t2", 0, 10) + assert res == [(0, 4.0)] + res = timeseries.revrange("t2", 0, 10, latest=True) + assert res == [(10, 8.0), (0, 4.0)] + res = timeseries.revrange("t2", 0, 9, latest=True) + assert res == [(0, 4.0)] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_revrange_bucket_timestamp(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.add("t1", 15, 1) + timeseries.add("t1", 17, 4) + timeseries.add("t1", 51, 3) + timeseries.add("t1", 73, 5) + timeseries.add("t1", 75, 3) + assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ) + assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_revrange_empty(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.add("t1", 15, 1) + timeseries.add("t1", 17, 4) + timeseries.add("t1", 51, 3) + timeseries.add("t1", 73, 5) + timeseries.add("t1", 75, 3) + assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ) + res = timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True + ) + for i in range(len(res)): + if math.isnan(res[i][1]): + res[i] = (res[i][0], None) + assert [ + (70, 5.0), + (60, None), + (50, 3.0), + (40, None), + (30, None), + (20, None), + (10, 4.0), + ] == res @pytest.mark.redismod @pytest.mark.onlynoncluster -def testMultiRange(client): +def test_mrange(client): client.ts().create(1, labels={"Test": "This", "team": "ny"}) client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) for i in range(100): @@ -351,6 +508,31 @@ def test_multi_range_advanced(client): assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_mrange_latest(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.create("t2", labels={"is_compaction": "true"}) + timeseries.create("t3") + timeseries.create("t4", labels={"is_compaction": "true"}) + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + timeseries.createrule("t3", "t4", aggregation_type="sum", bucket_size_msec=10) + timeseries.add("t1", 1, 1) + timeseries.add("t1", 2, 3) + timeseries.add("t1", 11, 7) + timeseries.add("t1", 13, 1) + timeseries.add("t3", 1, 1) + timeseries.add("t3", 2, 3) + timeseries.add("t3", 11, 7) + timeseries.add("t3", 13, 1) + assert client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True) == [ + {"t2": [{}, [(0, 4.0), (10, 8.0)]]}, + {"t4": [{}, [(0, 4.0), (10, 8.0)]]}, + ] + + @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") @@ -434,6 +616,30 @@ def test_multi_reverse_range(client): assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_mrevrange_latest(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.create("t2", labels={"is_compaction": "true"}) + timeseries.create("t3") + timeseries.create("t4", labels={"is_compaction": "true"}) + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + timeseries.createrule("t3", "t4", aggregation_type="sum", bucket_size_msec=10) + timeseries.add("t1", 1, 1) + timeseries.add("t1", 2, 3) + timeseries.add("t1", 11, 7) + timeseries.add("t1", 13, 1) + timeseries.add("t3", 1, 1) + timeseries.add("t3", 2, 3) + timeseries.add("t3", 11, 7) + timeseries.add("t3", 13, 1) + assert client.ts().mrevrange( + 0, 10, filters=["is_compaction=true"], latest=True + ) == [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}] + + @pytest.mark.redismod def test_get(client): name = "test" @@ -445,6 +651,21 @@ def test_get(client): assert 4 == client.ts().get(name)[1] +@pytest.mark.redismod +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_get_latest(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.create("t2") + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + timeseries.add("t1", 1, 1) + timeseries.add("t1", 2, 3) + timeseries.add("t1", 11, 7) + timeseries.add("t1", 13, 1) + assert (0, 4.0) == timeseries.get("t2") + assert (10, 8.0) == timeseries.get("t2", latest=True) + + @pytest.mark.redismod @pytest.mark.onlynoncluster def test_mget(client): @@ -467,6 +688,24 @@ def test_mget(client): assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_ifmodversion_lt("1.8.0", "timeseries") +def test_mget_latest(client: redis.Redis): + timeseries = client.ts() + timeseries.create("t1") + timeseries.create("t2", labels={"is_compaction": "true"}) + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + timeseries.add("t1", 1, 1) + timeseries.add("t1", 2, 3) + timeseries.add("t1", 11, 7) + timeseries.add("t1", 13, 1) + assert timeseries.mget(filters=["is_compaction=true"]) == [{"t2": [{}, 0, 4.0]}] + assert [{"t2": [{}, 10, 8.0]}] == timeseries.mget( + filters=["is_compaction=true"], latest=True + ) + + @pytest.mark.redismod def test_info(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) @@ -506,7 +745,7 @@ def test_pipeline(client): pipeline.execute() info = client.ts().info("with_pipeline") - assert info.lastTimeStamp == 99 + assert info.last_timestamp == 99 assert info.total_samples == 100 assert client.ts().get("with_pipeline")[1] == 99 * 1.1 -- cgit v1.2.1