diff options
Diffstat (limited to 'redis/client.py')
-rw-r--r-- | redis/client.py | 264 |
1 files changed, 238 insertions, 26 deletions
diff --git a/redis/client.py b/redis/client.py index d5e33ab..1274228 100644 --- a/redis/client.py +++ b/redis/client.py @@ -4,8 +4,8 @@ import socket import threading import time import warnings -from itertools import chain -from redis.exceptions import ConnectionError, ResponseError, InvalidResponse +from itertools import chain, imap +from redis.exceptions import ConnectionError, ResponseError, InvalidResponse, WatchError from redis.exceptions import RedisError, AuthenticationError @@ -207,8 +207,9 @@ class Redis(threading.local): bool ), string_keys_to_dict( - 'DECRBY HLEN INCRBY LLEN SCARD SDIFFSTORE SINTERSTORE ' - 'SUNIONSTORE ZCARD ZREMRANGEBYSCORE ZREVRANK', + 'DECRBY HLEN INCRBY LINSERT LLEN LPUSHX RPUSHX SCARD SDIFFSTORE ' + 'SINTERSTORE SUNIONSTORE ZCARD ZREMRANGEBYRANK ZREMRANGEBYSCORE ' + 'ZREVRANK', int ), string_keys_to_dict( @@ -219,11 +220,12 @@ class Redis(threading.local): string_keys_to_dict('ZSCORE ZINCRBY', float_or_none), string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET RENAME ' - 'SAVE SELECT SET SHUTDOWN', + 'SAVE SELECT SET SHUTDOWN WATCH UNWATCH', lambda r: r == 'OK' ), + string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), string_keys_to_dict('SDIFF SINTER SMEMBERS SUNION', - lambda r: set(r) + lambda r: r and set(r) or set() ), string_keys_to_dict('ZRANGE ZRANGEBYSCORE ZREVRANGE', zset_score_pairs), { @@ -241,7 +243,9 @@ class Redis(threading.local): ) # commands that should NOT pull data off the network buffer when executed - SUBSCRIPTION_COMMANDS = set(['SUBSCRIBE', 'UNSUBSCRIBE']) + SUBSCRIPTION_COMMANDS = set([ + 'SUBSCRIBE', 'UNSUBSCRIBE', 'PSUBSCRIBE', 'PUNSUBSCRIBE' + ]) def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, @@ -282,6 +286,19 @@ class Redis(threading.local): self.errors ) + def lock(self, name, timeout=None, sleep=0.1): + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + """ + return Lock(self, name, timeout=timeout, sleep=sleep) #### COMMAND EXECUTION AND PROTOCOL PARSING #### def _execute_command(self, command_name, command, **options): @@ -303,14 +320,11 @@ class Redis(threading.local): def execute_command(self, *args, **options): "Sends the command to the redis server and returns it's response" - cmd_count = len(args) - cmds = [] - for i in args: - enc_value = self.encode(i) - cmds.append('$%s\r\n%s\r\n' % (len(enc_value), enc_value)) + cmds = ['$%s\r\n%s\r\n' % (len(enc_value), enc_value) + for enc_value in imap(self.encode, args)] return self._execute_command( args[0], - '*%s\r\n%s' % (cmd_count, ''.join(cmds)), + '*%s\r\n%s' % (len(cmds), ''.join(cmds)), **options ) @@ -432,6 +446,17 @@ class Redis(threading.local): self.connection = self.get_connection( host, port, db, password, socket_timeout) + def shutdown(self): + "Shutdown the server" + if self.subscribed: + raise RedisError("Can't call 'shutdown' from a pipeline'") + try: + self.execute_command('SHUTDOWN') + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + #### SERVER INFORMATION #### def bgrewriteaof(self): @@ -563,7 +588,8 @@ class Redis(threading.local): def mset(self, mapping): "Sets each key in the ``mapping`` dict to its corresponding value" items = [] - [items.extend(pair) for pair in mapping.iteritems()] + for pair in mapping.iteritems(): + items.extend(pair) return self.execute_command('MSET', *items) def msetnx(self, mapping): @@ -572,7 +598,8 @@ class Redis(threading.local): none of the keys are already set """ items = [] - [items.extend(pair) for pair in mapping.iteritems()] + for pair in mapping.iteritems(): + items.extend(pair) return self.execute_command('MSETNX', *items) def move(self, name, db): @@ -657,6 +684,23 @@ class Redis(threading.local): "Returns the type of key ``name``" return self.execute_command('TYPE', name) + def watch(self, name): + """ + Watches the value at key ``name``, or None of the key doesn't exist + """ + if self.subscribed: + raise RedisError("Can't call 'watch' from a pipeline'") + + return self.execute_command('WATCH', name) + + def unwatch(self): + """ + Unwatches the value at key ``name``, or None of the key doesn't exist + """ + if self.subscribed: + raise RedisError("Can't call 'unwatch' from a pipeline'") + + return self.execute_command('UNWATCH') #### LIST COMMANDS #### def blpop(self, keys, timeout=0): @@ -670,7 +714,12 @@ class Redis(threading.local): If timeout is 0, then block indefinitely. """ - keys = list(keys) + if timeout is None: + timeout = 0 + if isinstance(keys, basestring): + keys = [keys] + else: + keys = list(keys) keys.append(timeout) return self.execute_command('BLPOP', *keys) @@ -685,7 +734,12 @@ class Redis(threading.local): If timeout is 0, then block indefinitely. """ - keys = list(keys) + if timeout is None: + timeout = 0 + if isinstance(keys, basestring): + keys = [keys] + else: + keys = list(keys) keys.append(timeout) return self.execute_command('BRPOP', *keys) @@ -697,7 +751,17 @@ class Redis(threading.local): end of the list """ return self.execute_command('LINDEX', name, index) - + + def linsert(self, name, where, refvalue, value): + """ + Insert ``value`` in list ``name`` either immediately before or after + [``where``] ``refvalue`` + + Returns the new length of the list on success or -1 if ``refvalue`` + is not in the list. + """ + return self.execute_command('LINSERT', name, where, refvalue, value) + def llen(self, name): "Return the length of the list ``name``" return self.execute_command('LLEN', name) @@ -709,6 +773,10 @@ class Redis(threading.local): def lpush(self, name, value): "Push ``value`` onto the head of the list ``name``" return self.execute_command('LPUSH', name, value) + + def lpushx(self, name, value): + "Push ``value`` onto the head of the list ``name`` if ``name`` exists" + return self.execute_command('LPUSHX', name, value) def lrange(self, name, start, end): """ @@ -785,8 +853,12 @@ class Redis(threading.local): "Push ``value`` onto the tail of the list ``name``" return self.execute_command('RPUSH', name, value) + def rpushx(self, name, value): + "Push ``value`` onto the tail of the list ``name`` if ``name`` exists" + return self.execute_command('RPUSHX', name, value) + def sort(self, name, start=None, num=None, by=None, get=None, - desc=False, alpha=False, store=None): + desc=False, alpha=False, store=None): """ Sort and return the list, set or sorted set at ``name``. @@ -819,8 +891,17 @@ class Redis(threading.local): pieces.append(start) pieces.append(num) if get is not None: - pieces.append('GET') - pieces.append(get) + # If get is a string assume we want to get a single value. + # Otherwise assume it's an interable and we want to get multiple + # values. We can't just iterate blindly because strings are + # iterable. + if isinstance(get, basestring): + pieces.append('GET') + pieces.append(get) + else: + for g in get: + pieces.append('GET') + pieces.append(g) if desc: pieces.append('DESC') if alpha: @@ -913,6 +994,9 @@ class Redis(threading.local): "Return the number of elements in the sorted set ``name``" return self.execute_command('ZCARD', name) + def zcount(self, name, min, max): + return self.execute_command('ZCOUNT', name, min, max) + def zincr(self, key, member, value=1): "This has been deprecated, use zincrby instead" warnings.warn(DeprecationWarning( @@ -925,6 +1009,12 @@ class Redis(threading.local): return self.execute_command('ZINCRBY', name, amount, value) def zinter(self, dest, keys, aggregate=None): + warnings.warn(DeprecationWarning( + "Redis.zinter has been deprecated, use Redis.zinterstore instead" + )) + return self.zinterstore(dest, keys, aggregate) + + def zinterstore(self, dest, keys, aggregate=None): """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -983,10 +1073,19 @@ class Redis(threading.local): "Remove member ``value`` from sorted set ``name``" return self.execute_command('ZREM', name, value) + def zremrangebyrank(self, name, min, max): + """ + Remove all elements in the sorted set ``name`` with ranks between + ``min`` and ``max``. Values are 0-based, ordered from smallest score + to largest. Values can be negative indicating the highest scores. + Returns the number of elements removed + """ + return self.execute_command('ZREMRANGEBYRANK', name, min, max) + def zremrangebyscore(self, name, min, max): """ Remove all elements in the sorted set ``name`` with scores - between ``min`` and ``max`` + between ``min`` and ``max``. Returns the number of elements removed. """ return self.execute_command('ZREMRANGEBYSCORE', name, min, max) @@ -1017,6 +1116,12 @@ class Redis(threading.local): return self.execute_command('ZSCORE', name, value) def zunion(self, dest, keys, aggregate=None): + warnings.warn(DeprecationWarning( + "Redis.zunion has been deprecated, use Redis.zunionstore instead" + )) + return self.zunionstore(dest, keys, aggregate) + + def zunionstore(self, dest, keys, aggregate=None): """ Union multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -1077,13 +1182,21 @@ class Redis(threading.local): """ return self.execute_command('HSET', name, key, value) + def hsetnx(self, name, key, value): + """ + Set ``key`` to ``value`` within hash ``name`` if ``key`` does not + exist. Returns 1 if HSETNX created a field, otherwise 0. + """ + return self.execute_command("HSETNX", name, key, value) + def hmset(self, name, mapping): """ Sets each key in the ``mapping`` dict to its corresponding value in the hash ``name`` """ items = [] - [items.extend(pair) for pair in mapping.iteritems()] + for pair in mapping.iteritems(): + items.extend(pair) return self.execute_command('HMSET', name, *items) def hmget(self, name, keys): @@ -1145,10 +1258,23 @@ class Redis(threading.local): "Listen for messages on channels this client has been subscribed to" while self.subscribed: r = self.parse_response('LISTEN') - message_type, channel, message = r[0], r[1], r[2] - yield (message_type, channel, message) - if message_type == 'unsubscribe' and message == 0: + if r[0] == 'pmessage': + msg = { + 'type': r[0], + 'pattern': r[1], + 'channel': r[2], + 'data': r[3] + } + else: + msg = { + 'type': r[0], + 'pattern': None, + 'channel': r[1], + 'data': r[2] + } + if r[0] == 'unsubscribe' and r[2] == 0: self.subscribed = False + yield msg class Pipeline(Redis): @@ -1217,6 +1343,10 @@ class Pipeline(Redis): _ = self.parse_response('_') # parse the EXEC. we want errors returned as items in the response response = self.parse_response('_', catch_errors=True) + + if response is None: + raise WatchError("Watched variable changed.") + if len(response) != len(commands): raise ResponseError("Wrong number of response items from " "pipeline execution") @@ -1257,3 +1387,85 @@ class Pipeline(Redis): def select(self, *args, **kwargs): raise RedisError("Cannot select a different database from a pipeline") + +class Lock(object): + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + LOCK_FOREVER = 2**31+1 # 1 past max unix time + + def __init__(self, redis, name, timeout=None, sleep=0.1): + """ + Create a new Lock instnace named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + Note: If using ``timeout``, you should make sure all the hosts + that are running clients are within the same timezone and are using + a network time service like ntp. + """ + self.redis = redis + self.name = name + self.acquired_until = None + self.timeout = timeout + self.sleep = sleep + + def __enter__(self): + return self.acquire() + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + + def acquire(self, blocking=True): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + """ + sleep = self.sleep + timeout = self.timeout + while 1: + unixtime = int(time.time()) + if timeout: + timeout_at = unixtime + timeout + else: + timeout_at = Lock.LOCK_FOREVER + if self.redis.setnx(self.name, timeout_at): + self.acquired_until = timeout_at + return True + # We want blocking, but didn't acquire the lock + # check to see if the current lock is expired + existing = long(self.redis.get(self.name) or 1) + if existing < unixtime: + # the previous lock is expired, attempt to overwrite it + existing = long(self.redis.getset(self.name, timeout_at) or 1) + if existing < unixtime: + # we successfully acquired the lock + self.acquired_until = timeout_at + return True + if not blocking: + return False + time.sleep(sleep) + + def release(self): + "Releases the already acquired lock" + if self.acquired_until is None: + raise ValueError("Cannot release an unlocked lock") + existing = long(self.redis.get(self.name) or 1) + # if the lock time is in the future, delete the lock + if existing >= self.acquired_until: + self.redis.delete(self.name) + self.acquired_until = None |