summaryrefslogtreecommitdiff
path: root/redis/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/client.py')
-rwxr-xr-xredis/client.py154
1 files changed, 88 insertions, 66 deletions
diff --git a/redis/client.py b/redis/client.py
index b37e96d..6383e55 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -6,6 +6,7 @@ import warnings
import time
import threading
import time as mod_time
+import hashlib
from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys,
itervalues, izip, long, nativestr, unicode,
safe_unicode)
@@ -207,7 +208,7 @@ def zset_score_pairs(response, **options):
If ``withscores`` is specified in the options, return the response as
a list of (value, score) pairs
"""
- if not response or not options['withscores']:
+ if not response or not options.get('withscores'):
return response
score_cast_func = options.get('score_cast_func', float)
it = iter(response)
@@ -339,6 +340,10 @@ def parse_georadius_generic(response, **options):
]
+def parse_pubsub_numsub(response, **options):
+ return list(zip(response[0::2], response[1::2]))
+
+
class StrictRedis(object):
"""
Implementation of the Redis protocol.
@@ -356,9 +361,9 @@ class StrictRedis(object):
bool
),
string_keys_to_dict(
- 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN '
- 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT '
- 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD '
+ 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY '
+ 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE '
+ 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD '
'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE '
'GEOADD',
int
@@ -443,10 +448,12 @@ class StrictRedis(object):
'CLUSTER SETSLOT': bool_ok,
'CLUSTER SLAVES': parse_cluster_nodes,
'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]),
- float(ll[1])), r)),
+ float(ll[1]))
+ if ll is not None else None, r)),
'GEOHASH': lambda r: list(map(nativestr, r)),
'GEORADIUS': parse_georadius_generic,
'GEORADIUSBYMEMBER': parse_georadius_generic,
+ 'PUBSUB NUMSUB': parse_pubsub_numsub,
}
)
@@ -1138,19 +1145,19 @@ class StrictRedis(object):
``px`` sets an expire flag on key ``name`` for ``px`` milliseconds.
- ``nx`` if set to True, set the value at key ``name`` to ``value`` if it
- does not already exist.
+ ``nx`` if set to True, set the value at key ``name`` to ``value`` only
+ if it does not exist.
- ``xx`` if set to True, set the value at key ``name`` to ``value`` if it
- already exists.
+ ``xx`` if set to True, set the value at key ``name`` to ``value`` only
+ if it already exists.
"""
pieces = [name, value]
- if ex:
+ if ex is not None:
pieces.append('EX')
if isinstance(ex, datetime.timedelta):
ex = ex.seconds + ex.days * 24 * 3600
pieces.append(ex)
- if px:
+ if px is not None:
pieces.append('PX')
if isinstance(px, datetime.timedelta):
ms = int(px.microseconds / 1000)
@@ -1212,6 +1219,13 @@ class StrictRedis(object):
"""
return self.execute_command('SUBSTR', name, start, end)
+ def touch(self, *args):
+ """
+ Alters the last access time of a key(s) ``*args``. A key is ignored
+ if it does not exist.
+ """
+ return self.execute_command('TOUCH', *args)
+
def ttl(self, name):
"Returns the number of seconds until the key ``name`` will expire"
return self.execute_command('TTL', name)
@@ -1259,7 +1273,7 @@ class StrictRedis(object):
RPOP a value off of the first non-empty list
named in the ``keys`` list.
- If none of the lists in ``keys`` has a value to LPOP, then block
+ If none of the lists in ``keys`` has a value to RPOP, then block
for ``timeout`` seconds, or until a value gets pushed on to one
of the lists.
@@ -1639,7 +1653,7 @@ class StrictRedis(object):
memebers of set ``name``. Note this is only available when running
Redis 2.6+.
"""
- args = number and [number] or []
+ args = (number is not None) and [number] or []
return self.execute_command('SRANDMEMBER', name, *args)
def srem(self, name, *values):
@@ -2005,6 +2019,13 @@ class StrictRedis(object):
"Return the list of values within hash ``name``"
return self.execute_command('HVALS', name)
+ def hstrlen(self, name, key):
+ """
+ Return the number of bytes stored in the value of ``key``
+ within hash ``name``
+ """
+ return self.execute_command('HSTRLEN', name, key)
+
def publish(self, channel, message):
"""
Publish ``message`` on ``channel``.
@@ -2012,6 +2033,25 @@ class StrictRedis(object):
"""
return self.execute_command('PUBLISH', channel, message)
+ def pubsub_channels(self, pattern='*'):
+ """
+ Return a list of channels that have at least one subscriber
+ """
+ return self.execute_command('PUBSUB CHANNELS', pattern)
+
+ def pubsub_numpat(self):
+ """
+ Returns the number of subscriptions to patterns
+ """
+ return self.execute_command('PUBSUB NUMPAT')
+
+ def pubsub_numsub(self, *args):
+ """
+ Return a list of (channel, number of subscribers) tuples
+ for each channel given in ``*args``
+ """
+ return self.execute_command('PUBSUB NUMSUB', *args)
+
def cluster(self, cluster_arg, *args):
return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args)
@@ -2073,10 +2113,10 @@ class StrictRedis(object):
Add the specified geospatial items to the specified key identified
by the ``name`` argument. The Geospatial items are given as ordered
members of the ``values`` argument, each item or place is formed by
- the triad latitude, longitude and name.
+ the triad longitude, latitude and name.
"""
if len(values) % 3 != 0:
- raise RedisError("GEOADD requires places with lat, lon and name"
+ raise RedisError("GEOADD requires places with lon, lat and name"
" values")
return self.execute_command('GEOADD', name, *values)
@@ -2105,7 +2145,7 @@ class StrictRedis(object):
"""
Return the positions of each item of ``values`` as members of
the specified key identified by the ``name``argument. Each position
- is represented by the pairs lat and lon.
+ is represented by the pairs lon and lat.
"""
return self.execute_command('GEOPOS', name, *values)
@@ -2300,13 +2340,7 @@ class PubSub(object):
self.connection = None
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
- conn = connection_pool.get_connection('pubsub', shard_hint)
- try:
- self.encoding = conn.encoding
- self.encoding_errors = conn.encoding_errors
- self.decode_responses = conn.decode_responses
- finally:
- connection_pool.release(conn)
+ self.encoder = self.connection_pool.get_encoder()
self.reset()
def __del__(self):
@@ -2338,29 +2372,14 @@ class PubSub(object):
if self.channels:
channels = {}
for k, v in iteritems(self.channels):
- if not self.decode_responses:
- k = k.decode(self.encoding, self.encoding_errors)
- channels[k] = v
+ channels[self.encoder.decode(k, force=True)] = v
self.subscribe(**channels)
if self.patterns:
patterns = {}
for k, v in iteritems(self.patterns):
- if not self.decode_responses:
- k = k.decode(self.encoding, self.encoding_errors)
- patterns[k] = v
+ patterns[self.encoder.decode(k, force=True)] = v
self.psubscribe(**patterns)
- def encode(self, value):
- """
- Encode the value so that it's identical to what we'll
- read off the connection
- """
- if self.decode_responses and isinstance(value, bytes):
- value = value.decode(self.encoding, self.encoding_errors)
- elif not self.decode_responses and isinstance(value, unicode):
- value = value.encode(self.encoding, self.encoding_errors)
- return value
-
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
@@ -2410,6 +2429,16 @@ class PubSub(object):
return None
return self._execute(connection, connection.read_response)
+ def _normalize_keys(self, data):
+ """
+ normalize channel/pattern names to be either bytes or strings
+ based on whether responses are automatically decoded. this saves us
+ from coercing the value for each message coming in.
+ """
+ encode = self.encoder.encode
+ decode = self.encoder.decode
+ return dict([(decode(encode(k)), v) for k, v in iteritems(data)])
+
def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
@@ -2420,15 +2449,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- new_patterns = {}
- new_patterns.update(dict.fromkeys(imap(self.encode, args)))
- for pattern, handler in iteritems(kwargs):
- new_patterns[self.encode(pattern)] = handler
+ new_patterns = dict.fromkeys(args)
+ new_patterns.update(kwargs)
ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
# update the patterns dict AFTER we send the command. we don't want to
# subscribe twice to these patterns, once for the command and again
# for the reconnection.
- self.patterns.update(new_patterns)
+ self.patterns.update(self._normalize_keys(new_patterns))
return ret_val
def punsubscribe(self, *args):
@@ -2450,15 +2477,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- new_channels = {}
- new_channels.update(dict.fromkeys(imap(self.encode, args)))
- for channel, handler in iteritems(kwargs):
- new_channels[self.encode(channel)] = handler
+ new_channels = dict.fromkeys(args)
+ new_channels.update(kwargs)
ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
# update the channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
- self.channels.update(new_channels)
+ self.channels.update(self._normalize_keys(new_channels))
return ret_val
def unsubscribe(self, *args):
@@ -2824,12 +2849,11 @@ class BasePipeline(object):
shas = [s.sha for s in scripts]
# we can't use the normal script_* methods because they would just
# get buffered in the pipeline.
- exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'})
+ exists = immediate('SCRIPT EXISTS', *shas)
if not all(exists):
for s, exist in izip(scripts, exists):
if not exist:
- s.sha = immediate('SCRIPT', 'LOAD', s.script,
- **{'parse': 'LOAD'})
+ s.sha = immediate('SCRIPT LOAD', s.script)
def execute(self, raise_on_error=True):
"Execute all the commands in the current pipeline"
@@ -2881,16 +2905,6 @@ class BasePipeline(object):
"Unwatches all previously specified keys"
return self.watching and self.execute_command('UNWATCH') or True
- def script_load_for_pipeline(self, script):
- "Make sure scripts are loaded prior to pipeline execution"
- # we need the sha now so that Script.__call__ can use it to run
- # evalsha.
- if not script.sha:
- script.sha = self.immediate_execute_command('SCRIPT', 'LOAD',
- script.script,
- **{'parse': 'LOAD'})
- self.scripts.add(script)
-
class StrictPipeline(BasePipeline, StrictRedis):
"Pipeline for the StrictRedis class"
@@ -2908,7 +2922,14 @@ class Script(object):
def __init__(self, registered_client, script):
self.registered_client = registered_client
self.script = script
- self.sha = ''
+ # Precalculate and store the SHA1 hex digest of the script.
+
+ if isinstance(script, basestring):
+ # We need the encoding from the client in order to generate an
+ # accurate byte representation of the script
+ encoder = registered_client.connection_pool.get_encoder()
+ script = encoder.encode(script)
+ self.sha = hashlib.sha1(script).hexdigest()
def __call__(self, keys=[], args=[], client=None):
"Execute the script, passing any required ``args``"
@@ -2917,12 +2938,13 @@ class Script(object):
args = tuple(keys) + tuple(args)
# make sure the Redis server knows about the script
if isinstance(client, BasePipeline):
- # make sure this script is good to go on pipeline
- client.script_load_for_pipeline(self)
+ # Make sure the pipeline can register the script before executing.
+ client.scripts.add(self)
try:
return client.evalsha(self.sha, len(keys), *args)
except NoScriptError:
# Maybe the client is pointed to a differnet server than the client
# that created this instance?
+ # Overwrite the sha just in case there was a discrepancy.
self.sha = client.script_load(self.script)
return client.evalsha(self.sha, len(keys), *args)