diff options
Diffstat (limited to 'redis/client.py')
-rwxr-xr-x | redis/client.py | 1528 |
1 files changed, 1224 insertions, 304 deletions
diff --git a/redis/client.py b/redis/client.py index 3acfb9f..0383d14 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,4 +1,4 @@ -from __future__ import with_statement +from __future__ import unicode_literals from itertools import chain import datetime import sys @@ -6,12 +6,12 @@ import warnings import time import threading import time as mod_time -from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, - itervalues, izip, long, nativestr, unicode, - safe_unicode) +import hashlib +from redis._compat import (basestring, bytes, imap, iteritems, iterkeys, + itervalues, izip, long, nativestr, safe_unicode) from redis.connection import (ConnectionPool, UnixDomainSocketConnection, SSLConnection, Token) -from redis.lock import Lock, LuaLock +from redis.lock import Lock from redis.exceptions import ( ConnectionError, DataError, @@ -24,17 +24,20 @@ from redis.exceptions import ( WatchError, ) -SYM_EMPTY = b('') +SYM_EMPTY = b'' +EMPTY_RESPONSE = 'EMPTY_RESPONSE' def list_or_args(keys, args): - # returns a single list combining keys and args + # returns a single new list combining keys and args try: iter(keys) # a string or bytes instance can be iterated, but indicates # keys wasn't passed as a list if isinstance(keys, (basestring, bytes)): keys = [keys] + else: + keys = list(keys) except TypeError: keys = [keys] if args: @@ -59,7 +62,8 @@ def string_keys_to_dict(key_string, callback): def dict_merge(*dicts): merged = {} - [merged.update(d) for d in dicts] + for d in dicts: + merged.update(d) return merged @@ -69,7 +73,7 @@ def parse_debug_object(response): # prefixed with a name response = nativestr(response) response = 'type:' + response - response = dict([kv.split(':') for kv in response.split()]) + response = dict(kv.split(':') for kv in response.split()) # parse some expected int values from the string response # note: this cmd isn't spec'd so these may not appear in all redis versions @@ -112,7 +116,8 @@ def parse_info(response): for line in response.splitlines(): if line and not line.startswith('#'): if line.find(':') != -1: - key, value = line.split(':', 1) + # support keys that include ':' by using rsplit + key, value = line.rsplit(':', 1) info[key] = get_value(value) else: # if the line isn't splittable, append it to the "__raw__" key @@ -180,10 +185,15 @@ def parse_sentinel_get_master(response): return response and (response[0], int(response[1])) or None -def pairs_to_dict(response): +def pairs_to_dict(response, decode_keys=False): "Create a dict given a list of key/value pairs" - it = iter(response) - return dict(izip(it, it)) + if decode_keys: + # the iter form is faster, but I don't know how to make that work + # with a nativestr() map + return dict(izip(imap(nativestr, response[::2]), response[1::2])) + else: + it = iter(response) + return dict(izip(it, it)) def pairs_to_dict_typed(response, type_info): @@ -193,7 +203,7 @@ def pairs_to_dict_typed(response, type_info): if key in type_info: try: value = type_info[key](value) - except: + except Exception: # if for some reason the value can't be coerced, just use # the string value pass @@ -206,7 +216,7 @@ def zset_score_pairs(response, **options): If ``withscores`` is specified in the options, return the response as a list of (value, score) pairs """ - if not response or not options['withscores']: + if not response or not options.get('withscores'): return response score_cast_func = options.get('score_cast_func', float) it = iter(response) @@ -218,7 +228,7 @@ def sort_return_tuples(response, **options): If ``groups`` is specified, return the response as a list of n-element tuples with n being the value found in options['groups'] """ - if not response or not options['groups']: + if not response or not options.get('groups'): return response n = options['groups'] return list(izip(*[response[i::n] for i in range(n)])) @@ -230,6 +240,58 @@ def int_or_none(response): return int(response) +def parse_stream_list(response): + if response is None: + return None + return [(r[0], pairs_to_dict(r[1])) for r in response] + + +def pairs_to_dict_with_nativestr_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(imap(pairs_to_dict_with_nativestr_keys, response)) + + +def parse_xclaim(response, **options): + if options.get('parse_justid', False): + return response + return parse_stream_list(response) + + +def parse_xinfo_stream(response): + data = pairs_to_dict(response, decode_keys=True) + first = data['first-entry'] + data['first-entry'] = (first[0], pairs_to_dict(first[1])) + last = data['last-entry'] + data['last-entry'] = (last[0], pairs_to_dict(last[1])) + return data + + +def parse_xread(response): + if response is None: + return [] + return [[nativestr(r[0]), parse_stream_list(r[1])] for r in response] + + +def parse_xpending(response, **options): + if options.get('parse_detail', False): + return parse_xpending_range(response) + consumers = [{'name': n, 'pending': long(p)} for n, p in response[3] or []] + return { + 'pending': response[0], + 'min': response[1], + 'max': response[2], + 'consumers': consumers + } + + +def parse_xpending_range(response): + k = ('message_id', 'consumer', 'time_since_delivered', 'times_delivered') + return [dict(izip(k, r)) for r in response] + + def float_or_none(response): if response is None: return None @@ -240,10 +302,17 @@ def bool_ok(response): return nativestr(response) == 'OK' +def parse_zadd(response, **options): + if options.get('as_score'): + return float(response) + return int(response) + + def parse_client_list(response, **options): clients = [] for c in nativestr(response).splitlines(): - clients.append(dict([pair.split('=') for pair in c.split(' ')])) + # Values might contain '=' + clients.append(dict(pair.split('=', 1) for pair in c.split(' '))) return clients @@ -274,11 +343,77 @@ def parse_slowlog_get(response, **options): 'id': item[0], 'start_time': int(item[1]), 'duration': int(item[2]), - 'command': b(' ').join(item[3]) + 'command': b' '.join(item[3]) } for item in response] -class StrictRedis(object): +def parse_cluster_info(response, **options): + response = nativestr(response) + return dict(line.split(':') for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(' ') + node_id, addr, flags, master_id, ping, pong, epoch, \ + connected = line.split(' ')[:8] + slots = [sl.split('-') for sl in line_items[8:]] + node_dict = { + 'node_id': node_id, + 'flags': flags, + 'master_id': master_id, + 'last_ping_sent': ping, + 'last_pong_rcvd': pong, + 'epoch': epoch, + 'slots': slots, + 'connected': True if connected == 'connected' else False + } + return addr, node_dict + + +def parse_cluster_nodes(response, **options): + response = nativestr(response) + raw_lines = response + if isinstance(response, basestring): + raw_lines = response.splitlines() + return dict(_parse_node_line(line) for line in raw_lines) + + +def parse_georadius_generic(response, **options): + if options['store'] or options['store_dist']: + # `store` and `store_diff` cant be combined + # with other command arguments. + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options['withdist'] and not options['withcoord']\ + and not options['withhash']: + # just a bunch of places + return [nativestr(r) for r in response_list] + + cast = { + 'withdist': float, + 'withcoord': lambda ll: (float(ll[0]), float(ll[1])), + 'withhash': int + } + + # zip all output results with each casting functino to get + # the properly native Python value. + f = [nativestr] + f += [cast[o] for o in ['withdist', 'withhash', 'withcoord'] if options[o]] + return [ + list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list + ] + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +class Redis(object): """ Implementation of the Redis protocol. @@ -290,28 +425,32 @@ class StrictRedis(object): """ RESPONSE_CALLBACKS = dict_merge( string_keys_to_dict( - 'AUTH EXISTS EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' + 'AUTH EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' 'PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN ' - 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT ' - 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' - 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', + 'BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN ' + 'HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD ' + 'SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN ' + 'SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM ' + 'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', int ), - string_keys_to_dict('INCRBYFLOAT HINCRBYFLOAT', float), + string_keys_to_dict( + 'INCRBYFLOAT HINCRBYFLOAT', + float + ), string_keys_to_dict( # these return OK, or int if redis-server is >=1.3.4 'LPUSH RPUSH', - lambda r: isinstance(r, long) and r or nativestr(r) == 'OK' + lambda r: isinstance(r, (long, int)) and r or nativestr(r) == 'OK' ), string_keys_to_dict('SORT', sort_return_tuples), - string_keys_to_dict('ZSCORE ZINCRBY', float_or_none), + string_keys_to_dict('ZSCORE ZINCRBY GEODIST', float_or_none), string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE RENAME ' - 'SAVE SELECT SHUTDOWN SLAVEOF WATCH UNWATCH', + 'SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH', bool_ok ), string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), @@ -320,26 +459,58 @@ class StrictRedis(object): lambda r: r and set(r) or set() ), string_keys_to_dict( - 'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', + 'ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs ), + string_keys_to_dict('BZPOPMIN BZPOPMAX', \ + lambda r: r and (r[0], r[1], float(r[2])) or None), string_keys_to_dict('ZRANK ZREVRANK', int_or_none), + string_keys_to_dict('XREVRANGE XRANGE', parse_stream_list), + string_keys_to_dict('XREAD XREADGROUP', parse_xread), string_keys_to_dict('BGREWRITEAOF BGSAVE', lambda r: True), { 'CLIENT GETNAME': lambda r: r and nativestr(r), + 'CLIENT ID': int, 'CLIENT KILL': bool_ok, 'CLIENT LIST': parse_client_list, 'CLIENT SETNAME': bool_ok, + 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, + 'CLIENT PAUSE': bool_ok, + 'CLUSTER ADDSLOTS': bool_ok, + 'CLUSTER COUNT-FAILURE-REPORTS': lambda x: int(x), + 'CLUSTER COUNTKEYSINSLOT': lambda x: int(x), + 'CLUSTER DELSLOTS': bool_ok, + 'CLUSTER FAILOVER': bool_ok, + 'CLUSTER FORGET': bool_ok, + 'CLUSTER INFO': parse_cluster_info, + 'CLUSTER KEYSLOT': lambda x: int(x), + 'CLUSTER MEET': bool_ok, + 'CLUSTER NODES': parse_cluster_nodes, + 'CLUSTER REPLICATE': bool_ok, + 'CLUSTER RESET': bool_ok, + 'CLUSTER SAVECONFIG': bool_ok, + 'CLUSTER SET-CONFIG-EPOCH': bool_ok, + 'CLUSTER SETSLOT': bool_ok, + 'CLUSTER SLAVES': parse_cluster_nodes, 'CONFIG GET': parse_config_get, 'CONFIG RESETSTAT': bool_ok, 'CONFIG SET': bool_ok, 'DEBUG OBJECT': parse_debug_object, + 'GEOHASH': lambda r: list(map(nativestr, r)), + 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), + float(ll[1])) + if ll is not None else None, r)), + 'GEORADIUS': parse_georadius_generic, + 'GEORADIUSBYMEMBER': parse_georadius_generic, 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, 'HSCAN': parse_hscan, 'INFO': parse_info, 'LASTSAVE': timestamp_to_datetime, + 'MEMORY PURGE': bool_ok, + 'MEMORY USAGE': int_or_none, 'OBJECT': parse_object, 'PING': lambda r: nativestr(r) == 'PONG', + 'PUBSUB NUMSUB': parse_pubsub_numsub, 'RANDOMKEY': lambda r: r and r or None, 'SCAN': parse_scan, 'SCRIPT EXISTS': lambda r: list(imap(bool, r)), @@ -360,20 +531,41 @@ class StrictRedis(object): 'SLOWLOG RESET': bool_ok, 'SSCAN': parse_scan, 'TIME': lambda x: (int(x[0]), int(x[1])), - 'ZSCAN': parse_zscan + 'XCLAIM': parse_xclaim, + 'XGROUP CREATE': bool_ok, + 'XGROUP DELCONSUMER': int, + 'XGROUP DESTROY': bool, + 'XGROUP SETID': bool_ok, + 'XINFO CONSUMERS': parse_list_of_dicts, + 'XINFO GROUPS': parse_list_of_dicts, + 'XINFO STREAM': parse_xinfo_stream, + 'XPENDING': parse_xpending, + 'ZADD': parse_zadd, + 'ZSCAN': parse_zscan, } ) @classmethod def from_url(cls, url, db=None, **kwargs): """ - Return a Redis client object configured from the given URL. + Return a Redis client object configured from the given URL For example:: redis://[:password]@localhost:6379/0 + rediss://[:password]@localhost:6379/0 unix://[:password]@/path/to/socket.sock?db=0 + Three URL schemes are supported: + + - ```redis://`` + <http://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a + normal TCP socket connection + - ```rediss://`` + <http://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates a + SSL wrapped TCP socket connection + - ``unix://`` creates a Unix Domain Socket connection + There are several ways to specify a database number. The parse function will return the first specified option: 1. A ``db`` querystring option, e.g. redis://localhost?db=0 @@ -399,7 +591,8 @@ class StrictRedis(object): charset=None, errors=None, decode_responses=False, retry_on_timeout=False, ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None): + ssl_cert_reqs='required', ssl_ca_certs=None, + max_connections=None): if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -417,7 +610,8 @@ class StrictRedis(object): 'encoding': encoding, 'encoding_errors': encoding_errors, 'decode_responses': decode_responses, - 'retry_on_timeout': retry_on_timeout + 'retry_on_timeout': retry_on_timeout, + 'max_connections': max_connections } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -445,7 +639,6 @@ class StrictRedis(object): }) connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool - self._use_lua_lock = None self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy() @@ -464,7 +657,7 @@ class StrictRedis(object): atomic, pipelines are useful for reducing the back-and-forth overhead between the client and server. """ - return StrictPipeline( + return Pipeline( self.connection_pool, self.response_callbacks, transaction, @@ -480,7 +673,7 @@ class StrictRedis(object): value_from_callable = kwargs.pop('value_from_callable', False) watch_delay = kwargs.pop('watch_delay', None) with self.pipeline(True, shard_hint) as pipe: - while 1: + while True: try: if watches: pipe.watch(*watches) @@ -538,15 +731,7 @@ class StrictRedis(object): is that these cases aren't common and as such default to using thread local storage. """ if lock_class is None: - if self._use_lua_lock is None: - # the first time .lock() is called, determine if we can use - # Lua by attempting to register the necessary scripts - try: - LuaLock.register_scripts(self) - self._use_lua_lock = True - except ResponseError: - self._use_lua_lock = False - lock_class = self._use_lua_lock and LuaLock or Lock + lock_class = Lock return lock_class(self, name, timeout=timeout, sleep=sleep, blocking_timeout=blocking_timeout, thread_local=thread_local) @@ -579,7 +764,12 @@ class StrictRedis(object): def parse_response(self, connection, command_name, **options): "Parses a response from the Redis server" - response = connection.read_response() + try: + response = connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response @@ -600,18 +790,56 @@ class StrictRedis(object): "Disconnects the client at ``address`` (ip:port)" return self.execute_command('CLIENT KILL', address) - def client_list(self): + def client_list(self, _type=None): + """ + Returns a list of currently connected clients. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ "Returns a list of currently connected clients" + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + return self.execute_command('CLIENT LIST', Token.get_token('TYPE'), + _type) return self.execute_command('CLIENT LIST') def client_getname(self): "Returns the current connection name" return self.execute_command('CLIENT GETNAME') + def client_id(self): + "Returns the current connection id" + return self.execute_command('CLIENT ID') + def client_setname(self, name): "Sets the current connection name" return self.execute_command('CLIENT SETNAME', name) + def client_unblock(self, client_id, error=False): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(Token.get_token('ERROR')) + return self.execute_command(*args) + + def client_pause(self, timeout): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, (int, long)): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout)) + def config_get(self, pattern="*"): "Return a dictionary of configuration based on the ``pattern``" return self.execute_command('CONFIG GET', pattern) @@ -640,13 +868,33 @@ class StrictRedis(object): "Echo the string back from the server" return self.execute_command('ECHO', value) - def flushall(self): - "Delete all keys in all databases on the current host" - return self.execute_command('FLUSHALL') + def flushall(self, asynchronous=False): + """ + Delete all keys in all databases on the current host. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if not asynchronous: + args.append(Token.get_token('ASYNC')) + return self.execute_command('FLUSHALL', *args) + + def flushdb(self, asynchronous=False): + """ + Delete all keys in the current database. - def flushdb(self): - "Delete all keys in the current database" - return self.execute_command('FLUSHDB') + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if not asynchronous: + args.append(Token.get_token('ASYNC')) + return self.execute_command('FLUSHDB', *args) + + def swapdb(self, first, second): + "Swap two databases" + return self.execute_command('SWAPDB', first, second) def info(self, section=None): """ @@ -670,10 +918,63 @@ class StrictRedis(object): """ return self.execute_command('LASTSAVE') + def migrate(self, host, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the current Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(Token.get_token('COPY')) + if replace: + pieces.append(Token.get_token('REPLACE')) + if auth: + pieces.append(Token.get_token('AUTH')) + pieces.append(auth) + pieces.append(Token.get_token('KEYS')) + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces) + def object(self, infotype, key): "Return the encoding, idletime, or refcount about the key" return self.execute_command('OBJECT', infotype, key, infotype=infotype) + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([Token.get_token('SAMPLES'), samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def memory_purge(self): + "Attempts to purge dirty pages for reclamation by allocator" + return self.execute_command('MEMORY PURGE') + def ping(self): "Ping the Redis server" return self.execute_command('PING') @@ -723,10 +1024,22 @@ class StrictRedis(object): "Returns a list of slaves for ``service_name``" return self.execute_command('SENTINEL SLAVES', service_name) - def shutdown(self): - "Shutdown the server" + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') try: - self.execute_command('SHUTDOWN') + self.execute_command(*args) except ConnectionError: # a ConnectionError here is expected return @@ -739,7 +1052,8 @@ class StrictRedis(object): instance is promoted to a master instead. """ if host is None and port is None: - return self.execute_command('SLAVEOF', Token('NO'), Token('ONE')) + return self.execute_command('SLAVEOF', Token.get_token('NO'), + Token.get_token('ONE')) return self.execute_command('SLAVEOF', host, port) def slowlog_get(self, num=None): @@ -767,6 +1081,15 @@ class StrictRedis(object): """ return self.execute_command('TIME') + def wait(self, num_replicas, timeout): + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + """ + return self.execute_command('WAIT', num_replicas, timeout) + # BASIC KEY COMMANDS def append(self, key, value): """ @@ -787,9 +1110,16 @@ class StrictRedis(object): params.append(end) elif (start is not None and end is None) or \ (end is not None and start is None): - raise RedisError("Both start and end must be specified") + raise DataError("Both start and end must be specified") return self.execute_command('BITCOUNT', *params) + def bitfield(self, key, default_overflow=None): + """ + Return a BitFieldOperation instance to conveniently construct one or + more bitfield operations on ``key``. + """ + return BitFieldOperation(self, key, default_overflow=default_overflow) + def bitop(self, operation, dest, *keys): """ Perform a bitwise operation using ``operation`` between ``keys`` and @@ -805,7 +1135,7 @@ class StrictRedis(object): means to look at the first three bytes. """ if bit not in (0, 1): - raise RedisError('bit must be 0 or 1') + raise DataError('bit must be 0 or 1') params = [key, bit] start is not None and params.append(start) @@ -813,8 +1143,8 @@ class StrictRedis(object): if start is not None and end is not None: params.append(end) elif start is None and end is not None: - raise RedisError("start argument is not set, " - "when end is specified") + raise DataError("start argument is not set, " + "when end is specified") return self.execute_command('BITPOS', *params) def decr(self, name, amount=1): @@ -848,9 +1178,9 @@ class StrictRedis(object): """ return self.execute_command('DUMP', name) - def exists(self, name): - "Returns a boolean indicating whether key ``name`` exists" - return self.execute_command('EXISTS', name) + def exists(self, *names): + "Returns the number of ``names`` that exist" + return self.execute_command('EXISTS', *names) __contains__ = exists def expire(self, name, time): @@ -859,7 +1189,7 @@ class StrictRedis(object): can be represented by an integer or a Python timedelta object. """ if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 + time = int(time.total_seconds()) return self.execute_command('EXPIRE', name, time) def expireat(self, name, when): @@ -883,7 +1213,7 @@ class StrictRedis(object): doesn't exist. """ value = self.get(name) - if value: + if value is not None: return value raise KeyError(name) @@ -938,35 +1268,31 @@ class StrictRedis(object): Returns a list of values ordered identically to ``keys`` """ args = list_or_args(keys, args) - return self.execute_command('MGET', *args) + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command('MGET', *args, **options) - def mset(self, *args, **kwargs): + def mset(self, mapping): """ - Sets key/values based on a mapping. Mapping can be supplied as a single - dictionary argument or as kwargs. + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). """ - if args: - if len(args) != 1 or not isinstance(args[0], dict): - raise RedisError('MSET requires **kwargs or a single dict arg') - kwargs.update(args[0]) items = [] - for pair in iteritems(kwargs): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSET', *items) - def msetnx(self, *args, **kwargs): + def msetnx(self, mapping): """ Sets key/values based on a mapping if none of the keys are already set. - Mapping can be supplied as a single dictionary argument or as kwargs. + Mapping is a dictionary of key/value pairs. Both keys and values + should be strings or types that can be cast to a string via str(). Returns a boolean indicating if the operation was successful. """ - if args: - if len(args) != 1 or not isinstance(args[0], dict): - raise RedisError('MSETNX requires **kwargs or a single ' - 'dict arg') - kwargs.update(args[0]) items = [] - for pair in iteritems(kwargs): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSETNX', *items) @@ -985,8 +1311,7 @@ class StrictRedis(object): object. """ if isinstance(time, datetime.timedelta): - ms = int(time.microseconds / 1000) - time = (time.seconds + time.days * 24 * 3600) * 1000 + ms + time = int(time.total_seconds() * 1000) return self.execute_command('PEXPIRE', name, time) def pexpireat(self, name, when): @@ -1007,8 +1332,7 @@ class StrictRedis(object): timedelta object """ if isinstance(time_ms, datetime.timedelta): - ms = int(time_ms.microseconds / 1000) - time_ms = (time_ms.seconds + time_ms.days * 24 * 3600) * 1000 + ms + time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command('PSETEX', name, time_ms, value) def pttl(self, name): @@ -1029,12 +1353,15 @@ class StrictRedis(object): "Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist" return self.execute_command('RENAMENX', src, dst) - def restore(self, name, ttl, value): + def restore(self, name, ttl, value, replace=False): """ Create a key using the provided serialized value, previously obtained using DUMP. """ - return self.execute_command('RESTORE', name, ttl, value) + params = [name, ttl, value] + if replace: + params.append('REPLACE') + return self.execute_command('RESTORE', *params) def set(self, name, value, ex=None, px=None, nx=False, xx=False): """ @@ -1044,23 +1371,22 @@ class StrictRedis(object): ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. - ``nx`` if set to True, set the value at key ``name`` to ``value`` if it - does not already exist. + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. - ``xx`` if set to True, set the value at key ``name`` to ``value`` if it - already exists. + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. """ pieces = [name, value] - if ex: + if ex is not None: pieces.append('EX') if isinstance(ex, datetime.timedelta): - ex = ex.seconds + ex.days * 24 * 3600 + ex = int(ex.total_seconds()) pieces.append(ex) - if px: + if px is not None: pieces.append('PX') if isinstance(px, datetime.timedelta): - ms = int(px.microseconds / 1000) - px = (px.seconds + px.days * 24 * 3600) * 1000 + ms + px = int(px.total_seconds() * 1000) pieces.append(px) if nx: @@ -1087,7 +1413,7 @@ class StrictRedis(object): timedelta object. """ if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 + time = int(time.total_seconds()) return self.execute_command('SETEX', name, time, value) def setnx(self, name, value): @@ -1118,6 +1444,13 @@ class StrictRedis(object): """ return self.execute_command('SUBSTR', name, start, end) + def touch(self, *args): + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command('TOUCH', *args) + def ttl(self, name): "Returns the number of seconds until the key ``name`` will expire" return self.execute_command('TTL', name) @@ -1139,6 +1472,10 @@ class StrictRedis(object): warnings.warn( DeprecationWarning('Call UNWATCH from a Pipeline object')) + def unlink(self, *names): + "Unlink one or more keys specified by ``names``" + return self.execute_command('UNLINK', *names) + # LIST COMMANDS def blpop(self, keys, timeout=0): """ @@ -1153,10 +1490,7 @@ class StrictRedis(object): """ if timeout is None: timeout = 0 - if isinstance(keys, basestring): - keys = [keys] - else: - keys = list(keys) + keys = list_or_args(keys, None) keys.append(timeout) return self.execute_command('BLPOP', *keys) @@ -1165,7 +1499,7 @@ class StrictRedis(object): RPOP a value off of the first non-empty list named in the ``keys`` list. - If none of the lists in ``keys`` has a value to LPOP, then block + If none of the lists in ``keys`` has a value to RPOP, then block for ``timeout`` seconds, or until a value gets pushed on to one of the lists. @@ -1173,10 +1507,7 @@ class StrictRedis(object): """ if timeout is None: timeout = 0 - if isinstance(keys, basestring): - keys = [keys] - else: - keys = list(keys) + keys = list_or_args(keys, None) keys.append(timeout) return self.execute_command('BRPOP', *keys) @@ -1311,14 +1642,14 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = [name] if by is not None: - pieces.append(Token('BY')) + pieces.append(Token.get_token('BY')) pieces.append(by) if start is not None and num is not None: - pieces.append(Token('LIMIT')) + pieces.append(Token.get_token('LIMIT')) pieces.append(start) pieces.append(num) if get is not None: @@ -1326,23 +1657,23 @@ class StrictRedis(object): # Otherwise assume it's an interable and we want to get multiple # values. We can't just iterate blindly because strings are # iterable. - if isinstance(get, basestring): - pieces.append(Token('GET')) + if isinstance(get, (bytes, basestring)): + pieces.append(Token.get_token('GET')) pieces.append(get) else: for g in get: - pieces.append(Token('GET')) + pieces.append(Token.get_token('GET')) pieces.append(g) if desc: - pieces.append(Token('DESC')) + pieces.append(Token.get_token('DESC')) if alpha: - pieces.append(Token('ALPHA')) + pieces.append(Token.get_token('ALPHA')) if store is not None: - pieces.append(Token('STORE')) + pieces.append(Token.get_token('STORE')) pieces.append(store) if groups: - if not get or isinstance(get, basestring) or len(get) < 2: + if not get or isinstance(get, (bytes, basestring)) or len(get) < 2: raise DataError('when using "groups" the "get" argument ' 'must be specified and contain at least ' 'two keys') @@ -1362,9 +1693,9 @@ class StrictRedis(object): """ pieces = [cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) return self.execute_command('SCAN', *pieces) def scan_iter(self, match=None, count=None): @@ -1393,9 +1724,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) return self.execute_command('SSCAN', *pieces) def sscan_iter(self, name, match=None, count=None): @@ -1425,9 +1756,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) return self.execute_command('HSCAN', *pieces) def hscan_iter(self, name, match=None, count=None): @@ -1460,9 +1791,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) options = {'score_cast_func': score_cast_func} return self.execute_command('ZSCAN', *pieces, **options) @@ -1533,9 +1864,10 @@ class StrictRedis(object): "Move ``value`` from set ``src`` to set ``dst`` atomically" return self.execute_command('SMOVE', src, dst, value) - def spop(self, name): + def spop(self, name, count=None): "Remove and return a random member of set ``name``" - return self.execute_command('SPOP', name) + args = (count is not None) and [count] or [] + return self.execute_command('SPOP', name, *args) def srandmember(self, name, number=None): """ @@ -1545,7 +1877,7 @@ class StrictRedis(object): memebers of set ``name``. Note this is only available when running Redis 2.6+. """ - args = number and [number] or [] + args = (number is not None) and [number] or [] return self.execute_command('SRANDMEMBER', name, *args) def srem(self, name, *values): @@ -1565,28 +1897,375 @@ class StrictRedis(object): args = list_or_args(keys, args) return self.execute_command('SUNIONSTORE', dest, *args) + # STREAMS COMMANDS + def xack(self, name, groupname, *ids): + """ + Acknowledges the successful processing of one or more messages. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowlege. + """ + return self.execute_command('XACK', name, groupname, *ids) + + def xadd(self, name, fields, id='*', maxlen=None, approximate=True): + """ + Add to a stream. + name: name of the stream + fields: dict of field/value pairs to insert into the stream + id: Location to insert this record. By default it is appended. + maxlen: truncate old stream members beyond this size + approximate: actual stream length may be slightly more than maxlen + + """ + pieces = [] + if maxlen is not None: + if not isinstance(maxlen, (int, long)) or maxlen < 1: + raise DataError('XADD maxlen must be a positive integer') + pieces.append(Token.get_token('MAXLEN')) + if approximate: + pieces.append(Token.get_token('~')) + pieces.append(str(maxlen)) + pieces.append(id) + if not isinstance(fields, dict) or len(fields) == 0: + raise DataError('XADD fields must be a non-empty dict') + for pair in iteritems(fields): + pieces.extend(pair) + return self.execute_command('XADD', name, *pieces) + + def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, + idle=None, time=None, retrycount=None, force=False, + justid=False): + """ + Changes the ownership of a pending message. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds + message_ids: non-empty list or tuple of message IDs to claim + idle: optional. Set the idle time (last time it was delivered) of the + message in ms + time: optional integer. This is the same as idle but instead of a + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + retrycount: optional integer. set the retry counter to the specified + value. This counter is incremented every time a message is delivered + again. + force: optional boolean, false by default. Creates the pending message + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + """ + if not isinstance(min_idle_time, (int, long)) or min_idle_time < 0: + raise DataError("XCLAIM min_idle_time must be a non negative " + "integer") + if not isinstance(message_ids, (list, tuple)) or not message_ids: + raise DataError("XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim") + + kwargs = {} + pieces = [name, groupname, consumername, str(min_idle_time)] + pieces.extend(list(message_ids)) + + if idle is not None: + if not isinstance(idle, (int, long)): + raise DataError("XCLAIM idle must be an integer") + pieces.extend((Token.get_token('IDLE'), str(idle))) + if time is not None: + if not isinstance(time, (int, long)): + raise DataError("XCLAIM time must be an integer") + pieces.extend((Token.get_token('TIME'), str(time))) + if retrycount is not None: + if not isinstance(retrycount, (int, long)): + raise DataError("XCLAIM retrycount must be an integer") + pieces.extend((Token.get_token('RETRYCOUNT'), str(retrycount))) + + if force: + if not isinstance(force, bool): + raise DataError("XCLAIM force must be a boolean") + pieces.append(Token.get_token('FORCE')) + if justid: + if not isinstance(justid, bool): + raise DataError("XCLAIM justid must be a boolean") + pieces.append(Token.get_token('JUSTID')) + kwargs['parse_justid'] = True + return self.execute_command('XCLAIM', *pieces, **kwargs) + + def xdel(self, name, *ids): + """ + Deletes one or more messages from a stream. + name: name of the stream. + *ids: message ids to delete. + """ + return self.execute_command('XDEL', name, *ids) + + def xgroup_create(self, name, groupname, id='$', mkstream=False): + """ + Create a new consumer group associated with a stream. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + pieces = ['XGROUP CREATE', name, groupname, id] + if mkstream: + pieces.append(Token.get_token('MKSTREAM')) + return self.execute_command(*pieces) + + def xgroup_delconsumer(self, name, groupname, consumername): + """ + Remove a specific consumer from a consumer group. + Returns the number of pending messages that the consumer had before it + was deleted. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to delete + """ + return self.execute_command('XGROUP DELCONSUMER', name, groupname, + consumername) + + def xgroup_destroy(self, name, groupname): + """ + Destroy a consumer group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XGROUP DESTROY', name, groupname) + + def xgroup_setid(self, name, groupname, id): + """ + Set the consumer group last delivered ID to something else. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + return self.execute_command('XGROUP SETID', name, groupname, id) + + def xinfo_consumers(self, name, groupname): + """ + Returns general information about the consumers in the group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XINFO CONSUMERS', name, groupname) + + def xinfo_groups(self, name): + """ + Returns general information about the consumer groups of the stream. + name: name of the stream. + """ + return self.execute_command('XINFO GROUPS', name) + + def xinfo_stream(self, name): + """ + Returns general information about the stream. + name: name of the stream. + """ + return self.execute_command('XINFO STREAM', name) + + def xlen(self, name): + """ + Returns the number of elements in a given stream. + """ + return self.execute_command('XLEN', name) + + def xpending(self, name, groupname): + """ + Returns information about pending messages of a group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XPENDING', name, groupname) + + def xpending_range(self, name, groupname, min='-', max='+', count=-1, + consumername=None): + """ + Returns information about pending messages, in a range. + name: name of the stream. + groupname: name of the consumer group. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + consumername: name of a consumer to filter by (optional). + """ + pieces = [name, groupname] + if min is not None or max is not None or count is not None: + if min is None or max is None or count is None: + raise DataError("XPENDING must be provided with min, max " + "and count parameters, or none of them. ") + if not isinstance(count, (int, long)) or count < -1: + raise DataError("XPENDING count must be a integer >= -1") + pieces.extend((min, max, str(count))) + if consumername is not None: + if min is None or max is None or count is None: + raise DataError("if XPENDING is provided with consumername," + " it must be provided with min, max and" + " count parameters") + pieces.append(consumername) + return self.execute_command('XPENDING', *pieces, parse_detail=True) + + def xrange(self, name, min='-', max='+', count=None): + """ + Read stream values within an interval. + name: name of the stream. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + """ + pieces = [min, max] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XRANGE count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + + return self.execute_command('XRANGE', name, *pieces) + + def xread(self, streams, count=None, block=None): + """ + Block and monitor multiple streams for new data. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [] + if block is not None: + if not isinstance(block, (int, long)) or block < 0: + raise DataError('XREAD block must be a non-negative integer') + pieces.append(Token.get_token('BLOCK')) + pieces.append(str(block)) + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XREAD count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREAD streams must be a non empty dict') + pieces.append(Token.get_token('STREAMS')) + keys, values = izip(*iteritems(streams)) + pieces.extend(keys) + pieces.extend(values) + return self.execute_command('XREAD', *pieces) + + def xreadgroup(self, groupname, consumername, streams, count=None, + block=None): + """ + Read from a stream via a consumer group. + groupname: name of the consumer group. + consumername: name of the requesting consumer. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [Token.get_token('GROUP'), groupname, consumername] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError("XREADGROUP count must be a positive integer") + pieces.append(Token.get_token("COUNT")) + pieces.append(str(count)) + if block is not None: + if not isinstance(block, (int, long)) or block < 0: + raise DataError("XREADGROUP block must be a non-negative " + "integer") + pieces.append(Token.get_token("BLOCK")) + pieces.append(str(block)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREADGROUP streams must be a non empty dict') + pieces.append(Token.get_token('STREAMS')) + pieces.extend(streams.keys()) + pieces.extend(streams.values()) + return self.execute_command('XREADGROUP', *pieces) + + def xrevrange(self, name, max='+', min='-', count=None): + """ + Read stream values within an interval, in reverse order. + name: name of the stream + start: first stream ID. defaults to '+', + meaning the latest available. + finish: last stream ID. defaults to '-', + meaning the earliest available. + count: if set, only return this many items, beginning with the + latest available. + """ + pieces = [max, min] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XREVRANGE count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + + return self.execute_command('XREVRANGE', name, *pieces) + + def xtrim(self, name, maxlen, approximate=True): + """ + Trims old messages from a stream. + name: name of the stream. + maxlen: truncate old stream messages beyond this size + approximate: actual stream length may be slightly more than maxlen + """ + pieces = [Token.get_token('MAXLEN')] + if approximate: + pieces.append(Token.get_token('~')) + pieces.append(maxlen) + return self.execute_command('XTRIM', name, *pieces) + # SORTED SET COMMANDS - def zadd(self, name, *args, **kwargs): + def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): """ - Set any number of score, element-name pairs to the key ``name``. Pairs - can be specified in two ways: + Set any number of element-name, score pairs to the key ``name``. Pairs + are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. - As *args, in the form of: score1, name1, score2, name2, ... - or as **kwargs, in the form of: name1=score1, name2=score2, ... + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. - The following example would add four values to the 'my-key' key: - redis.zadd('my-key', 1.1, 'name1', 2.2, 'name2', name3=3.3, name4=4.4) + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. """ + if not mapping: + raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if incr and len(mapping) != 1: + raise DataError("ZADD option 'incr' only works when passing a " + "single element/score pair") pieces = [] - if args: - if len(args) % 2 != 0: - raise RedisError("ZADD requires an equal number of " - "values and scores") - pieces.extend(args) - for pair in iteritems(kwargs): + options = {} + if nx: + pieces.append(Token.get_token('NX')) + if xx: + pieces.append(Token.get_token('XX')) + if ch: + pieces.append(Token.get_token('CH')) + if incr: + pieces.append(Token.get_token('INCR')) + options['as_score'] = True + for pair in iteritems(mapping): pieces.append(pair[1]) pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) + return self.execute_command('ZADD', name, *pieces, **options) def zcard(self, name): "Return the number of elements in the sorted set ``name``" @@ -1599,7 +2278,7 @@ class StrictRedis(object): """ return self.execute_command('ZCOUNT', name, min, max) - def zincrby(self, name, value, amount=1): + def zincrby(self, name, amount, value): "Increment the score of ``value`` in sorted set ``name`` by ``amount``" return self.execute_command('ZINCRBY', name, amount, value) @@ -1618,6 +2297,62 @@ class StrictRedis(object): """ return self.execute_command('ZLEXCOUNT', name, min, max) + def zpopmax(self, name, count=None): + """ + Remove and return up to ``count`` members with the highest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMAX', name, *args, **options) + + def zpopmin(self, name, count=None): + """ + Remove and return up to ``count`` members with the lowest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMIN', name, *args, **options) + + def bzpopmax(self, keys, timeout=0): + """ + ZPOPMAX a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMAX, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMAX', *keys) + + def bzpopmin(self, keys, timeout=0): + """ + ZPOPMIN a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMIN, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMIN', *keys) + def zrange(self, name, start, end, desc=False, withscores=False, score_cast_func=float): """ @@ -1638,7 +2373,7 @@ class StrictRedis(object): score_cast_func) pieces = ['ZRANGE', name, start, end] if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1655,10 +2390,26 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYLEX', name, min, max] if start is not None and num is not None: - pieces.extend([Token('LIMIT'), start, num]) + pieces.extend([Token.get_token('LIMIT'), start, num]) + return self.execute_command(*pieces) + + def zrevrangebylex(self, name, max, min, start=None, num=None): + """ + Return the reversed lexicographical range of values from sorted set + ``name`` between ``max`` and ``min``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ['ZREVRANGEBYLEX', name, max, min] + if start is not None and num is not None: + pieces.extend([Token.get_token('LIMIT'), start, num]) return self.execute_command(*pieces) def zrangebyscore(self, name, min, max, start=None, num=None, @@ -1677,12 +2428,12 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYSCORE', name, min, max] if start is not None and num is not None: - pieces.extend([Token('LIMIT'), start, num]) + pieces.extend([Token.get_token('LIMIT'), start, num]) if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1740,7 +2491,7 @@ class StrictRedis(object): """ pieces = ['ZREVRANGE', name, start, end] if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1763,12 +2514,12 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZREVRANGEBYSCORE', name, max, min] if start is not None and num is not None: - pieces.extend([Token('LIMIT'), start, num]) + pieces.extend([Token.get_token('LIMIT'), start, num]) if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1802,10 +2553,10 @@ class StrictRedis(object): weights = None pieces.extend(keys) if weights: - pieces.append(Token('WEIGHTS')) + pieces.append(Token.get_token('WEIGHTS')) pieces.extend(weights) if aggregate: - pieces.append(Token('AGGREGATE')) + pieces.append(Token.get_token('AGGREGATE')) pieces.append(aggregate) return self.execute_command(*pieces) @@ -1814,12 +2565,12 @@ class StrictRedis(object): "Adds the specified elements to the specified HyperLogLog." return self.execute_command('PFADD', name, *values) - def pfcount(self, name): + def pfcount(self, *sources): """ Return the approximated cardinality of - the set observed by the HyperLogLog at key. + the set observed by the HyperLogLog at key(s). """ - return self.execute_command('PFCOUNT', name) + return self.execute_command('PFCOUNT', *sources) def pfmerge(self, dest, *sources): "Merge N different HyperLogLogs into a single one." @@ -1895,6 +2646,13 @@ class StrictRedis(object): "Return the list of values within hash ``name``" return self.execute_command('HVALS', name) + def hstrlen(self, name, key): + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command('HSTRLEN', name, key) + def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -1902,6 +2660,28 @@ class StrictRedis(object): """ return self.execute_command('PUBLISH', channel, message) + def pubsub_channels(self, pattern='*'): + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command('PUBSUB CHANNELS', pattern) + + def pubsub_numpat(self): + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command('PUBSUB NUMPAT') + + def pubsub_numsub(self, *args): + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command('PUBSUB NUMSUB', *args) + + def cluster(self, cluster_arg, *args): + return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) + def eval(self, script, numkeys, *keys_and_args): """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -1954,89 +2734,137 @@ class StrictRedis(object): """ return Script(self, script) + # GEO COMMANDS + def geoadd(self, name, *values): + """ + Add the specified geospatial items to the specified key identified + by the ``name`` argument. The Geospatial items are given as ordered + members of the ``values`` argument, each item or place is formed by + the triad longitude, latitude and name. + """ + if len(values) % 3 != 0: + raise DataError("GEOADD requires places with lon, lat and name" + " values") + return self.execute_command('GEOADD', name, *values) -class Redis(StrictRedis): - """ - Provides backwards compatibility with older versions of redis-py that - changed arguments to some commands to be more Pythonic, sane, or by - accident. - """ - - # Overridden callbacks - RESPONSE_CALLBACKS = dict_merge( - StrictRedis.RESPONSE_CALLBACKS, - { - 'TTL': lambda r: r >= 0 and r or None, - 'PTTL': lambda r: r >= 0 and r or None, - } - ) + def geodist(self, name, place1, place2, unit=None): + """ + Return the distance between ``place1`` and ``place2`` members of the + ``name`` key. + The units must be one of the following : m, km mi, ft. By default + meters are used. + """ + pieces = [name, place1, place2] + if unit and unit not in ('m', 'km', 'mi', 'ft'): + raise DataError("GEODIST invalid unit") + elif unit: + pieces.append(unit) + return self.execute_command('GEODIST', *pieces) - def pipeline(self, transaction=True, shard_hint=None): + def geohash(self, name, *values): """ - Return a new pipeline object that can queue multiple commands for - later execution. ``transaction`` indicates whether all commands - should be executed atomically. Apart from making a group of operations - atomic, pipelines are useful for reducing the back-and-forth overhead - between the client and server. + Return the geo hash string for each item of ``values`` members of + the specified key identified by the ``name`` argument. """ - return Pipeline( - self.connection_pool, - self.response_callbacks, - transaction, - shard_hint) + return self.execute_command('GEOHASH', name, *values) - def setex(self, name, value, time): + def geopos(self, name, *values): """ - Set the value of key ``name`` to ``value`` that expires in ``time`` - seconds. ``time`` can be represented by an integer or a Python - timedelta object. + Return the positions of each item of ``values`` as members of + the specified key identified by the ``name`` argument. Each position + is represented by the pairs lon and lat. """ - if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 - return self.execute_command('SETEX', name, time, value) + return self.execute_command('GEOPOS', name, *values) - def lrem(self, name, value, num=0): + def georadius(self, name, longitude, latitude, radius, unit=None, + withdist=False, withcoord=False, withhash=False, count=None, + sort=None, store=None, store_dist=None): """ - Remove the first ``num`` occurrences of elements equal to ``value`` - from the list stored at ``name``. + Return the members of the specified key identified by the + ``name`` argument which are within the borders of the area specified + with the ``latitude`` and ``longitude`` location and the maximum + distance from the center specified by the ``radius`` value. - The ``num`` argument influences the operation in the following ways: - num > 0: Remove elements equal to value moving from head to tail. - num < 0: Remove elements equal to value moving from tail to head. - num = 0: Remove all elements equal to value. + The units must be one of the following : m, km mi, ft. By default + + ``withdist`` indicates to return the distances of each place. + + ``withcoord`` indicates to return the latitude and longitude of + each place. + + ``withhash`` indicates to return the geohash string of each place. + + ``count`` indicates to return the number of elements up to N. + + ``sort`` indicates to return the places in a sorted way, ASC for + nearest to fairest and DESC for fairest to nearest. + + ``store`` indicates to save the places names in a sorted set named + with a specific key, each element of the destination sorted set is + populated with the score got from the original geo sorted set. + + ``store_dist`` indicates to save the places names in a sorted set + named with a specific key, instead of ``store`` the sorted set + destination score is set with the distance. """ - return self.execute_command('LREM', name, num, value) + return self._georadiusgeneric('GEORADIUS', + name, longitude, latitude, radius, + unit=unit, withdist=withdist, + withcoord=withcoord, withhash=withhash, + count=count, sort=sort, store=store, + store_dist=store_dist) - def zadd(self, name, *args, **kwargs): + def georadiusbymember(self, name, member, radius, unit=None, + withdist=False, withcoord=False, withhash=False, + count=None, sort=None, store=None, store_dist=None): + """ + This command is exactly like ``georadius`` with the sole difference + that instead of taking, as the center of the area to query, a longitude + and latitude value, it takes the name of a member already existing + inside the geospatial index represented by the sorted set. """ - NOTE: The order of arguments differs from that of the official ZADD - command. For backwards compatability, this method accepts arguments - in the form of name1, score1, name2, score2, while the official Redis - documents expects score1, name1, score2, name2. + return self._georadiusgeneric('GEORADIUSBYMEMBER', + name, member, radius, unit=unit, + withdist=withdist, withcoord=withcoord, + withhash=withhash, count=count, + sort=sort, store=store, + store_dist=store_dist) - If you're looking to use the standard syntax, consider using the - StrictRedis class. See the API Reference section of the docs for more - information. + def _georadiusgeneric(self, command, *args, **kwargs): + pieces = list(args) + if kwargs['unit'] and kwargs['unit'] not in ('m', 'km', 'mi', 'ft'): + raise DataError("GEORADIUS invalid unit") + elif kwargs['unit']: + pieces.append(kwargs['unit']) + else: + pieces.append('m',) - Set any number of element-name, score pairs to the key ``name``. Pairs - can be specified in two ways: + for token in ('withdist', 'withcoord', 'withhash'): + if kwargs[token]: + pieces.append(Token(token.upper())) - As *args, in the form of: name1, score1, name2, score2, ... - or as **kwargs, in the form of: name1=score1, name2=score2, ... + if kwargs['count']: + pieces.extend([Token('COUNT'), kwargs['count']]) + + if kwargs['sort'] and kwargs['sort'] not in ('ASC', 'DESC'): + raise DataError("GEORADIUS invalid sort") + elif kwargs['sort']: + pieces.append(Token(kwargs['sort'])) + + if kwargs['store'] and kwargs['store_dist']: + raise DataError("GEORADIUS store and store_dist cant be set" + " together") + + if kwargs['store']: + pieces.extend([Token('STORE'), kwargs['store']]) + + if kwargs['store_dist']: + pieces.extend([Token('STOREDIST'), kwargs['store_dist']]) + + return self.execute_command(command, *pieces, **kwargs) - The following example would add four values to the 'my-key' key: - redis.zadd('my-key', 'name1', 1.1, 'name2', 2.2, name3=3.3, name4=4.4) - """ - pieces = [] - if args: - if len(args) % 2 != 0: - raise RedisError("ZADD requires an equal number of " - "values and scores") - pieces.extend(reversed(args)) - for pair in iteritems(kwargs): - pieces.append(pair[1]) - pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) + +StrictRedis = Redis class PubSub(object): @@ -2058,13 +2886,7 @@ class PubSub(object): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection('pubsub', shard_hint) - try: - self.encoding = conn.encoding - self.encoding_errors = conn.encoding_errors - self.decode_responses = conn.decode_responses - finally: - connection_pool.release(conn) + self.encoder = self.connection_pool.get_encoder() self.reset() def __del__(self): @@ -2096,29 +2918,14 @@ class PubSub(object): if self.channels: channels = {} for k, v in iteritems(self.channels): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - channels[k] = v + channels[self.encoder.decode(k, force=True)] = v self.subscribe(**channels) if self.patterns: patterns = {} for k, v in iteritems(self.patterns): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - patterns[k] = v + patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) - def encode(self, value): - """ - Encode the value so that it's identical to what we'll - read off the connection - """ - if self.decode_responses and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - elif not self.decode_responses and isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -2127,8 +2934,8 @@ class PubSub(object): def execute_command(self, *args, **kwargs): "Execute a publish/subscribe command" - # NOTE: don't parse the response in this function. it could pull a - # legitmate message off the stack if the connection is already + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already # subscribed to one or more channels if self.connection is None: @@ -2160,10 +2967,24 @@ class PubSub(object): def parse_response(self, block=True, timeout=0): "Parse the response from a publish/subscribe command" connection = self.connection + if connection is None: + raise RuntimeError( + 'pubsub connection not set: ' + 'did you forget to call subscribe() or psubscribe()?') if not block and not connection.can_read(timeout=timeout): return None return self._execute(connection, connection.read_response) + def _normalize_keys(self, data): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in iteritems(data)} + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments @@ -2174,15 +2995,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_patterns = {} - new_patterns.update(dict.fromkeys(imap(self.encode, args))) - for pattern, handler in iteritems(kwargs): - new_patterns[self.encode(pattern)] = handler + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(new_patterns) + self.patterns.update(self._normalize_keys(new_patterns)) return ret_val def punsubscribe(self, *args): @@ -2204,15 +3023,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_channels = {} - new_channels.update(dict.fromkeys(imap(self.encode, args))) - for channel, handler in iteritems(kwargs): - new_channels[self.encode(channel)] = handler + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(new_channels) + self.channels.update(self._normalize_keys(new_channels)) return ret_val def unsubscribe(self, *args): @@ -2244,6 +3061,13 @@ class PubSub(object): return self.handle_message(response, ignore_subscribe_messages) return None + def ping(self, message=None): + """ + Ping the Redis server + """ + message = '' if message is None else message + return self.execute_command('PING', message) + def handle_message(self, response, ignore_subscribe_messages=False): """ Parses a pub/sub message. If the channel or pattern was subscribed to @@ -2258,6 +3082,13 @@ class PubSub(object): 'channel': response[2], 'data': response[3] } + elif message_type == 'pong': + message = { + 'type': message_type, + 'pattern': None, + 'channel': None, + 'data': response[1] + } else: message = { 'type': message_type, @@ -2288,7 +3119,7 @@ class PubSub(object): if handler: handler(message) return None - else: + elif message_type != 'pong': # this is a subscribe/unsubscribe message. ignore if we don't # want them if ignore_subscribe_messages or self.ignore_subscribe_messages: @@ -2296,7 +3127,7 @@ class PubSub(object): return message - def run_in_thread(self, sleep_time=0): + def run_in_thread(self, sleep_time=0, daemon=False): for channel, handler in iteritems(self.channels): if handler is None: raise PubSubError("Channel: '%s' has no handler registered") @@ -2304,14 +3135,15 @@ class PubSub(object): if handler is None: raise PubSubError("Pattern: '%s' has no handler registered") - thread = PubSubWorkerThread(self, sleep_time) + thread = PubSubWorkerThread(self, sleep_time, daemon=daemon) thread.start() return thread class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time): + def __init__(self, pubsub, sleep_time, daemon=False): super(PubSubWorkerThread, self).__init__() + self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time self._running = False @@ -2336,7 +3168,7 @@ class PubSubWorkerThread(threading.Thread): self.pubsub.punsubscribe() -class BasePipeline(object): +class Pipeline(Redis): """ Pipelines provide a way to transmit multiple commands to the Redis server in one transmission. This is convenient for batch processing, such as @@ -2355,7 +3187,7 @@ class BasePipeline(object): on a key of a different datatype. """ - UNWATCH_COMMANDS = set(('DISCARD', 'EXEC', 'UNWATCH')) + UNWATCH_COMMANDS = {'DISCARD', 'EXEC', 'UNWATCH'} def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): @@ -2473,7 +3305,8 @@ class BasePipeline(object): def _execute_transaction(self, connection, commands, raise_on_error): cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) - all_cmds = connection.pack_commands([args for args, _ in cmds]) + all_cmds = connection.pack_commands([args for args, options in cmds + if EMPTY_RESPONSE not in options]) connection.send_packed_command(all_cmds) errors = [] @@ -2488,12 +3321,15 @@ class BasePipeline(object): # and all the other commands for i, command in enumerate(commands): - try: - self.parse_response(connection, '_') - except ResponseError: - ex = sys.exc_info()[1] - self.annotate_exception(ex, i + 1, command[0]) - errors.append((i, ex)) + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + self.parse_response(connection, '_') + except ResponseError: + ex = sys.exc_info()[1] + self.annotate_exception(ex, i + 1, command[0]) + errors.append((i, ex)) # parse the EXEC. try: @@ -2556,13 +3392,13 @@ class BasePipeline(object): raise r def annotate_exception(self, exception, number, command): - cmd = safe_unicode(' ').join(imap(safe_unicode, command)) - msg = unicode('Command # %d (%s) of pipeline caused error: %s') % ( + cmd = ' '.join(imap(safe_unicode, command)) + msg = 'Command # %d (%s) of pipeline caused error: %s' % ( number, cmd, safe_unicode(exception.args[0])) exception.args = (msg,) + exception.args[1:] def parse_response(self, connection, command_name, **options): - result = StrictRedis.parse_response( + result = Redis.parse_response( self, connection, command_name, **options) if command_name in self.UNWATCH_COMMANDS: self.watching = False @@ -2577,12 +3413,11 @@ class BasePipeline(object): shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + exists = immediate('SCRIPT EXISTS', *shas) if not all(exists): for s, exist in izip(scripts, exists): if not exist: - s.sha = immediate('SCRIPT', 'LOAD', s.script, - **{'parse': 'LOAD'}) + s.sha = immediate('SCRIPT LOAD', s.script) def execute(self, raise_on_error=True): "Execute all the commands in the current pipeline" @@ -2634,26 +3469,6 @@ class BasePipeline(object): "Unwatches all previously specified keys" return self.watching and self.execute_command('UNWATCH') or True - def script_load_for_pipeline(self, script): - "Make sure scripts are loaded prior to pipeline execution" - # we need the sha now so that Script.__call__ can use it to run - # evalsha. - if not script.sha: - script.sha = self.immediate_execute_command('SCRIPT', 'LOAD', - script.script, - **{'parse': 'LOAD'}) - self.scripts.add(script) - - -class StrictPipeline(BasePipeline, StrictRedis): - "Pipeline for the StrictRedis class" - pass - - -class Pipeline(BasePipeline, Redis): - "Pipeline for the Redis class" - pass - class Script(object): "An executable Lua script object returned by ``register_script``" @@ -2661,7 +3476,14 @@ class Script(object): def __init__(self, registered_client, script): self.registered_client = registered_client self.script = script - self.sha = '' + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, basestring): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() def __call__(self, keys=[], args=[], client=None): "Execute the script, passing any required ``args``" @@ -2669,13 +3491,111 @@ class Script(object): client = self.registered_client args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script - if isinstance(client, BasePipeline): - # make sure this script is good to go on pipeline - client.script_load_for_pipeline(self) + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) try: return client.evalsha(self.sha, len(keys), *args) except NoScriptError: # Maybe the client is pointed to a differnet server than the client # that created this instance? + # Overwrite the sha just in case there was a discrepancy. self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) + + +class BitFieldOperation(object): + """ + Command builder for BITFIELD commands. + """ + def __init__(self, client, key, default_overflow=None): + self.client = client + self.key = key + self._default_overflow = default_overflow + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = 'WRAP' + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(('OVERFLOW', overflow)) + return self + + def incrby(self, fmt, offset, increment, overflow=None): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(('INCRBY', fmt, offset, increment)) + return self + + def get(self, fmt, offset): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('GET', fmt, offset)) + return self + + def set(self, fmt, offset, value): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('SET', fmt, offset, value)) + return self + + @property + def command(self): + cmd = ['BITFIELD', self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self): + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) |