diff options
Diffstat (limited to 'redis/connection.py')
-rwxr-xr-x | redis/connection.py | 429 |
1 files changed, 244 insertions, 185 deletions
diff --git a/redis/connection.py b/redis/connection.py index ef3a667..d13fe65 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,8 +1,3 @@ -from packaging.version import Version -from itertools import chain -from time import time -from queue import LifoQueue, Empty, Full -from urllib.parse import parse_qs, unquote, urlparse import copy import errno import io @@ -10,6 +5,12 @@ import os import socket import threading import weakref +from itertools import chain +from queue import Empty, Full, LifoQueue +from time import time +from urllib.parse import parse_qs, unquote, urlparse + +from packaging.version import Version from redis.backoff import NoBackoff from redis.exceptions import ( @@ -21,20 +22,20 @@ from redis.exceptions import ( DataError, ExecAbortError, InvalidResponse, + ModuleError, NoPermissionError, NoScriptError, ReadOnlyError, RedisError, ResponseError, TimeoutError, - ModuleError, ) - from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE, str_if_bytes try: import ssl + ssl_available = True except ImportError: ssl_available = False @@ -44,7 +45,7 @@ NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { } if ssl_available: - if hasattr(ssl, 'SSLWantReadError'): + if hasattr(ssl, "SSLWantReadError"): NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 else: @@ -56,34 +57,31 @@ if HIREDIS_AVAILABLE: import hiredis hiredis_version = Version(hiredis.__version__) - HIREDIS_SUPPORTS_CALLABLE_ERRORS = \ - hiredis_version >= Version('0.1.3') - HIREDIS_SUPPORTS_BYTE_BUFFER = \ - hiredis_version >= Version('0.1.4') - HIREDIS_SUPPORTS_ENCODING_ERRORS = \ - hiredis_version >= Version('1.0.0') + HIREDIS_SUPPORTS_CALLABLE_ERRORS = hiredis_version >= Version("0.1.3") + HIREDIS_SUPPORTS_BYTE_BUFFER = hiredis_version >= Version("0.1.4") + HIREDIS_SUPPORTS_ENCODING_ERRORS = hiredis_version >= Version("1.0.0") HIREDIS_USE_BYTE_BUFFER = True # only use byte buffer if hiredis supports it if not HIREDIS_SUPPORTS_BYTE_BUFFER: HIREDIS_USE_BYTE_BUFFER = False -SYM_STAR = b'*' -SYM_DOLLAR = b'$' -SYM_CRLF = b'\r\n' -SYM_EMPTY = b'' +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_EMPTY = b"" SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." SENTINEL = object() -MODULE_LOAD_ERROR = 'Error loading the extension. ' \ - 'Please check the server logs.' -NO_SUCH_MODULE_ERROR = 'Error unloading module: no such module with that name' -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = 'Error unloading module: operation not ' \ - 'possible.' -MODULE_EXPORTS_DATA_TYPES_ERROR = "Error unloading module: the module " \ - "exports one or more module-side data " \ - "types, can't unload" +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) class Encoder: @@ -100,15 +98,19 @@ class Encoder: return value elif isinstance(value, bool): # special case bool since it is a subclass of int - raise DataError("Invalid input of type: 'bool'. Convert to a " - "bytes, string, int or float first.") + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) elif isinstance(value, (int, float)): value = repr(value).encode() elif not isinstance(value, str): # a value we don't know how to deal with. throw an error typename = type(value).__name__ - raise DataError(f"Invalid input of type: '{typename}'. " - f"Convert to a bytes, string, int or float first.") + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) if isinstance(value, str): value = value.encode(self.encoding, self.encoding_errors) return value @@ -125,36 +127,36 @@ class Encoder: class BaseParser: EXCEPTION_CLASSES = { - 'ERR': { - 'max number of clients reached': ConnectionError, - 'Client sent AUTH, but no password is set': AuthenticationError, - 'invalid password': AuthenticationError, + "ERR": { + "max number of clients reached": ConnectionError, + "Client sent AUTH, but no password is set": AuthenticationError, + "invalid password": AuthenticationError, # some Redis server versions report invalid command syntax # in lowercase - 'wrong number of arguments for \'auth\' command': - AuthenticationWrongNumberOfArgsError, + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, # some Redis server versions report invalid command syntax # in uppercase - 'wrong number of arguments for \'AUTH\' command': - AuthenticationWrongNumberOfArgsError, + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, MODULE_LOAD_ERROR: ModuleError, MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, NO_SUCH_MODULE_ERROR: ModuleError, MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, }, - 'EXECABORT': ExecAbortError, - 'LOADING': BusyLoadingError, - 'NOSCRIPT': NoScriptError, - 'READONLY': ReadOnlyError, - 'NOAUTH': AuthenticationError, - 'NOPERM': NoPermissionError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, } def parse_error(self, response): "Parse an error response" - error_code = response.split(' ')[0] + error_code = response.split(" ")[0] if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1:] + response = response[len(error_code) + 1 :] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) @@ -177,8 +179,7 @@ class SocketBuffer: def length(self): return self.bytes_written - self.bytes_read - def _read_from_socket(self, length=None, timeout=SENTINEL, - raise_on_timeout=True): + def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True): sock = self._sock socket_read_size = self.socket_read_size buf = self._buffer @@ -220,9 +221,9 @@ class SocketBuffer: sock.settimeout(self.socket_timeout) def can_read(self, timeout): - return bool(self.length) or \ - self._read_from_socket(timeout=timeout, - raise_on_timeout=False) + return bool(self.length) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) def read(self, length): length = length + 2 # make sure to read the \r\n terminator @@ -283,6 +284,7 @@ class SocketBuffer: class PythonParser(BaseParser): "Plain Python parsing class" + def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self.encoder = None @@ -298,9 +300,9 @@ class PythonParser(BaseParser): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - self._buffer = SocketBuffer(self._sock, - self.socket_read_size, - connection.socket_timeout) + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) self.encoder = connection.encoder def on_disconnect(self): @@ -321,12 +323,12 @@ class PythonParser(BaseParser): byte, response = raw[:1], raw[1:] - if byte not in (b'-', b'+', b':', b'$', b'*'): + if byte not in (b"-", b"+", b":", b"$", b"*"): raise InvalidResponse(f"Protocol Error: {raw!r}") # server returned an error - if byte == b'-': - response = response.decode('utf-8', errors='replace') + if byte == b"-": + response = response.decode("utf-8", errors="replace") error = self.parse_error(response) # if the error is a ConnectionError, raise immediately so the user # is notified @@ -338,24 +340,26 @@ class PythonParser(BaseParser): # necessary, so just return the exception instance here. return error # single value - elif byte == b'+': + elif byte == b"+": pass # int value - elif byte == b':': + elif byte == b":": response = int(response) # bulk response - elif byte == b'$': + elif byte == b"$": length = int(response) if length == -1: return None response = self._buffer.read(length) # multi-bulk response - elif byte == b'*': + elif byte == b"*": length = int(response) if length == -1: return None - response = [self.read_response(disable_decoding=disable_decoding) - for i in range(length)] + response = [ + self.read_response(disable_decoding=disable_decoding) + for i in range(length) + ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response @@ -363,6 +367,7 @@ class PythonParser(BaseParser): class HiredisParser(BaseParser): "Parser class for connections using Hiredis" + def __init__(self, socket_read_size): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not installed") @@ -381,18 +386,18 @@ class HiredisParser(BaseParser): self._sock = connection._sock self._socket_timeout = connection.socket_timeout kwargs = { - 'protocolError': InvalidResponse, - 'replyError': self.parse_error, + "protocolError": InvalidResponse, + "replyError": self.parse_error, } # hiredis < 0.1.3 doesn't support functions that create exceptions if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: - kwargs['replyError'] = ResponseError + kwargs["replyError"] = ResponseError if connection.encoder.decode_responses: - kwargs['encoding'] = connection.encoder.encoding + kwargs["encoding"] = connection.encoder.encoding if HIREDIS_SUPPORTS_ENCODING_ERRORS: - kwargs['errors'] = connection.encoder.encoding_errors + kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -408,8 +413,7 @@ class HiredisParser(BaseParser): if self._next_response is False: self._next_response = self._reader.gets() if self._next_response is False: - return self.read_from_socket(timeout=timeout, - raise_on_timeout=False) + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) return True def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): @@ -468,16 +472,22 @@ class HiredisParser(BaseParser): if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: if isinstance(response, ResponseError): response = self.parse_error(response.args[0]) - elif isinstance(response, list) and response and \ - isinstance(response[0], ResponseError): + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ResponseError) + ): response[0] = self.parse_error(response[0].args[0]) # if the response is a ConnectionError or the response is a list and # the first item is a ConnectionError, raise it as something bad # happened if isinstance(response, ConnectionError): raise response - elif isinstance(response, list) and response and \ - isinstance(response[0], ConnectionError): + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): raise response[0] return response @@ -491,14 +501,29 @@ else: class Connection: "Manages TCP communication to and from a Redis server" - def __init__(self, host='localhost', port=6379, db=0, password=None, - socket_timeout=None, socket_connect_timeout=None, - socket_keepalive=False, socket_keepalive_options=None, - socket_type=0, retry_on_timeout=False, encoding='utf-8', - encoding_errors='strict', decode_responses=False, - parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0, client_name=None, username=None, - retry=None, redis_connect_func=None): + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + socket_timeout=None, + socket_connect_timeout=None, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + retry_on_timeout=False, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + parser_class=DefaultParser, + socket_read_size=65536, + health_check_interval=0, + client_name=None, + username=None, + retry=None, + redis_connect_func=None, + ): """ Initialize a new Connection. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -536,17 +561,13 @@ class Connection: self._buffer_cutoff = 6000 def __repr__(self): - repr_args = ','.join([f'{k}={v}' for k, v in self.repr_pieces()]) - return f'{self.__class__.__name__}<{repr_args}>' + repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) + return f"{self.__class__.__name__}<{repr_args}>" def repr_pieces(self): - pieces = [ - ('host', self.host), - ('port', self.port), - ('db', self.db) - ] + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] if self.client_name: - pieces.append(('client_name', self.client_name)) + pieces.append(("client_name", self.client_name)) return pieces def __del__(self): @@ -606,8 +627,9 @@ class Connection: # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - for res in socket.getaddrinfo(self.host, self.port, self.socket_type, - socket.SOCK_STREAM): + for res in socket.getaddrinfo( + self.host, self.port, self.socket_type, socket.SOCK_STREAM + ): family, socktype, proto, canonname, socket_address = res sock = None try: @@ -658,12 +680,12 @@ class Connection: # if username and/or password are set, authenticate if self.username or self.password: if self.username: - auth_args = (self.username, self.password or '') + auth_args = (self.username, self.password or "") else: auth_args = (self.password,) # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH - self.send_command('AUTH', *auth_args, check_health=False) + self.send_command("AUTH", *auth_args, check_health=False) try: auth_response = self.read_response() @@ -672,23 +694,23 @@ class Connection: # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command('AUTH', self.password, check_health=False) + self.send_command("AUTH", self.password, check_health=False) auth_response = self.read_response() - if str_if_bytes(auth_response) != 'OK': - raise AuthenticationError('Invalid Username or Password') + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") # if a client_name is given, set it if self.client_name: - self.send_command('CLIENT', 'SETNAME', self.client_name) - if str_if_bytes(self.read_response()) != 'OK': - raise ConnectionError('Error setting client name') + self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Error setting client name") # if a database is specified, switch to it if self.db: - self.send_command('SELECT', self.db) - if str_if_bytes(self.read_response()) != 'OK': - raise ConnectionError('Invalid Database') + self.send_command("SELECT", self.db) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Invalid Database") def disconnect(self): "Disconnects from the Redis server" @@ -705,9 +727,9 @@ class Connection: def _send_ping(self): """Send PING, expect PONG in return""" - self.send_command('PING', check_health=False) - if str_if_bytes(self.read_response()) != 'PONG': - raise ConnectionError('Bad response from PING health check') + self.send_command("PING", check_health=False) + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") def _ping_failed(self, error): """Function to call when PING fails""" @@ -736,7 +758,7 @@ class Connection: except OSError as e: self.disconnect() if len(e.args) == 1: - errno, errmsg = 'UNKNOWN', e.args[0] + errno, errmsg = "UNKNOWN", e.args[0] else: errno = e.args[0] errmsg = e.args[1] @@ -747,8 +769,9 @@ class Connection: def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" - self.send_packed_command(self.pack_command(*args), - check_health=kwargs.get('check_health', True)) + self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) def can_read(self, timeout=0): """Poll the socket to see if there's data that can be read.""" @@ -760,17 +783,15 @@ class Connection: def read_response(self, disable_decoding=False): """Read the response from a previously sent command""" try: - response = self._parser.read_response( - disable_decoding=disable_decoding - ) + response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: self.disconnect() raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: self.disconnect() raise ConnectionError( - f"Error while reading from {self.host}:{self.port}" - f" : {e.args}") + f"Error while reading from {self.host}:{self.port}" f" : {e.args}" + ) except BaseException: self.disconnect() raise @@ -792,7 +813,7 @@ class Connection: # not encoded. if isinstance(args[0], str): args = tuple(args[0].encode().split()) + args[1:] - elif b' ' in args[0]: + elif b" " in args[0]: args = tuple(args[0].split()) + args[1:] buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) @@ -802,17 +823,28 @@ class Connection: # to avoid large string mallocs, chunk the command into the # output list if we're sending large values or memoryviews arg_length = len(arg) - if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff - or isinstance(arg, memoryview)): + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)) + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) output.append(buff) output.append(arg) buff = SYM_CRLF else: buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(arg_length).encode(), - SYM_CRLF, arg, SYM_CRLF)) + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) output.append(buff) return output @@ -826,8 +858,11 @@ class Connection: for cmd in commands: for chunk in self.pack_command(*cmd): chunklen = len(chunk) - if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff - or isinstance(chunk, memoryview)): + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] @@ -844,10 +879,15 @@ class Connection: class SSLConnection(Connection): - - def __init__(self, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs='required', ssl_ca_certs=None, - ssl_check_hostname=False, **kwargs): + def __init__( + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_check_hostname=False, + **kwargs, + ): if not ssl_available: raise RedisError("Python wasn't built with SSL support") @@ -859,13 +899,14 @@ class SSLConnection(Connection): ssl_cert_reqs = ssl.CERT_NONE elif isinstance(ssl_cert_reqs, str): CERT_REQS = { - 'none': ssl.CERT_NONE, - 'optional': ssl.CERT_OPTIONAL, - 'required': ssl.CERT_REQUIRED + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, } if ssl_cert_reqs not in CERT_REQS: raise RedisError( - f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}") + f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" + ) ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] self.cert_reqs = ssl_cert_reqs self.ca_certs = ssl_ca_certs @@ -878,22 +919,30 @@ class SSLConnection(Connection): context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs if self.certfile and self.keyfile: - context.load_cert_chain(certfile=self.certfile, - keyfile=self.keyfile) + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs: context.load_verify_locations(self.ca_certs) return context.wrap_socket(sock, server_hostname=self.host) class UnixDomainSocketConnection(Connection): - - def __init__(self, path='', db=0, username=None, password=None, - socket_timeout=None, encoding='utf-8', - encoding_errors='strict', decode_responses=False, - retry_on_timeout=False, - parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0, client_name=None, - retry=None): + def __init__( + self, + path="", + db=0, + username=None, + password=None, + socket_timeout=None, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + retry_on_timeout=False, + parser_class=DefaultParser, + socket_read_size=65536, + health_check_interval=0, + client_name=None, + retry=None, + ): """ Initialize a new UnixDomainSocketConnection. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -926,11 +975,11 @@ class UnixDomainSocketConnection(Connection): def repr_pieces(self): pieces = [ - ('path', self.path), - ('db', self.db), + ("path", self.path), + ("db", self.db), ] if self.client_name: - pieces.append(('client_name', self.client_name)) + pieces.append(("client_name", self.client_name)) return pieces def _connect(self): @@ -952,11 +1001,11 @@ class UnixDomainSocketConnection(Connection): ) -FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO') +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") def to_bool(value): - if value is None or value == '': + if value is None or value == "": return None if isinstance(value, str) and value.upper() in FALSE_STRINGS: return False @@ -964,14 +1013,14 @@ def to_bool(value): URL_QUERY_ARGUMENT_PARSERS = { - 'db': int, - 'socket_timeout': float, - 'socket_connect_timeout': float, - 'socket_keepalive': to_bool, - 'retry_on_timeout': to_bool, - 'max_connections': int, - 'health_check_interval': int, - 'ssl_check_hostname': to_bool, + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, } @@ -987,42 +1036,42 @@ def parse_url(url): try: kwargs[name] = parser(value) except (TypeError, ValueError): - raise ValueError( - f"Invalid value for `{name}` in connection URL." - ) + raise ValueError(f"Invalid value for `{name}` in connection URL.") else: kwargs[name] = value if url.username: - kwargs['username'] = unquote(url.username) + kwargs["username"] = unquote(url.username) if url.password: - kwargs['password'] = unquote(url.password) + kwargs["password"] = unquote(url.password) # We only support redis://, rediss:// and unix:// schemes. - if url.scheme == 'unix': + if url.scheme == "unix": if url.path: - kwargs['path'] = unquote(url.path) - kwargs['connection_class'] = UnixDomainSocketConnection + kwargs["path"] = unquote(url.path) + kwargs["connection_class"] = UnixDomainSocketConnection - elif url.scheme in ('redis', 'rediss'): + elif url.scheme in ("redis", "rediss"): if url.hostname: - kwargs['host'] = unquote(url.hostname) + kwargs["host"] = unquote(url.hostname) if url.port: - kwargs['port'] = int(url.port) + kwargs["port"] = int(url.port) # If there's a path argument, use it as the db argument if a # querystring value wasn't specified - if url.path and 'db' not in kwargs: + if url.path and "db" not in kwargs: try: - kwargs['db'] = int(unquote(url.path).replace('/', '')) + kwargs["db"] = int(unquote(url.path).replace("/", "")) except (AttributeError, ValueError): pass - if url.scheme == 'rediss': - kwargs['connection_class'] = SSLConnection + if url.scheme == "rediss": + kwargs["connection_class"] = SSLConnection else: - raise ValueError('Redis URL must specify one of the following ' - 'schemes (redis://, rediss://, unix://)') + raise ValueError( + "Redis URL must specify one of the following " + "schemes (redis://, rediss://, unix://)" + ) return kwargs @@ -1040,6 +1089,7 @@ class ConnectionPool: Any additional keyword arguments are passed to the constructor of ``connection_class``. """ + @classmethod def from_url(cls, url, **kwargs): """ @@ -1084,8 +1134,9 @@ class ConnectionPool: kwargs.update(url_options) return cls(**kwargs) - def __init__(self, connection_class=Connection, max_connections=None, - **connection_kwargs): + def __init__( + self, connection_class=Connection, max_connections=None, **connection_kwargs + ): max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') @@ -1194,12 +1245,12 @@ class ConnectionPool: # closed. either way, reconnect and verify everything is good. try: if connection.can_read(): - raise ConnectionError('Connection has data') + raise ConnectionError("Connection has data") except ConnectionError: connection.disconnect() connection.connect() if connection.can_read(): - raise ConnectionError('Connection not ready') + raise ConnectionError("Connection not ready") except BaseException: # release the connection back to the pool so that we don't # leak it @@ -1212,9 +1263,9 @@ class ConnectionPool: "Return an encoder based on encoding settings" kwargs = self.connection_kwargs return Encoder( - encoding=kwargs.get('encoding', 'utf-8'), - encoding_errors=kwargs.get('encoding_errors', 'strict'), - decode_responses=kwargs.get('decode_responses', False) + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), ) def make_connection(self): @@ -1259,8 +1310,9 @@ class ConnectionPool: self._checkpid() with self._lock: if inuse_connections: - connections = chain(self._available_connections, - self._in_use_connections) + connections = chain( + self._available_connections, self._in_use_connections + ) else: connections = self._available_connections @@ -1301,16 +1353,23 @@ class BlockingConnectionPool(ConnectionPool): >>> # not available. >>> pool = BlockingConnectionPool(timeout=5) """ - def __init__(self, max_connections=50, timeout=20, - connection_class=Connection, queue_class=LifoQueue, - **connection_kwargs): + + def __init__( + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, + ): self.queue_class = queue_class self.timeout = timeout super().__init__( connection_class=connection_class, max_connections=max_connections, - **connection_kwargs) + **connection_kwargs, + ) def reset(self): # Create and fill up a thread safe queue with ``None`` values. @@ -1381,12 +1440,12 @@ class BlockingConnectionPool(ConnectionPool): # closed. either way, reconnect and verify everything is good. try: if connection.can_read(): - raise ConnectionError('Connection has data') + raise ConnectionError("Connection has data") except ConnectionError: connection.disconnect() connection.connect() if connection.can_read(): - raise ConnectionError('Connection not ready') + raise ConnectionError("Connection not ready") except BaseException: # release the connection back to the pool so that we don't leak it self.release(connection) |