diff options
-rw-r--r-- | .github/ISSUE_TEMPLATE.md | 8 | ||||
-rw-r--r-- | .github/PULL_REQUEST_TEMPLATE.md | 15 | ||||
-rw-r--r-- | .travis.yml | 17 | ||||
-rw-r--r-- | CHANGES | 55 | ||||
-rw-r--r-- | README.rst | 17 | ||||
-rw-r--r-- | redis/__init__.py | 2 | ||||
-rw-r--r-- | redis/_compat.py | 2 | ||||
-rwxr-xr-x | redis/client.py | 141 | ||||
-rwxr-xr-x | redis/connection.py | 35 | ||||
-rw-r--r-- | redis/exceptions.py | 5 | ||||
-rw-r--r-- | redis/lock.py | 70 | ||||
-rw-r--r-- | setup.py | 32 | ||||
-rw-r--r-- | tests/conftest.py | 15 | ||||
-rw-r--r-- | tests/test_commands.py | 178 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 4 | ||||
-rw-r--r-- | tests/test_lock.py | 76 | ||||
-rw-r--r-- | tests/test_pubsub.py | 42 | ||||
-rw-r--r-- | tox.ini | 7 |
18 files changed, 577 insertions, 144 deletions
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..7323c14 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,8 @@ +Thanks for wanting to report an issue you've found in redis-py. Please delete this text and fill in the template below. +It is of course not always possible to reduce your code to a small test case, but it's highly appreciated to have as much data as possible. Thank you! + +**Version**: What redis-py and what redis version is the issue happening on? + +**Platform**: What platform / version? (For example Python 3.5.1 on Windows 7 / Ubuntu 15.10 / Azure) + +**Description**: Description of your issue, stack traces from errors and code that reproduces the issue diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..798be44 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,15 @@ +### Pull Request check-list + +_Please make sure to review and check all of these items:_ + +- [ ] Does `$ python setup test` pass with this change (including linting)? +- [ ] Does travis tests pass with this change (enable it first in your forked repo and wait for the travis build to finish)? +- [ ] Is the new or changed code fully tested? +- [ ] Is a documentation update included (if this change modifies existing APIs, or introduces new ones)? + +_NOTE: these things are not required to open a PR and can be done +afterwards / while the PR is open._ + +### Description of change + +_Please provide a description of the change here._ diff --git a/.travis.yml b/.travis.yml index 10070e5..1002dbc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,12 +1,15 @@ +dist: xenial +sudo: false language: python cache: pip python: + - 3.7 - 3.6 - 3.5 - 3.4 - 2.7 before_install: - - wget http://download.redis.io/releases/redis-5.0.0.tar.gz && mkdir redis_install && tar -xvzf redis-5.0.0.tar.gz -C redis_install && cd redis_install/redis-5.0.0 && make && src/redis-server --daemonize yes && cd ../.. + - wget http://download.redis.io/releases/redis-5.0.3.tar.gz && mkdir redis_install && tar -xvzf redis-5.0.3.tar.gz -C redis_install && cd redis_install/redis-5.0.3 && make && src/redis-server --daemonize yes && cd ../.. - redis-cli info env: - TEST_HIREDIS=0 @@ -20,15 +23,5 @@ matrix: include: - python: 2.7 env: TEST_PYCODESTYLE=1 - - python: 3.6 - env: TEST_PYCODESTYLE=1 - # python 3.7 has to be specified manually in the matrix - # https://github.com/travis-ci/travis-ci/issues/9815 - - python: 3.7 - dist: xenial - sudo: true - env: TEST_HIREDIS=0 - python: 3.7 - dist: xenial - sudo: true - env: TEST_HIREDIS=1 + env: TEST_PYCODESTYLE=1 @@ -1,3 +1,52 @@ +* 3.1.0 (in development) + * Connection URLs must have one of the following schemes: + redis://, rediss://, unix://. Thanks @jdupl123. #961/#969 + * Fixed an issue with retry_on_timeout logic that caused some TimeoutErrors + to be retried. Thanks Aaron Yang. #1022/#1023 + * Added support for SNI for SSL. Thanks @oridistor and Roey Prat. #1087 + * Fixed ConnectionPool repr for pools with no connections. Thanks + Cody Scott. #1043/#995 + * Fixed GEOHASH to return a None value when specifying a place that + doesn't exist on the server. Thanks @guybe7. #1126 + * Fixed XREADGROUP to return an empty dictionary for messages that + have been deleted but still exist in the unacknowledged queue. Thanks + @xeizmendi. #1116 + * Added an owned method to Lock objects. owned returns a boolean + indicating whether the current lock instance still owns the lock. + Thanks Dave Johansen. #1112 + * Allow lock.acquire() to accept an optional token argument. If + provided, the token argument is used as the unique value used to claim + the lock. Thankd Dave Johansen. #1112 + * Added a reacquire method to Lock objects. reaquire attempts to renew + the lock such that the timeout is extended to the same value that the + lock was initially acquired with. Thanks Ihor Kalnytskyi. #1014 + * Stream names found within XREAD and XREADGROUP responses now properly + respect the decode_responses flag. + * XPENDING_RANGE now requires the user the specify the min, max and + count arguments. Newer versions of Redis prevent ount from being + infinite so it's left to the user to specify these values explicitly. + * ZADD now returns None when xx=True and incr=True and an element + is specified that doesn't exist in the sorted set. This matches + what the server returns in this case. #1084 + * Added client_kill_filter that accepts various filters to identify + and kill clients. Thanks Theofanis Despoudis. #1098 + * Fixed a race condition that occurred when unsubscribing and + resubscribing to the same channel or pattern in rapid succession. + Thanks Marcin RaczyĆski. #764 + * Added a LockNotOwnedError that is raised when trying to extend or + release a lock that is no longer owned. This is a subclass of LockError + so previous code should continue to work as expected. Thanks Joshua + Harlow. #1095 + * Fixed a bug in GEORADIUS that forced decoding of places without + respecting the decode_responses option. Thanks Bo Bayles. #1082 +* 3.0.1 + * Fixed regression with UnixDomainSocketConnection caused by 3.0.0. + Thanks Jyrki Muukkonen + * Fixed an issue with the new asynchronous flag on flushdb and flushall. + Thanks rogeryen + * Updated Lock.locked() method to indicate whether *any* process has + acquired the lock, not just the current one. This is in line with + the behavior of threading.Lock. Thanks Alan Justino da Silva * 3.0.0 BACKWARDS INCOMPATIBLE CHANGES * When using a Lock as a context manager and the lock fails to be acquired @@ -27,6 +76,9 @@ * ZADD now requires all element names/scores be specified in a single dictionary argument named mapping. This was required to allow the NX, XX, CH and INCR options to be specified. + * ssl_cert_reqs now has a default value of 'required' by default. This + should make connecting to a remote Redis server over SSL more secure. + Thanks u2mejc * Removed support for EOL Python 2.6 and 3.3. Thanks jdufresne OTHER CHANGES * Added missing DECRBY command. Thanks derek-dchu @@ -68,9 +120,6 @@ * Added a ping() method to pubsub objects. Thanks krishan-carbon * Fixed a bug with keys in the INFO dict that contained ':' symbols. Thanks mzalimeni - * ssl_cert_reqs now has a default value of 'required' by default. This - should make connecting to a remote Redis server over SSL more secure. - Thanks u2mejc * Fixed the select system call retry compatibility with Python 2.x. Thanks lddubeau * max_connections is now a valid querystring argument for creating @@ -99,6 +99,23 @@ use any of the following commands: not exist) +SSL Connections +^^^^^^^^^^^^^^^ + +redis-py 3.0 changes the default value of the `ssl_cert_reqs` option from +`None` to `'required'`. See +`Issue 1016 <https://github.com/andymccurdy/redis-py/issues/1016>`_. This +change enforces hostname validation when accepting a cert from a remote SSL +terminator. If the terminator doesn't properly set the hostname on the cert +this will cause redis-py 3.0 to raise a ConnectionError. + +This check can be disabled by setting `ssl_cert_reqs` to `None`. Note that +doing so removes the security check. Do so at your own risk. + +It has been reported that SSL certs received from AWS ElastiCache do not have +proper hostnames and turning off hostname verification is currently required. + + MSET, MSETNX and ZADD ^^^^^^^^^^^^^^^^^^^^^ diff --git a/redis/__init__.py b/redis/__init__.py index 99b095a..dc78f73 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -29,7 +29,7 @@ def int_or_str(value): return value -__version__ = '3.0.0.post1' +__version__ = '3.0.1' VERSION = tuple(map(int_or_str, __version__.split('.'))) __all__ = [ diff --git a/redis/_compat.py b/redis/_compat.py index 80973b3..c9213a6 100644 --- a/redis/_compat.py +++ b/redis/_compat.py @@ -112,7 +112,6 @@ if sys.version_info[0] < 3: xrange = xrange basestring = basestring unicode = unicode - bytes = str long = long else: from urllib.parse import parse_qs, unquote, urlparse @@ -142,7 +141,6 @@ else: basestring = str unicode = str safe_unicode = str - bytes = bytes long = int try: # Python 3 diff --git a/redis/client.py b/redis/client.py index d62e20e..24120d2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -7,7 +7,7 @@ import time import threading import time as mod_time import hashlib -from redis._compat import (basestring, bytes, imap, iteritems, iterkeys, +from redis._compat import (basestring, imap, iteritems, iterkeys, itervalues, izip, long, nativestr, safe_unicode) from redis.connection import (ConnectionPool, UnixDomainSocketConnection, SSLConnection, Token) @@ -116,8 +116,12 @@ def parse_info(response): for line in response.splitlines(): if line and not line.startswith('#'): if line.find(':') != -1: - # support keys that include ':' by using rsplit - key, value = line.rsplit(':', 1) + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(':', 1) + if key == 'cmdstat_host': + key, value = line.rsplit(':', 1) info[key] = get_value(value) else: # if the line isn't splittable, append it to the "__raw__" key @@ -187,6 +191,8 @@ def parse_sentinel_get_master(response): def pairs_to_dict(response, decode_keys=False): "Create a dict given a list of key/value pairs" + if response is None: + return {} if decode_keys: # the iter form is faster, but I don't know how to make that work # with a nativestr() map @@ -240,6 +246,12 @@ def int_or_none(response): return int(response) +def nativestr_or_none(response): + if response is None: + return None + return nativestr(response) + + def parse_stream_list(response): if response is None: return None @@ -272,7 +284,7 @@ def parse_xinfo_stream(response): def parse_xread(response): if response is None: return [] - return [[nativestr(r[0]), parse_stream_list(r[1])] for r in response] + return [[r[0], parse_stream_list(r[1])] for r in response] def parse_xpending(response, **options): @@ -303,6 +315,8 @@ def bool_ok(response): def parse_zadd(response, **options): + if response is None: + return None if options.get('as_score'): return float(response) return int(response) @@ -392,7 +406,7 @@ def parse_georadius_generic(response, **options): if not options['withdist'] and not options['withcoord']\ and not options['withhash']: # just a bunch of places - return [nativestr(r) for r in response_list] + return response_list cast = { 'withdist': float, @@ -402,7 +416,7 @@ def parse_georadius_generic(response, **options): # zip all output results with each casting functino to get # the properly native Python value. - f = [nativestr] + f = [lambda x: x] f += [cast[o] for o in ['withdist', 'withhash', 'withcoord'] if options[o]] return [ list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list @@ -413,6 +427,12 @@ def parse_pubsub_numsub(response, **options): return list(zip(response[0::2], response[1::2])) +def parse_client_kill(response, **options): + if isinstance(response, (long, int)): + return int(response) + return nativestr(response) == 'OK' + + class Redis(object): """ Implementation of the Redis protocol. @@ -471,7 +491,7 @@ class Redis(object): { 'CLIENT GETNAME': lambda r: r and nativestr(r), 'CLIENT ID': int, - 'CLIENT KILL': bool_ok, + 'CLIENT KILL': parse_client_kill, 'CLIENT LIST': parse_client_list, 'CLIENT SETNAME': bool_ok, 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, @@ -496,7 +516,7 @@ class Redis(object): 'CONFIG RESETSTAT': bool_ok, 'CONFIG SET': bool_ok, 'DEBUG OBJECT': parse_debug_object, - 'GEOHASH': lambda r: list(map(nativestr, r)), + 'GEOHASH': lambda r: list(map(nativestr_or_none, r)), 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r)), @@ -755,7 +775,8 @@ class Redis(object): return self.parse_response(connection, command_name, **options) except (ConnectionError, TimeoutError) as e: connection.disconnect() - if not connection.retry_on_timeout and isinstance(e, TimeoutError): + if not (connection.retry_on_timeout and + isinstance(e, TimeoutError)): raise connection.send_command(*args) return self.parse_response(connection, command_name, **options) @@ -790,6 +811,42 @@ class Redis(object): "Disconnects the client at ``address`` (ip:port)" return self.execute_command('CLIENT KILL', address) + def client_kill_filter(self, _id=None, _type=None, addr=None, skipme=None): + """ + Disconnects client(s) using a variety of filter options + :param id: Kills a client by its unique ID field + :param type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + """ + args = [] + if _type is not None: + client_types = ('normal', 'master', 'slave', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT KILL type must be one of %r" % ( + client_types,)) + args.extend((Token.get_token('TYPE'), _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((Token.get_token('SKIPME'), + Token.get_token('YES'))) + else: + args.extend((Token.get_token('SKIPME'), + Token.get_token('NO'))) + if _id is not None: + args.extend((Token.get_token('ID'), _id)) + if addr is not None: + args.extend((Token.get_token('ADDR'), addr)) + if not args: + raise DataError("CLIENT KILL <filter> <value> ... ... <filter> " + "<value> must specify at least one filter") + return self.execute_command('CLIENT KILL', *args) + def client_list(self, _type=None): """ Returns a list of currently connected clients. @@ -2075,18 +2132,15 @@ class Redis(object): """ return self.execute_command('XPENDING', name, groupname) - def xpending_range(self, name, groupname, min='-', max='+', count=-1, + def xpending_range(self, name, groupname, min, max, count, consumername=None): """ Returns information about pending messages, in a range. name: name of the stream. groupname: name of the consumer group. - start: first stream ID. defaults to '-', - meaning the earliest available. - finish: last stream ID. defaults to '+', - meaning the latest available. - count: if set, only return this many items, beginning with the - earliest available. + min: minimum stream ID. + max: maximum stream ID. + count: number of messages to return consumername: name of a consumer to filter by (optional). """ pieces = [name, groupname] @@ -2154,7 +2208,7 @@ class Redis(object): return self.execute_command('XREAD', *pieces) def xreadgroup(self, groupname, consumername, streams, count=None, - block=None): + block=None, noack=False): """ Read from a stream via a consumer group. groupname: name of the consumer group. @@ -2164,6 +2218,7 @@ class Redis(object): count: if set, only return this many items, beginning with the earliest available. block: number of milliseconds to wait, if nothing already present. + noack: do not add messages to the PEL """ pieces = [Token.get_token('GROUP'), groupname, consumername] if count is not None: @@ -2177,6 +2232,8 @@ class Redis(object): "integer") pieces.append(Token.get_token("BLOCK")) pieces.append(str(block)) + if noack: + pieces.append(Token.get_token("NOACK")) if not isinstance(streams, dict) or len(streams) == 0: raise DataError('XREADGROUP streams must be a non empty dict') pieces.append(Token.get_token('STREAMS')) @@ -2903,7 +2960,9 @@ class PubSub(object): self.connection_pool.release(self.connection) self.connection = None self.channels = {} + self.pending_unsubscribe_channels = set() self.patterns = {} + self.pending_unsubscribe_patterns = set() def close(self): self.reset() @@ -2913,6 +2972,8 @@ class PubSub(object): # NOTE: for python3, we can't pass bytestrings as keyword arguments # so we need to decode channel/pattern names back to unicode strings # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() if self.channels: channels = {} for k, v in iteritems(self.channels): @@ -2952,7 +3013,8 @@ class PubSub(object): return command(*args) except (ConnectionError, TimeoutError) as e: connection.disconnect() - if not connection.retry_on_timeout and isinstance(e, TimeoutError): + if not (connection.retry_on_timeout and + isinstance(e, TimeoutError)): raise # Connect manually here. If the Redis server is down, this will # fail and raise a ConnectionError as desired. @@ -2999,17 +3061,25 @@ class PubSub(object): # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(self._normalize_keys(new_patterns)) + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val def punsubscribe(self, *args): """ - Unsubscribe from the supplied patterns. If empy, unsubscribe from + Unsubscribe from the supplied patterns. If empty, unsubscribe from all patterns. """ if args: args = list_or_args(args[0], args[1:]) - return self.execute_command('PUNSUBSCRIBE', *args) + retval = self.execute_command('PUNSUBSCRIBE', *args) + if args: + patterns = self._normalize_keys(dict.fromkeys(args)) + else: + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return retval def subscribe(self, *args, **kwargs): """ @@ -3027,7 +3097,9 @@ class PubSub(object): # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(self._normalize_keys(new_channels)) + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val def unsubscribe(self, *args): @@ -3037,7 +3109,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - return self.execute_command('UNSUBSCRIBE', *args) + retval = self.execute_command('UNSUBSCRIBE', *args) + if args: + channels = self._normalize_keys(dict.fromkeys(args)) + else: + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return retval def listen(self): "Listen for messages on channels this client has been subscribed to" @@ -3094,22 +3172,21 @@ class PubSub(object): 'channel': response[1], 'data': response[2] } - # if this is an unsubscribe message, remove it from memory if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: - subscribed_dict = None if message_type == 'punsubscribe': - subscribed_dict = self.patterns + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) else: - subscribed_dict = self.channels - try: - del subscribed_dict[message['channel']] - except KeyError: - pass + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it - handler = None if message_type == 'pmessage': handler = self.patterns.get(message['pattern'], None) else: diff --git a/redis/connection.py b/redis/connection.py index 9b949c5..0d1c394 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -14,7 +14,7 @@ try: except ImportError: ssl_available = False -from redis._compat import (xrange, imap, byte_to_chr, unicode, bytes, long, +from redis._compat import (xrange, imap, byte_to_chr, unicode, long, nativestr, basestring, iteritems, LifoQueue, Empty, Full, urlparse, parse_qs, recv, recv_into, select, unquote) @@ -729,11 +729,24 @@ class SSLConnection(Connection): def _connect(self): "Wrap the socket with SSL support" sock = super(SSLConnection, self)._connect() - sock = ssl.wrap_socket(sock, - cert_reqs=self.cert_reqs, - keyfile=self.keyfile, - certfile=self.certfile, - ca_certs=self.ca_certs) + if hasattr(ssl, "create_default_context"): + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = self.cert_reqs + if self.certfile and self.keyfile: + context.load_cert_chain(certfile=self.certfile, + keyfile=self.keyfile) + if self.ca_certs: + context.load_verify_locations(self.ca_certs) + sock = context.wrap_socket(sock, server_hostname=self.host) + else: + # In case this code runs in a version which is older than 2.7.9, + # we want to fall back to old code + sock = ssl.wrap_socket(sock, + cert_reqs=self.cert_reqs, + keyfile=self.keyfile, + certfile=self.certfile, + ca_certs=self.ca_certs) return sock @@ -872,7 +885,7 @@ class ConnectionPool(object): path = url.path hostname = url.hostname - # We only support redis:// and unix:// schemes. + # We only support redis://, rediss:// and unix:// schemes. if url.scheme == 'unix': url_options.update({ 'password': password, @@ -880,7 +893,7 @@ class ConnectionPool(object): 'connection_class': UnixDomainSocketConnection, }) - else: + elif url.scheme in ('redis', 'rediss'): url_options.update({ 'host': hostname, 'port': int(url.port or 6379), @@ -897,6 +910,10 @@ class ConnectionPool(object): if url.scheme == 'rediss': url_options['connection_class'] = SSLConnection + else: + valid_schemes = ', '.join(('redis://', 'rediss://', 'unix://')) + raise ValueError('Redis URL must specify one of the following' + 'schemes (%s)' % valid_schemes) # last shot at the db value url_options['db'] = int(url_options.get('db', db or 0)) @@ -941,7 +958,7 @@ class ConnectionPool(object): def __repr__(self): return "%s<%s>" % ( type(self).__name__, - self.connection_class.description_format % self.connection_kwargs, + repr(self.connection_class(**self.connection_kwargs)), ) def reset(self): diff --git a/redis/exceptions.py b/redis/exceptions.py index 44ab6f7..35bfbe5 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -58,3 +58,8 @@ class LockError(RedisError, ValueError): # NOTE: For backwards compatability, this class derives from ValueError. # This was originally chosen to behave like threading.Lock. pass + + +class LockNotOwnedError(LockError): + "Error trying to extend or release a lock that is (no longer) owned" + pass diff --git a/redis/lock.py b/redis/lock.py index c6fd01d..8d481d7 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,7 +1,7 @@ import threading import time as mod_time import uuid -from redis.exceptions import LockError +from redis.exceptions import LockError, LockNotOwnedError from redis.utils import dummy @@ -16,6 +16,7 @@ class Lock(object): lua_release = None lua_extend = None + lua_reacquire = None # KEYS[1] - lock name # ARGS[1] - token @@ -49,6 +50,19 @@ class Lock(object): return 1 """ + # KEYS[1] - lock name + # ARGS[1] - token + # ARGS[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + def __init__(self, redis, name, timeout=None, sleep=0.1, blocking=True, blocking_timeout=None, thread_local=True): """ @@ -121,6 +135,9 @@ class Lock(object): cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) if cls.lua_extend is None: cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = \ + client.register_script(cls.LUA_REACQUIRE_SCRIPT) def __enter__(self): # force blocking, as otherwise the user would have to check whether @@ -132,7 +149,7 @@ class Lock(object): def __exit__(self, exc_type, exc_value, traceback): self.release() - def acquire(self, blocking=None, blocking_timeout=None): + def acquire(self, blocking=None, blocking_timeout=None, token=None): """ Use Redis to hold a shared, distributed lock named ``name``. Returns True once the lock is acquired. @@ -142,9 +159,18 @@ class Lock(object): ``blocking_timeout`` specifies the maximum number of seconds to wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. """ sleep = self.sleep - token = uuid.uuid1().hex.encode() + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.connection_pool.get_encoder() + token = encoder.encode(token) if blocking is None: blocking = self.blocking if blocking_timeout is None: @@ -178,6 +204,19 @@ class Lock(object): """ return self.redis.get(self.name) is not None + def owned(self): + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.connection_pool.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and \ + stored_token == self.local.token + def release(self): "Releases the already acquired lock" expected_token = self.local.token @@ -190,7 +229,8 @@ class Lock(object): if not bool(self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)): - raise LockError("Cannot release a lock that's no longer owned") + raise LockNotOwnedError("Cannot release a lock" + " that's no longer owned") def extend(self, additional_time): """ @@ -210,5 +250,25 @@ class Lock(object): if not bool(self.lua_extend(keys=[self.name], args=[self.local.token, additional_time], client=self.redis)): - raise LockError("Cannot extend a lock that's no longer owned") + raise LockNotOwnedError("Cannot extend a lock that's" + " no longer owned") + return True + + def reacquire(self): + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + def do_reacquire(self): + timeout = int(self.timeout * 1000) + if not bool(self.lua_reacquire(keys=[self.name], + args=[self.local.token, timeout], + client=self.redis)): + raise LockNotOwnedError("Cannot reacquire a lock that's" + " no longer owned") return True @@ -1,31 +1,24 @@ #!/usr/bin/env python import os import sys +from setuptools import setup +from setuptools.command.test import test as TestCommand from redis import __version__ -try: - from setuptools import setup - from setuptools.command.test import test as TestCommand - class PyTest(TestCommand): - def finalize_options(self): - TestCommand.finalize_options(self) - self.test_args = [] - self.test_suite = True +class PyTest(TestCommand): + def finalize_options(self): + TestCommand.finalize_options(self) + self.test_args = [] + self.test_suite = True - def run_tests(self): - # import here, because outside the eggs aren't loaded - import pytest - errno = pytest.main(self.test_args) - sys.exit(errno) + def run_tests(self): + # import here, because outside the eggs aren't loaded + import pytest + errno = pytest.main(self.test_args) + sys.exit(errno) -except ImportError: - - from distutils.core import setup - - def PyTest(x): - x f = open(os.path.join(os.path.dirname(__file__), 'README.rst')) long_description = f.read() @@ -69,5 +62,6 @@ setup( 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', ] ) diff --git a/tests/conftest.py b/tests/conftest.py index 7b1335c..bb4c9a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,12 @@ def _get_client(cls, request=None, **kwargs): client.flushdb() if request: def teardown(): - client.flushdb() + try: + client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + client.flushdb() client.connection_pool.disconnect() request.addfinalizer(teardown) return client @@ -54,6 +59,14 @@ def r(request, **kwargs): return _get_client(redis.Redis, request, **kwargs) +@pytest.fixture() +def r2(request, **kwargs): + return [ + _get_client(redis.Redis, request, **kwargs), + _get_client(redis.Redis, request, **kwargs), + ] + + def _gen_cluster_mock_resp(r, response): mock_connection_pool = Mock() connection = Mock() diff --git a/tests/test_commands.py b/tests/test_commands.py index 4ba96b3..33f78d5 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -97,6 +97,59 @@ class TestRedisCommands(object): assert r.client_getname() == 'redis_py_test' @skip_if_server_version_lt('2.6.9') + def test_client_kill(self, r, r2): + r.client_setname('redis-py-c1') + r2[0].client_setname('redis-py-c2') + r2[1].client_setname('redis-py-c3') + test_clients = [client for client in r.client_list() + if client.get('name') + in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] + assert len(test_clients) == 3 + + resp = r.client_kill(test_clients[1].get('addr')) + assert isinstance(resp, bool) and resp is True + + test_clients = [client for client in r.client_list() + if client.get('name') + in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] + assert len(test_clients) == 2 + + @skip_if_server_version_lt('2.8.12') + def test_client_kill_filter_invalid_params(self, r): + # empty + with pytest.raises(exceptions.DataError): + r.client_kill_filter() + + # invalid skipme + with pytest.raises(exceptions.DataError): + r.client_kill_filter(skipme="yeah") + + # invalid type + with pytest.raises(exceptions.DataError): + r.client_kill_filter(_type="caster") + + @skip_if_server_version_lt('2.8.12') + def test_client_kill_filter(self, r, r2): + r.client_setname('redis-py-c1') + r2[0].client_setname('redis-py-c2') + r2[1].client_setname('redis-py-c3') + test_clients = [client for client in r.client_list() + if client.get('name') + in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] + assert len(test_clients) == 3 + + resp = r.client_kill_filter(_id=test_clients[1].get('id')) + assert isinstance(resp, int) and resp == 1 + + resp = r.client_kill_filter(addr=test_clients[2].get('addr')) + assert isinstance(resp, int) and resp == 1 + + test_clients = [client for client in r.client_list() + if client.get('name') + in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] + assert len(test_clients) == 1 + + @skip_if_server_version_lt('2.6.9') def test_client_list_after_client_setname(self, r): r.client_setname('redis_py_test') clients = r.client_list() @@ -1021,6 +1074,12 @@ class TestRedisCommands(object): assert r.zadd('a', {'a1': 1}) == 1 assert r.zadd('a', {'a1': 4.5}, incr=True) == 5.5 + def test_zadd_incr_with_xx(self, r): + # this asks zadd to incr 'a1' only if it exists, but it clearly + # doesn't. Redis returns a null value in this case and so should + # redis-py + assert r.zadd('a', {'a1': 1}, xx=True, incr=True) is None + def test_zcard(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zcard('a') == 3 @@ -1628,8 +1687,8 @@ class TestRedisCommands(object): (2.1873744593677, 41.406342043777, 'place2') r.geoadd('barcelona', *values) - assert r.geohash('barcelona', 'place1', 'place2') ==\ - ['sp3e9yg3kd0', 'sp3e9cbc3t0'] + assert r.geohash('barcelona', 'place1', 'place2', 'place3') ==\ + ['sp3e9yg3kd0', 'sp3e9cbc3t0', None] @skip_unless_arch_bits(64) @skip_if_server_version_lt('3.2.0') @@ -1655,10 +1714,11 @@ class TestRedisCommands(object): @skip_if_server_version_lt('3.2.0') def test_georadius(self, r): values = (2.1909389952632, 41.433791470673, 'place1') +\ - (2.1873744593677, 41.406342043777, 'place2') + (2.1873744593677, 41.406342043777, b'\x80place2') r.geoadd('barcelona', *values) - assert r.georadius('barcelona', 2.191, 41.433, 1000) == ['place1'] + assert r.georadius('barcelona', 2.191, 41.433, 1000) == [b'place1'] + assert r.georadius('barcelona', 2.187, 41.406, 1000) == [b'\x80place2'] @skip_if_server_version_lt('3.2.0') def test_georadius_no_values(self, r): @@ -1675,7 +1735,7 @@ class TestRedisCommands(object): r.geoadd('barcelona', *values) assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km') ==\ - ['place1'] + [b'place1'] @skip_unless_arch_bits(64) @skip_if_server_version_lt('3.2.0') @@ -1689,17 +1749,17 @@ class TestRedisCommands(object): # function. assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', withdist=True, withcoord=True, withhash=True) ==\ - [['place1', 0.0881, 3471609698139488, + [[b'place1', 0.0881, 3471609698139488, (2.19093829393386841, 41.43379028184083523)]] assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', withdist=True, withcoord=True) ==\ - [['place1', 0.0881, + [[b'place1', 0.0881, (2.19093829393386841, 41.43379028184083523)]] assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', withhash=True, withcoord=True) ==\ - [['place1', 3471609698139488, + [[b'place1', 3471609698139488, (2.19093829393386841, 41.43379028184083523)]] # test no values. @@ -1713,7 +1773,7 @@ class TestRedisCommands(object): r.geoadd('barcelona', *values) assert r.georadius('barcelona', 2.191, 41.433, 3000, count=1) ==\ - ['place1'] + [b'place1'] @skip_if_server_version_lt('3.2.0') def test_georadius_sort(self, r): @@ -1722,9 +1782,9 @@ class TestRedisCommands(object): r.geoadd('barcelona', *values) assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='ASC') ==\ - ['place1', 'place2'] + [b'place1', b'place2'] assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='DESC') ==\ - ['place2', 'place1'] + [b'place2', b'place1'] @skip_if_server_version_lt('3.2.0') def test_georadius_store(self, r): @@ -1751,19 +1811,19 @@ class TestRedisCommands(object): @skip_if_server_version_lt('3.2.0') def test_georadiusmember(self, r): values = (2.1909389952632, 41.433791470673, 'place1') +\ - (2.1873744593677, 41.406342043777, 'place2') + (2.1873744593677, 41.406342043777, b'\x80place2') r.geoadd('barcelona', *values) assert r.georadiusbymember('barcelona', 'place1', 4000) ==\ - ['place2', 'place1'] - assert r.georadiusbymember('barcelona', 'place1', 10) == ['place1'] + [b'\x80place2', b'place1'] + assert r.georadiusbymember('barcelona', 'place1', 10) == [b'place1'] assert r.georadiusbymember('barcelona', 'place1', 4000, withdist=True, withcoord=True, withhash=True) ==\ - [['place2', 3067.4157, 3471609625421029, + [[b'\x80place2', 3067.4157, 3471609625421029, (2.187376320362091, 41.40634178640635)], - ['place1', 0.0, 3471609698139488, + [b'place1', 0.0, 3471609698139488, (2.1909382939338684, 41.433790281840835)]] @skip_if_server_version_lt('5.0.0') @@ -1782,7 +1842,7 @@ class TestRedisCommands(object): assert r.xack(stream, group, m1) == 0 r.xgroup_create(stream, group, 0) - r.xreadgroup(group, consumer, streams={stream: 0}) + r.xreadgroup(group, consumer, streams={stream: '>'}) # xack returns the number of ack'd elements assert r.xack(stream, group, m1) == 1 assert r.xack(stream, group, m2, m3) == 2 @@ -1819,7 +1879,7 @@ class TestRedisCommands(object): assert response == [] # read the group as consumer1 to initially claim the messages - r.xreadgroup(group, consumer1, streams={stream: 0}) + r.xreadgroup(group, consumer1, streams={stream: '>'}) # claim the message as consumer2 response = r.xclaim(stream, group, consumer2, @@ -1901,7 +1961,7 @@ class TestRedisCommands(object): assert r.xgroup_delconsumer(stream, group, consumer) == 0 # read all messages from the group - r.xreadgroup(group, consumer, streams={stream: 0}) + r.xreadgroup(group, consumer, streams={stream: '>'}) # deleting the consumer should return 2 pending messages assert r.xgroup_delconsumer(stream, group, consumer) == 2 @@ -1942,15 +2002,17 @@ class TestRedisCommands(object): consumer1 = 'consumer1' consumer2 = 'consumer2' r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) r.xgroup_create(stream, group, 0) - r.xreadgroup(group, consumer1, streams={stream: 0}) - r.xreadgroup(group, consumer2, streams={stream: 0}) + r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) + r.xreadgroup(group, consumer2, streams={stream: '>'}) info = r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ {'name': consumer1.encode(), 'pending': 1}, - {'name': consumer2.encode(), 'pending': 0}, + {'name': consumer2.encode(), 'pending': 2}, ] # we can't determine the idle time, so just make sure it's an int @@ -1997,8 +2059,8 @@ class TestRedisCommands(object): assert r.xpending(stream, group) == expected # read 1 message from the group with each consumer - r.xreadgroup(group, consumer1, streams={stream: 0}, count=1) - r.xreadgroup(group, consumer2, streams={stream: m1}, count=1) + r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) + r.xreadgroup(group, consumer2, streams={stream: '>'}, count=1) expected = { 'pending': 2, @@ -2022,13 +2084,14 @@ class TestRedisCommands(object): r.xgroup_create(stream, group, 0) # xpending range on a group that has no consumers yet - assert r.xpending_range(stream, group) == [] + assert r.xpending_range(stream, group, min='-', max='+', count=5) == [] # read 1 message from the group with each consumer - r.xreadgroup(group, consumer1, streams={stream: 0}, count=1) - r.xreadgroup(group, consumer2, streams={stream: m1}, count=1) + r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) + r.xreadgroup(group, consumer2, streams={stream: '>'}, count=1) - response = r.xpending_range(stream, group) + response = r.xpending_range(stream, group, + min='-', max='+', count=5) assert len(response) == 2 assert response[0]['message_id'] == m1 assert response[0]['consumer'] == consumer1.encode() @@ -2066,7 +2129,7 @@ class TestRedisCommands(object): expected = [ [ - stream, + stream.encode(), [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), @@ -2078,7 +2141,7 @@ class TestRedisCommands(object): expected = [ [ - stream, + stream.encode(), [ get_stream_message(r, stream, m1), ] @@ -2089,7 +2152,7 @@ class TestRedisCommands(object): expected = [ [ - stream, + stream.encode(), [ get_stream_message(r, stream, m2), ] @@ -2112,7 +2175,7 @@ class TestRedisCommands(object): expected = [ [ - stream, + stream.encode(), [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), @@ -2120,48 +2183,57 @@ class TestRedisCommands(object): ] ] # xread starting at 0 returns both messages - assert r.xreadgroup(group, consumer, streams={stream: 0}) == expected + assert r.xreadgroup(group, consumer, streams={stream: '>'}) == expected r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) expected = [ [ - stream, + stream.encode(), [ get_stream_message(r, stream, m1), ] ] ] - # xread starting at 0 and count=1 returns only the first message - assert r.xreadgroup(group, consumer, streams={stream: 0}, count=1) == \ - expected + # xread with count=1 returns only the first message + assert r.xreadgroup(group, consumer, + streams={stream: '>'}, count=1) == expected r.xgroup_destroy(stream, group) - r.xgroup_create(stream, group, 0) - expected = [ - [ - stream, - [ - get_stream_message(r, stream, m2), - ] - ] - ] - # xread starting at m1 returns only the second message - assert r.xreadgroup(group, consumer, streams={stream: m1}) == expected + # create the group using $ as the last id meaning subsequent reads + # will only find messages added after this + r.xgroup_create(stream, group, '$') + expected = [] + # xread starting after the last message returns an empty message list + assert r.xreadgroup(group, consumer, streams={stream: '>'}) == expected + + # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) - r.xgroup_create(stream, group, 0) + r.xgroup_create(stream, group, '0') + assert len(r.xreadgroup(group, consumer, streams={stream: '>'}, + noack=True)[0][1]) == 2 + # now there should be nothing pending + assert len(r.xreadgroup(group, consumer, + streams={stream: '0'})[0][1]) == 0 - # xread starting at the last message returns an empty message list + r.xgroup_destroy(stream, group) + r.xgroup_create(stream, group, '0') + # delete all the messages in the stream expected = [ [ - stream, - [] + stream.encode(), + [ + (m1, {}), + (m2, {}), + ] ] ] - assert r.xreadgroup(group, consumer, streams={stream: m2}) == expected + r.xreadgroup(group, consumer, streams={stream: '>'}) + r.xtrim(stream, 0) + assert r.xreadgroup(group, consumer, streams={stream: '0'}) == expected @skip_if_server_version_lt('5.0.0') def test_xrevrange(self, r): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index b0dec67..ca56a76 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -308,6 +308,10 @@ class TestConnectionPoolURLParsing(object): 'password': None, } + def test_invalid_scheme_raises_error(self): + with pytest.raises(ValueError): + redis.ConnectionPool.from_url('localhost') + class TestConnectionPoolUnixSocketURLParsing(object): def test_defaults(self): diff --git a/tests/test_lock.py b/tests/test_lock.py index a6adbc2..f92afec 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -1,11 +1,17 @@ import pytest import time -from redis.exceptions import LockError +from redis.exceptions import LockError, LockNotOwnedError +from redis.client import Redis from redis.lock import Lock +from .conftest import _get_client class TestLock(object): + @pytest.fixture() + def r_decoded(self, request): + return _get_client(Redis, request=request, decode_responses=True) + def get_lock(self, redis, *args, **kwargs): kwargs['lock_class'] = Lock return redis.lock(*args, **kwargs) @@ -18,6 +24,16 @@ class TestLock(object): lock.release() assert r.get('foo') is None + def test_lock_token(self, r): + lock = self.get_lock(r, 'foo') + assert lock.acquire(blocking=False, token='test') + assert r.get('foo') == b'test' + assert lock.local.token == b'test' + assert r.ttl('foo') == -1 + lock.release() + assert r.get('foo') is None + assert lock.local.token is None + def test_locked(self, r): lock = self.get_lock(r, 'foo') assert lock.locked() is False @@ -26,6 +42,30 @@ class TestLock(object): lock.release() assert lock.locked() is False + def _test_owned(self, client): + lock = self.get_lock(client, 'foo') + assert lock.owned() is False + lock.acquire(blocking=False) + assert lock.owned() is True + lock.release() + assert lock.owned() is False + + lock2 = self.get_lock(client, 'foo') + assert lock.owned() is False + assert lock2.owned() is False + lock2.acquire(blocking=False) + assert lock.owned() is False + assert lock2.owned() is True + lock2.release() + assert lock.owned() is False + assert lock2.owned() is False + + def test_owned(self, r): + self._test_owned(r) + + def test_owned_with_decoded_responses(self, r_decoded): + self._test_owned(r_decoded) + def test_competing_locks(self, r): lock1 = self.get_lock(r, 'foo') lock2 = self.get_lock(r, 'foo') @@ -85,7 +125,7 @@ class TestLock(object): lock.acquire(blocking=False) # manually change the token r.set('foo', 'a') - with pytest.raises(LockError): + with pytest.raises(LockNotOwnedError): lock.release() # even though we errored, the token is still cleared assert lock.local.token is None @@ -119,12 +159,40 @@ class TestLock(object): lock.release() def test_extending_lock_no_longer_owned_raises_error(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, 'foo', timeout=10) assert lock.acquire(blocking=False) r.set('foo', 'a') - with pytest.raises(LockError): + with pytest.raises(LockNotOwnedError): lock.extend(10) + def test_reacquire_lock(self, r): + lock = self.get_lock(r, 'foo', timeout=10) + assert lock.acquire(blocking=False) + assert r.pexpire('foo', 5000) + assert r.pttl('foo') <= 5000 + assert lock.reacquire() + assert 8000 < r.pttl('foo') <= 10000 + lock.release() + + def test_reacquiring_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, 'foo', timeout=10) + with pytest.raises(LockError): + lock.reacquire() + + def test_reacquiring_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, 'foo') + assert lock.acquire(blocking=False) + with pytest.raises(LockError): + lock.reacquire() + lock.release() + + def test_reacquiring_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, 'foo', timeout=10) + assert lock.acquire(blocking=False) + r.set('foo', 'a') + with pytest.raises(LockNotOwnedError): + lock.reacquire() + class TestLockClassSelection(object): def test_lock_class_argument(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 91e9e48..fc91abf 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -212,6 +212,48 @@ class TestPubSubSubscribeUnsubscribe(object): assert message is None assert p.subscribed is False + def test_sub_unsub_resub_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + self._test_sub_unsub_resub(**kwargs) + + def test_sub_unsub_resub_patterns(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + self._test_sub_unsub_resub(**kwargs) + + def _test_sub_unsub_resub(self, p, sub_type, unsub_type, sub_func, + unsub_func, keys): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + sub_func(key) + unsub_func(key) + sub_func(key) + assert p.subscribed is True + assert wait_for_message(p) == make_message(sub_type, key, 1) + assert wait_for_message(p) == make_message(unsub_type, key, 0) + assert wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + def test_sub_unsub_all_resub_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + self._test_sub_unsub_all_resub(**kwargs) + + def test_sub_unsub_all_resub_patterns(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + self._test_sub_unsub_all_resub(**kwargs) + + def _test_sub_unsub_all_resub(self, p, sub_type, unsub_type, sub_func, + unsub_func, keys): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + sub_func(key) + unsub_func() + sub_func(key) + assert p.subscribed is True + assert wait_for_message(p) == make_message(sub_type, key, 1) + assert wait_for_message(p) == make_message(unsub_type, key, 0) + assert wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + class TestPubSubMessages(object): def setup_method(self, method): @@ -1,12 +1,13 @@ [tox] -minversion = 1.8 -envlist = {py27,py34,py35,py36}-{plain,hiredis}, pycodestyle +minversion = 2.4 +envlist = {py27,py34,py35,py36,py37}-{plain,hiredis}, pycodestyle [testenv] deps = mock pytest >= 2.7.0 - hiredis: hiredis >= 0.1.3 +extras = + hiredis: hiredis commands = py.test {posargs} [testenv:pycodestyle] |