diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2018-11-14 21:48:05 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-11-14 21:48:05 -0800 |
commit | d70319312460d2016153a75fcc20c7623cc612e2 (patch) | |
tree | b8dcf3a66d99ee624c2eea1eb4ca185a8799cb32 /redis/client.py | |
parent | 1e62bef251d834e9a2830d608377449b6540f804 (diff) | |
parent | a32a01cf7b76ae0c8f711b639df6719d58393970 (diff) | |
download | redis-py-d70319312460d2016153a75fcc20c7623cc612e2.tar.gz |
Merge branch 'master' into fix/no-interruptederror-on-python-2.7
Diffstat (limited to 'redis/client.py')
-rwxr-xr-x | redis/client.py | 1157 |
1 files changed, 898 insertions, 259 deletions
diff --git a/redis/client.py b/redis/client.py index 79e94d0..0a46ce2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,4 +1,4 @@ -from __future__ import with_statement +from __future__ import unicode_literals from itertools import chain import datetime import sys @@ -7,12 +7,11 @@ import time import threading import time as mod_time import hashlib -from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, - itervalues, izip, long, nativestr, unicode, - safe_unicode) +from redis._compat import (basestring, bytes, imap, iteritems, iterkeys, + itervalues, izip, long, nativestr, safe_unicode) from redis.connection import (ConnectionPool, UnixDomainSocketConnection, SSLConnection, Token) -from redis.lock import Lock, LuaLock +from redis.lock import Lock from redis.exceptions import ( ConnectionError, DataError, @@ -25,17 +24,20 @@ from redis.exceptions import ( WatchError, ) -SYM_EMPTY = b('') +SYM_EMPTY = b'' +EMPTY_RESPONSE = 'EMPTY_RESPONSE' def list_or_args(keys, args): - # returns a single list combining keys and args + # returns a single new list combining keys and args try: iter(keys) # a string or bytes instance can be iterated, but indicates # keys wasn't passed as a list if isinstance(keys, (basestring, bytes)): keys = [keys] + else: + keys = list(keys) except TypeError: keys = [keys] if args: @@ -71,7 +73,7 @@ def parse_debug_object(response): # prefixed with a name response = nativestr(response) response = 'type:' + response - response = dict([kv.split(':') for kv in response.split()]) + response = dict(kv.split(':') for kv in response.split()) # parse some expected int values from the string response # note: this cmd isn't spec'd so these may not appear in all redis versions @@ -114,7 +116,8 @@ def parse_info(response): for line in response.splitlines(): if line and not line.startswith('#'): if line.find(':') != -1: - key, value = line.split(':', 1) + # support keys that include ':' by using rsplit + key, value = line.rsplit(':', 1) info[key] = get_value(value) else: # if the line isn't splittable, append it to the "__raw__" key @@ -182,10 +185,15 @@ def parse_sentinel_get_master(response): return response and (response[0], int(response[1])) or None -def pairs_to_dict(response): +def pairs_to_dict(response, decode_keys=False): "Create a dict given a list of key/value pairs" - it = iter(response) - return dict(izip(it, it)) + if decode_keys: + # the iter form is faster, but I don't know how to make that work + # with a nativestr() map + return dict(izip(imap(nativestr, response[::2]), response[1::2])) + else: + it = iter(response) + return dict(izip(it, it)) def pairs_to_dict_typed(response, type_info): @@ -195,7 +203,7 @@ def pairs_to_dict_typed(response, type_info): if key in type_info: try: value = type_info[key](value) - except: + except Exception: # if for some reason the value can't be coerced, just use # the string value pass @@ -220,7 +228,7 @@ def sort_return_tuples(response, **options): If ``groups`` is specified, return the response as a list of n-element tuples with n being the value found in options['groups'] """ - if not response or not options['groups']: + if not response or not options.get('groups'): return response n = options['groups'] return list(izip(*[response[i::n] for i in range(n)])) @@ -232,6 +240,58 @@ def int_or_none(response): return int(response) +def parse_stream_list(response): + if response is None: + return None + return [(r[0], pairs_to_dict(r[1])) for r in response] + + +def pairs_to_dict_with_nativestr_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(imap(pairs_to_dict_with_nativestr_keys, response)) + + +def parse_xclaim(response, **options): + if options.get('parse_justid', False): + return response + return parse_stream_list(response) + + +def parse_xinfo_stream(response): + data = pairs_to_dict(response, decode_keys=True) + first = data['first-entry'] + data['first-entry'] = (first[0], pairs_to_dict(first[1])) + last = data['last-entry'] + data['last-entry'] = (last[0], pairs_to_dict(last[1])) + return data + + +def parse_xread(response): + if response is None: + return [] + return [[nativestr(r[0]), parse_stream_list(r[1])] for r in response] + + +def parse_xpending(response, **options): + if options.get('parse_detail', False): + return parse_xpending_range(response) + consumers = [{'name': n, 'pending': long(p)} for n, p in response[3] or []] + return { + 'pending': response[0], + 'min': response[1], + 'max': response[2], + 'consumers': consumers + } + + +def parse_xpending_range(response): + k = ('message_id', 'consumer', 'time_since_delivered', 'times_delivered') + return [dict(izip(k, r)) for r in response] + + def float_or_none(response): if response is None: return None @@ -242,11 +302,17 @@ def bool_ok(response): return nativestr(response) == 'OK' +def parse_zadd(response, **options): + if options.get('as_score'): + return float(response) + return int(response) + + def parse_client_list(response, **options): clients = [] for c in nativestr(response).splitlines(): # Values might contain '=' - clients.append(dict([pair.split('=', 1) for pair in c.split(' ')])) + clients.append(dict(pair.split('=', 1) for pair in c.split(' '))) return clients @@ -277,12 +343,13 @@ def parse_slowlog_get(response, **options): 'id': item[0], 'start_time': int(item[1]), 'duration': int(item[2]), - 'command': b(' ').join(item[3]) + 'command': b' '.join(item[3]) } for item in response] def parse_cluster_info(response, **options): - return dict([line.split(':') for line in response.splitlines() if line]) + response = nativestr(response) + return dict(line.split(':') for line in response.splitlines() if line) def _parse_node_line(line): @@ -304,10 +371,11 @@ def _parse_node_line(line): def parse_cluster_nodes(response, **options): + response = nativestr(response) raw_lines = response if isinstance(response, basestring): raw_lines = response.splitlines() - return dict([_parse_node_line(line) for line in raw_lines]) + return dict(_parse_node_line(line) for line in raw_lines) def parse_georadius_generic(response, **options): @@ -345,7 +413,7 @@ def parse_pubsub_numsub(response, **options): return list(zip(response[0::2], response[1::2])) -class StrictRedis(object): +class Redis(object): """ Implementation of the Redis protocol. @@ -357,20 +425,20 @@ class StrictRedis(object): """ RESPONSE_CALLBACKS = dict_merge( string_keys_to_dict( - 'AUTH EXISTS EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' + 'AUTH EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' 'PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY ' - 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE ' - 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' - 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE ' - 'GEOADD', + 'BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN ' + 'HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD ' + 'SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN ' + 'SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM ' + 'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', int ), string_keys_to_dict( - 'INCRBYFLOAT HINCRBYFLOAT GEODIST', + 'INCRBYFLOAT HINCRBYFLOAT', float ), string_keys_to_dict( @@ -379,10 +447,10 @@ class StrictRedis(object): lambda r: isinstance(r, (long, int)) and r or nativestr(r) == 'OK' ), string_keys_to_dict('SORT', sort_return_tuples), - string_keys_to_dict('ZSCORE ZINCRBY', float_or_none), + string_keys_to_dict('ZSCORE ZINCRBY GEODIST', float_or_none), string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE RENAME ' - 'SAVE SELECT SHUTDOWN SLAVEOF WATCH UNWATCH', + 'SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH', bool_ok ), string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), @@ -391,26 +459,58 @@ class StrictRedis(object): lambda r: r and set(r) or set() ), string_keys_to_dict( - 'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', + 'ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs ), + string_keys_to_dict('BZPOPMIN BZPOPMAX', \ + lambda r: r and (r[0], r[1], float(r[2])) or None), string_keys_to_dict('ZRANK ZREVRANK', int_or_none), + string_keys_to_dict('XREVRANGE XRANGE', parse_stream_list), + string_keys_to_dict('XREAD XREADGROUP', parse_xread), string_keys_to_dict('BGREWRITEAOF BGSAVE', lambda r: True), { 'CLIENT GETNAME': lambda r: r and nativestr(r), + 'CLIENT ID': int, 'CLIENT KILL': bool_ok, 'CLIENT LIST': parse_client_list, 'CLIENT SETNAME': bool_ok, + 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, + 'CLIENT PAUSE': bool_ok, + 'CLUSTER ADDSLOTS': bool_ok, + 'CLUSTER COUNT-FAILURE-REPORTS': lambda x: int(x), + 'CLUSTER COUNTKEYSINSLOT': lambda x: int(x), + 'CLUSTER DELSLOTS': bool_ok, + 'CLUSTER FAILOVER': bool_ok, + 'CLUSTER FORGET': bool_ok, + 'CLUSTER INFO': parse_cluster_info, + 'CLUSTER KEYSLOT': lambda x: int(x), + 'CLUSTER MEET': bool_ok, + 'CLUSTER NODES': parse_cluster_nodes, + 'CLUSTER REPLICATE': bool_ok, + 'CLUSTER RESET': bool_ok, + 'CLUSTER SAVECONFIG': bool_ok, + 'CLUSTER SET-CONFIG-EPOCH': bool_ok, + 'CLUSTER SETSLOT': bool_ok, + 'CLUSTER SLAVES': parse_cluster_nodes, 'CONFIG GET': parse_config_get, 'CONFIG RESETSTAT': bool_ok, 'CONFIG SET': bool_ok, 'DEBUG OBJECT': parse_debug_object, + 'GEOHASH': lambda r: list(map(nativestr, r)), + 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), + float(ll[1])) + if ll is not None else None, r)), + 'GEORADIUS': parse_georadius_generic, + 'GEORADIUSBYMEMBER': parse_georadius_generic, 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, 'HSCAN': parse_hscan, 'INFO': parse_info, 'LASTSAVE': timestamp_to_datetime, + 'MEMORY PURGE': bool_ok, + 'MEMORY USAGE': int_or_none, 'OBJECT': parse_object, 'PING': lambda r: nativestr(r) == 'PONG', + 'PUBSUB NUMSUB': parse_pubsub_numsub, 'RANDOMKEY': lambda r: r and r or None, 'SCAN': parse_scan, 'SCRIPT EXISTS': lambda r: list(imap(bool, r)), @@ -431,46 +531,41 @@ class StrictRedis(object): 'SLOWLOG RESET': bool_ok, 'SSCAN': parse_scan, 'TIME': lambda x: (int(x[0]), int(x[1])), + 'XCLAIM': parse_xclaim, + 'XGROUP CREATE': bool_ok, + 'XGROUP DELCONSUMER': int, + 'XGROUP DESTROY': bool, + 'XGROUP SETID': bool_ok, + 'XINFO CONSUMERS': parse_list_of_dicts, + 'XINFO GROUPS': parse_list_of_dicts, + 'XINFO STREAM': parse_xinfo_stream, + 'XPENDING': parse_xpending, + 'ZADD': parse_zadd, 'ZSCAN': parse_zscan, - 'CLUSTER ADDSLOTS': bool_ok, - 'CLUSTER COUNT-FAILURE-REPORTS': lambda x: int(x), - 'CLUSTER COUNTKEYSINSLOT': lambda x: int(x), - 'CLUSTER DELSLOTS': bool_ok, - 'CLUSTER FAILOVER': bool_ok, - 'CLUSTER FORGET': bool_ok, - 'CLUSTER INFO': parse_cluster_info, - 'CLUSTER KEYSLOT': lambda x: int(x), - 'CLUSTER MEET': bool_ok, - 'CLUSTER NODES': parse_cluster_nodes, - 'CLUSTER REPLICATE': bool_ok, - 'CLUSTER RESET': bool_ok, - 'CLUSTER SAVECONFIG': bool_ok, - 'CLUSTER SET-CONFIG-EPOCH': bool_ok, - 'CLUSTER SETSLOT': bool_ok, - 'CLUSTER SLAVES': parse_cluster_nodes, - 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), - float(ll[1])) - if ll is not None else None, r)), - 'GEOHASH': lambda r: list(map(nativestr, r)), - 'GEORADIUS': parse_georadius_generic, - 'GEORADIUSBYMEMBER': parse_georadius_generic, - 'PUBSUB NUMSUB': parse_pubsub_numsub, } ) @classmethod def from_url(cls, url, db=None, **kwargs): """ - Return a Redis client object configured from the given URL, which must - use either `the ``redis://`` scheme - <http://www.iana.org/assignments/uri-schemes/prov/redis>`_ for RESP - connections or the ``unix://`` scheme for Unix domain sockets. + Return a Redis client object configured from the given URL For example:: redis://[:password]@localhost:6379/0 + rediss://[:password]@localhost:6379/0 unix://[:password]@/path/to/socket.sock?db=0 + Three URL schemes are supported: + + - ```redis://`` + <http://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a + normal TCP socket connection + - ```rediss://`` + <http://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates a + SSL wrapped TCP socket connection + - ``unix://`` creates a Unix Domain Socket connection + There are several ways to specify a database number. The parse function will return the first specified option: 1. A ``db`` querystring option, e.g. redis://localhost?db=0 @@ -496,7 +591,7 @@ class StrictRedis(object): charset=None, errors=None, decode_responses=False, retry_on_timeout=False, ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None, + ssl_cert_reqs='required', ssl_ca_certs=None, max_connections=None): if not connection_pool: if charset is not None: @@ -544,7 +639,6 @@ class StrictRedis(object): }) connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool - self._use_lua_lock = None self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy() @@ -563,7 +657,7 @@ class StrictRedis(object): atomic, pipelines are useful for reducing the back-and-forth overhead between the client and server. """ - return StrictPipeline( + return Pipeline( self.connection_pool, self.response_callbacks, transaction, @@ -579,7 +673,7 @@ class StrictRedis(object): value_from_callable = kwargs.pop('value_from_callable', False) watch_delay = kwargs.pop('watch_delay', None) with self.pipeline(True, shard_hint) as pipe: - while 1: + while True: try: if watches: pipe.watch(*watches) @@ -637,15 +731,7 @@ class StrictRedis(object): is that these cases aren't common and as such default to using thread local storage. """ if lock_class is None: - if self._use_lua_lock is None: - # the first time .lock() is called, determine if we can use - # Lua by attempting to register the necessary scripts - try: - LuaLock.register_scripts(self) - self._use_lua_lock = True - except ResponseError: - self._use_lua_lock = False - lock_class = self._use_lua_lock and LuaLock or Lock + lock_class = Lock return lock_class(self, name, timeout=timeout, sleep=sleep, blocking_timeout=blocking_timeout, thread_local=thread_local) @@ -678,7 +764,12 @@ class StrictRedis(object): def parse_response(self, connection, command_name, **options): "Parses a response from the Redis server" - response = connection.read_response() + try: + response = connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response @@ -699,18 +790,56 @@ class StrictRedis(object): "Disconnects the client at ``address`` (ip:port)" return self.execute_command('CLIENT KILL', address) - def client_list(self): + def client_list(self, _type=None): + """ + Returns a list of currently connected clients. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ "Returns a list of currently connected clients" + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + return self.execute_command('CLIENT LIST', Token.get_token('TYPE'), + _type) return self.execute_command('CLIENT LIST') def client_getname(self): "Returns the current connection name" return self.execute_command('CLIENT GETNAME') + def client_id(self): + "Returns the current connection id" + return self.execute_command('CLIENT ID') + def client_setname(self, name): "Sets the current connection name" return self.execute_command('CLIENT SETNAME', name) + def client_unblock(self, client_id, error=False): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(Token.get_token('ERROR')) + return self.execute_command(*args) + + def client_pause(self, timeout): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, (int, long)): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout)) + def config_get(self, pattern="*"): "Return a dictionary of configuration based on the ``pattern``" return self.execute_command('CONFIG GET', pattern) @@ -739,13 +868,33 @@ class StrictRedis(object): "Echo the string back from the server" return self.execute_command('ECHO', value) - def flushall(self): - "Delete all keys in all databases on the current host" - return self.execute_command('FLUSHALL') + def flushall(self, asynchronous=False): + """ + Delete all keys in all databases on the current host. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if not asynchronous: + args.append(Token.get_token('ASYNC')) + return self.execute_command('FLUSHALL', *args) + + def flushdb(self, asynchronous=False): + """ + Delete all keys in the current database. - def flushdb(self): - "Delete all keys in the current database" - return self.execute_command('FLUSHDB') + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if not asynchronous: + args.append(Token.get_token('ASYNC')) + return self.execute_command('FLUSHDB', *args) + + def swapdb(self, first, second): + "Swap two databases" + return self.execute_command('SWAPDB', first, second) def info(self, section=None): """ @@ -769,10 +918,63 @@ class StrictRedis(object): """ return self.execute_command('LASTSAVE') + def migrate(self, host, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the current Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(Token.get_token('COPY')) + if replace: + pieces.append(Token.get_token('REPLACE')) + if auth: + pieces.append(Token.get_token('AUTH')) + pieces.append(auth) + pieces.append(Token.get_token('KEYS')) + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces) + def object(self, infotype, key): "Return the encoding, idletime, or refcount about the key" return self.execute_command('OBJECT', infotype, key, infotype=infotype) + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([Token.get_token('SAMPLES'), samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def memory_purge(self): + "Attempts to purge dirty pages for reclamation by allocator" + return self.execute_command('MEMORY PURGE') + def ping(self): "Ping the Redis server" return self.execute_command('PING') @@ -822,10 +1024,22 @@ class StrictRedis(object): "Returns a list of slaves for ``service_name``" return self.execute_command('SENTINEL SLAVES', service_name) - def shutdown(self): - "Shutdown the server" + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') try: - self.execute_command('SHUTDOWN') + self.execute_command(*args) except ConnectionError: # a ConnectionError here is expected return @@ -896,9 +1110,16 @@ class StrictRedis(object): params.append(end) elif (start is not None and end is None) or \ (end is not None and start is None): - raise RedisError("Both start and end must be specified") + raise DataError("Both start and end must be specified") return self.execute_command('BITCOUNT', *params) + def bitfield(self, key, default_overflow=None): + """ + Return a BitFieldOperation instance to conveniently construct one or + more bitfield operations on ``key``. + """ + return BitFieldOperation(self, key, default_overflow=default_overflow) + def bitop(self, operation, dest, *keys): """ Perform a bitwise operation using ``operation`` between ``keys`` and @@ -914,7 +1135,7 @@ class StrictRedis(object): means to look at the first three bytes. """ if bit not in (0, 1): - raise RedisError('bit must be 0 or 1') + raise DataError('bit must be 0 or 1') params = [key, bit] start is not None and params.append(start) @@ -922,8 +1143,8 @@ class StrictRedis(object): if start is not None and end is not None: params.append(end) elif start is None and end is not None: - raise RedisError("start argument is not set, " - "when end is specified") + raise DataError("start argument is not set, " + "when end is specified") return self.execute_command('BITPOS', *params) def decr(self, name, amount=1): @@ -947,9 +1168,9 @@ class StrictRedis(object): """ return self.execute_command('DUMP', name) - def exists(self, name): - "Returns a boolean indicating whether key ``name`` exists" - return self.execute_command('EXISTS', name) + def exists(self, *names): + "Returns the number of ``names`` that exist" + return self.execute_command('EXISTS', *names) __contains__ = exists def expire(self, name, time): @@ -958,7 +1179,7 @@ class StrictRedis(object): can be represented by an integer or a Python timedelta object. """ if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 + time = int(time.total_seconds()) return self.execute_command('EXPIRE', name, time) def expireat(self, name, when): @@ -1037,35 +1258,31 @@ class StrictRedis(object): Returns a list of values ordered identically to ``keys`` """ args = list_or_args(keys, args) - return self.execute_command('MGET', *args) + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command('MGET', *args, **options) - def mset(self, *args, **kwargs): + def mset(self, mapping): """ - Sets key/values based on a mapping. Mapping can be supplied as a single - dictionary argument or as kwargs. + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). """ - if args: - if len(args) != 1 or not isinstance(args[0], dict): - raise RedisError('MSET requires **kwargs or a single dict arg') - kwargs.update(args[0]) items = [] - for pair in iteritems(kwargs): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSET', *items) - def msetnx(self, *args, **kwargs): + def msetnx(self, mapping): """ Sets key/values based on a mapping if none of the keys are already set. - Mapping can be supplied as a single dictionary argument or as kwargs. + Mapping is a dictionary of key/value pairs. Both keys and values + should be strings or types that can be cast to a string via str(). Returns a boolean indicating if the operation was successful. """ - if args: - if len(args) != 1 or not isinstance(args[0], dict): - raise RedisError('MSETNX requires **kwargs or a single ' - 'dict arg') - kwargs.update(args[0]) items = [] - for pair in iteritems(kwargs): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSETNX', *items) @@ -1084,8 +1301,7 @@ class StrictRedis(object): object. """ if isinstance(time, datetime.timedelta): - ms = int(time.microseconds / 1000) - time = (time.seconds + time.days * 24 * 3600) * 1000 + ms + time = int(time.total_seconds() * 1000) return self.execute_command('PEXPIRE', name, time) def pexpireat(self, name, when): @@ -1106,8 +1322,7 @@ class StrictRedis(object): timedelta object """ if isinstance(time_ms, datetime.timedelta): - ms = int(time_ms.microseconds / 1000) - time_ms = (time_ms.seconds + time_ms.days * 24 * 3600) * 1000 + ms + time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command('PSETEX', name, time_ms, value) def pttl(self, name): @@ -1156,13 +1371,12 @@ class StrictRedis(object): if ex is not None: pieces.append('EX') if isinstance(ex, datetime.timedelta): - ex = ex.seconds + ex.days * 24 * 3600 + ex = int(ex.total_seconds()) pieces.append(ex) if px is not None: pieces.append('PX') if isinstance(px, datetime.timedelta): - ms = int(px.microseconds / 1000) - px = (px.seconds + px.days * 24 * 3600) * 1000 + ms + px = int(px.total_seconds() * 1000) pieces.append(px) if nx: @@ -1189,7 +1403,7 @@ class StrictRedis(object): timedelta object. """ if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 + time = int(time.total_seconds()) return self.execute_command('SETEX', name, time, value) def setnx(self, name, value): @@ -1248,6 +1462,10 @@ class StrictRedis(object): warnings.warn( DeprecationWarning('Call UNWATCH from a Pipeline object')) + def unlink(self, *names): + "Unlink one or more keys specified by ``names``" + return self.execute_command('UNLINK', *names) + # LIST COMMANDS def blpop(self, keys, timeout=0): """ @@ -1262,10 +1480,7 @@ class StrictRedis(object): """ if timeout is None: timeout = 0 - if isinstance(keys, basestring): - keys = [keys] - else: - keys = list(keys) + keys = list_or_args(keys, None) keys.append(timeout) return self.execute_command('BLPOP', *keys) @@ -1282,10 +1497,7 @@ class StrictRedis(object): """ if timeout is None: timeout = 0 - if isinstance(keys, basestring): - keys = [keys] - else: - keys = list(keys) + keys = list_or_args(keys, None) keys.append(timeout) return self.execute_command('BRPOP', *keys) @@ -1420,7 +1632,7 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = [name] if by is not None: @@ -1435,7 +1647,7 @@ class StrictRedis(object): # Otherwise assume it's an interable and we want to get multiple # values. We can't just iterate blindly because strings are # iterable. - if isinstance(get, basestring): + if isinstance(get, (bytes, basestring)): pieces.append(Token.get_token('GET')) pieces.append(get) else: @@ -1451,7 +1663,7 @@ class StrictRedis(object): pieces.append(store) if groups: - if not get or isinstance(get, basestring) or len(get) < 2: + if not get or isinstance(get, (bytes, basestring)) or len(get) < 2: raise DataError('when using "groups" the "get" argument ' 'must be specified and contain at least ' 'two keys') @@ -1675,28 +1887,375 @@ class StrictRedis(object): args = list_or_args(keys, args) return self.execute_command('SUNIONSTORE', dest, *args) + # STREAMS COMMANDS + def xack(self, name, groupname, *ids): + """ + Acknowledges the successful processing of one or more messages. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowlege. + """ + return self.execute_command('XACK', name, groupname, *ids) + + def xadd(self, name, fields, id='*', maxlen=None, approximate=True): + """ + Add to a stream. + name: name of the stream + fields: dict of field/value pairs to insert into the stream + id: Location to insert this record. By default it is appended. + maxlen: truncate old stream members beyond this size + approximate: actual stream length may be slightly more than maxlen + + """ + pieces = [] + if maxlen is not None: + if not isinstance(maxlen, (int, long)) or maxlen < 1: + raise DataError('XADD maxlen must be a positive integer') + pieces.append(Token.get_token('MAXLEN')) + if approximate: + pieces.append(Token.get_token('~')) + pieces.append(str(maxlen)) + pieces.append(id) + if not isinstance(fields, dict) or len(fields) == 0: + raise DataError('XADD fields must be a non-empty dict') + for pair in iteritems(fields): + pieces.extend(pair) + return self.execute_command('XADD', name, *pieces) + + def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, + idle=None, time=None, retrycount=None, force=False, + justid=False): + """ + Changes the ownership of a pending message. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds + message_ids: non-empty list or tuple of message IDs to claim + idle: optional. Set the idle time (last time it was delivered) of the + message in ms + time: optional integer. This is the same as idle but instead of a + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + retrycount: optional integer. set the retry counter to the specified + value. This counter is incremented every time a message is delivered + again. + force: optional boolean, false by default. Creates the pending message + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + """ + if not isinstance(min_idle_time, (int, long)) or min_idle_time < 0: + raise DataError("XCLAIM min_idle_time must be a non negative " + "integer") + if not isinstance(message_ids, (list, tuple)) or not message_ids: + raise DataError("XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim") + + kwargs = {} + pieces = [name, groupname, consumername, str(min_idle_time)] + pieces.extend(list(message_ids)) + + if idle is not None: + if not isinstance(idle, (int, long)): + raise DataError("XCLAIM idle must be an integer") + pieces.extend((Token.get_token('IDLE'), str(idle))) + if time is not None: + if not isinstance(time, (int, long)): + raise DataError("XCLAIM time must be an integer") + pieces.extend((Token.get_token('TIME'), str(time))) + if retrycount is not None: + if not isinstance(retrycount, (int, long)): + raise DataError("XCLAIM retrycount must be an integer") + pieces.extend((Token.get_token('RETRYCOUNT'), str(retrycount))) + + if force: + if not isinstance(force, bool): + raise DataError("XCLAIM force must be a boolean") + pieces.append(Token.get_token('FORCE')) + if justid: + if not isinstance(justid, bool): + raise DataError("XCLAIM justid must be a boolean") + pieces.append(Token.get_token('JUSTID')) + kwargs['parse_justid'] = True + return self.execute_command('XCLAIM', *pieces, **kwargs) + + def xdel(self, name, *ids): + """ + Deletes one or more messages from a stream. + name: name of the stream. + *ids: message ids to delete. + """ + return self.execute_command('XDEL', name, *ids) + + def xgroup_create(self, name, groupname, id='$', mkstream=False): + """ + Create a new consumer group associated with a stream. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + pieces = ['XGROUP CREATE', name, groupname, id] + if mkstream: + pieces.append(Token.get_token('MKSTREAM')) + return self.execute_command(*pieces) + + def xgroup_delconsumer(self, name, groupname, consumername): + """ + Remove a specific consumer from a consumer group. + Returns the number of pending messages that the consumer had before it + was deleted. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to delete + """ + return self.execute_command('XGROUP DELCONSUMER', name, groupname, + consumername) + + def xgroup_destroy(self, name, groupname): + """ + Destroy a consumer group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XGROUP DESTROY', name, groupname) + + def xgroup_setid(self, name, groupname, id): + """ + Set the consumer group last delivered ID to something else. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + return self.execute_command('XGROUP SETID', name, groupname, id) + + def xinfo_consumers(self, name, groupname): + """ + Returns general information about the consumers in the group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XINFO CONSUMERS', name, groupname) + + def xinfo_groups(self, name): + """ + Returns general information about the consumer groups of the stream. + name: name of the stream. + """ + return self.execute_command('XINFO GROUPS', name) + + def xinfo_stream(self, name): + """ + Returns general information about the stream. + name: name of the stream. + """ + return self.execute_command('XINFO STREAM', name) + + def xlen(self, name): + """ + Returns the number of elements in a given stream. + """ + return self.execute_command('XLEN', name) + + def xpending(self, name, groupname): + """ + Returns information about pending messages of a group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XPENDING', name, groupname) + + def xpending_range(self, name, groupname, min='-', max='+', count=-1, + consumername=None): + """ + Returns information about pending messages, in a range. + name: name of the stream. + groupname: name of the consumer group. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + consumername: name of a consumer to filter by (optional). + """ + pieces = [name, groupname] + if min is not None or max is not None or count is not None: + if min is None or max is None or count is None: + raise DataError("XPENDING must be provided with min, max " + "and count parameters, or none of them. ") + if not isinstance(count, (int, long)) or count < -1: + raise DataError("XPENDING count must be a integer >= -1") + pieces.extend((min, max, str(count))) + if consumername is not None: + if min is None or max is None or count is None: + raise DataError("if XPENDING is provided with consumername," + " it must be provided with min, max and" + " count parameters") + pieces.append(consumername) + return self.execute_command('XPENDING', *pieces, parse_detail=True) + + def xrange(self, name, min='-', max='+', count=None): + """ + Read stream values within an interval. + name: name of the stream. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + """ + pieces = [min, max] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XRANGE count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + + return self.execute_command('XRANGE', name, *pieces) + + def xread(self, streams, count=None, block=None): + """ + Block and monitor multiple streams for new data. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [] + if block is not None: + if not isinstance(block, (int, long)) or block < 0: + raise DataError('XREAD block must be a non-negative integer') + pieces.append(Token.get_token('BLOCK')) + pieces.append(str(block)) + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XREAD count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREAD streams must be a non empty dict') + pieces.append(Token.get_token('STREAMS')) + keys, values = izip(*iteritems(streams)) + pieces.extend(keys) + pieces.extend(values) + return self.execute_command('XREAD', *pieces) + + def xreadgroup(self, groupname, consumername, streams, count=None, + block=None): + """ + Read from a stream via a consumer group. + groupname: name of the consumer group. + consumername: name of the requesting consumer. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [Token.get_token('GROUP'), groupname, consumername] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError("XREADGROUP count must be a positive integer") + pieces.append(Token.get_token("COUNT")) + pieces.append(str(count)) + if block is not None: + if not isinstance(block, (int, long)) or block < 0: + raise DataError("XREADGROUP block must be a non-negative " + "integer") + pieces.append(Token.get_token("BLOCK")) + pieces.append(str(block)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREADGROUP streams must be a non empty dict') + pieces.append(Token.get_token('STREAMS')) + pieces.extend(streams.keys()) + pieces.extend(streams.values()) + return self.execute_command('XREADGROUP', *pieces) + + def xrevrange(self, name, max='+', min='-', count=None): + """ + Read stream values within an interval, in reverse order. + name: name of the stream + start: first stream ID. defaults to '+', + meaning the latest available. + finish: last stream ID. defaults to '-', + meaning the earliest available. + count: if set, only return this many items, beginning with the + latest available. + """ + pieces = [max, min] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XREVRANGE count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + + return self.execute_command('XREVRANGE', name, *pieces) + + def xtrim(self, name, maxlen, approximate=True): + """ + Trims old messages from a stream. + name: name of the stream. + maxlen: truncate old stream messages beyond this size + approximate: actual stream length may be slightly more than maxlen + """ + pieces = [Token.get_token('MAXLEN')] + if approximate: + pieces.append(Token.get_token('~')) + pieces.append(maxlen) + return self.execute_command('XTRIM', name, *pieces) + # SORTED SET COMMANDS - def zadd(self, name, *args, **kwargs): + def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): """ - Set any number of score, element-name pairs to the key ``name``. Pairs - can be specified in two ways: + Set any number of element-name, score pairs to the key ``name``. Pairs + are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. - As *args, in the form of: score1, name1, score2, name2, ... - or as **kwargs, in the form of: name1=score1, name2=score2, ... + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. - The following example would add four values to the 'my-key' key: - redis.zadd('my-key', 1.1, 'name1', 2.2, 'name2', name3=3.3, name4=4.4) + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. + + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. """ + if not mapping: + raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if incr and len(mapping) != 1: + raise DataError("ZADD option 'incr' only works when passing a " + "single element/score pair") pieces = [] - if args: - if len(args) % 2 != 0: - raise RedisError("ZADD requires an equal number of " - "values and scores") - pieces.extend(args) - for pair in iteritems(kwargs): + options = {} + if nx: + pieces.append(Token.get_token('NX')) + if xx: + pieces.append(Token.get_token('XX')) + if ch: + pieces.append(Token.get_token('CH')) + if incr: + pieces.append(Token.get_token('INCR')) + options['as_score'] = True + for pair in iteritems(mapping): pieces.append(pair[1]) pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) + return self.execute_command('ZADD', name, *pieces, **options) def zcard(self, name): "Return the number of elements in the sorted set ``name``" @@ -1709,7 +2268,7 @@ class StrictRedis(object): """ return self.execute_command('ZCOUNT', name, min, max) - def zincrby(self, name, value, amount=1): + def zincrby(self, name, amount, value): "Increment the score of ``value`` in sorted set ``name`` by ``amount``" return self.execute_command('ZINCRBY', name, amount, value) @@ -1728,6 +2287,62 @@ class StrictRedis(object): """ return self.execute_command('ZLEXCOUNT', name, min, max) + def zpopmax(self, name, count=None): + """ + Remove and return up to ``count`` members with the highest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMAX', name, *args, **options) + + def zpopmin(self, name, count=None): + """ + Remove and return up to ``count`` members with the lowest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMIN', name, *args, **options) + + def bzpopmax(self, keys, timeout=0): + """ + ZPOPMAX a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMAX, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMAX', *keys) + + def bzpopmin(self, keys, timeout=0): + """ + ZPOPMIN a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMIN, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMIN', *keys) + def zrange(self, name, start, end, desc=False, withscores=False, score_cast_func=float): """ @@ -1765,7 +2380,7 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYLEX', name, min, max] if start is not None and num is not None: pieces.extend([Token.get_token('LIMIT'), start, num]) @@ -1781,7 +2396,7 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZREVRANGEBYLEX', name, max, min] if start is not None and num is not None: pieces.extend([Token.get_token('LIMIT'), start, num]) @@ -1803,7 +2418,7 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYSCORE', name, min, max] if start is not None and num is not None: pieces.extend([Token.get_token('LIMIT'), start, num]) @@ -1889,7 +2504,7 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZREVRANGEBYSCORE', name, max, min] if start is not None and num is not None: pieces.extend([Token.get_token('LIMIT'), start, num]) @@ -2118,8 +2733,8 @@ class StrictRedis(object): the triad longitude, latitude and name. """ if len(values) % 3 != 0: - raise RedisError("GEOADD requires places with lon, lat and name" - " values") + raise DataError("GEOADD requires places with lon, lat and name" + " values") return self.execute_command('GEOADD', name, *values) def geodist(self, name, place1, place2, unit=None): @@ -2131,7 +2746,7 @@ class StrictRedis(object): """ pieces = [name, place1, place2] if unit and unit not in ('m', 'km', 'mi', 'ft'): - raise RedisError("GEODIST invalid unit") + raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) return self.execute_command('GEODIST', *pieces) @@ -2139,14 +2754,14 @@ class StrictRedis(object): def geohash(self, name, *values): """ Return the geo hash string for each item of ``values`` members of - the specified key identified by the ``name``argument. + the specified key identified by the ``name`` argument. """ return self.execute_command('GEOHASH', name, *values) def geopos(self, name, *values): """ Return the positions of each item of ``values`` as members of - the specified key identified by the ``name``argument. Each position + the specified key identified by the ``name`` argument. Each position is represented by the pairs lon and lat. """ return self.execute_command('GEOPOS', name, *values) @@ -2208,7 +2823,7 @@ class StrictRedis(object): def _georadiusgeneric(self, command, *args, **kwargs): pieces = list(args) if kwargs['unit'] and kwargs['unit'] not in ('m', 'km', 'mi', 'ft'): - raise RedisError("GEORADIUS invalid unit") + raise DataError("GEORADIUS invalid unit") elif kwargs['unit']: pieces.append(kwargs['unit']) else: @@ -2222,13 +2837,13 @@ class StrictRedis(object): pieces.extend([Token('COUNT'), kwargs['count']]) if kwargs['sort'] and kwargs['sort'] not in ('ASC', 'DESC'): - raise RedisError("GEORADIUS invalid sort") + raise DataError("GEORADIUS invalid sort") elif kwargs['sort']: pieces.append(Token(kwargs['sort'])) if kwargs['store'] and kwargs['store_dist']: - raise RedisError("GEORADIUS store and store_dist cant be set" - " together") + raise DataError("GEORADIUS store and store_dist cant be set" + " together") if kwargs['store']: pieces.extend([Token('STORE'), kwargs['store']]) @@ -2239,88 +2854,7 @@ class StrictRedis(object): return self.execute_command(command, *pieces, **kwargs) -class Redis(StrictRedis): - """ - Provides backwards compatibility with older versions of redis-py that - changed arguments to some commands to be more Pythonic, sane, or by - accident. - """ - - # Overridden callbacks - RESPONSE_CALLBACKS = dict_merge( - StrictRedis.RESPONSE_CALLBACKS, - { - 'TTL': lambda r: r >= 0 and r or None, - 'PTTL': lambda r: r >= 0 and r or None, - } - ) - - def pipeline(self, transaction=True, shard_hint=None): - """ - Return a new pipeline object that can queue multiple commands for - later execution. ``transaction`` indicates whether all commands - should be executed atomically. Apart from making a group of operations - atomic, pipelines are useful for reducing the back-and-forth overhead - between the client and server. - """ - return Pipeline( - self.connection_pool, - self.response_callbacks, - transaction, - shard_hint) - - def setex(self, name, value, time): - """ - Set the value of key ``name`` to ``value`` that expires in ``time`` - seconds. ``time`` can be represented by an integer or a Python - timedelta object. - """ - if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 - return self.execute_command('SETEX', name, time, value) - - def lrem(self, name, value, num=0): - """ - Remove the first ``num`` occurrences of elements equal to ``value`` - from the list stored at ``name``. - - The ``num`` argument influences the operation in the following ways: - num > 0: Remove elements equal to value moving from head to tail. - num < 0: Remove elements equal to value moving from tail to head. - num = 0: Remove all elements equal to value. - """ - return self.execute_command('LREM', name, num, value) - - def zadd(self, name, *args, **kwargs): - """ - NOTE: The order of arguments differs from that of the official ZADD - command. For backwards compatability, this method accepts arguments - in the form of name1, score1, name2, score2, while the official Redis - documents expects score1, name1, score2, name2. - - If you're looking to use the standard syntax, consider using the - StrictRedis class. See the API Reference section of the docs for more - information. - - Set any number of element-name, score pairs to the key ``name``. Pairs - can be specified in two ways: - - As *args, in the form of: name1, score1, name2, score2, ... - or as **kwargs, in the form of: name1=score1, name2=score2, ... - - The following example would add four values to the 'my-key' key: - redis.zadd('my-key', 'name1', 1.1, 'name2', 2.2, name3=3.3, name4=4.4) - """ - pieces = [] - if args: - if len(args) % 2 != 0: - raise RedisError("ZADD requires an equal number of " - "values and scores") - pieces.extend(reversed(args)) - for pair in iteritems(kwargs): - pieces.append(pair[1]) - pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) +StrictRedis = Redis class PubSub(object): @@ -2439,7 +2973,7 @@ class PubSub(object): """ encode = self.encoder.encode decode = self.encoder.decode - return dict([(decode(encode(k)), v) for k, v in iteritems(data)]) + return {decode(encode(k)): v for k, v in iteritems(data)} def psubscribe(self, *args, **kwargs): """ @@ -2517,6 +3051,13 @@ class PubSub(object): return self.handle_message(response, ignore_subscribe_messages) return None + def ping(self, message=None): + """ + Ping the Redis server + """ + message = '' if message is None else message + return self.execute_command('PING', message) + def handle_message(self, response, ignore_subscribe_messages=False): """ Parses a pub/sub message. If the channel or pattern was subscribed to @@ -2531,6 +3072,13 @@ class PubSub(object): 'channel': response[2], 'data': response[3] } + elif message_type == 'pong': + message = { + 'type': message_type, + 'pattern': None, + 'channel': None, + 'data': response[1] + } else: message = { 'type': message_type, @@ -2561,7 +3109,7 @@ class PubSub(object): if handler: handler(message) return None - else: + elif message_type != 'pong': # this is a subscribe/unsubscribe message. ignore if we don't # want them if ignore_subscribe_messages or self.ignore_subscribe_messages: @@ -2610,7 +3158,7 @@ class PubSubWorkerThread(threading.Thread): self.pubsub.punsubscribe() -class BasePipeline(object): +class Pipeline(Redis): """ Pipelines provide a way to transmit multiple commands to the Redis server in one transmission. This is convenient for batch processing, such as @@ -2629,7 +3177,7 @@ class BasePipeline(object): on a key of a different datatype. """ - UNWATCH_COMMANDS = set(('DISCARD', 'EXEC', 'UNWATCH')) + UNWATCH_COMMANDS = {'DISCARD', 'EXEC', 'UNWATCH'} def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): @@ -2747,7 +3295,8 @@ class BasePipeline(object): def _execute_transaction(self, connection, commands, raise_on_error): cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) - all_cmds = connection.pack_commands([args for args, _ in cmds]) + all_cmds = connection.pack_commands([args for args, options in cmds + if EMPTY_RESPONSE not in options]) connection.send_packed_command(all_cmds) errors = [] @@ -2762,12 +3311,15 @@ class BasePipeline(object): # and all the other commands for i, command in enumerate(commands): - try: - self.parse_response(connection, '_') - except ResponseError: - ex = sys.exc_info()[1] - self.annotate_exception(ex, i + 1, command[0]) - errors.append((i, ex)) + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + self.parse_response(connection, '_') + except ResponseError: + ex = sys.exc_info()[1] + self.annotate_exception(ex, i + 1, command[0]) + errors.append((i, ex)) # parse the EXEC. try: @@ -2830,13 +3382,13 @@ class BasePipeline(object): raise r def annotate_exception(self, exception, number, command): - cmd = safe_unicode(' ').join(imap(safe_unicode, command)) - msg = unicode('Command # %d (%s) of pipeline caused error: %s') % ( + cmd = ' '.join(imap(safe_unicode, command)) + msg = 'Command # %d (%s) of pipeline caused error: %s' % ( number, cmd, safe_unicode(exception.args[0])) exception.args = (msg,) + exception.args[1:] def parse_response(self, connection, command_name, **options): - result = StrictRedis.parse_response( + result = Redis.parse_response( self, connection, command_name, **options) if command_name in self.UNWATCH_COMMANDS: self.watching = False @@ -2908,16 +3460,6 @@ class BasePipeline(object): return self.watching and self.execute_command('UNWATCH') or True -class StrictPipeline(BasePipeline, StrictRedis): - "Pipeline for the StrictRedis class" - pass - - -class Pipeline(BasePipeline, Redis): - "Pipeline for the Redis class" - pass - - class Script(object): "An executable Lua script object returned by ``register_script``" @@ -2939,7 +3481,7 @@ class Script(object): client = self.registered_client args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script - if isinstance(client, BasePipeline): + if isinstance(client, Pipeline): # Make sure the pipeline can register the script before executing. client.scripts.add(self) try: @@ -2950,3 +3492,100 @@ class Script(object): # Overwrite the sha just in case there was a discrepancy. self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) + + +class BitFieldOperation(object): + """ + Command builder for BITFIELD commands. + """ + def __init__(self, client, key, default_overflow=None): + self.client = client + self.key = key + self._default_overflow = default_overflow + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = 'WRAP' + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(('OVERFLOW', overflow)) + return self + + def incrby(self, fmt, offset, increment, overflow=None): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(('INCRBY', fmt, offset, increment)) + return self + + def get(self, fmt, offset): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('GET', fmt, offset)) + return self + + def set(self, fmt, offset, value): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('SET', fmt, offset, value)) + return self + + @property + def command(self): + cmd = ['BITFIELD', self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self): + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) |