diff options
author | Tim Savage <tim.savage@poweredbypenguins.org> | 2016-03-29 18:15:34 +1100 |
---|---|---|
committer | Tim Savage <tim.savage@poweredbypenguins.org> | 2016-03-29 18:15:34 +1100 |
commit | 8e6f655d069fd3b31de63ae7f9ef967a7bf6de14 (patch) | |
tree | 14510ac419a08d6ee2e966592498027e381e32d8 | |
parent | b40875d553ab6d6db69e64eef134e5fac652b033 (diff) | |
download | redis-py-8e6f655d069fd3b31de63ae7f9ef967a7bf6de14.tar.gz |
Extend ConnectionPool.to_url to parse querystring arguments to correct type.
Previously if a value for socket_timeout was supplied as part fo the URL an error would be raised when a socket was created with an invalid type, this change fixes that by parsing `socket_timeout`, `socket_connect_timeout` to float values.
In addition the boolean values `socket_keepalive` and `retry_on_timeout` are parsed to bool types taking into account the usage of True/False, Yes/No strings.
-rwxr-xr-x | redis/connection.py | 33 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 39 |
2 files changed, 68 insertions, 4 deletions
diff --git a/redis/connection.py b/redis/connection.py index 91730ea..fb90e93 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -735,6 +735,22 @@ class UnixDomainSocketConnection(Connection): (exception.args[0], self.path, exception.args[1]) +def to_bool(value): + if value is None or value == '': + return None + if isinstance(value, basestring) and value.upper() in ('0', 'F', 'FALSE', 'N', 'NO'): + return False + return bool(value) + + +URL_QUERY_PARAMETER_TYPES = { + 'socket_timeout': float, + 'socket_connect_timeout': float, + 'socket_keepalive': to_bool, + 'retry_on_timeout': to_bool +} + + class ConnectionPool(object): "Generic connection pool" @classmethod @@ -769,8 +785,13 @@ class ConnectionPool(object): ``path``, and ``password`` components. Any additional querystring arguments and keyword arguments will be - passed along to the ConnectionPool class's initializer. In the case - of conflicting arguments, querystring arguments always win. + passed along to the ConnectionPool class's initializer. The querystring + arguments ``socket_connect_timeout`` and ``socket_timeout`` if supplied + are parsed as float values. The arguments ``socket_keepalive`` and + ``retry_on_timeout`` are parsed to boolean values that accept + True/False, Yes/No values to indicate state. Invalid types cause a + ``UserWarning`` to be raised. In the case of conflicting arguments, + querystring arguments always win. """ url_string = url url = urlparse(url) @@ -790,7 +811,13 @@ class ConnectionPool(object): for name, value in iteritems(parse_qs(qs)): if value and len(value) > 0: - url_options[name] = value[0] + if name in URL_QUERY_PARAMETER_TYPES: + try: + url_options[name] = URL_QUERY_PARAMETER_TYPES[name](value[0]) + except (TypeError, ValueError): + warnings.warn(UserWarning("Invalid value for `%s` in connection URL." % name)) + else: + url_options[name] = value[0] if decode_components: password = unquote(url.password) if url.password else None diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 6b2478a..c36177f 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -6,7 +6,7 @@ import time import re from threading import Thread -from redis.connection import ssl_available +from redis.connection import ssl_available, to_bool from .conftest import skip_if_server_version_lt @@ -237,6 +237,43 @@ class TestConnectionPoolURLParsing(object): 'password': None, } + def test_extra_typed_querystring_options(self): + pool = redis.ConnectionPool.from_url( + 'redis://localhost/2?socket_timeout=20&socket_connect_timeout=10&socket_keepalive=&retry_on_timeout=Yes') + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 2, + 'socket_timeout': 20.0, + 'socket_connect_timeout': 10.0, + 'retry_on_timeout': True, + 'password': None, + } + + def test_boolean_parsing(self): + for expected, value in ( + (None, None), + (None, ''), + (False, 0), (False, '0'), + (False, 'f'), (False, 'F'), (False, 'False'), + (False, 'n'), (False, 'N'), (False, 'No'), + (True, 1), (True, '1'), + (True, 'y'), (True, 'Y'), (True, 'Yes'), + ): + assert expected is to_bool(value) + + def test_invalid_extra_typed_querystring_options(self): + import warnings + with warnings.catch_warnings(record=True) as warning_log: + redis.ConnectionPool.from_url('redis://localhost/2?socket_timeout=_&socket_connect_timeout=abc') + # Compare the message values + assert [str(m.message) for m in sorted(warning_log, key=lambda l: str(l.message))] == [ + 'Invalid value for `socket_connect_timeout` in connection URL.', + 'Invalid value for `socket_timeout` in connection URL.', + ] + def test_extra_querystring_options(self): pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2') assert pool.connection_class == redis.Connection |