diff options
-rw-r--r-- | CHANGES | 30 | ||||
-rw-r--r-- | README.rst | 2 | ||||
-rw-r--r-- | redis/__init__.py | 2 | ||||
-rwxr-xr-x | redis/client.py | 123 | ||||
-rwxr-xr-x | redis/connection.py | 84 | ||||
-rw-r--r-- | setup.py | 5 | ||||
-rw-r--r-- | tests/conftest.py | 5 | ||||
-rw-r--r-- | tests/test_commands.py | 27 | ||||
-rw-r--r-- | tests/test_scripting.py | 34 |
9 files changed, 200 insertions, 112 deletions
@@ -1,4 +1,32 @@ -* 2.10.6 (in development) +* 2.10.6 + * Various performance improvements. Thanks cjsimpson + * Fixed a bug with SRANDMEMBER where + * Added HSTRLEN command. Thanks Alexander Putilin + * Added the TOUCH command. Thanks Anis Jonischkeit + * Remove unnecessary calls to the server when registering Lua scripts. + Thanks Ben Greenberg + * SET's EX and PX arguments now allow values of zero. Thanks huangqiyin + * Added PUBSUB {CHANNELS, NUMPAT, NUMSUB} commands. Thanks Angus Pearson + * PubSub connections that that encounter `InterruptedError`s now + retry automatically. Thanks Carlton Gibson and Seth M. Larson + * LPUSH and RPUSH commands run on PyPy now correctly returns the number + of items of the list. Thanks Jeong YunWon + * Added support to automatically retry socket EINTR errors. Thanks + Thomas Steinacher + * PubSubWorker threads started with `run_in_thread` are now daemonized + so the thread shuts down when the running process goes away. Thanks + Keith Ainsworth + * Added support for GEO commands. Thanks Pau Freixes, Alex DeBrie and + Abraham Toriz + * Made client construction from URLs smarter. Thanks Tim Savage + * Added support for CLUSTER * commands. Thanks Andy Huang + * The RESTORE command now accepts an optional `replace` boolean. + Thanks Yoshinari Takaoka + * Attempt to connect to a new Sentinel if a TimeoutError occurs. Thanks + Bo Lopker + * Fixed a bug in the client's `__getitem__` where a KeyError would be + raised if the value returned by the server is an empty string. + Thanks Javier Candeira. * Socket timeouts when connecting to a server are now properly raised as TimeoutErrors. * 2.10.5 @@ -183,7 +183,7 @@ set_response_callback method. This method accepts two arguments: a command name and the callback. Callbacks added in this manner are only valid on the instance the callback is added to. If you want to define or override a callback globally, you should make a subclass of the Redis client and add your callback -to its REDIS_CALLBACKS class dictionary. +to its RESPONSE_CALLBACKS class dictionary. Response callbacks take at least one parameter: the response from the Redis server. Keyword arguments may also be accepted in order to further control diff --git a/redis/__init__.py b/redis/__init__.py index 11281ba..6607155 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -22,7 +22,7 @@ from redis.exceptions import ( ) -__version__ = '2.10.5' +__version__ = '2.10.6' VERSION = tuple(map(int, __version__.split('.'))) __all__ = [ diff --git a/redis/client.py b/redis/client.py index 92095de..e3b5ed7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,6 +6,7 @@ import warnings import time import threading import time as mod_time +import hashlib from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, itervalues, izip, long, nativestr, unicode, safe_unicode) @@ -244,7 +245,8 @@ def bool_ok(response): def parse_client_list(response, **options): clients = [] for c in nativestr(response).splitlines(): - clients.append(dict([pair.split('=') for pair in c.split(' ')])) + # Values might contain '=' + clients.append(dict([pair.split('=', 1) for pair in c.split(' ')])) return clients @@ -360,9 +362,9 @@ class StrictRedis(object): bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN ' - 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT ' - 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' + 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY ' + 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE ' + 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE ' 'GEOADD', int @@ -447,7 +449,8 @@ class StrictRedis(object): 'CLUSTER SETSLOT': bool_ok, 'CLUSTER SLAVES': parse_cluster_nodes, 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), - float(ll[1])), r)), + float(ll[1])) + if ll is not None else None, r)), 'GEOHASH': lambda r: list(map(nativestr, r)), 'GEORADIUS': parse_georadius_generic, 'GEORADIUSBYMEMBER': parse_georadius_generic, @@ -1143,11 +1146,11 @@ class StrictRedis(object): ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. - ``nx`` if set to True, set the value at key ``name`` to ``value`` if it - does not already exist. + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. - ``xx`` if set to True, set the value at key ``name`` to ``value`` if it - already exists. + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. """ pieces = [name, value] if ex is not None: @@ -1217,6 +1220,13 @@ class StrictRedis(object): """ return self.execute_command('SUBSTR', name, start, end) + def touch(self, *args): + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command('TOUCH', *args) + def ttl(self, name): "Returns the number of seconds until the key ``name`` will expire" return self.execute_command('TTL', name) @@ -1264,7 +1274,7 @@ class StrictRedis(object): RPOP a value off of the first non-empty list named in the ``keys`` list. - If none of the lists in ``keys`` has a value to LPOP, then block + If none of the lists in ``keys`` has a value to RPOP, then block for ``timeout`` seconds, or until a value gets pushed on to one of the lists. @@ -2010,6 +2020,13 @@ class StrictRedis(object): "Return the list of values within hash ``name``" return self.execute_command('HVALS', name) + def hstrlen(self, name, key): + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command('HSTRLEN', name, key) + def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -2100,7 +2117,7 @@ class StrictRedis(object): the triad longitude, latitude and name. """ if len(values) % 3 != 0: - raise RedisError("GEOADD requires places with lat, lon and name" + raise RedisError("GEOADD requires places with lon, lat and name" " values") return self.execute_command('GEOADD', name, *values) @@ -2129,7 +2146,7 @@ class StrictRedis(object): """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name``argument. Each position - is represented by the pairs lat and lon. + is represented by the pairs lon and lat. """ return self.execute_command('GEOPOS', name, *values) @@ -2324,13 +2341,7 @@ class PubSub(object): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection('pubsub', shard_hint) - try: - self.encoding = conn.encoding - self.encoding_errors = conn.encoding_errors - self.decode_responses = conn.decode_responses - finally: - connection_pool.release(conn) + self.encoder = self.connection_pool.get_encoder() self.reset() def __del__(self): @@ -2362,29 +2373,14 @@ class PubSub(object): if self.channels: channels = {} for k, v in iteritems(self.channels): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - channels[k] = v + channels[self.encoder.decode(k, force=True)] = v self.subscribe(**channels) if self.patterns: patterns = {} for k, v in iteritems(self.patterns): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - patterns[k] = v + patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) - def encode(self, value): - """ - Encode the value so that it's identical to what we'll - read off the connection - """ - if self.decode_responses and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - elif not self.decode_responses and isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -2434,6 +2430,16 @@ class PubSub(object): return None return self._execute(connection, connection.read_response) + def _normalize_keys(self, data): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return dict([(decode(encode(k)), v) for k, v in iteritems(data)]) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments @@ -2444,15 +2450,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_patterns = {} - new_patterns.update(dict.fromkeys(imap(self.encode, args))) - for pattern, handler in iteritems(kwargs): - new_patterns[self.encode(pattern)] = handler + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(new_patterns) + self.patterns.update(self._normalize_keys(new_patterns)) return ret_val def punsubscribe(self, *args): @@ -2474,15 +2478,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_channels = {} - new_channels.update(dict.fromkeys(imap(self.encode, args))) - for channel, handler in iteritems(kwargs): - new_channels[self.encode(channel)] = handler + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(new_channels) + self.channels.update(self._normalize_keys(new_channels)) return ret_val def unsubscribe(self, *args): @@ -2848,12 +2850,11 @@ class BasePipeline(object): shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + exists = immediate('SCRIPT EXISTS', *shas) if not all(exists): for s, exist in izip(scripts, exists): if not exist: - s.sha = immediate('SCRIPT', 'LOAD', s.script, - **{'parse': 'LOAD'}) + s.sha = immediate('SCRIPT LOAD', s.script) def execute(self, raise_on_error=True): "Execute all the commands in the current pipeline" @@ -2905,16 +2906,6 @@ class BasePipeline(object): "Unwatches all previously specified keys" return self.watching and self.execute_command('UNWATCH') or True - def script_load_for_pipeline(self, script): - "Make sure scripts are loaded prior to pipeline execution" - # we need the sha now so that Script.__call__ can use it to run - # evalsha. - if not script.sha: - script.sha = self.immediate_execute_command('SCRIPT', 'LOAD', - script.script, - **{'parse': 'LOAD'}) - self.scripts.add(script) - class StrictPipeline(BasePipeline, StrictRedis): "Pipeline for the StrictRedis class" @@ -2932,7 +2923,14 @@ class Script(object): def __init__(self, registered_client, script): self.registered_client = registered_client self.script = script - self.sha = '' + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, basestring): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() def __call__(self, keys=[], args=[], client=None): "Execute the script, passing any required ``args``" @@ -2941,12 +2939,13 @@ class Script(object): args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script if isinstance(client, BasePipeline): - # make sure this script is good to go on pipeline - client.script_load_for_pipeline(self) + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) try: return client.evalsha(self.sha, len(keys), *args) except NoScriptError: # Maybe the client is pointed to a differnet server than the client # that created this instance? + # Overwrite the sha just in case there was a discrepancy. self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) diff --git a/redis/connection.py b/redis/connection.py index 728a627..c8f0513 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -94,6 +94,38 @@ class Token(object): return self.value +class Encoder(object): + "Encode strings to bytes and decode bytes to strings" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring representation of the value" + if isinstance(value, Token): + return value.encoded_value + elif isinstance(value, bytes): + return value + elif isinstance(value, (int, long)): + value = b(str(value)) + elif isinstance(value, float): + value = b(repr(value)) + elif not isinstance(value, basestring): + # an object we don't know how to deal with. default to unicode() + value = unicode(value) + if isinstance(value, unicode): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the byte representation" + if (self.decode_responses or force) and isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value + + class BaseParser(object): EXCEPTION_CLASSES = { 'ERR': { @@ -217,10 +249,9 @@ class SocketBuffer(object): class PythonParser(BaseParser): "Plain Python parsing class" - encoding = None - def __init__(self, socket_read_size): self.socket_read_size = socket_read_size + self.encoder = None self._sock = None self._buffer = None @@ -234,8 +265,7 @@ class PythonParser(BaseParser): "Called when the socket connects" self._sock = connection._sock self._buffer = SocketBuffer(self._sock, self.socket_read_size) - if connection.decode_responses: - self.encoding = connection.encoding + self.encoder = connection.encoder def on_disconnect(self): "Called when the socket disconnects" @@ -245,7 +275,7 @@ class PythonParser(BaseParser): if self._buffer is not None: self._buffer.close() self._buffer = None - self.encoding = None + self.encoder = None def can_read(self): return self._buffer and bool(self._buffer.length) @@ -292,8 +322,8 @@ class PythonParser(BaseParser): if length == -1: return None response = [self.read_response() for i in xrange(length)] - if isinstance(response, bytes) and self.encoding: - response = response.decode(self.encoding) + if isinstance(response, bytes): + response = self.encoder.decode(response) return response @@ -324,8 +354,8 @@ class HiredisParser(BaseParser): if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: kwargs['replyError'] = ResponseError - if connection.decode_responses: - kwargs['encoding'] = connection.encoding + if connection.encoder.decode_responses: + kwargs['encoding'] = connection.encoder.encoding self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -394,6 +424,7 @@ class HiredisParser(BaseParser): raise response[0] return response + if HIREDIS_AVAILABLE: DefaultParser = HiredisParser else: @@ -420,9 +451,7 @@ class Connection(object): self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -600,22 +629,6 @@ class Connection(object): raise response return response - def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, Token): - return value.encoded_value - elif isinstance(value, bytes): - return value - elif isinstance(value, (int, long)): - value = b(str(value)) - elif isinstance(value, float): - value = b(repr(value)) - elif not isinstance(value, basestring): - value = unicode(value) - if isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - def pack_command(self, *args): "Pack a series of arguments into the Redis protocol" output = [] @@ -634,7 +647,7 @@ class Connection(object): buff = SYM_EMPTY.join( (SYM_STAR, b(str(len(args))), SYM_CRLF)) - for arg in imap(self.encode, args): + for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values if len(buff) > 6000 or len(arg) > 6000: @@ -723,9 +736,7 @@ class UnixDomainSocketConnection(Connection): self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -955,6 +966,15 @@ class ConnectionPool(object): self._in_use_connections.add(connection) return connection + def get_encoder(self): + "Return an encoder based on encoding settings" + kwargs = self.connection_kwargs + return Encoder( + encoding=kwargs.get('encoding', 'utf-8'), + encoding_errors=kwargs.get('encoding_errors', 'strict'), + decode_responses=kwargs.get('decode_responses', False) + ) + def make_connection(self): "Create a new connection" if self._created_connections >= self.max_connections: @@ -44,6 +44,11 @@ setup( keywords=['Redis', 'key-value store'], license='MIT', packages=['redis'], + extras_require={ + 'hiredis': [ + "hiredis>=0.1.3", + ], + }, tests_require=[ 'mock', 'pytest>=2.5.0', diff --git a/tests/conftest.py b/tests/conftest.py index d7b2b14..006697f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,11 @@ def skip_if_server_version_lt(min_version): return pytest.mark.skipif(check, reason="") +def skip_if_server_version_gte(min_version): + check = StrictVersion(get_version()) >= StrictVersion(min_version) + return pytest.mark.skipif(check, reason="") + + @pytest.fixture() def r(request, **kwargs): return _get_client(redis.Redis, request, **kwargs) diff --git a/tests/test_commands.py b/tests/test_commands.py index 246ef65..ed46298 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -10,7 +10,7 @@ from redis._compat import (unichr, u, b, ascii_letters, iteritems, iterkeys, from redis.client import parse_info from redis import exceptions -from .conftest import skip_if_server_version_lt +from .conftest import skip_if_server_version_lt, skip_if_server_version_gte @pytest.fixture() @@ -68,6 +68,14 @@ class TestRedisCommands(object): assert r.client_setname('redis_py_test') assert r.client_getname() == 'redis_py_test' + @skip_if_server_version_lt('2.6.9') + def test_client_list_after_client_setname(self, r): + r.client_setname('cl=i=ent') + clients = r.client_list() + assert isinstance(clients[0], dict) + assert 'name' in clients[0] + assert clients[0]['name'] == 'cl=i=ent' + def test_config_get(self, r): data = r.config_get() assert 'maxmemory' in data @@ -901,6 +909,8 @@ class TestRedisCommands(object): r.zadd('a', a1=1, a2=2, a3=3) assert r.zcount('a', '-inf', '+inf') == 3 assert r.zcount('a', 1, 2) == 2 + assert r.zcount('a', '(' + str(1), 2) == 1 + assert r.zcount('a', 1, '(' + str(2)) == 1 assert r.zcount('a', 10, 20) == 0 def test_zincrby(self, r): @@ -1229,6 +1239,12 @@ class TestRedisCommands(object): remote_vals = r.hvals('a') assert sorted(local_vals) == sorted(remote_vals) + @skip_if_server_version_lt('3.2.0') + def test_hstrlen(self, r): + r.hmset('a', {'1': '22', '2': '333'}) + assert r.hstrlen('a', '1') == 2 + assert r.hstrlen('a', '2') == 3 + # SORT def test_sort_basic(self, r): r.rpush('a', '3', '2', '1', '4') @@ -1453,6 +1469,15 @@ class TestRedisCommands(object): [(2.19093829393386841, 41.43379028184083523), (2.18737632036209106, 41.40634178640635099)] + @skip_if_server_version_lt('4.0.0') + def test_geopos_no_value(self, r): + assert r.geopos('barcelona', 'place1', 'place2') == [None, None] + + @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_gte('4.0.0') + def test_old_geopos_no_value(self, r): + assert r.geopos('barcelona', 'place1', 'place2') == [] + @skip_if_server_version_lt('3.2.0') def test_georadius(self, r): values = (2.1909389952632, 41.433791470673, 'place1') +\ diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 2213ec6..9131de8 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -57,44 +57,49 @@ class TestScripting(object): def test_script_object(self, r): r.set('a', 2) multiply = r.register_script(multiply_script) - assert not multiply.sha - # test evalsha fail -> script load + retry + precalculated_sha = multiply.sha + assert precalculated_sha + assert r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) assert multiply(keys=['a'], args=[3]) == 6 - assert multiply.sha + # At this point, the script should be loaded assert r.script_exists(multiply.sha) == [True] - # test first evalsha + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block assert multiply(keys=['a'], args=[3]) == 6 def test_script_object_in_pipeline(self, r): multiply = r.register_script(multiply_script) - assert not multiply.sha + precalculated_sha = multiply.sha + assert precalculated_sha pipe = r.pipeline() pipe.set('a', 2) pipe.get('a') multiply(keys=['a'], args=[3], client=pipe) - # even though the pipeline wasn't executed yet, we made sure the - # script was loaded and got a valid sha - assert multiply.sha - assert r.script_exists(multiply.sha) == [True] + assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] assert pipe.execute() == [True, b('2'), 6] + # The script should have been loaded by pipe.execute() + assert r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha # purge the script from redis's cache and re-run the pipeline - # the multiply script object knows it's sha, so it shouldn't get - # reloaded until pipe.execute() + # the multiply script should be reloaded by pipe.execute() r.script_flush() pipe = r.pipeline() pipe.set('a', 2) pipe.get('a') - assert multiply.sha multiply(keys=['a'], args=[3], client=pipe) assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] assert pipe.execute() == [True, b('2'), 6] + assert r.script_exists(multiply.sha) == [True] def test_eval_msgpack_pipeline_error_in_lua(self, r): msgpack_hello = r.register_script(msgpack_hello_script) - assert not msgpack_hello.sha + assert msgpack_hello.sha pipe = r.pipeline() @@ -104,8 +109,9 @@ class TestScripting(object): msgpack_hello(args=[msgpack_message_1], client=pipe) - assert r.script_exists(msgpack_hello.sha) == [True] + assert r.script_exists(msgpack_hello.sha) == [False] assert pipe.execute()[0] == b'hello Joe' + assert r.script_exists(msgpack_hello.sha) == [True] msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) |