diff options
Diffstat (limited to 'redis/client.py')
-rwxr-xr-x | redis/client.py | 154 |
1 files changed, 88 insertions, 66 deletions
diff --git a/redis/client.py b/redis/client.py index b37e96d..6383e55 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,6 +6,7 @@ import warnings import time import threading import time as mod_time +import hashlib from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, itervalues, izip, long, nativestr, unicode, safe_unicode) @@ -207,7 +208,7 @@ def zset_score_pairs(response, **options): If ``withscores`` is specified in the options, return the response as a list of (value, score) pairs """ - if not response or not options['withscores']: + if not response or not options.get('withscores'): return response score_cast_func = options.get('score_cast_func', float) it = iter(response) @@ -339,6 +340,10 @@ def parse_georadius_generic(response, **options): ] +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + class StrictRedis(object): """ Implementation of the Redis protocol. @@ -356,9 +361,9 @@ class StrictRedis(object): bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN ' - 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT ' - 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' + 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY ' + 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE ' + 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE ' 'GEOADD', int @@ -443,10 +448,12 @@ class StrictRedis(object): 'CLUSTER SETSLOT': bool_ok, 'CLUSTER SLAVES': parse_cluster_nodes, 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), - float(ll[1])), r)), + float(ll[1])) + if ll is not None else None, r)), 'GEOHASH': lambda r: list(map(nativestr, r)), 'GEORADIUS': parse_georadius_generic, 'GEORADIUSBYMEMBER': parse_georadius_generic, + 'PUBSUB NUMSUB': parse_pubsub_numsub, } ) @@ -1138,19 +1145,19 @@ class StrictRedis(object): ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. - ``nx`` if set to True, set the value at key ``name`` to ``value`` if it - does not already exist. + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. - ``xx`` if set to True, set the value at key ``name`` to ``value`` if it - already exists. + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. """ pieces = [name, value] - if ex: + if ex is not None: pieces.append('EX') if isinstance(ex, datetime.timedelta): ex = ex.seconds + ex.days * 24 * 3600 pieces.append(ex) - if px: + if px is not None: pieces.append('PX') if isinstance(px, datetime.timedelta): ms = int(px.microseconds / 1000) @@ -1212,6 +1219,13 @@ class StrictRedis(object): """ return self.execute_command('SUBSTR', name, start, end) + def touch(self, *args): + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command('TOUCH', *args) + def ttl(self, name): "Returns the number of seconds until the key ``name`` will expire" return self.execute_command('TTL', name) @@ -1259,7 +1273,7 @@ class StrictRedis(object): RPOP a value off of the first non-empty list named in the ``keys`` list. - If none of the lists in ``keys`` has a value to LPOP, then block + If none of the lists in ``keys`` has a value to RPOP, then block for ``timeout`` seconds, or until a value gets pushed on to one of the lists. @@ -1639,7 +1653,7 @@ class StrictRedis(object): memebers of set ``name``. Note this is only available when running Redis 2.6+. """ - args = number and [number] or [] + args = (number is not None) and [number] or [] return self.execute_command('SRANDMEMBER', name, *args) def srem(self, name, *values): @@ -2005,6 +2019,13 @@ class StrictRedis(object): "Return the list of values within hash ``name``" return self.execute_command('HVALS', name) + def hstrlen(self, name, key): + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command('HSTRLEN', name, key) + def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -2012,6 +2033,25 @@ class StrictRedis(object): """ return self.execute_command('PUBLISH', channel, message) + def pubsub_channels(self, pattern='*'): + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command('PUBSUB CHANNELS', pattern) + + def pubsub_numpat(self): + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command('PUBSUB NUMPAT') + + def pubsub_numsub(self, *args): + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command('PUBSUB NUMSUB', *args) + def cluster(self, cluster_arg, *args): return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) @@ -2073,10 +2113,10 @@ class StrictRedis(object): Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered members of the ``values`` argument, each item or place is formed by - the triad latitude, longitude and name. + the triad longitude, latitude and name. """ if len(values) % 3 != 0: - raise RedisError("GEOADD requires places with lat, lon and name" + raise RedisError("GEOADD requires places with lon, lat and name" " values") return self.execute_command('GEOADD', name, *values) @@ -2105,7 +2145,7 @@ class StrictRedis(object): """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name``argument. Each position - is represented by the pairs lat and lon. + is represented by the pairs lon and lat. """ return self.execute_command('GEOPOS', name, *values) @@ -2300,13 +2340,7 @@ class PubSub(object): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection('pubsub', shard_hint) - try: - self.encoding = conn.encoding - self.encoding_errors = conn.encoding_errors - self.decode_responses = conn.decode_responses - finally: - connection_pool.release(conn) + self.encoder = self.connection_pool.get_encoder() self.reset() def __del__(self): @@ -2338,29 +2372,14 @@ class PubSub(object): if self.channels: channels = {} for k, v in iteritems(self.channels): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - channels[k] = v + channels[self.encoder.decode(k, force=True)] = v self.subscribe(**channels) if self.patterns: patterns = {} for k, v in iteritems(self.patterns): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - patterns[k] = v + patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) - def encode(self, value): - """ - Encode the value so that it's identical to what we'll - read off the connection - """ - if self.decode_responses and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - elif not self.decode_responses and isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -2410,6 +2429,16 @@ class PubSub(object): return None return self._execute(connection, connection.read_response) + def _normalize_keys(self, data): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return dict([(decode(encode(k)), v) for k, v in iteritems(data)]) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments @@ -2420,15 +2449,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_patterns = {} - new_patterns.update(dict.fromkeys(imap(self.encode, args))) - for pattern, handler in iteritems(kwargs): - new_patterns[self.encode(pattern)] = handler + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(new_patterns) + self.patterns.update(self._normalize_keys(new_patterns)) return ret_val def punsubscribe(self, *args): @@ -2450,15 +2477,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_channels = {} - new_channels.update(dict.fromkeys(imap(self.encode, args))) - for channel, handler in iteritems(kwargs): - new_channels[self.encode(channel)] = handler + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(new_channels) + self.channels.update(self._normalize_keys(new_channels)) return ret_val def unsubscribe(self, *args): @@ -2824,12 +2849,11 @@ class BasePipeline(object): shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + exists = immediate('SCRIPT EXISTS', *shas) if not all(exists): for s, exist in izip(scripts, exists): if not exist: - s.sha = immediate('SCRIPT', 'LOAD', s.script, - **{'parse': 'LOAD'}) + s.sha = immediate('SCRIPT LOAD', s.script) def execute(self, raise_on_error=True): "Execute all the commands in the current pipeline" @@ -2881,16 +2905,6 @@ class BasePipeline(object): "Unwatches all previously specified keys" return self.watching and self.execute_command('UNWATCH') or True - def script_load_for_pipeline(self, script): - "Make sure scripts are loaded prior to pipeline execution" - # we need the sha now so that Script.__call__ can use it to run - # evalsha. - if not script.sha: - script.sha = self.immediate_execute_command('SCRIPT', 'LOAD', - script.script, - **{'parse': 'LOAD'}) - self.scripts.add(script) - class StrictPipeline(BasePipeline, StrictRedis): "Pipeline for the StrictRedis class" @@ -2908,7 +2922,14 @@ class Script(object): def __init__(self, registered_client, script): self.registered_client = registered_client self.script = script - self.sha = '' + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, basestring): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() def __call__(self, keys=[], args=[], client=None): "Execute the script, passing any required ``args``" @@ -2917,12 +2938,13 @@ class Script(object): args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script if isinstance(client, BasePipeline): - # make sure this script is good to go on pipeline - client.script_load_for_pipeline(self) + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) try: return client.evalsha(self.sha, len(keys), *args) except NoScriptError: # Maybe the client is pointed to a differnet server than the client # that created this instance? + # Overwrite the sha just in case there was a discrepancy. self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) |