summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandy <andy@whiskeymedia.com>2013-05-25 10:51:01 -0400
committerandy <andy@whiskeymedia.com>2013-05-25 10:51:01 -0400
commitf69b9095fa10bc02ad2eb4fd251204316b869bc7 (patch)
tree859adad10403bf1bf3822f57a3c6624dda944d7f
parentd36f7532afc8dfc61d52bda09d7b3f8cd0ab2aaf (diff)
downloadredis-py-f69b9095fa10bc02ad2eb4fd251204316b869bc7.tar.gz
refactored pubsub. needs tests
-rw-r--r--redis/client.py199
-rw-r--r--redis/connection.py30
2 files changed, 167 insertions, 62 deletions
diff --git a/redis/client.py b/redis/client.py
index 352e98a..5f3920d 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -3,6 +3,7 @@ from itertools import chain, starmap
import datetime
import sys
import warnings
+import threading
import time as mod_time
from redis._compat import (b, izip, imap, iteritems, iterkeys, itervalues,
basestring, long, nativestr, urlparse, bytes)
@@ -10,11 +11,12 @@ from redis.connection import ConnectionPool, UnixDomainSocketConnection
from redis.exceptions import (
ConnectionError,
DataError,
+ ExecAbortError,
+ NoScriptError,
+ PubSubError,
RedisError,
ResponseError,
WatchError,
- NoScriptError,
- ExecAbortError,
)
SYM_EMPTY = b('')
@@ -1536,12 +1538,16 @@ class PubSub(object):
until a message arrives on one of the subscribed channels. That message
will be returned and it's safe to start listening again.
"""
- def __init__(self, connection_pool, shard_hint=None):
+ MESSAGE_TYPES = ('message', 'pmessage')
+
+ def __init__(self, connection_pool, shard_hint=None,
+ suppress_subscribe_messages=False):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
+ self.suppress_subscribe_messages = suppress_subscribe_messages
self.connection = None
- self.channels = set()
- self.patterns = set()
+ self.channels = {}
+ self.patterns = {}
self.subscription_count = 0
self.subscribe_commands = set(
('subscribe', 'psubscribe', 'unsubscribe', 'punsubscribe')
@@ -1580,8 +1586,11 @@ class PubSub(object):
self.shard_hint
)
connection = self.connection
+ self._execute(connection, connection.send_command, *args)
+
+ def _execute(self, connection, command, *args):
try:
- connection.send_command(*args)
+ return command(*args)
except ConnectionError:
connection.disconnect()
# Connect manually here. If the Redis server is down, this will
@@ -1589,15 +1598,16 @@ class PubSub(object):
connection.connect()
# resubscribe to all channels and patterns before
# resending the current command
- for channel in self.channels:
- self.subscribe(channel)
- for pattern in self.patterns:
- self.psubscribe(pattern)
- connection.send_command(*args)
+ self.subscribe(**self.channels)
+ self.psubscribe(**self.patterns)
+ return command(*args)
- def parse_response(self):
+ def parse_response(self, block=True):
"Parse the response from a publish/subscribe command"
- response = self.connection.read_response()
+ connection = self.connection
+ if not block and not connection.can_read():
+ return None
+ response = self._execute(connection, connection.read_response)
if nativestr(response[0]) in self.subscribe_commands:
self.subscription_count = response[2]
# if we've just unsubscribed from the remaining channels,
@@ -1606,70 +1616,139 @@ class PubSub(object):
self.reset()
return response
- def psubscribe(self, patterns):
- "Subscribe to all channels matching any pattern in ``patterns``"
- if isinstance(patterns, basestring):
- patterns = [patterns]
- for pattern in patterns:
- self.patterns.add(pattern)
- return self.execute_command('PSUBSCRIBE', *patterns)
+ def psubscribe(self, *args, **kwargs):
+ """
+ Subscribe to channel patterns. Patterns supplied as keyword arguments
+ expect a pattern name as the key and a callable as the value. A
+ pattern's callable will be invoked automatically when a message is
+ received on that pattern rather than producing a message via
+ ``listen()``.
+ """
+ if args:
+ args = list_or_args(args[0], args[1:])
+ self.patterns.update(dict.fromkeys(args))
+ self.patterns.update(kwargs)
+ return self.execute_command('PSUBSCRIBE', *iterkeys(self.patterns))
- def punsubscribe(self, patterns=[]):
+ def punsubscribe(self, *args):
"""
- Unsubscribe from any channel matching any pattern in ``patterns``.
- If empty, unsubscribe from all channels.
+ Unsubscribe from the supplied patterns. If empy, unsubscribe from
+ all patterns.
"""
- if isinstance(patterns, basestring):
- patterns = [patterns]
- for pattern in patterns:
+ if args:
+ args = list_or_args(args[0], args[1:])
+ for pattern in args:
try:
- self.patterns.remove(pattern)
+ del self.patterns[pattern]
except KeyError:
pass
- return self.execute_command('PUNSUBSCRIBE', *patterns)
+ return self.execute_command('PUNSUBSCRIBE', *args)
- def subscribe(self, channels):
- "Subscribe to ``channels``, waiting for messages to be published"
- if isinstance(channels, basestring):
- channels = [channels]
- for channel in channels:
- self.channels.add(channel)
- return self.execute_command('SUBSCRIBE', *channels)
+ def subscribe(self, *args, **kwargs):
+ """
+ Subscribe to channels. Channels supplied as keyword arguments expect
+ a channel name as the key and a callable as the value. A channel's
+ callable will be invoked automatically when a message is received on
+ that channel rather than producing a message via ``listen()``.
+ """
+ if args:
+ args = list_or_args(args[0], args[1:])
+ self.channels.update(dict.fromkeys(args))
+ self.channels.update(kwargs)
+ return self.execute_command('SUBSCRIBE', *iterkeys(self.channels))
- def unsubscribe(self, channels=[]):
+ def unsubscribe(self, *args):
"""
- Unsubscribe from ``channels``. If empty, unsubscribe
- from all channels
+ Unsubscribe from the supplied channels. If empty, unsubscribe from
+ all channels
"""
- if isinstance(channels, basestring):
- channels = [channels]
- for channel in channels:
+ if args:
+ args = list_or_args(args[0], args[1:])
+ for channel in args:
try:
- self.channels.remove(channel)
+ del self.channels[channel]
except KeyError:
pass
- return self.execute_command('UNSUBSCRIBE', *channels)
+ return self.execute_command('UNSUBSCRIBE', *args)
def listen(self):
"Listen for messages on channels this client has been subscribed to"
while self.subscription_count or self.channels or self.patterns:
- r = self.parse_response()
- msg_type = nativestr(r[0])
- if msg_type == 'pmessage':
- msg = {
- 'type': msg_type,
- 'pattern': nativestr(r[1]),
- 'channel': nativestr(r[2]),
- 'data': r[3]
- }
- else:
- msg = {
- 'type': msg_type,
- 'pattern': None,
- 'channel': nativestr(r[1]),
- 'data': r[2]
- }
- yield msg
+ response = self.handle_message(self.parse_response(block=True))
+ if response is not None:
+ yield response
+
+ def get_message(self):
+ "Get the next message if one is available, otherwise None"
+ response = self.parse_response(block=False)
+ return response and self.handle_message(response) or None
+
+ def handle_message(self, response, suppress_subscribe_messages=False):
+ """
+ Parses a pub/sub message. If the channel or pattern was subscribed to
+ with a message handler, the handler is invoked instead of a parsed
+ message being returned.
+ """
+ msg_type = nativestr(response[0])
+ handler = None
+
+ # optionally suppress subscribe/unsubscribe messages
+ s = suppress_subscribe_messages or self.suppress_subscribe_messages
+ if s and msg_type not in self.MESSAGE_TYPES:
+ return None
+
+ if msg_type == 'pmessage':
+ msg = {
+ 'type': msg_type,
+ 'pattern': nativestr(response[1]),
+ 'channel': nativestr(response[2]),
+ 'data': response[3]
+ }
+ handler = self.patterns.get(msg['pattern'], None)
+ else:
+ msg = {
+ 'type': msg_type,
+ 'pattern': None,
+ 'channel': nativestr(response[1]),
+ 'data': response[2]
+ }
+ handler = self.channels.get(msg['channel'], None)
+
+ if handler:
+ handler(msg)
+ return None
+ else:
+ return msg
+
+ def run_in_thread(self, sleep_time=0):
+ for channel, handler in iteritems(self.channels):
+ if handler is None:
+ raise PubSubError("Channel: '%s' has no handler registered")
+ for pattern, handler in iteritems(self.patterns):
+ if handler is None:
+ raise PubSubError("Pattern: '%s' has no handler registered")
+ pubsub = self
+
+ class WorkerThread(threading.Thread):
+ def __init__(self, *args, **kwargs):
+ super(WorkerThread, self).__init__(*args, **kwargs)
+ self._running = False
+
+ def run(self):
+ if self._running:
+ return
+ self._running = True
+ while self._running:
+ pubsub.get_message()
+ mod_time.sleep(sleep_time)
+
+ def stop(self):
+ self._running = False
+ self.join()
+
+ thread = WorkerThread()
+ thread.start()
+ return thread
class BasePipeline(object):
diff --git a/redis/connection.py b/redis/connection.py
index e138d4c..d949fef 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -1,8 +1,8 @@
-from itertools import chain
import os
import socket
import sys
-
+from itertools import chain
+from select import select
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
BytesIO, nativestr, basestring,
LifoQueue, Empty, Full)
@@ -62,6 +62,9 @@ class PythonParser(object):
self._fp.close()
self._fp = None
+ def can_read(self):
+ return self._fp and self._fp._rbuf.tell()
+
def read(self, length=None):
"""
Read a line from the socket if no length is specified,
@@ -146,6 +149,7 @@ class HiredisParser(object):
def __init__(self):
if not hiredis_available:
raise RedisError("Hiredis is not installed")
+ self._next_response = False
def __del__(self):
try:
@@ -167,9 +171,24 @@ class HiredisParser(object):
self._sock = None
self._reader = None
+ def can_read(self):
+ if not self._reader:
+ raise ConnectionError("Socket closed on remote end")
+
+ if self._next_response is False:
+ self._next_response = self._reader.gets()
+ return self._next_response is not False
+
def read_response(self):
if not self._reader:
raise ConnectionError("Socket closed on remote end")
+
+ # _next_response might be cached from a can_read() call
+ if self._next_response is not False:
+ response = self._next_response
+ self._next_response = False
+ return response
+
response = self._reader.gets()
while response is False:
try:
@@ -298,6 +317,13 @@ class Connection(object):
"Pack and send a command to the Redis server"
self.send_packed_command(self.pack_command(*args))
+ def can_read(self):
+ "Poll the socket to see if there's data that can be read."
+ sock = self._sock
+ if not sock:
+ return False
+ return bool(select([sock], [], [], 0)[0]) or self._parser.can_read()
+
def read_response(self):
"Read the response from a previously sent command"
try: