From dca7bd40a3a5d0c0853fe2befe706e214407697b Mon Sep 17 00:00:00 2001 From: Peter van Dijk Date: Tue, 15 Nov 2016 00:40:07 +0100 Subject: Allow setting client_name during connection construction. Client instances and Connection pools now accept "client_name" as an optional argument. If supplied, all connections created will be named via CLIENT SETNAME once the connection to the server is established. --- CHANGES | 4 ++++ redis/client.py | 3 ++- redis/connection.py | 46 +++++++++++++++++++++++++++++-------------- tests/test_connection_pool.py | 45 +++++++++++++++++++++++++++++++++++------- 4 files changed, 75 insertions(+), 23 deletions(-) diff --git a/CHANGES b/CHANGES index a8eb848..7cff020 100644 --- a/CHANGES +++ b/CHANGES @@ -10,6 +10,10 @@ pipeline instances relied on __len__ for boolean evaluation which meant that pipelines with no commands on the stack would be considered False. #994 + * Client instances and Connection pools now support a 'client_name' + argument. If supplied, all connections created will call CLIENT SETNAME + as soon as the connection is opened. Thanks to @Habbie for supplying + the basis of this chanfge. #802 * 3.3.11 * Further fix for the SSLError -> TimeoutError mapping to work on obscure releases of Python 2.7. diff --git a/redis/client.py b/redis/client.py index 0486022..eb1ccf1 100755 --- a/redis/client.py +++ b/redis/client.py @@ -684,7 +684,7 @@ class Redis(object): ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, max_connections=None, single_connection_client=False, - health_check_interval=0): + health_check_interval=0, client_name=None): if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -706,6 +706,7 @@ class Redis(object): 'retry_on_timeout': retry_on_timeout, 'max_connections': max_connections, 'health_check_interval': health_check_interval, + 'client_name': client_name } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/connection.py b/redis/connection.py index b90cafe..9a0e12d 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -485,7 +485,6 @@ else: class Connection(object): "Manages TCP communication to and from a Redis server" - description_format = "Connection" def __init__(self, host='localhost', port=6379, db=0, username=None, password=None, socket_timeout=None, @@ -494,12 +493,13 @@ class Connection(object): retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0): + health_check_interval=0, client_name=None): self.pid = os.getpid() self.host = host self.port = int(port) self.db = db self.username = username + self.client_name = client_name self.password = password self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout @@ -512,16 +512,22 @@ class Connection(object): self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) - self._description_args = { - 'host': self.host, - 'port': self.port, - 'db': self.db, - } self._connect_callbacks = [] self._buffer_cutoff = 6000 def __repr__(self): - return self.description_format % self._description_args + repr_args = ','.join(['%s=%s' % (k, v) for k, v in self.repr_pieces()]) + return '%s<%s>' % (self.__class__.__name__, repr_args) + + def repr_pieces(self): + pieces = [ + ('host', self.host), + ('port', self.port), + ('db', self.db) + ] + if self.client_name: + pieces.append(('client_name', self.client_name)) + return pieces def __del__(self): try: @@ -626,6 +632,12 @@ class Connection(object): if nativestr(self.read_response()) != 'OK': raise AuthenticationError('Invalid Username or Password') + # if a client_name is given, set it + if self.client_name: + self.send_command('CLIENT', 'SETNAME', self.client_name) + if nativestr(self.read_response()) != 'OK': + raise ConnectionError('Error setting client name') + # if a database is specified, switch to it if self.db: self.send_command('SELECT', self.db) @@ -785,7 +797,6 @@ class Connection(object): class SSLConnection(Connection): - description_format = "SSLConnection" def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, **kwargs): @@ -838,18 +849,18 @@ class SSLConnection(Connection): class UnixDomainSocketConnection(Connection): - description_format = "UnixDomainSocketConnection" def __init__(self, path='', db=0, username=None, password=None, socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, retry_on_timeout=False, parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0): + health_check_interval=0, client_name=None): self.pid = os.getpid() self.path = path self.db = db self.username = username + self.client_name = client_name self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout @@ -858,13 +869,18 @@ class UnixDomainSocketConnection(Connection): self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) - self._description_args = { - 'path': self.path, - 'db': self.db, - } self._connect_callbacks = [] self._buffer_cutoff = 6000 + def repr_pieces(self): + pieces = [ + ('path', self.path), + ('db', self.db), + ] + if self.client_name: + pieces.append(('client_name', self.client_name)) + return pieces + def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index e0f0822..7ebd5ff 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -64,17 +64,28 @@ class TestConnectionPool(object): assert c1 == c2 def test_repr_contains_db_info_tcp(self): - connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs = { + 'host': 'localhost', + 'port': 6379, + 'db': 1, + 'client_name': 'test-client' + } pool = self.get_pool(connection_kwargs=connection_kwargs, connection_class=redis.Connection) - expected = 'ConnectionPool>' + expected = ('ConnectionPool>') assert repr(pool) == expected def test_repr_contains_db_info_unix(self): - connection_kwargs = {'path': '/abc', 'db': 1} + connection_kwargs = { + 'path': '/abc', + 'db': 1, + 'client_name': 'test-client' + } pool = self.get_pool(connection_kwargs=connection_kwargs, connection_class=redis.UnixDomainSocketConnection) - expected = 'ConnectionPool>' + expected = ('ConnectionPool>') assert repr(pool) == expected def test_pool_equality(self): @@ -177,8 +188,14 @@ class TestBlockingConnectionPool(object): assert c1 == c2 def test_repr_contains_db_info_tcp(self): - pool = redis.ConnectionPool(host='localhost', port=6379, db=0) - expected = 'ConnectionPool>' + pool = redis.ConnectionPool( + host='localhost', + port=6379, + db=0, + client_name='test-client' + ) + expected = ('ConnectionPool>') assert repr(pool) == expected def test_repr_contains_db_info_unix(self): @@ -186,8 +203,10 @@ class TestBlockingConnectionPool(object): connection_class=redis.UnixDomainSocketConnection, path='abc', db=0, + client_name='test-client' ) - expected = 'ConnectionPool>' + expected = ('ConnectionPool>') assert repr(pool) == expected @@ -364,6 +383,12 @@ class TestConnectionPoolURLParsing(object): ): assert expected is to_bool(value) + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url( + 'redis://location?client_name=test-client' + ) + assert pool.connection_kwargs['client_name'] == 'test-client' + def test_invalid_extra_typed_querystring_options(self): import warnings with warnings.catch_warnings(record=True) as warning_log: @@ -502,6 +527,12 @@ class TestConnectionPoolUnixSocketURLParsing(object): 'password': None, } + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url( + 'redis://location?client_name=test-client' + ) + assert pool.connection_kwargs['client_name'] == 'test-client' + def test_extra_querystring_options(self): pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2') assert pool.connection_class == redis.UnixDomainSocketConnection -- cgit v1.2.1