summaryrefslogtreecommitdiff
path: root/redis/client.py
diff options
context:
space:
mode:
authorKonstantin Merenkov <kmerenkov@yandex-team.ru>2010-11-02 16:06:27 +0300
committerKonstantin Merenkov <kmerenkov@yandex-team.ru>2010-11-02 16:06:27 +0300
commit36c73d49a6255b1c3f86a257d6bbc97ba898c77d (patch)
tree3982987c05cbd952d720797a4b1f28c72f0c83fd /redis/client.py
parentab623e03ca84d22313af06744a1f887ed436b9f0 (diff)
parent6eeee751a90779d65d45165279dc21dae399ac7a (diff)
downloadredis-py-36c73d49a6255b1c3f86a257d6bbc97ba898c77d.tar.gz
Merge branch 'master' of http://github.com/andymccurdy/redis-py
Diffstat (limited to 'redis/client.py')
-rw-r--r--redis/client.py264
1 files changed, 238 insertions, 26 deletions
diff --git a/redis/client.py b/redis/client.py
index d5e33ab..1274228 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -4,8 +4,8 @@ import socket
import threading
import time
import warnings
-from itertools import chain
-from redis.exceptions import ConnectionError, ResponseError, InvalidResponse
+from itertools import chain, imap
+from redis.exceptions import ConnectionError, ResponseError, InvalidResponse, WatchError
from redis.exceptions import RedisError, AuthenticationError
@@ -207,8 +207,9 @@ class Redis(threading.local):
bool
),
string_keys_to_dict(
- 'DECRBY HLEN INCRBY LLEN SCARD SDIFFSTORE SINTERSTORE '
- 'SUNIONSTORE ZCARD ZREMRANGEBYSCORE ZREVRANK',
+ 'DECRBY HLEN INCRBY LINSERT LLEN LPUSHX RPUSHX SCARD SDIFFSTORE '
+ 'SINTERSTORE SUNIONSTORE ZCARD ZREMRANGEBYRANK ZREMRANGEBYSCORE '
+ 'ZREVRANK',
int
),
string_keys_to_dict(
@@ -219,11 +220,12 @@ class Redis(threading.local):
string_keys_to_dict('ZSCORE ZINCRBY', float_or_none),
string_keys_to_dict(
'FLUSHALL FLUSHDB LSET LTRIM MSET RENAME '
- 'SAVE SELECT SET SHUTDOWN',
+ 'SAVE SELECT SET SHUTDOWN WATCH UNWATCH',
lambda r: r == 'OK'
),
+ string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None),
string_keys_to_dict('SDIFF SINTER SMEMBERS SUNION',
- lambda r: set(r)
+ lambda r: r and set(r) or set()
),
string_keys_to_dict('ZRANGE ZRANGEBYSCORE ZREVRANGE', zset_score_pairs),
{
@@ -241,7 +243,9 @@ class Redis(threading.local):
)
# commands that should NOT pull data off the network buffer when executed
- SUBSCRIPTION_COMMANDS = set(['SUBSCRIBE', 'UNSUBSCRIBE'])
+ SUBSCRIPTION_COMMANDS = set([
+ 'SUBSCRIBE', 'UNSUBSCRIBE', 'PSUBSCRIBE', 'PUNSUBSCRIBE'
+ ])
def __init__(self, host='localhost', port=6379,
db=0, password=None, socket_timeout=None,
@@ -282,6 +286,19 @@ class Redis(threading.local):
self.errors
)
+ def lock(self, name, timeout=None, sleep=0.1):
+ """
+ Return a new Lock object using key ``name`` that mimics
+ the behavior of threading.Lock.
+
+ If specified, ``timeout`` indicates a maximum life for the lock.
+ By default, it will remain locked until release() is called.
+
+ ``sleep`` indicates the amount of time to sleep per loop iteration
+ when the lock is in blocking mode and another client is currently
+ holding the lock.
+ """
+ return Lock(self, name, timeout=timeout, sleep=sleep)
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
def _execute_command(self, command_name, command, **options):
@@ -303,14 +320,11 @@ class Redis(threading.local):
def execute_command(self, *args, **options):
"Sends the command to the redis server and returns it's response"
- cmd_count = len(args)
- cmds = []
- for i in args:
- enc_value = self.encode(i)
- cmds.append('$%s\r\n%s\r\n' % (len(enc_value), enc_value))
+ cmds = ['$%s\r\n%s\r\n' % (len(enc_value), enc_value)
+ for enc_value in imap(self.encode, args)]
return self._execute_command(
args[0],
- '*%s\r\n%s' % (cmd_count, ''.join(cmds)),
+ '*%s\r\n%s' % (len(cmds), ''.join(cmds)),
**options
)
@@ -432,6 +446,17 @@ class Redis(threading.local):
self.connection = self.get_connection(
host, port, db, password, socket_timeout)
+ def shutdown(self):
+ "Shutdown the server"
+ if self.subscribed:
+ raise RedisError("Can't call 'shutdown' from a pipeline'")
+ try:
+ self.execute_command('SHUTDOWN')
+ except ConnectionError:
+ # a ConnectionError here is expected
+ return
+ raise RedisError("SHUTDOWN seems to have failed.")
+
#### SERVER INFORMATION ####
def bgrewriteaof(self):
@@ -563,7 +588,8 @@ class Redis(threading.local):
def mset(self, mapping):
"Sets each key in the ``mapping`` dict to its corresponding value"
items = []
- [items.extend(pair) for pair in mapping.iteritems()]
+ for pair in mapping.iteritems():
+ items.extend(pair)
return self.execute_command('MSET', *items)
def msetnx(self, mapping):
@@ -572,7 +598,8 @@ class Redis(threading.local):
none of the keys are already set
"""
items = []
- [items.extend(pair) for pair in mapping.iteritems()]
+ for pair in mapping.iteritems():
+ items.extend(pair)
return self.execute_command('MSETNX', *items)
def move(self, name, db):
@@ -657,6 +684,23 @@ class Redis(threading.local):
"Returns the type of key ``name``"
return self.execute_command('TYPE', name)
+ def watch(self, name):
+ """
+ Watches the value at key ``name``, or None of the key doesn't exist
+ """
+ if self.subscribed:
+ raise RedisError("Can't call 'watch' from a pipeline'")
+
+ return self.execute_command('WATCH', name)
+
+ def unwatch(self):
+ """
+ Unwatches the value at key ``name``, or None of the key doesn't exist
+ """
+ if self.subscribed:
+ raise RedisError("Can't call 'unwatch' from a pipeline'")
+
+ return self.execute_command('UNWATCH')
#### LIST COMMANDS ####
def blpop(self, keys, timeout=0):
@@ -670,7 +714,12 @@ class Redis(threading.local):
If timeout is 0, then block indefinitely.
"""
- keys = list(keys)
+ if timeout is None:
+ timeout = 0
+ if isinstance(keys, basestring):
+ keys = [keys]
+ else:
+ keys = list(keys)
keys.append(timeout)
return self.execute_command('BLPOP', *keys)
@@ -685,7 +734,12 @@ class Redis(threading.local):
If timeout is 0, then block indefinitely.
"""
- keys = list(keys)
+ if timeout is None:
+ timeout = 0
+ if isinstance(keys, basestring):
+ keys = [keys]
+ else:
+ keys = list(keys)
keys.append(timeout)
return self.execute_command('BRPOP', *keys)
@@ -697,7 +751,17 @@ class Redis(threading.local):
end of the list
"""
return self.execute_command('LINDEX', name, index)
-
+
+ def linsert(self, name, where, refvalue, value):
+ """
+ Insert ``value`` in list ``name`` either immediately before or after
+ [``where``] ``refvalue``
+
+ Returns the new length of the list on success or -1 if ``refvalue``
+ is not in the list.
+ """
+ return self.execute_command('LINSERT', name, where, refvalue, value)
+
def llen(self, name):
"Return the length of the list ``name``"
return self.execute_command('LLEN', name)
@@ -709,6 +773,10 @@ class Redis(threading.local):
def lpush(self, name, value):
"Push ``value`` onto the head of the list ``name``"
return self.execute_command('LPUSH', name, value)
+
+ def lpushx(self, name, value):
+ "Push ``value`` onto the head of the list ``name`` if ``name`` exists"
+ return self.execute_command('LPUSHX', name, value)
def lrange(self, name, start, end):
"""
@@ -785,8 +853,12 @@ class Redis(threading.local):
"Push ``value`` onto the tail of the list ``name``"
return self.execute_command('RPUSH', name, value)
+ def rpushx(self, name, value):
+ "Push ``value`` onto the tail of the list ``name`` if ``name`` exists"
+ return self.execute_command('RPUSHX', name, value)
+
def sort(self, name, start=None, num=None, by=None, get=None,
- desc=False, alpha=False, store=None):
+ desc=False, alpha=False, store=None):
"""
Sort and return the list, set or sorted set at ``name``.
@@ -819,8 +891,17 @@ class Redis(threading.local):
pieces.append(start)
pieces.append(num)
if get is not None:
- pieces.append('GET')
- pieces.append(get)
+ # If get is a string assume we want to get a single value.
+ # Otherwise assume it's an interable and we want to get multiple
+ # values. We can't just iterate blindly because strings are
+ # iterable.
+ if isinstance(get, basestring):
+ pieces.append('GET')
+ pieces.append(get)
+ else:
+ for g in get:
+ pieces.append('GET')
+ pieces.append(g)
if desc:
pieces.append('DESC')
if alpha:
@@ -913,6 +994,9 @@ class Redis(threading.local):
"Return the number of elements in the sorted set ``name``"
return self.execute_command('ZCARD', name)
+ def zcount(self, name, min, max):
+ return self.execute_command('ZCOUNT', name, min, max)
+
def zincr(self, key, member, value=1):
"This has been deprecated, use zincrby instead"
warnings.warn(DeprecationWarning(
@@ -925,6 +1009,12 @@ class Redis(threading.local):
return self.execute_command('ZINCRBY', name, amount, value)
def zinter(self, dest, keys, aggregate=None):
+ warnings.warn(DeprecationWarning(
+ "Redis.zinter has been deprecated, use Redis.zinterstore instead"
+ ))
+ return self.zinterstore(dest, keys, aggregate)
+
+ def zinterstore(self, dest, keys, aggregate=None):
"""
Intersect multiple sorted sets specified by ``keys`` into
a new sorted set, ``dest``. Scores in the destination will be
@@ -983,10 +1073,19 @@ class Redis(threading.local):
"Remove member ``value`` from sorted set ``name``"
return self.execute_command('ZREM', name, value)
+ def zremrangebyrank(self, name, min, max):
+ """
+ Remove all elements in the sorted set ``name`` with ranks between
+ ``min`` and ``max``. Values are 0-based, ordered from smallest score
+ to largest. Values can be negative indicating the highest scores.
+ Returns the number of elements removed
+ """
+ return self.execute_command('ZREMRANGEBYRANK', name, min, max)
+
def zremrangebyscore(self, name, min, max):
"""
Remove all elements in the sorted set ``name`` with scores
- between ``min`` and ``max``
+ between ``min`` and ``max``. Returns the number of elements removed.
"""
return self.execute_command('ZREMRANGEBYSCORE', name, min, max)
@@ -1017,6 +1116,12 @@ class Redis(threading.local):
return self.execute_command('ZSCORE', name, value)
def zunion(self, dest, keys, aggregate=None):
+ warnings.warn(DeprecationWarning(
+ "Redis.zunion has been deprecated, use Redis.zunionstore instead"
+ ))
+ return self.zunionstore(dest, keys, aggregate)
+
+ def zunionstore(self, dest, keys, aggregate=None):
"""
Union multiple sorted sets specified by ``keys`` into
a new sorted set, ``dest``. Scores in the destination will be
@@ -1077,13 +1182,21 @@ class Redis(threading.local):
"""
return self.execute_command('HSET', name, key, value)
+ def hsetnx(self, name, key, value):
+ """
+ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not
+ exist. Returns 1 if HSETNX created a field, otherwise 0.
+ """
+ return self.execute_command("HSETNX", name, key, value)
+
def hmset(self, name, mapping):
"""
Sets each key in the ``mapping`` dict to its corresponding value
in the hash ``name``
"""
items = []
- [items.extend(pair) for pair in mapping.iteritems()]
+ for pair in mapping.iteritems():
+ items.extend(pair)
return self.execute_command('HMSET', name, *items)
def hmget(self, name, keys):
@@ -1145,10 +1258,23 @@ class Redis(threading.local):
"Listen for messages on channels this client has been subscribed to"
while self.subscribed:
r = self.parse_response('LISTEN')
- message_type, channel, message = r[0], r[1], r[2]
- yield (message_type, channel, message)
- if message_type == 'unsubscribe' and message == 0:
+ if r[0] == 'pmessage':
+ msg = {
+ 'type': r[0],
+ 'pattern': r[1],
+ 'channel': r[2],
+ 'data': r[3]
+ }
+ else:
+ msg = {
+ 'type': r[0],
+ 'pattern': None,
+ 'channel': r[1],
+ 'data': r[2]
+ }
+ if r[0] == 'unsubscribe' and r[2] == 0:
self.subscribed = False
+ yield msg
class Pipeline(Redis):
@@ -1217,6 +1343,10 @@ class Pipeline(Redis):
_ = self.parse_response('_')
# parse the EXEC. we want errors returned as items in the response
response = self.parse_response('_', catch_errors=True)
+
+ if response is None:
+ raise WatchError("Watched variable changed.")
+
if len(response) != len(commands):
raise ResponseError("Wrong number of response items from "
"pipeline execution")
@@ -1257,3 +1387,85 @@ class Pipeline(Redis):
def select(self, *args, **kwargs):
raise RedisError("Cannot select a different database from a pipeline")
+
+class Lock(object):
+ """
+ A shared, distributed Lock. Using Redis for locking allows the Lock
+ to be shared across processes and/or machines.
+
+ It's left to the user to resolve deadlock issues and make sure
+ multiple clients play nicely together.
+ """
+
+ LOCK_FOREVER = 2**31+1 # 1 past max unix time
+
+ def __init__(self, redis, name, timeout=None, sleep=0.1):
+ """
+ Create a new Lock instnace named ``name`` using the Redis client
+ supplied by ``redis``.
+
+ ``timeout`` indicates a maximum life for the lock.
+ By default, it will remain locked until release() is called.
+
+ ``sleep`` indicates the amount of time to sleep per loop iteration
+ when the lock is in blocking mode and another client is currently
+ holding the lock.
+
+ Note: If using ``timeout``, you should make sure all the hosts
+ that are running clients are within the same timezone and are using
+ a network time service like ntp.
+ """
+ self.redis = redis
+ self.name = name
+ self.acquired_until = None
+ self.timeout = timeout
+ self.sleep = sleep
+
+ def __enter__(self):
+ return self.acquire()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.release()
+
+ def acquire(self, blocking=True):
+ """
+ Use Redis to hold a shared, distributed lock named ``name``.
+ Returns True once the lock is acquired.
+
+ If ``blocking`` is False, always return immediately. If the lock
+ was acquired, return True, otherwise return False.
+ """
+ sleep = self.sleep
+ timeout = self.timeout
+ while 1:
+ unixtime = int(time.time())
+ if timeout:
+ timeout_at = unixtime + timeout
+ else:
+ timeout_at = Lock.LOCK_FOREVER
+ if self.redis.setnx(self.name, timeout_at):
+ self.acquired_until = timeout_at
+ return True
+ # We want blocking, but didn't acquire the lock
+ # check to see if the current lock is expired
+ existing = long(self.redis.get(self.name) or 1)
+ if existing < unixtime:
+ # the previous lock is expired, attempt to overwrite it
+ existing = long(self.redis.getset(self.name, timeout_at) or 1)
+ if existing < unixtime:
+ # we successfully acquired the lock
+ self.acquired_until = timeout_at
+ return True
+ if not blocking:
+ return False
+ time.sleep(sleep)
+
+ def release(self):
+ "Releases the already acquired lock"
+ if self.acquired_until is None:
+ raise ValueError("Cannot release an unlocked lock")
+ existing = long(self.redis.get(self.name) or 1)
+ # if the lock time is in the future, delete the lock
+ if existing >= self.acquired_until:
+ self.redis.delete(self.name)
+ self.acquired_until = None