diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | .travis.yml | 3 | ||||
-rw-r--r-- | README.rst | 18 | ||||
-rw-r--r-- | redis/_compat.py | 14 | ||||
-rwxr-xr-x | redis/client.py | 154 | ||||
-rwxr-xr-x | redis/connection.py | 88 | ||||
-rw-r--r-- | setup.py | 7 | ||||
-rw-r--r-- | tests/conftest.py | 5 | ||||
-rw-r--r-- | tests/test_commands.py | 17 | ||||
-rw-r--r-- | tests/test_encoding.py | 2 | ||||
-rw-r--r-- | tests/test_pubsub.py | 31 | ||||
-rw-r--r-- | tests/test_scripting.py | 34 | ||||
-rw-r--r-- | tox.ini | 2 |
13 files changed, 254 insertions, 123 deletions
@@ -7,3 +7,5 @@ dump.rdb _build vagrant/.vagrant .python-version +.cache +.eggs diff --git a/.travis.yml b/.travis.yml index d645f4d..2b6bd3d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,6 @@ language: python python: + - "3.6" - "3.5" - "3.4" - "3.3" @@ -19,5 +20,5 @@ matrix: include: - python: "2.7" env: TEST_PEP8=1 - - python: "3.5" + - python: "3.6" env: TEST_PEP8=1 @@ -174,7 +174,7 @@ set_response_callback method. This method accepts two arguments: a command name and the callback. Callbacks added in this manner are only valid on the instance the callback is added to. If you want to define or override a callback globally, you should make a subclass of the Redis client and add your callback -to its REDIS_CALLBACKS class dictionary. +to its RESPONSE_CALLBACKS class dictionary. Response callbacks take at least one parameter: the response from the Redis server. Keyword arguments may also be accepted in order to further control @@ -519,6 +519,22 @@ cannot be delivered. When you're finished with a PubSub object, call its >>> ... >>> p.close() + +The PUBSUB set of subcommands CHANNELS, NUMSUB and NUMPAT are also +supported: + +.. code-block:: pycon + + >>> r.pubsub_channels() + ['foo', 'bar'] + >>> r.pubsub_numsub('foo', 'bar') + [('foo', 9001), ('bar', 42)] + >>> r.pubsub_numsub('baz') + [('baz', 0)] + >>> r.pubsub_numpat() + 1204 + + LUA Scripting ^^^^^^^^^^^^^ diff --git a/redis/_compat.py b/redis/_compat.py index 4713d01..307f3cc 100644 --- a/redis/_compat.py +++ b/redis/_compat.py @@ -1,6 +1,12 @@ """Internal module for Python 2 backwards compatibility.""" +import errno import sys +try: + InterruptedError = InterruptedError +except: + InterruptedError = OSError + # For Python older than 3.5, retry EINTR. if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 5): @@ -15,8 +21,12 @@ if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and while True: try: return _select(rlist, wlist, xlist, timeout) - except InterruptedError: - continue + except InterruptedError as e: + # Python 2 does not define InterruptedError, instead + # try to catch an OSError with errno == EINTR == 4. + if getattr(e, 'errno', None) == getattr(errno, 'EINTR', 4): + continue + raise # Wrapper for handling interruptable system calls. def _retryable_call(s, func, *args, **kwargs): diff --git a/redis/client.py b/redis/client.py index b37e96d..6383e55 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,6 +6,7 @@ import warnings import time import threading import time as mod_time +import hashlib from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, itervalues, izip, long, nativestr, unicode, safe_unicode) @@ -207,7 +208,7 @@ def zset_score_pairs(response, **options): If ``withscores`` is specified in the options, return the response as a list of (value, score) pairs """ - if not response or not options['withscores']: + if not response or not options.get('withscores'): return response score_cast_func = options.get('score_cast_func', float) it = iter(response) @@ -339,6 +340,10 @@ def parse_georadius_generic(response, **options): ] +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + class StrictRedis(object): """ Implementation of the Redis protocol. @@ -356,9 +361,9 @@ class StrictRedis(object): bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN ' - 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT ' - 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' + 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY ' + 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE ' + 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE ' 'GEOADD', int @@ -443,10 +448,12 @@ class StrictRedis(object): 'CLUSTER SETSLOT': bool_ok, 'CLUSTER SLAVES': parse_cluster_nodes, 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), - float(ll[1])), r)), + float(ll[1])) + if ll is not None else None, r)), 'GEOHASH': lambda r: list(map(nativestr, r)), 'GEORADIUS': parse_georadius_generic, 'GEORADIUSBYMEMBER': parse_georadius_generic, + 'PUBSUB NUMSUB': parse_pubsub_numsub, } ) @@ -1138,19 +1145,19 @@ class StrictRedis(object): ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. - ``nx`` if set to True, set the value at key ``name`` to ``value`` if it - does not already exist. + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. - ``xx`` if set to True, set the value at key ``name`` to ``value`` if it - already exists. + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. """ pieces = [name, value] - if ex: + if ex is not None: pieces.append('EX') if isinstance(ex, datetime.timedelta): ex = ex.seconds + ex.days * 24 * 3600 pieces.append(ex) - if px: + if px is not None: pieces.append('PX') if isinstance(px, datetime.timedelta): ms = int(px.microseconds / 1000) @@ -1212,6 +1219,13 @@ class StrictRedis(object): """ return self.execute_command('SUBSTR', name, start, end) + def touch(self, *args): + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command('TOUCH', *args) + def ttl(self, name): "Returns the number of seconds until the key ``name`` will expire" return self.execute_command('TTL', name) @@ -1259,7 +1273,7 @@ class StrictRedis(object): RPOP a value off of the first non-empty list named in the ``keys`` list. - If none of the lists in ``keys`` has a value to LPOP, then block + If none of the lists in ``keys`` has a value to RPOP, then block for ``timeout`` seconds, or until a value gets pushed on to one of the lists. @@ -1639,7 +1653,7 @@ class StrictRedis(object): memebers of set ``name``. Note this is only available when running Redis 2.6+. """ - args = number and [number] or [] + args = (number is not None) and [number] or [] return self.execute_command('SRANDMEMBER', name, *args) def srem(self, name, *values): @@ -2005,6 +2019,13 @@ class StrictRedis(object): "Return the list of values within hash ``name``" return self.execute_command('HVALS', name) + def hstrlen(self, name, key): + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command('HSTRLEN', name, key) + def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -2012,6 +2033,25 @@ class StrictRedis(object): """ return self.execute_command('PUBLISH', channel, message) + def pubsub_channels(self, pattern='*'): + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command('PUBSUB CHANNELS', pattern) + + def pubsub_numpat(self): + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command('PUBSUB NUMPAT') + + def pubsub_numsub(self, *args): + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command('PUBSUB NUMSUB', *args) + def cluster(self, cluster_arg, *args): return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) @@ -2073,10 +2113,10 @@ class StrictRedis(object): Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered members of the ``values`` argument, each item or place is formed by - the triad latitude, longitude and name. + the triad longitude, latitude and name. """ if len(values) % 3 != 0: - raise RedisError("GEOADD requires places with lat, lon and name" + raise RedisError("GEOADD requires places with lon, lat and name" " values") return self.execute_command('GEOADD', name, *values) @@ -2105,7 +2145,7 @@ class StrictRedis(object): """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name``argument. Each position - is represented by the pairs lat and lon. + is represented by the pairs lon and lat. """ return self.execute_command('GEOPOS', name, *values) @@ -2300,13 +2340,7 @@ class PubSub(object): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection('pubsub', shard_hint) - try: - self.encoding = conn.encoding - self.encoding_errors = conn.encoding_errors - self.decode_responses = conn.decode_responses - finally: - connection_pool.release(conn) + self.encoder = self.connection_pool.get_encoder() self.reset() def __del__(self): @@ -2338,29 +2372,14 @@ class PubSub(object): if self.channels: channels = {} for k, v in iteritems(self.channels): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - channels[k] = v + channels[self.encoder.decode(k, force=True)] = v self.subscribe(**channels) if self.patterns: patterns = {} for k, v in iteritems(self.patterns): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - patterns[k] = v + patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) - def encode(self, value): - """ - Encode the value so that it's identical to what we'll - read off the connection - """ - if self.decode_responses and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - elif not self.decode_responses and isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -2410,6 +2429,16 @@ class PubSub(object): return None return self._execute(connection, connection.read_response) + def _normalize_keys(self, data): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return dict([(decode(encode(k)), v) for k, v in iteritems(data)]) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments @@ -2420,15 +2449,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_patterns = {} - new_patterns.update(dict.fromkeys(imap(self.encode, args))) - for pattern, handler in iteritems(kwargs): - new_patterns[self.encode(pattern)] = handler + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(new_patterns) + self.patterns.update(self._normalize_keys(new_patterns)) return ret_val def punsubscribe(self, *args): @@ -2450,15 +2477,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_channels = {} - new_channels.update(dict.fromkeys(imap(self.encode, args))) - for channel, handler in iteritems(kwargs): - new_channels[self.encode(channel)] = handler + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(new_channels) + self.channels.update(self._normalize_keys(new_channels)) return ret_val def unsubscribe(self, *args): @@ -2824,12 +2849,11 @@ class BasePipeline(object): shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + exists = immediate('SCRIPT EXISTS', *shas) if not all(exists): for s, exist in izip(scripts, exists): if not exist: - s.sha = immediate('SCRIPT', 'LOAD', s.script, - **{'parse': 'LOAD'}) + s.sha = immediate('SCRIPT LOAD', s.script) def execute(self, raise_on_error=True): "Execute all the commands in the current pipeline" @@ -2881,16 +2905,6 @@ class BasePipeline(object): "Unwatches all previously specified keys" return self.watching and self.execute_command('UNWATCH') or True - def script_load_for_pipeline(self, script): - "Make sure scripts are loaded prior to pipeline execution" - # we need the sha now so that Script.__call__ can use it to run - # evalsha. - if not script.sha: - script.sha = self.immediate_execute_command('SCRIPT', 'LOAD', - script.script, - **{'parse': 'LOAD'}) - self.scripts.add(script) - class StrictPipeline(BasePipeline, StrictRedis): "Pipeline for the StrictRedis class" @@ -2908,7 +2922,14 @@ class Script(object): def __init__(self, registered_client, script): self.registered_client = registered_client self.script = script - self.sha = '' + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, basestring): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() def __call__(self, keys=[], args=[], client=None): "Execute the script, passing any required ``args``" @@ -2917,12 +2938,13 @@ class Script(object): args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script if isinstance(client, BasePipeline): - # make sure this script is good to go on pipeline - client.script_load_for_pipeline(self) + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) try: return client.evalsha(self.sha, len(keys), *args) except NoScriptError: # Maybe the client is pointed to a differnet server than the client # that created this instance? + # Overwrite the sha just in case there was a discrepancy. self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) diff --git a/redis/connection.py b/redis/connection.py index ed051a3..c8f0513 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -94,6 +94,38 @@ class Token(object): return self.value +class Encoder(object): + "Encode strings to bytes and decode bytes to strings" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring representation of the value" + if isinstance(value, Token): + return value.encoded_value + elif isinstance(value, bytes): + return value + elif isinstance(value, (int, long)): + value = b(str(value)) + elif isinstance(value, float): + value = b(repr(value)) + elif not isinstance(value, basestring): + # an object we don't know how to deal with. default to unicode() + value = unicode(value) + if isinstance(value, unicode): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the byte representation" + if (self.decode_responses or force) and isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value + + class BaseParser(object): EXCEPTION_CLASSES = { 'ERR': { @@ -217,10 +249,9 @@ class SocketBuffer(object): class PythonParser(BaseParser): "Plain Python parsing class" - encoding = None - def __init__(self, socket_read_size): self.socket_read_size = socket_read_size + self.encoder = None self._sock = None self._buffer = None @@ -234,8 +265,7 @@ class PythonParser(BaseParser): "Called when the socket connects" self._sock = connection._sock self._buffer = SocketBuffer(self._sock, self.socket_read_size) - if connection.decode_responses: - self.encoding = connection.encoding + self.encoder = connection.encoder def on_disconnect(self): "Called when the socket disconnects" @@ -245,7 +275,7 @@ class PythonParser(BaseParser): if self._buffer is not None: self._buffer.close() self._buffer = None - self.encoding = None + self.encoder = None def can_read(self): return self._buffer and bool(self._buffer.length) @@ -292,8 +322,8 @@ class PythonParser(BaseParser): if length == -1: return None response = [self.read_response() for i in xrange(length)] - if isinstance(response, bytes) and self.encoding: - response = response.decode(self.encoding) + if isinstance(response, bytes): + response = self.encoder.decode(response) return response @@ -324,8 +354,8 @@ class HiredisParser(BaseParser): if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: kwargs['replyError'] = ResponseError - if connection.decode_responses: - kwargs['encoding'] = connection.encoding + if connection.encoder.decode_responses: + kwargs['encoding'] = connection.encoder.encoding self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -394,6 +424,7 @@ class HiredisParser(BaseParser): raise response[0] return response + if HIREDIS_AVAILABLE: DefaultParser = HiredisParser else: @@ -420,9 +451,7 @@ class Connection(object): self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -600,22 +629,6 @@ class Connection(object): raise response return response - def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, Token): - return value.encoded_value - elif isinstance(value, bytes): - return value - elif isinstance(value, (int, long)): - value = b(str(value)) - elif isinstance(value, float): - value = b(repr(value)) - elif not isinstance(value, basestring): - value = unicode(value) - if isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - def pack_command(self, *args): "Pack a series of arguments into the Redis protocol" output = [] @@ -634,7 +647,7 @@ class Connection(object): buff = SYM_EMPTY.join( (SYM_STAR, b(str(len(args))), SYM_CRLF)) - for arg in imap(self.encode, args): + for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values if len(buff) > 6000 or len(arg) > 6000: @@ -723,9 +736,7 @@ class UnixDomainSocketConnection(Connection): self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -906,8 +917,8 @@ class ConnectionPool(object): Create a connection pool. If max_connections is set, then this object raises redis.ConnectionError when the pool's limit is reached. - By default, TCP connections are created connection_class is specified. - Use redis.UnixDomainSocketConnection for unix sockets. + By default, TCP connections are created unless connection_class is + specified. Use redis.UnixDomainSocketConnection for unix sockets. Any additional keyword arguments are passed to the constructor of connection_class. @@ -955,6 +966,15 @@ class ConnectionPool(object): self._in_use_connections.add(connection) return connection + def get_encoder(self): + "Return an encoder based on encoding settings" + kwargs = self.connection_kwargs + return Encoder( + encoding=kwargs.get('encoding', 'utf-8'), + encoding_errors=kwargs.get('encoding_errors', 'strict'), + decode_responses=kwargs.get('decode_responses', False) + ) + def make_connection(self): "Create a new connection" if self._created_connections >= self.max_connections: @@ -44,7 +44,10 @@ setup( keywords=['Redis', 'key-value store'], license='MIT', packages=['redis'], - tests_require=['pytest>=2.5.0'], + tests_require=[ + 'mock', + 'pytest>=2.5.0', + ], cmdclass={'test': PyTest}, classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -53,11 +56,13 @@ setup( 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', 'Programming Language :: Python', + 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', ] ) diff --git a/tests/conftest.py b/tests/conftest.py index d7b2b14..006697f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,11 @@ def skip_if_server_version_lt(min_version): return pytest.mark.skipif(check, reason="") +def skip_if_server_version_gte(min_version): + check = StrictVersion(get_version()) >= StrictVersion(min_version) + return pytest.mark.skipif(check, reason="") + + @pytest.fixture() def r(request, **kwargs): return _get_client(redis.Redis, request, **kwargs) diff --git a/tests/test_commands.py b/tests/test_commands.py index 51c6fd4..04aeca4 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -10,7 +10,7 @@ from redis._compat import (unichr, u, b, ascii_letters, iteritems, iterkeys, from redis.client import parse_info from redis import exceptions -from .conftest import skip_if_server_version_lt +from .conftest import skip_if_server_version_lt, skip_if_server_version_gte @pytest.fixture() @@ -1231,6 +1231,12 @@ class TestRedisCommands(object): remote_vals = r.hvals('a') assert sorted(local_vals) == sorted(remote_vals) + @skip_if_server_version_lt('3.2.0') + def test_hstrlen(self, r): + r.hmset('a', {'1': '22', '2': '333'}) + assert r.hstrlen('a', '1') == 2 + assert r.hstrlen('a', '2') == 3 + # SORT def test_sort_basic(self, r): r.rpush('a', '3', '2', '1', '4') @@ -1455,6 +1461,15 @@ class TestRedisCommands(object): [(2.19093829393386841, 41.43379028184083523), (2.18737632036209106, 41.40634178640635099)] + @skip_if_server_version_lt('4.0.0') + def test_geopos_no_value(self, r): + assert r.geopos('barcelona', 'place1', 'place2') == [None, None] + + @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_gte('4.0.0') + def test_old_geopos_no_value(self, r): + assert r.geopos('barcelona', 'place1', 'place2') == [] + @skip_if_server_version_lt('3.2.0') def test_georadius(self, r): values = (2.1909389952632, 41.433791470673, 'place1') +\ diff --git a/tests/test_encoding.py b/tests/test_encoding.py index f0c67be..8b5bf5a 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -34,7 +34,7 @@ class TestEncoding(object): class TestCommandsAndTokensArentEncoded(object): @pytest.fixture() def r(self, request): - return _redis_client(request=request, charset='utf-16') + return _redis_client(request=request, encoding='utf-16') def test_basic_command(self, r): r.set('hello', 'world') diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index dacac83..a240248 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -4,9 +4,10 @@ import time import redis from redis.exceptions import ConnectionError -from redis._compat import basestring, u, unichr +from redis._compat import basestring, u, unichr, b from .conftest import r as _redis_client +from .conftest import skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -398,3 +399,31 @@ class TestPubSubRedisDown(object): p = r.pubsub() with pytest.raises(ConnectionError): p.subscribe('foo') + + +class TestPubSubPubSubSubcommands(object): + + @skip_if_server_version_lt('2.8.0') + def test_pubsub_channels(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo', 'bar', 'baz', 'quux') + channels = sorted(r.pubsub_channels()) + assert channels == [b('bar'), b('baz'), b('foo'), b('quux')] + + @skip_if_server_version_lt('2.8.0') + def test_pubsub_numsub(self, r): + p1 = r.pubsub(ignore_subscribe_messages=True) + p1.subscribe('foo', 'bar', 'baz') + p2 = r.pubsub(ignore_subscribe_messages=True) + p2.subscribe('bar', 'baz') + p3 = r.pubsub(ignore_subscribe_messages=True) + p3.subscribe('baz') + + channels = [(b('foo'), 1), (b('bar'), 2), (b('baz'), 3)] + assert channels == r.pubsub_numsub('foo', 'bar', 'baz') + + @skip_if_server_version_lt('2.8.0') + def test_pubsub_numpat(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe('*oo', '*ar', 'b*z') + assert r.pubsub_numpat() == 3 diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 2213ec6..9131de8 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -57,44 +57,49 @@ class TestScripting(object): def test_script_object(self, r): r.set('a', 2) multiply = r.register_script(multiply_script) - assert not multiply.sha - # test evalsha fail -> script load + retry + precalculated_sha = multiply.sha + assert precalculated_sha + assert r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) assert multiply(keys=['a'], args=[3]) == 6 - assert multiply.sha + # At this point, the script should be loaded assert r.script_exists(multiply.sha) == [True] - # test first evalsha + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block assert multiply(keys=['a'], args=[3]) == 6 def test_script_object_in_pipeline(self, r): multiply = r.register_script(multiply_script) - assert not multiply.sha + precalculated_sha = multiply.sha + assert precalculated_sha pipe = r.pipeline() pipe.set('a', 2) pipe.get('a') multiply(keys=['a'], args=[3], client=pipe) - # even though the pipeline wasn't executed yet, we made sure the - # script was loaded and got a valid sha - assert multiply.sha - assert r.script_exists(multiply.sha) == [True] + assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] assert pipe.execute() == [True, b('2'), 6] + # The script should have been loaded by pipe.execute() + assert r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha # purge the script from redis's cache and re-run the pipeline - # the multiply script object knows it's sha, so it shouldn't get - # reloaded until pipe.execute() + # the multiply script should be reloaded by pipe.execute() r.script_flush() pipe = r.pipeline() pipe.set('a', 2) pipe.get('a') - assert multiply.sha multiply(keys=['a'], args=[3], client=pipe) assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] assert pipe.execute() == [True, b('2'), 6] + assert r.script_exists(multiply.sha) == [True] def test_eval_msgpack_pipeline_error_in_lua(self, r): msgpack_hello = r.register_script(msgpack_hello_script) - assert not msgpack_hello.sha + assert msgpack_hello.sha pipe = r.pipeline() @@ -104,8 +109,9 @@ class TestScripting(object): msgpack_hello(args=[msgpack_message_1], client=pipe) - assert r.script_exists(msgpack_hello.sha) == [True] + assert r.script_exists(msgpack_hello.sha) == [False] assert pipe.execute()[0] == b'hello Joe' + assert r.script_exists(msgpack_hello.sha) == [True] msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) @@ -1,6 +1,6 @@ [tox] minversion = 1.8 -envlist = {py26,py27,py32,py33,py34,py35}-{plain,hiredis}, pep8 +envlist = {py26,py27,py32,py33,py34,py35,py36}-{plain,hiredis}, pep8 [testenv] deps = |