diff options
-rw-r--r-- | redis/client.py | 199 | ||||
-rw-r--r-- | redis/connection.py | 30 |
2 files changed, 167 insertions, 62 deletions
diff --git a/redis/client.py b/redis/client.py index 352e98a..5f3920d 100644 --- a/redis/client.py +++ b/redis/client.py @@ -3,6 +3,7 @@ from itertools import chain, starmap import datetime import sys import warnings +import threading import time as mod_time from redis._compat import (b, izip, imap, iteritems, iterkeys, itervalues, basestring, long, nativestr, urlparse, bytes) @@ -10,11 +11,12 @@ from redis.connection import ConnectionPool, UnixDomainSocketConnection from redis.exceptions import ( ConnectionError, DataError, + ExecAbortError, + NoScriptError, + PubSubError, RedisError, ResponseError, WatchError, - NoScriptError, - ExecAbortError, ) SYM_EMPTY = b('') @@ -1536,12 +1538,16 @@ class PubSub(object): until a message arrives on one of the subscribed channels. That message will be returned and it's safe to start listening again. """ - def __init__(self, connection_pool, shard_hint=None): + MESSAGE_TYPES = ('message', 'pmessage') + + def __init__(self, connection_pool, shard_hint=None, + suppress_subscribe_messages=False): self.connection_pool = connection_pool self.shard_hint = shard_hint + self.suppress_subscribe_messages = suppress_subscribe_messages self.connection = None - self.channels = set() - self.patterns = set() + self.channels = {} + self.patterns = {} self.subscription_count = 0 self.subscribe_commands = set( ('subscribe', 'psubscribe', 'unsubscribe', 'punsubscribe') @@ -1580,8 +1586,11 @@ class PubSub(object): self.shard_hint ) connection = self.connection + self._execute(connection, connection.send_command, *args) + + def _execute(self, connection, command, *args): try: - connection.send_command(*args) + return command(*args) except ConnectionError: connection.disconnect() # Connect manually here. If the Redis server is down, this will @@ -1589,15 +1598,16 @@ class PubSub(object): connection.connect() # resubscribe to all channels and patterns before # resending the current command - for channel in self.channels: - self.subscribe(channel) - for pattern in self.patterns: - self.psubscribe(pattern) - connection.send_command(*args) + self.subscribe(**self.channels) + self.psubscribe(**self.patterns) + return command(*args) - def parse_response(self): + def parse_response(self, block=True): "Parse the response from a publish/subscribe command" - response = self.connection.read_response() + connection = self.connection + if not block and not connection.can_read(): + return None + response = self._execute(connection, connection.read_response) if nativestr(response[0]) in self.subscribe_commands: self.subscription_count = response[2] # if we've just unsubscribed from the remaining channels, @@ -1606,70 +1616,139 @@ class PubSub(object): self.reset() return response - def psubscribe(self, patterns): - "Subscribe to all channels matching any pattern in ``patterns``" - if isinstance(patterns, basestring): - patterns = [patterns] - for pattern in patterns: - self.patterns.add(pattern) - return self.execute_command('PSUBSCRIBE', *patterns) + def psubscribe(self, *args, **kwargs): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + self.patterns.update(dict.fromkeys(args)) + self.patterns.update(kwargs) + return self.execute_command('PSUBSCRIBE', *iterkeys(self.patterns)) - def punsubscribe(self, patterns=[]): + def punsubscribe(self, *args): """ - Unsubscribe from any channel matching any pattern in ``patterns``. - If empty, unsubscribe from all channels. + Unsubscribe from the supplied patterns. If empy, unsubscribe from + all patterns. """ - if isinstance(patterns, basestring): - patterns = [patterns] - for pattern in patterns: + if args: + args = list_or_args(args[0], args[1:]) + for pattern in args: try: - self.patterns.remove(pattern) + del self.patterns[pattern] except KeyError: pass - return self.execute_command('PUNSUBSCRIBE', *patterns) + return self.execute_command('PUNSUBSCRIBE', *args) - def subscribe(self, channels): - "Subscribe to ``channels``, waiting for messages to be published" - if isinstance(channels, basestring): - channels = [channels] - for channel in channels: - self.channels.add(channel) - return self.execute_command('SUBSCRIBE', *channels) + def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + self.channels.update(dict.fromkeys(args)) + self.channels.update(kwargs) + return self.execute_command('SUBSCRIBE', *iterkeys(self.channels)) - def unsubscribe(self, channels=[]): + def unsubscribe(self, *args): """ - Unsubscribe from ``channels``. If empty, unsubscribe - from all channels + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels """ - if isinstance(channels, basestring): - channels = [channels] - for channel in channels: + if args: + args = list_or_args(args[0], args[1:]) + for channel in args: try: - self.channels.remove(channel) + del self.channels[channel] except KeyError: pass - return self.execute_command('UNSUBSCRIBE', *channels) + return self.execute_command('UNSUBSCRIBE', *args) def listen(self): "Listen for messages on channels this client has been subscribed to" while self.subscription_count or self.channels or self.patterns: - r = self.parse_response() - msg_type = nativestr(r[0]) - if msg_type == 'pmessage': - msg = { - 'type': msg_type, - 'pattern': nativestr(r[1]), - 'channel': nativestr(r[2]), - 'data': r[3] - } - else: - msg = { - 'type': msg_type, - 'pattern': None, - 'channel': nativestr(r[1]), - 'data': r[2] - } - yield msg + response = self.handle_message(self.parse_response(block=True)) + if response is not None: + yield response + + def get_message(self): + "Get the next message if one is available, otherwise None" + response = self.parse_response(block=False) + return response and self.handle_message(response) or None + + def handle_message(self, response, suppress_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + msg_type = nativestr(response[0]) + handler = None + + # optionally suppress subscribe/unsubscribe messages + s = suppress_subscribe_messages or self.suppress_subscribe_messages + if s and msg_type not in self.MESSAGE_TYPES: + return None + + if msg_type == 'pmessage': + msg = { + 'type': msg_type, + 'pattern': nativestr(response[1]), + 'channel': nativestr(response[2]), + 'data': response[3] + } + handler = self.patterns.get(msg['pattern'], None) + else: + msg = { + 'type': msg_type, + 'pattern': None, + 'channel': nativestr(response[1]), + 'data': response[2] + } + handler = self.channels.get(msg['channel'], None) + + if handler: + handler(msg) + return None + else: + return msg + + def run_in_thread(self, sleep_time=0): + for channel, handler in iteritems(self.channels): + if handler is None: + raise PubSubError("Channel: '%s' has no handler registered") + for pattern, handler in iteritems(self.patterns): + if handler is None: + raise PubSubError("Pattern: '%s' has no handler registered") + pubsub = self + + class WorkerThread(threading.Thread): + def __init__(self, *args, **kwargs): + super(WorkerThread, self).__init__(*args, **kwargs) + self._running = False + + def run(self): + if self._running: + return + self._running = True + while self._running: + pubsub.get_message() + mod_time.sleep(sleep_time) + + def stop(self): + self._running = False + self.join() + + thread = WorkerThread() + thread.start() + return thread class BasePipeline(object): diff --git a/redis/connection.py b/redis/connection.py index e138d4c..d949fef 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,8 +1,8 @@ -from itertools import chain import os import socket import sys - +from itertools import chain +from select import select from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, BytesIO, nativestr, basestring, LifoQueue, Empty, Full) @@ -62,6 +62,9 @@ class PythonParser(object): self._fp.close() self._fp = None + def can_read(self): + return self._fp and self._fp._rbuf.tell() + def read(self, length=None): """ Read a line from the socket if no length is specified, @@ -146,6 +149,7 @@ class HiredisParser(object): def __init__(self): if not hiredis_available: raise RedisError("Hiredis is not installed") + self._next_response = False def __del__(self): try: @@ -167,9 +171,24 @@ class HiredisParser(object): self._sock = None self._reader = None + def can_read(self): + if not self._reader: + raise ConnectionError("Socket closed on remote end") + + if self._next_response is False: + self._next_response = self._reader.gets() + return self._next_response is not False + def read_response(self): if not self._reader: raise ConnectionError("Socket closed on remote end") + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + response = self._reader.gets() while response is False: try: @@ -298,6 +317,13 @@ class Connection(object): "Pack and send a command to the Redis server" self.send_packed_command(self.pack_command(*args)) + def can_read(self): + "Poll the socket to see if there's data that can be read." + sock = self._sock + if not sock: + return False + return bool(select([sock], [], [], 0)[0]) or self._parser.can_read() + def read_response(self): "Read the response from a previously sent command" try: |