diff options
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | .travis.yml | 36 | ||||
-rw-r--r-- | CHANGES | 68 | ||||
-rw-r--r-- | README.rst | 105 | ||||
-rw-r--r-- | RELEASE | 9 | ||||
-rw-r--r-- | benchmarks/base.py | 2 | ||||
-rw-r--r-- | benchmarks/basic_operations.py | 200 | ||||
-rw-r--r-- | benchmarks/command_packer_benchmark.py | 23 | ||||
-rw-r--r-- | build_tools/.bash_profile | 1 | ||||
-rwxr-xr-x | build_tools/bootstrap.sh (renamed from vagrant/bootstrap.sh) | 0 | ||||
-rwxr-xr-x | build_tools/build_redis.sh (renamed from vagrant/build_redis.sh) | 2 | ||||
-rwxr-xr-x | build_tools/install_redis.sh (renamed from vagrant/install_redis.sh) | 2 | ||||
-rwxr-xr-x | build_tools/install_sentinel.sh (renamed from vagrant/install_sentinel.sh) | 2 | ||||
-rw-r--r-- | build_tools/redis-configs/001-master (renamed from vagrant/redis-configs/001-master) | 3 | ||||
-rw-r--r-- | build_tools/redis-configs/002-slave (renamed from vagrant/redis-configs/002-slave) | 3 | ||||
-rwxr-xr-x | build_tools/redis_init_script (renamed from vagrant/redis_init_script) | 6 | ||||
-rwxr-xr-x | build_tools/redis_vars.sh (renamed from vagrant/redis_vars.sh) | 6 | ||||
-rw-r--r-- | build_tools/sentinel-configs/001-1 (renamed from vagrant/sentinel-configs/001-1) | 0 | ||||
-rw-r--r-- | build_tools/sentinel-configs/002-2 (renamed from vagrant/sentinel-configs/002-2) | 0 | ||||
-rw-r--r-- | build_tools/sentinel-configs/003-3 (renamed from vagrant/sentinel-configs/003-3) | 0 | ||||
-rwxr-xr-x | build_tools/sentinel_init_script (renamed from vagrant/sentinel_init_script) | 6 | ||||
-rw-r--r-- | docs/conf.py | 20 | ||||
-rw-r--r-- | docs/index.rst | 19 | ||||
-rw-r--r-- | redis/__init__.py | 2 | ||||
-rw-r--r-- | redis/_compat.py | 155 | ||||
-rwxr-xr-x | redis/client.py | 1528 | ||||
-rwxr-xr-x | redis/connection.py | 336 | ||||
-rw-r--r-- | redis/exceptions.py | 11 | ||||
-rw-r--r-- | redis/lock.py | 190 | ||||
-rw-r--r-- | redis/sentinel.py | 28 | ||||
-rw-r--r-- | setup.cfg | 9 | ||||
-rw-r--r-- | setup.py | 23 | ||||
-rw-r--r-- | tests/conftest.py | 71 | ||||
-rw-r--r-- | tests/test_commands.py | 1557 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 96 | ||||
-rw-r--r-- | tests/test_encoding.py | 36 | ||||
-rw-r--r-- | tests/test_lock.py | 165 | ||||
-rw-r--r-- | tests/test_pipeline.py | 101 | ||||
-rw-r--r-- | tests/test_pubsub.py | 88 | ||||
-rw-r--r-- | tests/test_scripting.py | 41 | ||||
-rw-r--r-- | tests/test_sentinel.py | 29 | ||||
-rw-r--r-- | tox.ini | 52 | ||||
-rw-r--r-- | vagrant/.bash_profile | 1 | ||||
-rw-r--r-- | vagrant/Vagrantfile | 10 |
44 files changed, 3779 insertions, 1266 deletions
@@ -7,3 +7,6 @@ dump.rdb _build vagrant/.vagrant .python-version +.cache +.eggs +.idea
\ No newline at end of file diff --git a/.travis.yml b/.travis.yml index cf38f4b..10070e5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,22 +1,34 @@ language: python +cache: pip python: - - "3.3" - - "3.2" - - "2.7" - - "2.6" -services: - - redis-server + - 3.6 + - 3.5 + - 3.4 + - 2.7 +before_install: + - wget http://download.redis.io/releases/redis-5.0.0.tar.gz && mkdir redis_install && tar -xvzf redis-5.0.0.tar.gz -C redis_install && cd redis_install/redis-5.0.0 && make && src/redis-server --daemonize yes && cd ../.. + - redis-cli info env: - TEST_HIREDIS=0 - TEST_HIREDIS=1 install: - pip install -e . - - "if [[ $TEST_PEP8 == '1' ]]; then pip install pep8; fi" + - "if [[ $TEST_PYCODESTYLE == '1' ]]; then pip install pycodestyle; fi" - "if [[ $TEST_HIREDIS == '1' ]]; then pip install hiredis; fi" -script: "if [[ $TEST_PEP8 == '1' ]]; then pep8 --repeat --show-source --exclude=.venv,.tox,dist,docs,build,*.egg .; else python setup.py test; fi" +script: "if [[ $TEST_PYCODESTYLE == '1' ]]; then pycodestyle --repeat --show-source --exclude=.venv,.tox,dist,docs,build,*.egg,redis_install .; else python setup.py test; fi" matrix: include: - - python: "2.7" - env: TEST_PEP8=1 - - python: "3.4" - env: TEST_PEP8=1 + - python: 2.7 + env: TEST_PYCODESTYLE=1 + - python: 3.6 + env: TEST_PYCODESTYLE=1 + # python 3.7 has to be specified manually in the matrix + # https://github.com/travis-ci/travis-ci/issues/9815 + - python: 3.7 + dist: xenial + sudo: true + env: TEST_HIREDIS=0 + - python: 3.7 + dist: xenial + sudo: true + env: TEST_HIREDIS=1 @@ -1,5 +1,65 @@ -* 2.10.4 (in development) +* UNRELEASED + * Removed support for EOL Python 2.6 and 3.3. +* 2.10.6 + * Various performance improvements. Thanks cjsimpson + * Fixed a bug with SRANDMEMBER where the behavior for `number=0` did + not match the spec. Thanks Alex Wang + * Added HSTRLEN command. Thanks Alexander Putilin + * Added the TOUCH command. Thanks Anis Jonischkeit + * Remove unnecessary calls to the server when registering Lua scripts. + Thanks Ben Greenberg + * SET's EX and PX arguments now allow values of zero. Thanks huangqiyin + * Added PUBSUB {CHANNELS, NUMPAT, NUMSUB} commands. Thanks Angus Pearson + * PubSub connections that that encounter `InterruptedError`s now + retry automatically. Thanks Carlton Gibson and Seth M. Larson + * LPUSH and RPUSH commands run on PyPy now correctly returns the number + of items of the list. Thanks Jeong YunWon + * Added support to automatically retry socket EINTR errors. Thanks + Thomas Steinacher + * PubSubWorker threads started with `run_in_thread` are now daemonized + so the thread shuts down when the running process goes away. Thanks + Keith Ainsworth + * Added support for GEO commands. Thanks Pau Freixes, Alex DeBrie and + Abraham Toriz + * Made client construction from URLs smarter. Thanks Tim Savage + * Added support for CLUSTER * commands. Thanks Andy Huang + * The RESTORE command now accepts an optional `replace` boolean. + Thanks Yoshinari Takaoka + * Attempt to connect to a new Sentinel if a TimeoutError occurs. Thanks + Bo Lopker + * Fixed a bug in the client's `__getitem__` where a KeyError would be + raised if the value returned by the server is an empty string. + Thanks Javier Candeira. + * Socket timeouts when connecting to a server are now properly raised + as TimeoutErrors. +* 2.10.5 + * Allow URL encoded parameters in Redis URLs. Characters like a "/" can + now be URL encoded and redis-py will correctly decode them. Thanks + Paul Keene. + * Added support for the WAIT command. Thanks https://github.com/eshizhan + * Better shutdown support for the PubSub Worker Thread. It now properly + cleans up the connection, unsubscribes from any channels and patterns + previously subscribed to and consumes any waiting messages on the socket. + * Added the ability to sleep for a brief period in the event of a + WatchError occuring. Thanks Joshua Harlow. + * Fixed a bug with pipeline error reporting when dealing with characters + in error messages that could not be encoded to the connection's + character set. Thanks Hendrik Muhs. + * Fixed a bug in Sentinel connections that would inadvertantly connect + to the master when the connection pool resets. Thanks + https://github.com/df3n5 * Better timeout support in Pubsub get_message. Thanks Andy Isaacson. + * Fixed a bug with the HiredisParser that would cause the parser to + get stuck in an endless loop if a specific number of bytes were + delivered from the socket. This fix also increases performance of + parsing large responses from the Redis server. + * Added support for ZREVRANGEBYLEX. + * ConnectionErrors are now raised if Redis refuses a connection due to + the maxclients limit being exceeded. Thanks Roman Karpovich. + * max_connections can now be set when instantiating client instances. + Thanks Ohad Perry. +* 2.10.4 + (skipped due to a PyPI snafu) * 2.10.3 * Fixed a bug with the bytearray support introduced in 2.10.2. Thanks Josh Owen. @@ -133,7 +193,7 @@ for the report. * Connections now call socket.shutdown() prior to socket.close() to ensure communication ends immediately per the note at - http://docs.python.org/2/library/socket.html#socket.socket.close + https://docs.python.org/2/library/socket.html#socket.socket.close Thanks to David Martin for pointing this out. * Lock checks are now based on floats rather than ints. Thanks Vitja Makarov. @@ -167,11 +227,11 @@ * Prevent DISCARD from being called if MULTI wasn't also called. Thanks Pete Aykroyd. * SREM now returns an integer indicating the number of items removed from - the set. Thanks http://github.com/ronniekk. + the set. Thanks https://github.com/ronniekk. * Fixed a bug with BGSAVE and BGREWRITEAOF response callbacks with Python3. Thanks Nathan Wan. * Added CLIENT GETNAME and CLIENT SETNAME commands. - Thanks http://github.com/bitterb. + Thanks https://github.com/bitterb. * It's now possible to use len() on a pipeline instance to determine the number of commands that will be executed. Thanks Jon Parise. * Fixed a bug in INFO's parse routine with floating point numbers. Thanks @@ -4,31 +4,36 @@ redis-py The Python interface to the Redis key-value store. .. image:: https://secure.travis-ci.org/andymccurdy/redis-py.png?branch=master - :target: http://travis-ci.org/andymccurdy/redis-py + :target: https://travis-ci.org/andymccurdy/redis-py +.. image:: https://readthedocs.org/projects/redis-py/badge/?version=latest&style=flat + :target: https://redis-py.readthedocs.io/en/latest/ +.. image:: https://badge.fury.io/py/redis.svg + :target: https://pypi.org/project/redis/ Installation ------------ redis-py requires a running Redis server. See `Redis's quickstart -<http://redis.io/topics/quickstart>`_ for installation instructions. +<https://redis.io/topics/quickstart>`_ for installation instructions. -To install redis-py, simply: - -.. code-block:: bash +redis-py can be installed using `pip` similar to other Python packages. Do not use `sudo` +with `pip`. It is usually good to work in a +`virtualenv <https://virtualenv.pypa.io/en/latest/>`_ or +`venv <https://docs.python.org/3/library/venv.html>`_ to avoid conflicts with other package +managers and Python projects. For a quick introduction see +`Python Virtual Environments in Five Minutes <https://bit.ly/py-env>`_. - $ sudo pip install redis - -or alternatively (you really should be using pip though): +To install redis-py, simply: .. code-block:: bash - $ sudo easy_install redis + $ pip install redis or from source: .. code-block:: bash - $ sudo python setup.py install + $ python setup.py install Getting Started @@ -37,18 +42,27 @@ Getting Started .. code-block:: pycon >>> import redis - >>> r = redis.StrictRedis(host='localhost', port=6379, db=0) + >>> r = redis.Redis(host='localhost', port=6379, db=0) >>> r.set('foo', 'bar') True >>> r.get('foo') 'bar' +By default, all responses are returned as `bytes` in Python 3 and `str` in +Python 2. The user is responsible for decoding to Python 3 strings or Python 2 +unicode objects. + +If **all** string responses from a client should be decoded, the user can +specify `decode_responses=True` to `Redis.__init__`. In this case, any +Redis command that returns a string type will be decoded with the `encoding` +specified. + API Reference ------------- -The `official Redis command documentation <http://redis.io/commands>`_ does a +The `official Redis command documentation <https://redis.io/commands>`_ does a great job of explaining each command in detail. redis-py exposes two client -classes that implement these commands. The StrictRedis class attempts to adhere +classes that implement these commands. The Redis class attempts to adhere to the official command syntax. There are a few exceptions: * **SELECT**: Not implemented. See the explanation in the Thread Safety section @@ -74,7 +88,7 @@ to the official command syntax. There are a few exceptions: to keep track of the cursor while iterating. Use the scan_iter/sscan_iter/hscan_iter/zscan_iter methods for this behavior. -In addition to the changes above, the Redis class, a subclass of StrictRedis, +In addition to the changes above, the Redis class, a subclass of Redis, overrides several other commands to provide backwards compatibility with older versions of redis-py: @@ -149,19 +163,12 @@ kind enough to create Python bindings. Using Hiredis can provide up to a performance increase is most noticeable when retrieving many pieces of data, such as from LRANGE or SMEMBERS operations. -Hiredis is available on PyPI, and can be installed via pip or easy_install -just like redis-py. +Hiredis is available on PyPI, and can be installed via pip just like redis-py. .. code-block:: bash $ pip install hiredis -or - -.. code-block:: bash - - $ easy_install hiredis - Response Callbacks ^^^^^^^^^^^^^^^^^^ @@ -174,7 +181,7 @@ set_response_callback method. This method accepts two arguments: a command name and the callback. Callbacks added in this manner are only valid on the instance the callback is added to. If you want to define or override a callback globally, you should make a subclass of the Redis client and add your callback -to its REDIS_CALLBACKS class dictionary. +to its RESPONSE_CALLBACKS class dictionary. Response callbacks take at least one parameter: the response from the Redis server. Keyword arguments may also be accepted in order to further control @@ -263,7 +270,7 @@ could do something like this: .. code-block:: pycon >>> with r.pipeline() as pipe: - ... while 1: + ... while True: ... try: ... # put a WATCH on the key that holds our sequence value ... pipe.watch('OUR-SEQUENCE-KEY') @@ -291,12 +298,12 @@ duration of a WATCH, care must be taken to ensure that the connection is returned to the connection pool by calling the reset() method. If the Pipeline is used as a context manager (as in the example above) reset() will be called automatically. Of course you can do this the manual way by -explicity calling reset(): +explicitly calling reset(): .. code-block:: pycon >>> pipe = r.pipeline() - >>> while 1: + >>> while True: ... try: ... pipe.watch('OUR-SEQUENCE-KEY') ... ... @@ -332,7 +339,7 @@ for new messages. Creating a `PubSub` object is easy. .. code-block:: pycon - >>> r = redis.StrictRedis(...) + >>> r = redis.Redis(...) >>> p = r.pubsub() Once a `PubSub` instance is created, channels and patterns can be subscribed @@ -444,7 +451,7 @@ application. >>> r.publish('my-channel') 1 >>> p.get_message() - {'channel': 'my-channel', data': 'my data', 'pattern': None, 'type': 'message'} + {'channel': 'my-channel', 'data': 'my data', 'pattern': None, 'type': 'message'} There are three different strategies for reading messages. @@ -519,7 +526,23 @@ cannot be delivered. When you're finished with a PubSub object, call its >>> ... >>> p.close() -LUA Scripting + +The PUBSUB set of subcommands CHANNELS, NUMSUB and NUMPAT are also +supported: + +.. code-block:: pycon + + >>> r.pubsub_channels() + ['foo', 'bar'] + >>> r.pubsub_numsub('foo', 'bar') + [('foo', 9001), ('bar', 42)] + >>> r.pubsub_numsub('baz') + [('baz', 0)] + >>> r.pubsub_numpat() + 1204 + + +Lua Scripting ^^^^^^^^^^^^^ redis-py supports the EVAL, EVALSHA, and SCRIPT commands. However, there are @@ -528,16 +551,16 @@ scenarios. Therefore, redis-py exposes a Script object that makes scripting much easier to use. To create a Script instance, use the `register_script` function on a client -instance passing the LUA code as the first argument. `register_script` returns +instance passing the Lua code as the first argument. `register_script` returns a Script instance that you can use throughout your code. -The following trivial LUA script accepts two parameters: the name of a key and +The following trivial Lua script accepts two parameters: the name of a key and a multiplier value. The script fetches the value stored in the key, multiplies it with the multiplier value and returns the result. .. code-block:: pycon - >>> r = redis.StrictRedis() + >>> r = redis.Redis() >>> lua = """ ... local value = redis.call('GET', KEYS[1]) ... value = tonumber(value) @@ -548,8 +571,8 @@ it with the multiplier value and returns the result. function. Script instances accept the following optional arguments: * **keys**: A list of key names that the script will access. This becomes the - KEYS list in LUA. -* **args**: A list of argument values. This becomes the ARGV list in LUA. + KEYS list in Lua. +* **args**: A list of argument values. This becomes the ARGV list in Lua. * **client**: A redis-py Client or Pipeline instance that will invoke the script. If client isn't specified, the client that intiially created the Script instance (the one that `register_script` was @@ -564,7 +587,7 @@ Continuing the example from above: 10 The value of key 'foo' is set to 2. When multiply is invoked, the 'foo' key is -passed to the script along with the multiplier value of 5. LUA executes the +passed to the script along with the multiplier value of 5. Lua executes the script and returns the result, 10. Script instances can be executed using a different client instance, even one @@ -572,12 +595,12 @@ that points to a completely different Redis server. .. code-block:: pycon - >>> r2 = redis.StrictRedis('redis2.example.com') + >>> r2 = redis.Redis('redis2.example.com') >>> r2.set('foo', 3) >>> multiply(keys=['foo'], args=[5], client=r2) 15 -The Script object ensures that the LUA script is loaded into Redis's script +The Script object ensures that the Lua script is loaded into Redis's script cache. In the event of a NOSCRIPT error, it will load the script and retry executing it. @@ -597,7 +620,7 @@ execution. Sentinel support ^^^^^^^^^^^^^^^^ -redis-py can be used together with `Redis Sentinel <http://redis.io/topics/sentinel>`_ +redis-py can be used together with `Redis Sentinel <https://redis.io/topics/sentinel>`_ to discover Redis nodes. You need to have at least one Sentinel daemon running in order to use redis-py's Sentinel support. @@ -625,7 +648,7 @@ operations). >>> slave.get('foo') 'bar' -The master and slave objects are normal StrictRedis instances with their +The master and slave objects are normal Redis instances with their connection pool bound to the Sentinel instance. When a Sentinel backed client attempts to establish a connection, it first queries the Sentinel servers to determine an appropriate host to connect to. If no server is found, @@ -638,7 +661,7 @@ If no slaves can be connected to, a connection will be established with the master. See `Guidelines for Redis clients with support for Redis Sentinel -<http://redis.io/topics/sentinel-clients>`_ to learn more about Redis Sentinel. +<https://redis.io/topics/sentinel-clients>`_ to learn more about Redis Sentinel. Scan Iterators ^^^^^^^^^^^^^^ @@ -662,7 +685,7 @@ Author ^^^^^^ redis-py is developed and maintained by Andy McCurdy (sedrik@gmail.com). -It can be found here: http://github.com/andymccurdy/redis-py +It can be found here: https://github.com/andymccurdy/redis-py Special thanks to: @@ -0,0 +1,9 @@ +Release Process +=============== + +1. Make sure all tests pass. +2. Make sure CHANGES is up to date. +3. Update redis.__init__.__version__ and commit +4. git tag <version-number> +5. git push --tag +6. rm dist/* && python setup.py sdist bdist_wheel && twine upload dist/* diff --git a/benchmarks/base.py b/benchmarks/base.py index a97001f..44e9341 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -21,7 +21,7 @@ class Benchmark(object): } defaults.update(kwargs) pool = redis.ConnectionPool(**kwargs) - self._client = redis.StrictRedis(connection_pool=pool) + self._client = redis.Redis(connection_pool=pool) return self._client def setup(self, **kwargs): diff --git a/benchmarks/basic_operations.py b/benchmarks/basic_operations.py new file mode 100644 index 0000000..a4b675d --- /dev/null +++ b/benchmarks/basic_operations.py @@ -0,0 +1,200 @@ +from __future__ import print_function +import redis +import time +import sys +from functools import wraps +from argparse import ArgumentParser + +if sys.version_info[0] == 3: + long = int + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('-n', + type=int, + help='Total number of requests (default 100000)', + default=100000) + parser.add_argument('-P', + type=int, + help=('Pipeline <numreq> requests.' + ' Default 1 (no pipeline).'), + default=1) + parser.add_argument('-s', + type=int, + help='Data size of SET/GET value in bytes (default 2)', + default=2) + + args = parser.parse_args() + return args + + +def run(): + args = parse_args() + r = redis.Redis() + r.flushall() + set_str(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + set_int(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + get_str(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + get_int(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + incr(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + lpush(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + lrange_300(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + lpop(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + hmset(conn=r, num=args.n, pipeline_size=args.P, data_size=args.s) + + +def timer(func): + @wraps(func) + def wrapper(*args, **kwargs): + start = time.clock() + ret = func(*args, **kwargs) + duration = time.clock() - start + if 'num' in kwargs: + count = kwargs['num'] + else: + count = args[1] + print('{} - {} Requests'.format(func.__name__, count)) + print('Duration = {}'.format(duration)) + print('Rate = {}'.format(count/duration)) + print('') + return ret + return wrapper + + +@timer +def set_str(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + format_str = '{:0<%d}' % data_size + set_data = format_str.format('a') + for i in range(num): + conn.set('set_str:%d' % i, set_data) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def set_int(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + format_str = '{:0<%d}' % data_size + set_data = int(format_str.format('1')) + for i in range(num): + conn.set('set_int:%d' % i, set_data) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def get_str(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + for i in range(num): + conn.get('set_str:%d' % i) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def get_int(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + for i in range(num): + conn.get('set_int:%d' % i) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def incr(conn, num, pipeline_size, *args, **kwargs): + if pipeline_size > 1: + conn = conn.pipeline() + + for i in range(num): + conn.incr('incr_key') + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def lpush(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + format_str = '{:0<%d}' % data_size + set_data = int(format_str.format('1')) + for i in range(num): + conn.lpush('lpush_key', set_data) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def lrange_300(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + for i in range(num): + conn.lrange('lpush_key', i, i+300) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +@timer +def lpop(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + for i in range(num): + conn.lpop('lpush_key') + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + if pipeline_size > 1: + conn.execute() + + +@timer +def hmset(conn, num, pipeline_size, data_size): + if pipeline_size > 1: + conn = conn.pipeline() + + set_data = {'str_value': 'string', + 'int_value': 123456, + 'long_value': long(123456), + 'float_value': 123456.0} + for i in range(num): + conn.hmset('hmset_key', set_data) + if pipeline_size > 1 and i % pipeline_size == 0: + conn.execute() + + if pipeline_size > 1: + conn.execute() + + +if __name__ == '__main__': + run() diff --git a/benchmarks/command_packer_benchmark.py b/benchmarks/command_packer_benchmark.py index 13d6f97..0d69bdf 100644 --- a/benchmarks/command_packer_benchmark.py +++ b/benchmarks/command_packer_benchmark.py @@ -22,17 +22,18 @@ class StringJoiningConnection(Connection): _errno, errmsg = e.args raise ConnectionError("Error %s while writing to socket. %s." % (_errno, errmsg)) - except: + except Exception as e: self.disconnect() - raise + raise e def pack_command(self, *args): "Pack a series of arguments into a value Redis command" args_output = SYM_EMPTY.join([ - SYM_EMPTY.join((SYM_DOLLAR, b(str(len(k))), SYM_CRLF, k, SYM_CRLF)) - for k in imap(self.encode, args)]) + SYM_EMPTY.join( + (SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF, k, SYM_CRLF)) + for k in imap(self.encoder.encode, args)]) output = SYM_EMPTY.join( - (SYM_STAR, b(str(len(args))), SYM_CRLF, args_output)) + (SYM_STAR, str(len(args)).encode(), SYM_CRLF, args_output)) return output @@ -54,24 +55,24 @@ class ListJoiningConnection(Connection): _errno, errmsg = e.args raise ConnectionError("Error %s while writing to socket. %s." % (_errno, errmsg)) - except: + except Exception as e: self.disconnect() - raise + raise e def pack_command(self, *args): output = [] buff = SYM_EMPTY.join( - (SYM_STAR, b(str(len(args))), SYM_CRLF)) + (SYM_STAR, str(len(args)).encode(), SYM_CRLF)) - for k in imap(self.encode, args): + for k in imap(self.encoder.encode, args): if len(buff) > 6000 or len(k) > 6000: buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, b(str(len(k))), SYM_CRLF)) + (buff, SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF)) output.append(buff) output.append(k) buff = SYM_CRLF else: - buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(k))), + buff = SYM_EMPTY.join((buff, SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF, k, SYM_CRLF)) output.append(buff) return output diff --git a/build_tools/.bash_profile b/build_tools/.bash_profile new file mode 100644 index 0000000..b023cf7 --- /dev/null +++ b/build_tools/.bash_profile @@ -0,0 +1 @@ +PATH=$PATH:/var/lib/redis/bin diff --git a/vagrant/bootstrap.sh b/build_tools/bootstrap.sh index a5a0d2c..a5a0d2c 100755 --- a/vagrant/bootstrap.sh +++ b/build_tools/bootstrap.sh diff --git a/vagrant/build_redis.sh b/build_tools/build_redis.sh index 728e617..379c6cc 100755 --- a/vagrant/build_redis.sh +++ b/build_tools/build_redis.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -source /home/vagrant/redis-py/vagrant/redis_vars.sh +source /home/vagrant/redis-py/build_tools/redis_vars.sh pushd /home/vagrant diff --git a/vagrant/install_redis.sh b/build_tools/install_redis.sh index bb5f1d2..fd53a1c 100755 --- a/vagrant/install_redis.sh +++ b/build_tools/install_redis.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -source /home/vagrant/redis-py/vagrant/redis_vars.sh +source /home/vagrant/redis-py/build_tools/redis_vars.sh for filename in `ls $VAGRANT_REDIS_CONF_DIR`; do # cuts the order prefix off of the filename, e.g. 001-master -> master diff --git a/vagrant/install_sentinel.sh b/build_tools/install_sentinel.sh index 58cd808..0597208 100755 --- a/vagrant/install_sentinel.sh +++ b/build_tools/install_sentinel.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -source /home/vagrant/redis-py/vagrant/redis_vars.sh +source /home/vagrant/redis-py/build_tools/redis_vars.sh for filename in `ls $VAGRANT_SENTINEL_CONF_DIR`; do # cuts the order prefix off of the filename, e.g. 001-master -> master diff --git a/vagrant/redis-configs/001-master b/build_tools/redis-configs/001-master index ac069ea..8591f1a 100644 --- a/vagrant/redis-configs/001-master +++ b/build_tools/redis-configs/001-master @@ -1,7 +1,8 @@ pidfile /var/run/redis-master.pid +bind * port 6379 daemonize yes unixsocket /tmp/redis_master.sock unixsocketperm 777 dbfilename master.rdb -dir /home/vagrant/redis/backups +dir /var/lib/redis/backups diff --git a/vagrant/redis-configs/002-slave b/build_tools/redis-configs/002-slave index d5f6fdc..13eb77e 100644 --- a/vagrant/redis-configs/002-slave +++ b/build_tools/redis-configs/002-slave @@ -1,9 +1,10 @@ pidfile /var/run/redis-slave.pid +bind * port 6380 daemonize yes unixsocket /tmp/redis-slave.sock unixsocketperm 777 dbfilename slave.rdb -dir /home/vagrant/redis/backups +dir /var/lib/redis/backups slaveof 127.0.0.1 6379 diff --git a/vagrant/redis_init_script b/build_tools/redis_init_script index e8bfa08..04cb2db 100755 --- a/vagrant/redis_init_script +++ b/build_tools/redis_init_script @@ -12,10 +12,10 @@ REDISPORT={{ PORT }} PIDFILE=/var/run/{{ PROCESS_NAME }}.pid -CONF=/home/vagrant/redis/conf/{{ PROCESS_NAME }}.conf +CONF=/var/lib/redis/conf/{{ PROCESS_NAME }}.conf -EXEC=/home/vagrant/redis/bin/redis-server -CLIEXEC=/home/vagrant/redis/bin/redis-cli +EXEC=/var/lib/redis/bin/redis-server +CLIEXEC=/var/lib/redis/bin/redis-cli case "$1" in start) diff --git a/vagrant/redis_vars.sh b/build_tools/redis_vars.sh index 5a4b610..c52dd4c 100755 --- a/vagrant/redis_vars.sh +++ b/build_tools/redis_vars.sh @@ -1,13 +1,13 @@ #!/usr/bin/env bash -VAGRANT_DIR=/home/vagrant/redis-py/vagrant +VAGRANT_DIR=/home/vagrant/redis-py/build_tools VAGRANT_REDIS_CONF_DIR=$VAGRANT_DIR/redis-configs VAGRANT_SENTINEL_CONF_DIR=$VAGRANT_DIR/sentinel-configs -REDIS_VERSION=2.8.9 +REDIS_VERSION=3.2.0 REDIS_DOWNLOAD_DIR=/home/vagrant/redis-downloads REDIS_PACKAGE=redis-$REDIS_VERSION.tar.gz REDIS_BUILD_DIR=$REDIS_DOWNLOAD_DIR/redis-$REDIS_VERSION -REDIS_DIR=/home/vagrant/redis +REDIS_DIR=/var/lib/redis REDIS_BIN_DIR=$REDIS_DIR/bin REDIS_CONF_DIR=$REDIS_DIR/conf REDIS_SAVE_DIR=$REDIS_DIR/backups diff --git a/vagrant/sentinel-configs/001-1 b/build_tools/sentinel-configs/001-1 index eccc3d1..eccc3d1 100644 --- a/vagrant/sentinel-configs/001-1 +++ b/build_tools/sentinel-configs/001-1 diff --git a/vagrant/sentinel-configs/002-2 b/build_tools/sentinel-configs/002-2 index 0cd2801..0cd2801 100644 --- a/vagrant/sentinel-configs/002-2 +++ b/build_tools/sentinel-configs/002-2 diff --git a/vagrant/sentinel-configs/003-3 b/build_tools/sentinel-configs/003-3 index c7f4fcd..c7f4fcd 100644 --- a/vagrant/sentinel-configs/003-3 +++ b/build_tools/sentinel-configs/003-3 diff --git a/vagrant/sentinel_init_script b/build_tools/sentinel_init_script index ea93537..1d94804 100755 --- a/vagrant/sentinel_init_script +++ b/build_tools/sentinel_init_script @@ -12,10 +12,10 @@ SENTINELPORT={{ PORT }} PIDFILE=/var/run/{{ PROCESS_NAME }}.pid -CONF=/home/vagrant/redis/conf/{{ PROCESS_NAME }}.conf +CONF=/var/lib/redis/conf/{{ PROCESS_NAME }}.conf -EXEC=/home/vagrant/redis/bin/redis-sentinel -CLIEXEC=/home/vagrant/redis/bin/redis-cli +EXEC=/var/lib/redis/bin/redis-sentinel +CLIEXEC=/var/lib/redis/bin/redis-cli case "$1" in start) diff --git a/docs/conf.py b/docs/conf.py index 8463eaa..690be03 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,6 +19,7 @@ import sys # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. #sys.path.insert(0, os.path.abspath('.')) +sys.path.append(os.path.abspath(os.path.pardir)) # -- General configuration ---------------------------------------------------- @@ -27,7 +28,7 @@ import sys # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = [] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.viewcode'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -43,16 +44,16 @@ master_doc = 'index' # General information about the project. project = u'redis-py' -copyright = u'2013, Andy McCurdy, Mahdi Yusuf' +copyright = u'2016, Andy McCurdy' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '2.7.2' +version = '2.10.5' # The full version, including alpha/beta/rc tags. -release = '2.7.2' +release = '2.10.5' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -188,7 +189,7 @@ latex_elements = { # [howto/manual]). latex_documents = [ ('index', 'redis-py.tex', u'redis-py Documentation', - u'Andy McCurdy, Mahdi Yusuf', 'manual'), + u'Andy McCurdy', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -218,7 +219,7 @@ latex_documents = [ # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'redis-py', u'redis-py Documentation', - [u'Andy McCurdy, Mahdi Yusuf'], 1) + [u'Andy McCurdy'], 1) ] # If true, show URL addresses after external links. @@ -232,7 +233,7 @@ man_pages = [ # dir menu entry, description, category) texinfo_documents = [ ('index', 'redis-py', u'redis-py Documentation', - u'Andy McCurdy, Mahdi Yusuf', 'redis-py', + u'Andy McCurdy', 'redis-py', 'One line description of project.', 'Miscellaneous'), ] @@ -244,3 +245,8 @@ texinfo_documents = [ # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' + +epub_title = u'redis-py' +epub_author = u'Andy McCurdy' +epub_publisher = u'Andy McCurdy' +epub_copyright = u'2011, Andy McCurdy' diff --git a/docs/index.rst b/docs/index.rst index 2394587..e441bee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,22 +1,23 @@ .. redis-py documentation master file, created by - sphinx-quickstart on Fri Feb 8 00:47:08 2013. + sphinx-quickstart on Thu Jul 28 13:55:57 2011. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to redis-py's documentation! ==================================== -Contents: - -.. toctree:: - :maxdepth: 2 - - - Indices and tables -================== +------------------ * :ref:`genindex` * :ref:`modindex` * :ref:`search` +Contents: +--------- + +.. toctree:: + :maxdepth: 2 + +.. automodule:: redis + :members: diff --git a/redis/__init__.py b/redis/__init__.py index 3b0995d..6607155 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -22,7 +22,7 @@ from redis.exceptions import ( ) -__version__ = '2.10.3' +__version__ = '2.10.6' VERSION = tuple(map(int, __version__.split('.'))) __all__ = [ diff --git a/redis/_compat.py b/redis/_compat.py index 38d767d..80973b3 100644 --- a/redis/_compat.py +++ b/redis/_compat.py @@ -1,16 +1,84 @@ """Internal module for Python 2 backwards compatibility.""" +import errno import sys +# For Python older than 3.5, retry EINTR. +if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and + sys.version_info[1] < 5): + # Adapted from https://bugs.python.org/review/23863/patch/14532/54418 + import socket + import time + + from select import select as _select, error as select_error + + def select(rlist, wlist, xlist, timeout): + while True: + try: + return _select(rlist, wlist, xlist, timeout) + except select_error as e: + if e.args[0] == errno.EINTR: + continue + raise + + # Wrapper for handling interruptable system calls. + def _retryable_call(s, func, *args, **kwargs): + # Some modules (SSL) use the _fileobject wrapper directly and + # implement a smaller portion of the socket interface, thus we + # need to let them continue to do so. + timeout, deadline = None, 0.0 + attempted = False + try: + timeout = s.gettimeout() + except AttributeError: + pass + + if timeout: + deadline = time.time() + timeout + + try: + while True: + if attempted and timeout: + now = time.time() + if now >= deadline: + raise socket.error(errno.EWOULDBLOCK, "timed out") + else: + # Overwrite the timeout on the socket object + # to take into account elapsed time. + s.settimeout(deadline - now) + try: + attempted = True + return func(*args, **kwargs) + except socket.error as e: + if e.args[0] == errno.EINTR: + continue + raise + finally: + # Set the existing timeout back for future + # calls. + if timeout: + s.settimeout(timeout) + + def recv(sock, *args, **kwargs): + return _retryable_call(sock, sock.recv, *args, **kwargs) + + def recv_into(sock, *args, **kwargs): + return _retryable_call(sock, sock.recv_into, *args, **kwargs) + +else: # Python 3.5 and above automatically retry EINTR + from select import select + + def recv(sock, *args, **kwargs): + return sock.recv(*args, **kwargs) + + def recv_into(sock, *args, **kwargs): + return sock.recv_into(*args, **kwargs) if sys.version_info[0] < 3: + from urllib import unquote from urlparse import parse_qs, urlparse from itertools import imap, izip from string import letters as ascii_letters from Queue import Queue - try: - from cStringIO import StringIO as BytesIO - except ImportError: - from StringIO import StringIO as BytesIO # special unicode handling for python2 to avoid UnicodeDecodeError def safe_unicode(obj, *args): @@ -22,15 +90,24 @@ if sys.version_info[0] < 3: ascii_text = str(obj).encode('string_escape') return unicode(ascii_text) - iteritems = lambda x: x.iteritems() - iterkeys = lambda x: x.iterkeys() - itervalues = lambda x: x.itervalues() - nativestr = lambda x: \ - x if isinstance(x, str) else x.encode('utf-8', 'replace') - u = lambda x: x.decode() - b = lambda x: x - next = lambda x: x.next() - byte_to_chr = lambda x: x + def iteritems(x): + return x.iteritems() + + def iterkeys(x): + return x.iterkeys() + + def itervalues(x): + return x.itervalues() + + def nativestr(x): + return x if isinstance(x, str) else x.encode('utf-8', 'replace') + + def next(x): + return x.next() + + def byte_to_chr(x): + return x + unichr = unichr xrange = xrange basestring = basestring @@ -38,19 +115,25 @@ if sys.version_info[0] < 3: bytes = str long = long else: - from urllib.parse import parse_qs, urlparse - from io import BytesIO + from urllib.parse import parse_qs, unquote, urlparse from string import ascii_letters from queue import Queue - iteritems = lambda x: iter(x.items()) - iterkeys = lambda x: iter(x.keys()) - itervalues = lambda x: iter(x.values()) - byte_to_chr = lambda x: chr(x) - nativestr = lambda x: \ - x if isinstance(x, str) else x.decode('utf-8', 'replace') - u = lambda x: x - b = lambda x: x.encode('latin-1') if not isinstance(x, bytes) else x + def iteritems(x): + return iter(x.items()) + + def iterkeys(x): + return iter(x.keys()) + + def itervalues(x): + return iter(x.values()) + + def byte_to_chr(x): + return chr(x) + + def nativestr(x): + return x if isinstance(x, str) else x.decode('utf-8', 'replace') + next = next unichr = chr imap = map @@ -64,27 +147,5 @@ else: try: # Python 3 from queue import LifoQueue, Empty, Full -except ImportError: - from Queue import Empty, Full - try: # Python 2.6 - 2.7 - from Queue import LifoQueue - except ImportError: # Python 2.5 - from Queue import Queue - # From the Python 2.7 lib. Python 2.5 already extracted the core - # methods to aid implementating different queue organisations. - - class LifoQueue(Queue): - "Override queue methods to implement a last-in first-out queue." - - def _init(self, maxsize): - self.maxsize = maxsize - self.queue = [] - - def _qsize(self, len=len): - return len(self.queue) - - def _put(self, item): - self.queue.append(item) - - def _get(self): - return self.queue.pop() +except ImportError: # Python 2 + from Queue import LifoQueue, Empty, Full diff --git a/redis/client.py b/redis/client.py index 3acfb9f..0383d14 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,4 +1,4 @@ -from __future__ import with_statement +from __future__ import unicode_literals from itertools import chain import datetime import sys @@ -6,12 +6,12 @@ import warnings import time import threading import time as mod_time -from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, - itervalues, izip, long, nativestr, unicode, - safe_unicode) +import hashlib +from redis._compat import (basestring, bytes, imap, iteritems, iterkeys, + itervalues, izip, long, nativestr, safe_unicode) from redis.connection import (ConnectionPool, UnixDomainSocketConnection, SSLConnection, Token) -from redis.lock import Lock, LuaLock +from redis.lock import Lock from redis.exceptions import ( ConnectionError, DataError, @@ -24,17 +24,20 @@ from redis.exceptions import ( WatchError, ) -SYM_EMPTY = b('') +SYM_EMPTY = b'' +EMPTY_RESPONSE = 'EMPTY_RESPONSE' def list_or_args(keys, args): - # returns a single list combining keys and args + # returns a single new list combining keys and args try: iter(keys) # a string or bytes instance can be iterated, but indicates # keys wasn't passed as a list if isinstance(keys, (basestring, bytes)): keys = [keys] + else: + keys = list(keys) except TypeError: keys = [keys] if args: @@ -59,7 +62,8 @@ def string_keys_to_dict(key_string, callback): def dict_merge(*dicts): merged = {} - [merged.update(d) for d in dicts] + for d in dicts: + merged.update(d) return merged @@ -69,7 +73,7 @@ def parse_debug_object(response): # prefixed with a name response = nativestr(response) response = 'type:' + response - response = dict([kv.split(':') for kv in response.split()]) + response = dict(kv.split(':') for kv in response.split()) # parse some expected int values from the string response # note: this cmd isn't spec'd so these may not appear in all redis versions @@ -112,7 +116,8 @@ def parse_info(response): for line in response.splitlines(): if line and not line.startswith('#'): if line.find(':') != -1: - key, value = line.split(':', 1) + # support keys that include ':' by using rsplit + key, value = line.rsplit(':', 1) info[key] = get_value(value) else: # if the line isn't splittable, append it to the "__raw__" key @@ -180,10 +185,15 @@ def parse_sentinel_get_master(response): return response and (response[0], int(response[1])) or None -def pairs_to_dict(response): +def pairs_to_dict(response, decode_keys=False): "Create a dict given a list of key/value pairs" - it = iter(response) - return dict(izip(it, it)) + if decode_keys: + # the iter form is faster, but I don't know how to make that work + # with a nativestr() map + return dict(izip(imap(nativestr, response[::2]), response[1::2])) + else: + it = iter(response) + return dict(izip(it, it)) def pairs_to_dict_typed(response, type_info): @@ -193,7 +203,7 @@ def pairs_to_dict_typed(response, type_info): if key in type_info: try: value = type_info[key](value) - except: + except Exception: # if for some reason the value can't be coerced, just use # the string value pass @@ -206,7 +216,7 @@ def zset_score_pairs(response, **options): If ``withscores`` is specified in the options, return the response as a list of (value, score) pairs """ - if not response or not options['withscores']: + if not response or not options.get('withscores'): return response score_cast_func = options.get('score_cast_func', float) it = iter(response) @@ -218,7 +228,7 @@ def sort_return_tuples(response, **options): If ``groups`` is specified, return the response as a list of n-element tuples with n being the value found in options['groups'] """ - if not response or not options['groups']: + if not response or not options.get('groups'): return response n = options['groups'] return list(izip(*[response[i::n] for i in range(n)])) @@ -230,6 +240,58 @@ def int_or_none(response): return int(response) +def parse_stream_list(response): + if response is None: + return None + return [(r[0], pairs_to_dict(r[1])) for r in response] + + +def pairs_to_dict_with_nativestr_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(imap(pairs_to_dict_with_nativestr_keys, response)) + + +def parse_xclaim(response, **options): + if options.get('parse_justid', False): + return response + return parse_stream_list(response) + + +def parse_xinfo_stream(response): + data = pairs_to_dict(response, decode_keys=True) + first = data['first-entry'] + data['first-entry'] = (first[0], pairs_to_dict(first[1])) + last = data['last-entry'] + data['last-entry'] = (last[0], pairs_to_dict(last[1])) + return data + + +def parse_xread(response): + if response is None: + return [] + return [[nativestr(r[0]), parse_stream_list(r[1])] for r in response] + + +def parse_xpending(response, **options): + if options.get('parse_detail', False): + return parse_xpending_range(response) + consumers = [{'name': n, 'pending': long(p)} for n, p in response[3] or []] + return { + 'pending': response[0], + 'min': response[1], + 'max': response[2], + 'consumers': consumers + } + + +def parse_xpending_range(response): + k = ('message_id', 'consumer', 'time_since_delivered', 'times_delivered') + return [dict(izip(k, r)) for r in response] + + def float_or_none(response): if response is None: return None @@ -240,10 +302,17 @@ def bool_ok(response): return nativestr(response) == 'OK' +def parse_zadd(response, **options): + if options.get('as_score'): + return float(response) + return int(response) + + def parse_client_list(response, **options): clients = [] for c in nativestr(response).splitlines(): - clients.append(dict([pair.split('=') for pair in c.split(' ')])) + # Values might contain '=' + clients.append(dict(pair.split('=', 1) for pair in c.split(' '))) return clients @@ -274,11 +343,77 @@ def parse_slowlog_get(response, **options): 'id': item[0], 'start_time': int(item[1]), 'duration': int(item[2]), - 'command': b(' ').join(item[3]) + 'command': b' '.join(item[3]) } for item in response] -class StrictRedis(object): +def parse_cluster_info(response, **options): + response = nativestr(response) + return dict(line.split(':') for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(' ') + node_id, addr, flags, master_id, ping, pong, epoch, \ + connected = line.split(' ')[:8] + slots = [sl.split('-') for sl in line_items[8:]] + node_dict = { + 'node_id': node_id, + 'flags': flags, + 'master_id': master_id, + 'last_ping_sent': ping, + 'last_pong_rcvd': pong, + 'epoch': epoch, + 'slots': slots, + 'connected': True if connected == 'connected' else False + } + return addr, node_dict + + +def parse_cluster_nodes(response, **options): + response = nativestr(response) + raw_lines = response + if isinstance(response, basestring): + raw_lines = response.splitlines() + return dict(_parse_node_line(line) for line in raw_lines) + + +def parse_georadius_generic(response, **options): + if options['store'] or options['store_dist']: + # `store` and `store_diff` cant be combined + # with other command arguments. + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options['withdist'] and not options['withcoord']\ + and not options['withhash']: + # just a bunch of places + return [nativestr(r) for r in response_list] + + cast = { + 'withdist': float, + 'withcoord': lambda ll: (float(ll[0]), float(ll[1])), + 'withhash': int + } + + # zip all output results with each casting functino to get + # the properly native Python value. + f = [nativestr] + f += [cast[o] for o in ['withdist', 'withhash', 'withcoord'] if options[o]] + return [ + list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list + ] + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +class Redis(object): """ Implementation of the Redis protocol. @@ -290,28 +425,32 @@ class StrictRedis(object): """ RESPONSE_CALLBACKS = dict_merge( string_keys_to_dict( - 'AUTH EXISTS EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' + 'AUTH EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' 'PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN ' - 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT ' - 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' - 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', + 'BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN ' + 'HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD ' + 'SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN ' + 'SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM ' + 'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', int ), - string_keys_to_dict('INCRBYFLOAT HINCRBYFLOAT', float), + string_keys_to_dict( + 'INCRBYFLOAT HINCRBYFLOAT', + float + ), string_keys_to_dict( # these return OK, or int if redis-server is >=1.3.4 'LPUSH RPUSH', - lambda r: isinstance(r, long) and r or nativestr(r) == 'OK' + lambda r: isinstance(r, (long, int)) and r or nativestr(r) == 'OK' ), string_keys_to_dict('SORT', sort_return_tuples), - string_keys_to_dict('ZSCORE ZINCRBY', float_or_none), + string_keys_to_dict('ZSCORE ZINCRBY GEODIST', float_or_none), string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE RENAME ' - 'SAVE SELECT SHUTDOWN SLAVEOF WATCH UNWATCH', + 'SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH', bool_ok ), string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), @@ -320,26 +459,58 @@ class StrictRedis(object): lambda r: r and set(r) or set() ), string_keys_to_dict( - 'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', + 'ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs ), + string_keys_to_dict('BZPOPMIN BZPOPMAX', \ + lambda r: r and (r[0], r[1], float(r[2])) or None), string_keys_to_dict('ZRANK ZREVRANK', int_or_none), + string_keys_to_dict('XREVRANGE XRANGE', parse_stream_list), + string_keys_to_dict('XREAD XREADGROUP', parse_xread), string_keys_to_dict('BGREWRITEAOF BGSAVE', lambda r: True), { 'CLIENT GETNAME': lambda r: r and nativestr(r), + 'CLIENT ID': int, 'CLIENT KILL': bool_ok, 'CLIENT LIST': parse_client_list, 'CLIENT SETNAME': bool_ok, + 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, + 'CLIENT PAUSE': bool_ok, + 'CLUSTER ADDSLOTS': bool_ok, + 'CLUSTER COUNT-FAILURE-REPORTS': lambda x: int(x), + 'CLUSTER COUNTKEYSINSLOT': lambda x: int(x), + 'CLUSTER DELSLOTS': bool_ok, + 'CLUSTER FAILOVER': bool_ok, + 'CLUSTER FORGET': bool_ok, + 'CLUSTER INFO': parse_cluster_info, + 'CLUSTER KEYSLOT': lambda x: int(x), + 'CLUSTER MEET': bool_ok, + 'CLUSTER NODES': parse_cluster_nodes, + 'CLUSTER REPLICATE': bool_ok, + 'CLUSTER RESET': bool_ok, + 'CLUSTER SAVECONFIG': bool_ok, + 'CLUSTER SET-CONFIG-EPOCH': bool_ok, + 'CLUSTER SETSLOT': bool_ok, + 'CLUSTER SLAVES': parse_cluster_nodes, 'CONFIG GET': parse_config_get, 'CONFIG RESETSTAT': bool_ok, 'CONFIG SET': bool_ok, 'DEBUG OBJECT': parse_debug_object, + 'GEOHASH': lambda r: list(map(nativestr, r)), + 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), + float(ll[1])) + if ll is not None else None, r)), + 'GEORADIUS': parse_georadius_generic, + 'GEORADIUSBYMEMBER': parse_georadius_generic, 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, 'HSCAN': parse_hscan, 'INFO': parse_info, 'LASTSAVE': timestamp_to_datetime, + 'MEMORY PURGE': bool_ok, + 'MEMORY USAGE': int_or_none, 'OBJECT': parse_object, 'PING': lambda r: nativestr(r) == 'PONG', + 'PUBSUB NUMSUB': parse_pubsub_numsub, 'RANDOMKEY': lambda r: r and r or None, 'SCAN': parse_scan, 'SCRIPT EXISTS': lambda r: list(imap(bool, r)), @@ -360,20 +531,41 @@ class StrictRedis(object): 'SLOWLOG RESET': bool_ok, 'SSCAN': parse_scan, 'TIME': lambda x: (int(x[0]), int(x[1])), - 'ZSCAN': parse_zscan + 'XCLAIM': parse_xclaim, + 'XGROUP CREATE': bool_ok, + 'XGROUP DELCONSUMER': int, + 'XGROUP DESTROY': bool, + 'XGROUP SETID': bool_ok, + 'XINFO CONSUMERS': parse_list_of_dicts, + 'XINFO GROUPS': parse_list_of_dicts, + 'XINFO STREAM': parse_xinfo_stream, + 'XPENDING': parse_xpending, + 'ZADD': parse_zadd, + 'ZSCAN': parse_zscan, } ) @classmethod def from_url(cls, url, db=None, **kwargs): """ - Return a Redis client object configured from the given URL. + Return a Redis client object configured from the given URL For example:: redis://[:password]@localhost:6379/0 + rediss://[:password]@localhost:6379/0 unix://[:password]@/path/to/socket.sock?db=0 + Three URL schemes are supported: + + - ```redis://`` + <http://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a + normal TCP socket connection + - ```rediss://`` + <http://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates a + SSL wrapped TCP socket connection + - ``unix://`` creates a Unix Domain Socket connection + There are several ways to specify a database number. The parse function will return the first specified option: 1. A ``db`` querystring option, e.g. redis://localhost?db=0 @@ -399,7 +591,8 @@ class StrictRedis(object): charset=None, errors=None, decode_responses=False, retry_on_timeout=False, ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None): + ssl_cert_reqs='required', ssl_ca_certs=None, + max_connections=None): if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -417,7 +610,8 @@ class StrictRedis(object): 'encoding': encoding, 'encoding_errors': encoding_errors, 'decode_responses': decode_responses, - 'retry_on_timeout': retry_on_timeout + 'retry_on_timeout': retry_on_timeout, + 'max_connections': max_connections } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -445,7 +639,6 @@ class StrictRedis(object): }) connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool - self._use_lua_lock = None self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy() @@ -464,7 +657,7 @@ class StrictRedis(object): atomic, pipelines are useful for reducing the back-and-forth overhead between the client and server. """ - return StrictPipeline( + return Pipeline( self.connection_pool, self.response_callbacks, transaction, @@ -480,7 +673,7 @@ class StrictRedis(object): value_from_callable = kwargs.pop('value_from_callable', False) watch_delay = kwargs.pop('watch_delay', None) with self.pipeline(True, shard_hint) as pipe: - while 1: + while True: try: if watches: pipe.watch(*watches) @@ -538,15 +731,7 @@ class StrictRedis(object): is that these cases aren't common and as such default to using thread local storage. """ if lock_class is None: - if self._use_lua_lock is None: - # the first time .lock() is called, determine if we can use - # Lua by attempting to register the necessary scripts - try: - LuaLock.register_scripts(self) - self._use_lua_lock = True - except ResponseError: - self._use_lua_lock = False - lock_class = self._use_lua_lock and LuaLock or Lock + lock_class = Lock return lock_class(self, name, timeout=timeout, sleep=sleep, blocking_timeout=blocking_timeout, thread_local=thread_local) @@ -579,7 +764,12 @@ class StrictRedis(object): def parse_response(self, connection, command_name, **options): "Parses a response from the Redis server" - response = connection.read_response() + try: + response = connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response @@ -600,18 +790,56 @@ class StrictRedis(object): "Disconnects the client at ``address`` (ip:port)" return self.execute_command('CLIENT KILL', address) - def client_list(self): + def client_list(self, _type=None): + """ + Returns a list of currently connected clients. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ "Returns a list of currently connected clients" + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + return self.execute_command('CLIENT LIST', Token.get_token('TYPE'), + _type) return self.execute_command('CLIENT LIST') def client_getname(self): "Returns the current connection name" return self.execute_command('CLIENT GETNAME') + def client_id(self): + "Returns the current connection id" + return self.execute_command('CLIENT ID') + def client_setname(self, name): "Sets the current connection name" return self.execute_command('CLIENT SETNAME', name) + def client_unblock(self, client_id, error=False): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(Token.get_token('ERROR')) + return self.execute_command(*args) + + def client_pause(self, timeout): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, (int, long)): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout)) + def config_get(self, pattern="*"): "Return a dictionary of configuration based on the ``pattern``" return self.execute_command('CONFIG GET', pattern) @@ -640,13 +868,33 @@ class StrictRedis(object): "Echo the string back from the server" return self.execute_command('ECHO', value) - def flushall(self): - "Delete all keys in all databases on the current host" - return self.execute_command('FLUSHALL') + def flushall(self, asynchronous=False): + """ + Delete all keys in all databases on the current host. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if not asynchronous: + args.append(Token.get_token('ASYNC')) + return self.execute_command('FLUSHALL', *args) + + def flushdb(self, asynchronous=False): + """ + Delete all keys in the current database. - def flushdb(self): - "Delete all keys in the current database" - return self.execute_command('FLUSHDB') + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if not asynchronous: + args.append(Token.get_token('ASYNC')) + return self.execute_command('FLUSHDB', *args) + + def swapdb(self, first, second): + "Swap two databases" + return self.execute_command('SWAPDB', first, second) def info(self, section=None): """ @@ -670,10 +918,63 @@ class StrictRedis(object): """ return self.execute_command('LASTSAVE') + def migrate(self, host, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the current Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(Token.get_token('COPY')) + if replace: + pieces.append(Token.get_token('REPLACE')) + if auth: + pieces.append(Token.get_token('AUTH')) + pieces.append(auth) + pieces.append(Token.get_token('KEYS')) + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces) + def object(self, infotype, key): "Return the encoding, idletime, or refcount about the key" return self.execute_command('OBJECT', infotype, key, infotype=infotype) + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([Token.get_token('SAMPLES'), samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def memory_purge(self): + "Attempts to purge dirty pages for reclamation by allocator" + return self.execute_command('MEMORY PURGE') + def ping(self): "Ping the Redis server" return self.execute_command('PING') @@ -723,10 +1024,22 @@ class StrictRedis(object): "Returns a list of slaves for ``service_name``" return self.execute_command('SENTINEL SLAVES', service_name) - def shutdown(self): - "Shutdown the server" + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') try: - self.execute_command('SHUTDOWN') + self.execute_command(*args) except ConnectionError: # a ConnectionError here is expected return @@ -739,7 +1052,8 @@ class StrictRedis(object): instance is promoted to a master instead. """ if host is None and port is None: - return self.execute_command('SLAVEOF', Token('NO'), Token('ONE')) + return self.execute_command('SLAVEOF', Token.get_token('NO'), + Token.get_token('ONE')) return self.execute_command('SLAVEOF', host, port) def slowlog_get(self, num=None): @@ -767,6 +1081,15 @@ class StrictRedis(object): """ return self.execute_command('TIME') + def wait(self, num_replicas, timeout): + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + """ + return self.execute_command('WAIT', num_replicas, timeout) + # BASIC KEY COMMANDS def append(self, key, value): """ @@ -787,9 +1110,16 @@ class StrictRedis(object): params.append(end) elif (start is not None and end is None) or \ (end is not None and start is None): - raise RedisError("Both start and end must be specified") + raise DataError("Both start and end must be specified") return self.execute_command('BITCOUNT', *params) + def bitfield(self, key, default_overflow=None): + """ + Return a BitFieldOperation instance to conveniently construct one or + more bitfield operations on ``key``. + """ + return BitFieldOperation(self, key, default_overflow=default_overflow) + def bitop(self, operation, dest, *keys): """ Perform a bitwise operation using ``operation`` between ``keys`` and @@ -805,7 +1135,7 @@ class StrictRedis(object): means to look at the first three bytes. """ if bit not in (0, 1): - raise RedisError('bit must be 0 or 1') + raise DataError('bit must be 0 or 1') params = [key, bit] start is not None and params.append(start) @@ -813,8 +1143,8 @@ class StrictRedis(object): if start is not None and end is not None: params.append(end) elif start is None and end is not None: - raise RedisError("start argument is not set, " - "when end is specified") + raise DataError("start argument is not set, " + "when end is specified") return self.execute_command('BITPOS', *params) def decr(self, name, amount=1): @@ -848,9 +1178,9 @@ class StrictRedis(object): """ return self.execute_command('DUMP', name) - def exists(self, name): - "Returns a boolean indicating whether key ``name`` exists" - return self.execute_command('EXISTS', name) + def exists(self, *names): + "Returns the number of ``names`` that exist" + return self.execute_command('EXISTS', *names) __contains__ = exists def expire(self, name, time): @@ -859,7 +1189,7 @@ class StrictRedis(object): can be represented by an integer or a Python timedelta object. """ if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 + time = int(time.total_seconds()) return self.execute_command('EXPIRE', name, time) def expireat(self, name, when): @@ -883,7 +1213,7 @@ class StrictRedis(object): doesn't exist. """ value = self.get(name) - if value: + if value is not None: return value raise KeyError(name) @@ -938,35 +1268,31 @@ class StrictRedis(object): Returns a list of values ordered identically to ``keys`` """ args = list_or_args(keys, args) - return self.execute_command('MGET', *args) + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command('MGET', *args, **options) - def mset(self, *args, **kwargs): + def mset(self, mapping): """ - Sets key/values based on a mapping. Mapping can be supplied as a single - dictionary argument or as kwargs. + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). """ - if args: - if len(args) != 1 or not isinstance(args[0], dict): - raise RedisError('MSET requires **kwargs or a single dict arg') - kwargs.update(args[0]) items = [] - for pair in iteritems(kwargs): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSET', *items) - def msetnx(self, *args, **kwargs): + def msetnx(self, mapping): """ Sets key/values based on a mapping if none of the keys are already set. - Mapping can be supplied as a single dictionary argument or as kwargs. + Mapping is a dictionary of key/value pairs. Both keys and values + should be strings or types that can be cast to a string via str(). Returns a boolean indicating if the operation was successful. """ - if args: - if len(args) != 1 or not isinstance(args[0], dict): - raise RedisError('MSETNX requires **kwargs or a single ' - 'dict arg') - kwargs.update(args[0]) items = [] - for pair in iteritems(kwargs): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSETNX', *items) @@ -985,8 +1311,7 @@ class StrictRedis(object): object. """ if isinstance(time, datetime.timedelta): - ms = int(time.microseconds / 1000) - time = (time.seconds + time.days * 24 * 3600) * 1000 + ms + time = int(time.total_seconds() * 1000) return self.execute_command('PEXPIRE', name, time) def pexpireat(self, name, when): @@ -1007,8 +1332,7 @@ class StrictRedis(object): timedelta object """ if isinstance(time_ms, datetime.timedelta): - ms = int(time_ms.microseconds / 1000) - time_ms = (time_ms.seconds + time_ms.days * 24 * 3600) * 1000 + ms + time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command('PSETEX', name, time_ms, value) def pttl(self, name): @@ -1029,12 +1353,15 @@ class StrictRedis(object): "Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist" return self.execute_command('RENAMENX', src, dst) - def restore(self, name, ttl, value): + def restore(self, name, ttl, value, replace=False): """ Create a key using the provided serialized value, previously obtained using DUMP. """ - return self.execute_command('RESTORE', name, ttl, value) + params = [name, ttl, value] + if replace: + params.append('REPLACE') + return self.execute_command('RESTORE', *params) def set(self, name, value, ex=None, px=None, nx=False, xx=False): """ @@ -1044,23 +1371,22 @@ class StrictRedis(object): ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. - ``nx`` if set to True, set the value at key ``name`` to ``value`` if it - does not already exist. + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. - ``xx`` if set to True, set the value at key ``name`` to ``value`` if it - already exists. + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. """ pieces = [name, value] - if ex: + if ex is not None: pieces.append('EX') if isinstance(ex, datetime.timedelta): - ex = ex.seconds + ex.days * 24 * 3600 + ex = int(ex.total_seconds()) pieces.append(ex) - if px: + if px is not None: pieces.append('PX') if isinstance(px, datetime.timedelta): - ms = int(px.microseconds / 1000) - px = (px.seconds + px.days * 24 * 3600) * 1000 + ms + px = int(px.total_seconds() * 1000) pieces.append(px) if nx: @@ -1087,7 +1413,7 @@ class StrictRedis(object): timedelta object. """ if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 + time = int(time.total_seconds()) return self.execute_command('SETEX', name, time, value) def setnx(self, name, value): @@ -1118,6 +1444,13 @@ class StrictRedis(object): """ return self.execute_command('SUBSTR', name, start, end) + def touch(self, *args): + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command('TOUCH', *args) + def ttl(self, name): "Returns the number of seconds until the key ``name`` will expire" return self.execute_command('TTL', name) @@ -1139,6 +1472,10 @@ class StrictRedis(object): warnings.warn( DeprecationWarning('Call UNWATCH from a Pipeline object')) + def unlink(self, *names): + "Unlink one or more keys specified by ``names``" + return self.execute_command('UNLINK', *names) + # LIST COMMANDS def blpop(self, keys, timeout=0): """ @@ -1153,10 +1490,7 @@ class StrictRedis(object): """ if timeout is None: timeout = 0 - if isinstance(keys, basestring): - keys = [keys] - else: - keys = list(keys) + keys = list_or_args(keys, None) keys.append(timeout) return self.execute_command('BLPOP', *keys) @@ -1165,7 +1499,7 @@ class StrictRedis(object): RPOP a value off of the first non-empty list named in the ``keys`` list. - If none of the lists in ``keys`` has a value to LPOP, then block + If none of the lists in ``keys`` has a value to RPOP, then block for ``timeout`` seconds, or until a value gets pushed on to one of the lists. @@ -1173,10 +1507,7 @@ class StrictRedis(object): """ if timeout is None: timeout = 0 - if isinstance(keys, basestring): - keys = [keys] - else: - keys = list(keys) + keys = list_or_args(keys, None) keys.append(timeout) return self.execute_command('BRPOP', *keys) @@ -1311,14 +1642,14 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = [name] if by is not None: - pieces.append(Token('BY')) + pieces.append(Token.get_token('BY')) pieces.append(by) if start is not None and num is not None: - pieces.append(Token('LIMIT')) + pieces.append(Token.get_token('LIMIT')) pieces.append(start) pieces.append(num) if get is not None: @@ -1326,23 +1657,23 @@ class StrictRedis(object): # Otherwise assume it's an interable and we want to get multiple # values. We can't just iterate blindly because strings are # iterable. - if isinstance(get, basestring): - pieces.append(Token('GET')) + if isinstance(get, (bytes, basestring)): + pieces.append(Token.get_token('GET')) pieces.append(get) else: for g in get: - pieces.append(Token('GET')) + pieces.append(Token.get_token('GET')) pieces.append(g) if desc: - pieces.append(Token('DESC')) + pieces.append(Token.get_token('DESC')) if alpha: - pieces.append(Token('ALPHA')) + pieces.append(Token.get_token('ALPHA')) if store is not None: - pieces.append(Token('STORE')) + pieces.append(Token.get_token('STORE')) pieces.append(store) if groups: - if not get or isinstance(get, basestring) or len(get) < 2: + if not get or isinstance(get, (bytes, basestring)) or len(get) < 2: raise DataError('when using "groups" the "get" argument ' 'must be specified and contain at least ' 'two keys') @@ -1362,9 +1693,9 @@ class StrictRedis(object): """ pieces = [cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) return self.execute_command('SCAN', *pieces) def scan_iter(self, match=None, count=None): @@ -1393,9 +1724,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) return self.execute_command('SSCAN', *pieces) def sscan_iter(self, name, match=None, count=None): @@ -1425,9 +1756,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) return self.execute_command('HSCAN', *pieces) def hscan_iter(self, name, match=None, count=None): @@ -1460,9 +1791,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend([Token('MATCH'), match]) + pieces.extend([Token.get_token('MATCH'), match]) if count is not None: - pieces.extend([Token('COUNT'), count]) + pieces.extend([Token.get_token('COUNT'), count]) options = {'score_cast_func': score_cast_func} return self.execute_command('ZSCAN', *pieces, **options) @@ -1533,9 +1864,10 @@ class StrictRedis(object): "Move ``value`` from set ``src`` to set ``dst`` atomically" return self.execute_command('SMOVE', src, dst, value) - def spop(self, name): + def spop(self, name, count=None): "Remove and return a random member of set ``name``" - return self.execute_command('SPOP', name) + args = (count is not None) and [count] or [] + return self.execute_command('SPOP', name, *args) def srandmember(self, name, number=None): """ @@ -1545,7 +1877,7 @@ class StrictRedis(object): memebers of set ``name``. Note this is only available when running Redis 2.6+. """ - args = number and [number] or [] + args = (number is not None) and [number] or [] return self.execute_command('SRANDMEMBER', name, *args) def srem(self, name, *values): @@ -1565,28 +1897,375 @@ class StrictRedis(object): args = list_or_args(keys, args) return self.execute_command('SUNIONSTORE', dest, *args) + # STREAMS COMMANDS + def xack(self, name, groupname, *ids): + """ + Acknowledges the successful processing of one or more messages. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowlege. + """ + return self.execute_command('XACK', name, groupname, *ids) + + def xadd(self, name, fields, id='*', maxlen=None, approximate=True): + """ + Add to a stream. + name: name of the stream + fields: dict of field/value pairs to insert into the stream + id: Location to insert this record. By default it is appended. + maxlen: truncate old stream members beyond this size + approximate: actual stream length may be slightly more than maxlen + + """ + pieces = [] + if maxlen is not None: + if not isinstance(maxlen, (int, long)) or maxlen < 1: + raise DataError('XADD maxlen must be a positive integer') + pieces.append(Token.get_token('MAXLEN')) + if approximate: + pieces.append(Token.get_token('~')) + pieces.append(str(maxlen)) + pieces.append(id) + if not isinstance(fields, dict) or len(fields) == 0: + raise DataError('XADD fields must be a non-empty dict') + for pair in iteritems(fields): + pieces.extend(pair) + return self.execute_command('XADD', name, *pieces) + + def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, + idle=None, time=None, retrycount=None, force=False, + justid=False): + """ + Changes the ownership of a pending message. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds + message_ids: non-empty list or tuple of message IDs to claim + idle: optional. Set the idle time (last time it was delivered) of the + message in ms + time: optional integer. This is the same as idle but instead of a + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + retrycount: optional integer. set the retry counter to the specified + value. This counter is incremented every time a message is delivered + again. + force: optional boolean, false by default. Creates the pending message + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + """ + if not isinstance(min_idle_time, (int, long)) or min_idle_time < 0: + raise DataError("XCLAIM min_idle_time must be a non negative " + "integer") + if not isinstance(message_ids, (list, tuple)) or not message_ids: + raise DataError("XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim") + + kwargs = {} + pieces = [name, groupname, consumername, str(min_idle_time)] + pieces.extend(list(message_ids)) + + if idle is not None: + if not isinstance(idle, (int, long)): + raise DataError("XCLAIM idle must be an integer") + pieces.extend((Token.get_token('IDLE'), str(idle))) + if time is not None: + if not isinstance(time, (int, long)): + raise DataError("XCLAIM time must be an integer") + pieces.extend((Token.get_token('TIME'), str(time))) + if retrycount is not None: + if not isinstance(retrycount, (int, long)): + raise DataError("XCLAIM retrycount must be an integer") + pieces.extend((Token.get_token('RETRYCOUNT'), str(retrycount))) + + if force: + if not isinstance(force, bool): + raise DataError("XCLAIM force must be a boolean") + pieces.append(Token.get_token('FORCE')) + if justid: + if not isinstance(justid, bool): + raise DataError("XCLAIM justid must be a boolean") + pieces.append(Token.get_token('JUSTID')) + kwargs['parse_justid'] = True + return self.execute_command('XCLAIM', *pieces, **kwargs) + + def xdel(self, name, *ids): + """ + Deletes one or more messages from a stream. + name: name of the stream. + *ids: message ids to delete. + """ + return self.execute_command('XDEL', name, *ids) + + def xgroup_create(self, name, groupname, id='$', mkstream=False): + """ + Create a new consumer group associated with a stream. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + pieces = ['XGROUP CREATE', name, groupname, id] + if mkstream: + pieces.append(Token.get_token('MKSTREAM')) + return self.execute_command(*pieces) + + def xgroup_delconsumer(self, name, groupname, consumername): + """ + Remove a specific consumer from a consumer group. + Returns the number of pending messages that the consumer had before it + was deleted. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to delete + """ + return self.execute_command('XGROUP DELCONSUMER', name, groupname, + consumername) + + def xgroup_destroy(self, name, groupname): + """ + Destroy a consumer group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XGROUP DESTROY', name, groupname) + + def xgroup_setid(self, name, groupname, id): + """ + Set the consumer group last delivered ID to something else. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + return self.execute_command('XGROUP SETID', name, groupname, id) + + def xinfo_consumers(self, name, groupname): + """ + Returns general information about the consumers in the group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XINFO CONSUMERS', name, groupname) + + def xinfo_groups(self, name): + """ + Returns general information about the consumer groups of the stream. + name: name of the stream. + """ + return self.execute_command('XINFO GROUPS', name) + + def xinfo_stream(self, name): + """ + Returns general information about the stream. + name: name of the stream. + """ + return self.execute_command('XINFO STREAM', name) + + def xlen(self, name): + """ + Returns the number of elements in a given stream. + """ + return self.execute_command('XLEN', name) + + def xpending(self, name, groupname): + """ + Returns information about pending messages of a group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XPENDING', name, groupname) + + def xpending_range(self, name, groupname, min='-', max='+', count=-1, + consumername=None): + """ + Returns information about pending messages, in a range. + name: name of the stream. + groupname: name of the consumer group. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + consumername: name of a consumer to filter by (optional). + """ + pieces = [name, groupname] + if min is not None or max is not None or count is not None: + if min is None or max is None or count is None: + raise DataError("XPENDING must be provided with min, max " + "and count parameters, or none of them. ") + if not isinstance(count, (int, long)) or count < -1: + raise DataError("XPENDING count must be a integer >= -1") + pieces.extend((min, max, str(count))) + if consumername is not None: + if min is None or max is None or count is None: + raise DataError("if XPENDING is provided with consumername," + " it must be provided with min, max and" + " count parameters") + pieces.append(consumername) + return self.execute_command('XPENDING', *pieces, parse_detail=True) + + def xrange(self, name, min='-', max='+', count=None): + """ + Read stream values within an interval. + name: name of the stream. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + """ + pieces = [min, max] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XRANGE count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + + return self.execute_command('XRANGE', name, *pieces) + + def xread(self, streams, count=None, block=None): + """ + Block and monitor multiple streams for new data. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [] + if block is not None: + if not isinstance(block, (int, long)) or block < 0: + raise DataError('XREAD block must be a non-negative integer') + pieces.append(Token.get_token('BLOCK')) + pieces.append(str(block)) + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XREAD count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREAD streams must be a non empty dict') + pieces.append(Token.get_token('STREAMS')) + keys, values = izip(*iteritems(streams)) + pieces.extend(keys) + pieces.extend(values) + return self.execute_command('XREAD', *pieces) + + def xreadgroup(self, groupname, consumername, streams, count=None, + block=None): + """ + Read from a stream via a consumer group. + groupname: name of the consumer group. + consumername: name of the requesting consumer. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [Token.get_token('GROUP'), groupname, consumername] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError("XREADGROUP count must be a positive integer") + pieces.append(Token.get_token("COUNT")) + pieces.append(str(count)) + if block is not None: + if not isinstance(block, (int, long)) or block < 0: + raise DataError("XREADGROUP block must be a non-negative " + "integer") + pieces.append(Token.get_token("BLOCK")) + pieces.append(str(block)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREADGROUP streams must be a non empty dict') + pieces.append(Token.get_token('STREAMS')) + pieces.extend(streams.keys()) + pieces.extend(streams.values()) + return self.execute_command('XREADGROUP', *pieces) + + def xrevrange(self, name, max='+', min='-', count=None): + """ + Read stream values within an interval, in reverse order. + name: name of the stream + start: first stream ID. defaults to '+', + meaning the latest available. + finish: last stream ID. defaults to '-', + meaning the earliest available. + count: if set, only return this many items, beginning with the + latest available. + """ + pieces = [max, min] + if count is not None: + if not isinstance(count, (int, long)) or count < 1: + raise DataError('XREVRANGE count must be a positive integer') + pieces.append(Token.get_token('COUNT')) + pieces.append(str(count)) + + return self.execute_command('XREVRANGE', name, *pieces) + + def xtrim(self, name, maxlen, approximate=True): + """ + Trims old messages from a stream. + name: name of the stream. + maxlen: truncate old stream messages beyond this size + approximate: actual stream length may be slightly more than maxlen + """ + pieces = [Token.get_token('MAXLEN')] + if approximate: + pieces.append(Token.get_token('~')) + pieces.append(maxlen) + return self.execute_command('XTRIM', name, *pieces) + # SORTED SET COMMANDS - def zadd(self, name, *args, **kwargs): + def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): """ - Set any number of score, element-name pairs to the key ``name``. Pairs - can be specified in two ways: + Set any number of element-name, score pairs to the key ``name``. Pairs + are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. - As *args, in the form of: score1, name1, score2, name2, ... - or as **kwargs, in the form of: name1=score1, name2=score2, ... + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. - The following example would add four values to the 'my-key' key: - redis.zadd('my-key', 1.1, 'name1', 2.2, 'name2', name3=3.3, name4=4.4) + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. """ + if not mapping: + raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if incr and len(mapping) != 1: + raise DataError("ZADD option 'incr' only works when passing a " + "single element/score pair") pieces = [] - if args: - if len(args) % 2 != 0: - raise RedisError("ZADD requires an equal number of " - "values and scores") - pieces.extend(args) - for pair in iteritems(kwargs): + options = {} + if nx: + pieces.append(Token.get_token('NX')) + if xx: + pieces.append(Token.get_token('XX')) + if ch: + pieces.append(Token.get_token('CH')) + if incr: + pieces.append(Token.get_token('INCR')) + options['as_score'] = True + for pair in iteritems(mapping): pieces.append(pair[1]) pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) + return self.execute_command('ZADD', name, *pieces, **options) def zcard(self, name): "Return the number of elements in the sorted set ``name``" @@ -1599,7 +2278,7 @@ class StrictRedis(object): """ return self.execute_command('ZCOUNT', name, min, max) - def zincrby(self, name, value, amount=1): + def zincrby(self, name, amount, value): "Increment the score of ``value`` in sorted set ``name`` by ``amount``" return self.execute_command('ZINCRBY', name, amount, value) @@ -1618,6 +2297,62 @@ class StrictRedis(object): """ return self.execute_command('ZLEXCOUNT', name, min, max) + def zpopmax(self, name, count=None): + """ + Remove and return up to ``count`` members with the highest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMAX', name, *args, **options) + + def zpopmin(self, name, count=None): + """ + Remove and return up to ``count`` members with the lowest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMIN', name, *args, **options) + + def bzpopmax(self, keys, timeout=0): + """ + ZPOPMAX a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMAX, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMAX', *keys) + + def bzpopmin(self, keys, timeout=0): + """ + ZPOPMIN a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMIN, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMIN', *keys) + def zrange(self, name, start, end, desc=False, withscores=False, score_cast_func=float): """ @@ -1638,7 +2373,7 @@ class StrictRedis(object): score_cast_func) pieces = ['ZRANGE', name, start, end] if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1655,10 +2390,26 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYLEX', name, min, max] if start is not None and num is not None: - pieces.extend([Token('LIMIT'), start, num]) + pieces.extend([Token.get_token('LIMIT'), start, num]) + return self.execute_command(*pieces) + + def zrevrangebylex(self, name, max, min, start=None, num=None): + """ + Return the reversed lexicographical range of values from sorted set + ``name`` between ``max`` and ``min``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ['ZREVRANGEBYLEX', name, max, min] + if start is not None and num is not None: + pieces.extend([Token.get_token('LIMIT'), start, num]) return self.execute_command(*pieces) def zrangebyscore(self, name, min, max, start=None, num=None, @@ -1677,12 +2428,12 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYSCORE', name, min, max] if start is not None and num is not None: - pieces.extend([Token('LIMIT'), start, num]) + pieces.extend([Token.get_token('LIMIT'), start, num]) if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1740,7 +2491,7 @@ class StrictRedis(object): """ pieces = ['ZREVRANGE', name, start, end] if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1763,12 +2514,12 @@ class StrictRedis(object): """ if (start is not None and num is None) or \ (num is not None and start is None): - raise RedisError("``start`` and ``num`` must both be specified") + raise DataError("``start`` and ``num`` must both be specified") pieces = ['ZREVRANGEBYSCORE', name, max, min] if start is not None and num is not None: - pieces.extend([Token('LIMIT'), start, num]) + pieces.extend([Token.get_token('LIMIT'), start, num]) if withscores: - pieces.append(Token('WITHSCORES')) + pieces.append(Token.get_token('WITHSCORES')) options = { 'withscores': withscores, 'score_cast_func': score_cast_func @@ -1802,10 +2553,10 @@ class StrictRedis(object): weights = None pieces.extend(keys) if weights: - pieces.append(Token('WEIGHTS')) + pieces.append(Token.get_token('WEIGHTS')) pieces.extend(weights) if aggregate: - pieces.append(Token('AGGREGATE')) + pieces.append(Token.get_token('AGGREGATE')) pieces.append(aggregate) return self.execute_command(*pieces) @@ -1814,12 +2565,12 @@ class StrictRedis(object): "Adds the specified elements to the specified HyperLogLog." return self.execute_command('PFADD', name, *values) - def pfcount(self, name): + def pfcount(self, *sources): """ Return the approximated cardinality of - the set observed by the HyperLogLog at key. + the set observed by the HyperLogLog at key(s). """ - return self.execute_command('PFCOUNT', name) + return self.execute_command('PFCOUNT', *sources) def pfmerge(self, dest, *sources): "Merge N different HyperLogLogs into a single one." @@ -1895,6 +2646,13 @@ class StrictRedis(object): "Return the list of values within hash ``name``" return self.execute_command('HVALS', name) + def hstrlen(self, name, key): + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command('HSTRLEN', name, key) + def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -1902,6 +2660,28 @@ class StrictRedis(object): """ return self.execute_command('PUBLISH', channel, message) + def pubsub_channels(self, pattern='*'): + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command('PUBSUB CHANNELS', pattern) + + def pubsub_numpat(self): + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command('PUBSUB NUMPAT') + + def pubsub_numsub(self, *args): + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command('PUBSUB NUMSUB', *args) + + def cluster(self, cluster_arg, *args): + return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) + def eval(self, script, numkeys, *keys_and_args): """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -1954,89 +2734,137 @@ class StrictRedis(object): """ return Script(self, script) + # GEO COMMANDS + def geoadd(self, name, *values): + """ + Add the specified geospatial items to the specified key identified + by the ``name`` argument. The Geospatial items are given as ordered + members of the ``values`` argument, each item or place is formed by + the triad longitude, latitude and name. + """ + if len(values) % 3 != 0: + raise DataError("GEOADD requires places with lon, lat and name" + " values") + return self.execute_command('GEOADD', name, *values) -class Redis(StrictRedis): - """ - Provides backwards compatibility with older versions of redis-py that - changed arguments to some commands to be more Pythonic, sane, or by - accident. - """ - - # Overridden callbacks - RESPONSE_CALLBACKS = dict_merge( - StrictRedis.RESPONSE_CALLBACKS, - { - 'TTL': lambda r: r >= 0 and r or None, - 'PTTL': lambda r: r >= 0 and r or None, - } - ) + def geodist(self, name, place1, place2, unit=None): + """ + Return the distance between ``place1`` and ``place2`` members of the + ``name`` key. + The units must be one of the following : m, km mi, ft. By default + meters are used. + """ + pieces = [name, place1, place2] + if unit and unit not in ('m', 'km', 'mi', 'ft'): + raise DataError("GEODIST invalid unit") + elif unit: + pieces.append(unit) + return self.execute_command('GEODIST', *pieces) - def pipeline(self, transaction=True, shard_hint=None): + def geohash(self, name, *values): """ - Return a new pipeline object that can queue multiple commands for - later execution. ``transaction`` indicates whether all commands - should be executed atomically. Apart from making a group of operations - atomic, pipelines are useful for reducing the back-and-forth overhead - between the client and server. + Return the geo hash string for each item of ``values`` members of + the specified key identified by the ``name`` argument. """ - return Pipeline( - self.connection_pool, - self.response_callbacks, - transaction, - shard_hint) + return self.execute_command('GEOHASH', name, *values) - def setex(self, name, value, time): + def geopos(self, name, *values): """ - Set the value of key ``name`` to ``value`` that expires in ``time`` - seconds. ``time`` can be represented by an integer or a Python - timedelta object. + Return the positions of each item of ``values`` as members of + the specified key identified by the ``name`` argument. Each position + is represented by the pairs lon and lat. """ - if isinstance(time, datetime.timedelta): - time = time.seconds + time.days * 24 * 3600 - return self.execute_command('SETEX', name, time, value) + return self.execute_command('GEOPOS', name, *values) - def lrem(self, name, value, num=0): + def georadius(self, name, longitude, latitude, radius, unit=None, + withdist=False, withcoord=False, withhash=False, count=None, + sort=None, store=None, store_dist=None): """ - Remove the first ``num`` occurrences of elements equal to ``value`` - from the list stored at ``name``. + Return the members of the specified key identified by the + ``name`` argument which are within the borders of the area specified + with the ``latitude`` and ``longitude`` location and the maximum + distance from the center specified by the ``radius`` value. - The ``num`` argument influences the operation in the following ways: - num > 0: Remove elements equal to value moving from head to tail. - num < 0: Remove elements equal to value moving from tail to head. - num = 0: Remove all elements equal to value. + The units must be one of the following : m, km mi, ft. By default + + ``withdist`` indicates to return the distances of each place. + + ``withcoord`` indicates to return the latitude and longitude of + each place. + + ``withhash`` indicates to return the geohash string of each place. + + ``count`` indicates to return the number of elements up to N. + + ``sort`` indicates to return the places in a sorted way, ASC for + nearest to fairest and DESC for fairest to nearest. + + ``store`` indicates to save the places names in a sorted set named + with a specific key, each element of the destination sorted set is + populated with the score got from the original geo sorted set. + + ``store_dist`` indicates to save the places names in a sorted set + named with a specific key, instead of ``store`` the sorted set + destination score is set with the distance. """ - return self.execute_command('LREM', name, num, value) + return self._georadiusgeneric('GEORADIUS', + name, longitude, latitude, radius, + unit=unit, withdist=withdist, + withcoord=withcoord, withhash=withhash, + count=count, sort=sort, store=store, + store_dist=store_dist) - def zadd(self, name, *args, **kwargs): + def georadiusbymember(self, name, member, radius, unit=None, + withdist=False, withcoord=False, withhash=False, + count=None, sort=None, store=None, store_dist=None): + """ + This command is exactly like ``georadius`` with the sole difference + that instead of taking, as the center of the area to query, a longitude + and latitude value, it takes the name of a member already existing + inside the geospatial index represented by the sorted set. """ - NOTE: The order of arguments differs from that of the official ZADD - command. For backwards compatability, this method accepts arguments - in the form of name1, score1, name2, score2, while the official Redis - documents expects score1, name1, score2, name2. + return self._georadiusgeneric('GEORADIUSBYMEMBER', + name, member, radius, unit=unit, + withdist=withdist, withcoord=withcoord, + withhash=withhash, count=count, + sort=sort, store=store, + store_dist=store_dist) - If you're looking to use the standard syntax, consider using the - StrictRedis class. See the API Reference section of the docs for more - information. + def _georadiusgeneric(self, command, *args, **kwargs): + pieces = list(args) + if kwargs['unit'] and kwargs['unit'] not in ('m', 'km', 'mi', 'ft'): + raise DataError("GEORADIUS invalid unit") + elif kwargs['unit']: + pieces.append(kwargs['unit']) + else: + pieces.append('m',) - Set any number of element-name, score pairs to the key ``name``. Pairs - can be specified in two ways: + for token in ('withdist', 'withcoord', 'withhash'): + if kwargs[token]: + pieces.append(Token(token.upper())) - As *args, in the form of: name1, score1, name2, score2, ... - or as **kwargs, in the form of: name1=score1, name2=score2, ... + if kwargs['count']: + pieces.extend([Token('COUNT'), kwargs['count']]) + + if kwargs['sort'] and kwargs['sort'] not in ('ASC', 'DESC'): + raise DataError("GEORADIUS invalid sort") + elif kwargs['sort']: + pieces.append(Token(kwargs['sort'])) + + if kwargs['store'] and kwargs['store_dist']: + raise DataError("GEORADIUS store and store_dist cant be set" + " together") + + if kwargs['store']: + pieces.extend([Token('STORE'), kwargs['store']]) + + if kwargs['store_dist']: + pieces.extend([Token('STOREDIST'), kwargs['store_dist']]) + + return self.execute_command(command, *pieces, **kwargs) - The following example would add four values to the 'my-key' key: - redis.zadd('my-key', 'name1', 1.1, 'name2', 2.2, name3=3.3, name4=4.4) - """ - pieces = [] - if args: - if len(args) % 2 != 0: - raise RedisError("ZADD requires an equal number of " - "values and scores") - pieces.extend(reversed(args)) - for pair in iteritems(kwargs): - pieces.append(pair[1]) - pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) + +StrictRedis = Redis class PubSub(object): @@ -2058,13 +2886,7 @@ class PubSub(object): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection('pubsub', shard_hint) - try: - self.encoding = conn.encoding - self.encoding_errors = conn.encoding_errors - self.decode_responses = conn.decode_responses - finally: - connection_pool.release(conn) + self.encoder = self.connection_pool.get_encoder() self.reset() def __del__(self): @@ -2096,29 +2918,14 @@ class PubSub(object): if self.channels: channels = {} for k, v in iteritems(self.channels): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - channels[k] = v + channels[self.encoder.decode(k, force=True)] = v self.subscribe(**channels) if self.patterns: patterns = {} for k, v in iteritems(self.patterns): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - patterns[k] = v + patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) - def encode(self, value): - """ - Encode the value so that it's identical to what we'll - read off the connection - """ - if self.decode_responses and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - elif not self.decode_responses and isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -2127,8 +2934,8 @@ class PubSub(object): def execute_command(self, *args, **kwargs): "Execute a publish/subscribe command" - # NOTE: don't parse the response in this function. it could pull a - # legitmate message off the stack if the connection is already + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already # subscribed to one or more channels if self.connection is None: @@ -2160,10 +2967,24 @@ class PubSub(object): def parse_response(self, block=True, timeout=0): "Parse the response from a publish/subscribe command" connection = self.connection + if connection is None: + raise RuntimeError( + 'pubsub connection not set: ' + 'did you forget to call subscribe() or psubscribe()?') if not block and not connection.can_read(timeout=timeout): return None return self._execute(connection, connection.read_response) + def _normalize_keys(self, data): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in iteritems(data)} + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments @@ -2174,15 +2995,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_patterns = {} - new_patterns.update(dict.fromkeys(imap(self.encode, args))) - for pattern, handler in iteritems(kwargs): - new_patterns[self.encode(pattern)] = handler + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(new_patterns) + self.patterns.update(self._normalize_keys(new_patterns)) return ret_val def punsubscribe(self, *args): @@ -2204,15 +3023,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_channels = {} - new_channels.update(dict.fromkeys(imap(self.encode, args))) - for channel, handler in iteritems(kwargs): - new_channels[self.encode(channel)] = handler + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(new_channels) + self.channels.update(self._normalize_keys(new_channels)) return ret_val def unsubscribe(self, *args): @@ -2244,6 +3061,13 @@ class PubSub(object): return self.handle_message(response, ignore_subscribe_messages) return None + def ping(self, message=None): + """ + Ping the Redis server + """ + message = '' if message is None else message + return self.execute_command('PING', message) + def handle_message(self, response, ignore_subscribe_messages=False): """ Parses a pub/sub message. If the channel or pattern was subscribed to @@ -2258,6 +3082,13 @@ class PubSub(object): 'channel': response[2], 'data': response[3] } + elif message_type == 'pong': + message = { + 'type': message_type, + 'pattern': None, + 'channel': None, + 'data': response[1] + } else: message = { 'type': message_type, @@ -2288,7 +3119,7 @@ class PubSub(object): if handler: handler(message) return None - else: + elif message_type != 'pong': # this is a subscribe/unsubscribe message. ignore if we don't # want them if ignore_subscribe_messages or self.ignore_subscribe_messages: @@ -2296,7 +3127,7 @@ class PubSub(object): return message - def run_in_thread(self, sleep_time=0): + def run_in_thread(self, sleep_time=0, daemon=False): for channel, handler in iteritems(self.channels): if handler is None: raise PubSubError("Channel: '%s' has no handler registered") @@ -2304,14 +3135,15 @@ class PubSub(object): if handler is None: raise PubSubError("Pattern: '%s' has no handler registered") - thread = PubSubWorkerThread(self, sleep_time) + thread = PubSubWorkerThread(self, sleep_time, daemon=daemon) thread.start() return thread class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time): + def __init__(self, pubsub, sleep_time, daemon=False): super(PubSubWorkerThread, self).__init__() + self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time self._running = False @@ -2336,7 +3168,7 @@ class PubSubWorkerThread(threading.Thread): self.pubsub.punsubscribe() -class BasePipeline(object): +class Pipeline(Redis): """ Pipelines provide a way to transmit multiple commands to the Redis server in one transmission. This is convenient for batch processing, such as @@ -2355,7 +3187,7 @@ class BasePipeline(object): on a key of a different datatype. """ - UNWATCH_COMMANDS = set(('DISCARD', 'EXEC', 'UNWATCH')) + UNWATCH_COMMANDS = {'DISCARD', 'EXEC', 'UNWATCH'} def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): @@ -2473,7 +3305,8 @@ class BasePipeline(object): def _execute_transaction(self, connection, commands, raise_on_error): cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) - all_cmds = connection.pack_commands([args for args, _ in cmds]) + all_cmds = connection.pack_commands([args for args, options in cmds + if EMPTY_RESPONSE not in options]) connection.send_packed_command(all_cmds) errors = [] @@ -2488,12 +3321,15 @@ class BasePipeline(object): # and all the other commands for i, command in enumerate(commands): - try: - self.parse_response(connection, '_') - except ResponseError: - ex = sys.exc_info()[1] - self.annotate_exception(ex, i + 1, command[0]) - errors.append((i, ex)) + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + self.parse_response(connection, '_') + except ResponseError: + ex = sys.exc_info()[1] + self.annotate_exception(ex, i + 1, command[0]) + errors.append((i, ex)) # parse the EXEC. try: @@ -2556,13 +3392,13 @@ class BasePipeline(object): raise r def annotate_exception(self, exception, number, command): - cmd = safe_unicode(' ').join(imap(safe_unicode, command)) - msg = unicode('Command # %d (%s) of pipeline caused error: %s') % ( + cmd = ' '.join(imap(safe_unicode, command)) + msg = 'Command # %d (%s) of pipeline caused error: %s' % ( number, cmd, safe_unicode(exception.args[0])) exception.args = (msg,) + exception.args[1:] def parse_response(self, connection, command_name, **options): - result = StrictRedis.parse_response( + result = Redis.parse_response( self, connection, command_name, **options) if command_name in self.UNWATCH_COMMANDS: self.watching = False @@ -2577,12 +3413,11 @@ class BasePipeline(object): shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + exists = immediate('SCRIPT EXISTS', *shas) if not all(exists): for s, exist in izip(scripts, exists): if not exist: - s.sha = immediate('SCRIPT', 'LOAD', s.script, - **{'parse': 'LOAD'}) + s.sha = immediate('SCRIPT LOAD', s.script) def execute(self, raise_on_error=True): "Execute all the commands in the current pipeline" @@ -2634,26 +3469,6 @@ class BasePipeline(object): "Unwatches all previously specified keys" return self.watching and self.execute_command('UNWATCH') or True - def script_load_for_pipeline(self, script): - "Make sure scripts are loaded prior to pipeline execution" - # we need the sha now so that Script.__call__ can use it to run - # evalsha. - if not script.sha: - script.sha = self.immediate_execute_command('SCRIPT', 'LOAD', - script.script, - **{'parse': 'LOAD'}) - self.scripts.add(script) - - -class StrictPipeline(BasePipeline, StrictRedis): - "Pipeline for the StrictRedis class" - pass - - -class Pipeline(BasePipeline, Redis): - "Pipeline for the Redis class" - pass - class Script(object): "An executable Lua script object returned by ``register_script``" @@ -2661,7 +3476,14 @@ class Script(object): def __init__(self, registered_client, script): self.registered_client = registered_client self.script = script - self.sha = '' + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, basestring): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() def __call__(self, keys=[], args=[], client=None): "Execute the script, passing any required ``args``" @@ -2669,13 +3491,111 @@ class Script(object): client = self.registered_client args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script - if isinstance(client, BasePipeline): - # make sure this script is good to go on pipeline - client.script_load_for_pipeline(self) + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) try: return client.evalsha(self.sha, len(keys), *args) except NoScriptError: # Maybe the client is pointed to a differnet server than the client # that created this instance? + # Overwrite the sha just in case there was a discrepancy. self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) + + +class BitFieldOperation(object): + """ + Command builder for BITFIELD commands. + """ + def __init__(self, client, key, default_overflow=None): + self.client = client + self.key = key + self._default_overflow = default_overflow + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = 'WRAP' + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(('OVERFLOW', overflow)) + return self + + def incrby(self, fmt, offset, increment, overflow=None): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(('INCRBY', fmt, offset, increment)) + return self + + def get(self, fmt, offset): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('GET', fmt, offset)) + return self + + def set(self, fmt, offset, value): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('SET', fmt, offset, value)) + return self + + @property + def command(self): + cmd = ['BITFIELD', self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self): + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) diff --git a/redis/connection.py b/redis/connection.py index 7486da8..b38f24c 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,7 +1,7 @@ -from __future__ import with_statement +from __future__ import unicode_literals from distutils.version import StrictVersion from itertools import chain -from select import select +import io import os import socket import sys @@ -14,10 +14,12 @@ try: except ImportError: ssl_available = False -from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, - BytesIO, nativestr, basestring, iteritems, - LifoQueue, Empty, Full, urlparse, parse_qs) +from redis._compat import (xrange, imap, byte_to_chr, unicode, bytes, long, + nativestr, basestring, iteritems, + LifoQueue, Empty, Full, urlparse, parse_qs, + recv, recv_into, select, unquote) from redis.exceptions import ( + DataError, RedisError, ConnectionError, TimeoutError, @@ -45,16 +47,14 @@ if HIREDIS_AVAILABLE: warnings.warn(msg) HIREDIS_USE_BYTE_BUFFER = True - # only use byte buffer if hiredis supports it and the Python version - # is >= 2.7 - if not HIREDIS_SUPPORTS_BYTE_BUFFER or ( - sys.version_info[0] == 2 and sys.version_info[1] < 7): + # only use byte buffer if hiredis supports it + if not HIREDIS_SUPPORTS_BYTE_BUFFER: HIREDIS_USE_BYTE_BUFFER = False -SYM_STAR = b('*') -SYM_DOLLAR = b('$') -SYM_CRLF = b('\r\n') -SYM_EMPTY = b('') +SYM_STAR = b'*' +SYM_DOLLAR = b'$' +SYM_CRLF = b'\r\n' +SYM_EMPTY = b'' SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." @@ -65,10 +65,27 @@ class Token(object): hard-coded arguments are wrapped in this class so we know not to apply and encoding rules on them. """ + + _cache = {} + + @classmethod + def get_token(cls, value): + "Gets a cached token object or creates a new one if not already cached" + + # Use try/except because after running for a short time most tokens + # should already be cached + try: + return cls._cache[value] + except KeyError: + token = Token(value) + cls._cache[value] = token + return token + def __init__(self, value): if isinstance(value, Token): value = value.value self.value = value + self.encoded_value = value.encode() def __repr__(self): return self.value @@ -77,9 +94,50 @@ class Token(object): return self.value +class Encoder(object): + "Encode strings to bytes and decode bytes to strings" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring representation of the value" + if isinstance(value, Token): + return value.encoded_value + elif isinstance(value, bytes): + return value + elif isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError("Invalid input of type: 'bool'. Convert to a " + "byte, string or number first.") + elif isinstance(value, float): + value = repr(value).encode() + elif isinstance(value, (int, long)): + # python 2 repr() on longs is '123L', so use str() instead + value = str(value).encode() + elif not isinstance(value, basestring): + # a value we don't know how to deal with. throw an error + typename = type(value).__name__ + raise DataError("Invalid input of type: '%s'. Convert to a " + "byte, string or number first." % typename) + if isinstance(value, unicode): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the byte representation" + if (self.decode_responses or force) and isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value + + class BaseParser(object): EXCEPTION_CLASSES = { - 'ERR': ResponseError, + 'ERR': { + 'max number of clients reached': ConnectionError + }, 'EXECABORT': ExecAbortError, 'LOADING': BusyLoadingError, 'NOSCRIPT': NoScriptError, @@ -91,7 +149,10 @@ class BaseParser(object): error_code = response.split(' ')[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1:] - return self.EXCEPTION_CLASSES[error_code](response) + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) return ResponseError(response) @@ -99,7 +160,7 @@ class SocketBuffer(object): def __init__(self, socket, socket_read_size): self._sock = socket self.socket_read_size = socket_read_size - self._buffer = BytesIO() + self._buffer = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 # number of bytes read from the buffer @@ -117,7 +178,7 @@ class SocketBuffer(object): try: while True: - data = self._sock.recv(socket_read_size) + data = recv(self._sock, socket_read_size) # an empty string indicates the server shutdown the socket if isinstance(data, bytes) and len(data) == 0: raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) @@ -179,18 +240,25 @@ class SocketBuffer(object): self.bytes_read = 0 def close(self): - self.purge() - self._buffer.close() + try: + self.purge() + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass self._buffer = None self._sock = None class PythonParser(BaseParser): "Plain Python parsing class" - encoding = None - def __init__(self, socket_read_size): self.socket_read_size = socket_read_size + self.encoder = None self._sock = None self._buffer = None @@ -204,8 +272,7 @@ class PythonParser(BaseParser): "Called when the socket connects" self._sock = connection._sock self._buffer = SocketBuffer(self._sock, self.socket_read_size) - if connection.decode_responses: - self.encoding = connection.encoding + self.encoder = connection.encoder def on_disconnect(self): "Called when the socket disconnects" @@ -215,7 +282,7 @@ class PythonParser(BaseParser): if self._buffer is not None: self._buffer.close() self._buffer = None - self.encoding = None + self.encoder = None def can_read(self): return self._buffer and bool(self._buffer.length) @@ -262,8 +329,8 @@ class PythonParser(BaseParser): if length == -1: return None response = [self.read_response() for i in xrange(length)] - if isinstance(response, bytes) and self.encoding: - response = response.decode(self.encoding) + if isinstance(response, bytes): + response = self.encoder.decode(response) return response @@ -294,8 +361,8 @@ class HiredisParser(BaseParser): if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: kwargs['replyError'] = ResponseError - if connection.decode_responses: - kwargs['encoding'] = connection.encoding + if connection.encoder.decode_responses: + kwargs['encoding'] = connection.encoder.encoding self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -327,11 +394,11 @@ class HiredisParser(BaseParser): while response is False: try: if HIREDIS_USE_BYTE_BUFFER: - bufflen = self._sock.recv_into(self._buffer) + bufflen = recv_into(self._sock, self._buffer) if bufflen == 0: raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) else: - buffer = self._sock.recv(socket_read_size) + buffer = recv(self._sock, socket_read_size) # an empty string indicates the server shutdown the socket if not isinstance(buffer, bytes) or len(buffer) == 0: raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) @@ -345,15 +412,6 @@ class HiredisParser(BaseParser): self._reader.feed(self._buffer, 0, bufflen) else: self._reader.feed(buffer) - # proactively, but not conclusively, check if more data is in the - # buffer. if the data received doesn't end with \r\n, there's more. - if HIREDIS_USE_BYTE_BUFFER: - if bufflen > 2 and \ - self._buffer[bufflen - 2:bufflen] != SYM_CRLF: - continue - else: - if not buffer.endswith(SYM_CRLF): - continue response = self._reader.gets() # if an older version of hiredis is installed, we need to attempt # to convert ResponseErrors to their appropriate types. @@ -373,6 +431,7 @@ class HiredisParser(BaseParser): raise response[0] return response + if HIREDIS_AVAILABLE: DefaultParser = HiredisParser else: @@ -386,7 +445,7 @@ class Connection(object): def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, socket_connect_timeout=None, socket_keepalive=False, socket_keepalive_options=None, - retry_on_timeout=False, encoding='utf-8', + socket_type=0, retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, parser_class=DefaultParser, socket_read_size=65536): self.pid = os.getpid() @@ -398,10 +457,9 @@ class Connection(object): self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -410,6 +468,7 @@ class Connection(object): 'db': self.db, } self._connect_callbacks = [] + self._buffer_cutoff = 6000 def __repr__(self): return self.description_format % self._description_args @@ -432,6 +491,8 @@ class Connection(object): return try: sock = self._connect() + except socket.timeout: + raise TimeoutError("Timeout connecting to server") except socket.error: e = sys.exc_info()[1] raise ConnectionError(self._error_message(e)) @@ -455,7 +516,7 @@ class Connection(object): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - for res in socket.getaddrinfo(self.host, self.port, 0, + for res in socket.getaddrinfo(self.host, self.port, self.socket_type, socket.SOCK_STREAM): family, socktype, proto, canonname, socket_address = res sock = None @@ -543,14 +604,15 @@ class Connection(object): e = sys.exc_info()[1] self.disconnect() if len(e.args) == 1: - _errno, errmsg = 'UNKNOWN', e.args[0] + errno, errmsg = 'UNKNOWN', e.args[0] else: - _errno, errmsg = e.args + errno = e.args[0] + errmsg = e.args[1] raise ConnectionError("Error %s while writing to socket. %s." % - (_errno, errmsg)) - except: + (errno, errmsg)) + except Exception as e: self.disconnect() - raise + raise e def send_command(self, *args): "Pack and send a command to the Redis server" @@ -569,29 +631,13 @@ class Connection(object): "Read the response from a previously sent command" try: response = self._parser.read_response() - except: + except Exception as e: self.disconnect() - raise + raise e if isinstance(response, ResponseError): raise response return response - def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, Token): - return b(value.value) - elif isinstance(value, bytes): - return value - elif isinstance(value, (int, long)): - value = b(str(value)) - elif isinstance(value, float): - value = b(repr(value)) - elif not isinstance(value, basestring): - value = str(value) - if isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - def pack_command(self, *args): "Pack a series of arguments into the Redis protocol" output = [] @@ -602,25 +648,27 @@ class Connection(object): # to prevent them from being encoded. command = args[0] if ' ' in command: - args = tuple([Token(s) for s in command.split(' ')]) + args[1:] + args = tuple(Token.get_token(s) + for s in command.split()) + args[1:] else: - args = (Token(command),) + args[1:] + args = (Token.get_token(command),) + args[1:] - buff = SYM_EMPTY.join( - (SYM_STAR, b(str(len(args))), SYM_CRLF)) + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) - for arg in imap(self.encode, args): + buffer_cutoff = self._buffer_cutoff + for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values - if len(buff) > 6000 or len(arg) > 6000: + if len(buff) > buffer_cutoff or len(arg) > buffer_cutoff: buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) + (buff, SYM_DOLLAR, str(len(arg)).encode(), SYM_CRLF)) output.append(buff) output.append(arg) buff = SYM_CRLF else: - buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), - SYM_CRLF, arg, SYM_CRLF)) + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(len(arg)).encode(), + SYM_CRLF, arg, SYM_CRLF)) output.append(buff) return output @@ -629,16 +677,21 @@ class Connection(object): output = [] pieces = [] buffer_length = 0 + buffer_cutoff = self._buffer_cutoff for cmd in commands: for chunk in self.pack_command(*cmd): - pieces.append(chunk) - buffer_length += len(chunk) - - if buffer_length > 6000: - output.append(SYM_EMPTY.join(pieces)) - buffer_length = 0 - pieces = [] + chunklen = len(chunk) + if buffer_length > buffer_cutoff or chunklen > buffer_cutoff: + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > self._buffer_cutoff: + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen if pieces: output.append(SYM_EMPTY.join(pieces)) @@ -648,8 +701,8 @@ class Connection(object): class SSLConnection(Connection): description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>" - def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, **kwargs): + def __init__(self, ssl_keyfile=None, ssl_certfile=None, + ssl_cert_reqs='required', ssl_ca_certs=None, **kwargs): if not ssl_available: raise RedisError("Python wasn't built with SSL support") @@ -698,9 +751,7 @@ class UnixDomainSocketConnection(Connection): self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -727,10 +778,30 @@ class UnixDomainSocketConnection(Connection): (exception.args[0], self.path, exception.args[1]) +FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO') + + +def to_bool(value): + if value is None or value == '': + return None + if isinstance(value, basestring) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS = { + 'socket_timeout': float, + 'socket_connect_timeout': float, + 'socket_keepalive': to_bool, + 'retry_on_timeout': to_bool, + 'max_connections': int, +} + + class ConnectionPool(object): "Generic connection pool" @classmethod - def from_url(cls, url, db=None, **kwargs): + def from_url(cls, url, db=None, decode_components=False, **kwargs): """ Return a connection pool configured from the given URL. @@ -741,9 +812,14 @@ class ConnectionPool(object): unix://[:password]@/path/to/socket.sock?db=0 Three URL schemes are supported: - redis:// creates a normal TCP socket connection - rediss:// creates a SSL wrapped TCP socket connection - unix:// creates a Unix Domain Socket connection + + - ```redis://`` + <https://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a + normal TCP socket connection + - ```rediss://`` + <https://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates + a SSL wrapped TCP socket connection + - ``unix://`` creates a Unix Domain Socket connection There are several ways to specify a database number. The parse function will return the first specified option: @@ -754,50 +830,67 @@ class ConnectionPool(object): If none of these options are specified, db=0 is used. + The ``decode_components`` argument allows this function to work with + percent-encoded URLs. If this argument is set to ``True`` all ``%xx`` + escapes will be replaced by their single-character equivalents after + the URL has been parsed. This only applies to the ``hostname``, + ``path``, and ``password`` components. + Any additional querystring arguments and keyword arguments will be - passed along to the ConnectionPool class's initializer. In the case - of conflicting arguments, querystring arguments always win. + passed along to the ConnectionPool class's initializer. The querystring + arguments ``socket_connect_timeout`` and ``socket_timeout`` if supplied + are parsed as float values. The arguments ``socket_keepalive`` and + ``retry_on_timeout`` are parsed to boolean values that accept + True/False, Yes/No values to indicate state. Invalid types cause a + ``UserWarning`` to be raised. In the case of conflicting arguments, + querystring arguments always win. + """ - url_string = url url = urlparse(url) - qs = '' - - # in python2.6, custom URL schemes don't recognize querystring values - # they're left as part of the url.path. - if '?' in url.path and not url.query: - # chop the querystring including the ? off the end of the url - # and reparse it. - qs = url.path.split('?', 1)[1] - url = urlparse(url_string[:-(len(qs) + 1)]) - else: - qs = url.query - url_options = {} - for name, value in iteritems(parse_qs(qs)): + for name, value in iteritems(parse_qs(url.query)): if value and len(value) > 0: - url_options[name] = value[0] + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + url_options[name] = parser(value[0]) + except (TypeError, ValueError): + warnings.warn(UserWarning( + "Invalid value for `%s` in connection URL." % name + )) + else: + url_options[name] = value[0] + + if decode_components: + password = unquote(url.password) if url.password else None + path = unquote(url.path) if url.path else None + hostname = unquote(url.hostname) if url.hostname else None + else: + password = url.password + path = url.path + hostname = url.hostname # We only support redis:// and unix:// schemes. if url.scheme == 'unix': url_options.update({ - 'password': url.password, - 'path': url.path, + 'password': password, + 'path': path, 'connection_class': UnixDomainSocketConnection, }) else: url_options.update({ - 'host': url.hostname, + 'host': hostname, 'port': int(url.port or 6379), - 'password': url.password, + 'password': password, }) # If there's a path argument, use it as the db argument if a # querystring value wasn't specified - if 'db' not in url_options and url.path: + if 'db' not in url_options and path: try: - url_options['db'] = int(url.path.replace('/', '')) + url_options['db'] = int(path.replace('/', '')) except (AttributeError, ValueError): pass @@ -828,8 +921,8 @@ class ConnectionPool(object): Create a connection pool. If max_connections is set, then this object raises redis.ConnectionError when the pool's limit is reached. - By default, TCP connections are created connection_class is specified. - Use redis.UnixDomainSocketConnection for unix sockets. + By default, TCP connections are created unless connection_class is + specified. Use redis.UnixDomainSocketConnection for unix sockets. Any additional keyword arguments are passed to the constructor of connection_class. @@ -877,6 +970,15 @@ class ConnectionPool(object): self._in_use_connections.add(connection) return connection + def get_encoder(self): + "Return an encoder based on encoding settings" + kwargs = self.connection_kwargs + return Encoder( + encoding=kwargs.get('encoding', 'utf-8'), + encoding_errors=kwargs.get('encoding_errors', 'strict'), + decode_responses=kwargs.get('decode_responses', False) + ) + def make_connection(self): "Create a new connection" if self._created_connections >= self.max_connections: diff --git a/redis/exceptions.py b/redis/exceptions.py index a8518c7..44ab6f7 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -1,21 +1,10 @@ "Core exceptions raised by the Redis client" -from redis._compat import unicode class RedisError(Exception): pass -# python 2.5 doesn't implement Exception.__unicode__. Add it here to all -# our exception types -if not hasattr(RedisError, '__unicode__'): - def __unicode__(self): - if isinstance(self.args[0], unicode): - return self.args[0] - return unicode(self.args[0]) - RedisError.__unicode__ = __unicode__ - - class AuthenticationError(RedisError): pass diff --git a/redis/lock.py b/redis/lock.py index 90f0e7a..43c0813 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,9 +1,8 @@ import threading import time as mod_time import uuid -from redis.exceptions import LockError, WatchError +from redis.exceptions import LockError from redis.utils import dummy -from redis._compat import b class Lock(object): @@ -14,6 +13,42 @@ class Lock(object): It's left to the user to resolve deadlock issues and make sure multiple clients play nicely together. """ + + lua_release = None + lua_extend = None + + # KEYS[1] - lock name + # ARGS[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGS[1] - token + # ARGS[2] - additional milliseconds + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + redis.call('pexpire', KEYS[1], expiration + ARGV[2]) + return 1 + """ + def __init__(self, redis, name, timeout=None, sleep=0.1, blocking=True, blocking_timeout=None, thread_local=True): """ @@ -77,12 +112,22 @@ class Lock(object): self.local.token = None if self.timeout and self.sleep > self.timeout: raise LockError("'sleep' must be less than 'timeout'") + self.register_scripts() + + def register_scripts(self): + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) def __enter__(self): # force blocking, as otherwise the user would have to check whether # the lock was actually acquired or not. - self.acquire(blocking=True) - return self + if self.acquire(blocking=True): + return self + raise LockError("Unable to acquire lock within the time specified") def __exit__(self, exc_type, exc_value, traceback): self.release() @@ -99,7 +144,7 @@ class Lock(object): wait trying to acquire the lock. """ sleep = self.sleep - token = b(uuid.uuid1().hex) + token = uuid.uuid1().hex.encode() if blocking is None: blocking = self.blocking if blocking_timeout is None: @@ -107,7 +152,7 @@ class Lock(object): stop_trying_at = None if blocking_timeout is not None: stop_trying_at = mod_time.time() + blocking_timeout - while 1: + while True: if self.do_acquire(token): self.local.token = token return True @@ -118,14 +163,19 @@ class Lock(object): mod_time.sleep(sleep) def do_acquire(self, token): - if self.redis.setnx(self.name, token): - if self.timeout: - # convert to milliseconds - timeout = int(self.timeout * 1000) - self.redis.pexpire(self.name, timeout) + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if self.redis.set(self.name, token, nx=True, px=timeout): return True return False + def locked(self): + token = self.local.token + return token and self.redis.get(self.name) == token or False + def release(self): "Releases the already acquired lock" expected_token = self.local.token @@ -135,15 +185,10 @@ class Lock(object): self.do_release(expected_token) def do_release(self, expected_token): - name = self.name - - def execute_release(pipe): - lock_value = pipe.get(name) - if lock_value != expected_token: - raise LockError("Cannot release a lock that's no longer owned") - pipe.delete(name) - - self.redis.transaction(execute_release, name) + if not bool(self.lua_release(keys=[self.name], + args=[expected_token], + client=self.redis)): + raise LockError("Cannot release a lock that's no longer owned") def extend(self, additional_time): """ @@ -159,111 +204,6 @@ class Lock(object): return self.do_extend(additional_time) def do_extend(self, additional_time): - pipe = self.redis.pipeline() - pipe.watch(self.name) - lock_value = pipe.get(self.name) - if lock_value != self.local.token: - raise LockError("Cannot extend a lock that's no longer owned") - expiration = pipe.pttl(self.name) - if expiration is None or expiration < 0: - # Redis evicted the lock key between the previous get() and now - # we'll handle this when we call pexpire() - expiration = 0 - pipe.multi() - pipe.pexpire(self.name, expiration + int(additional_time * 1000)) - - try: - response = pipe.execute() - except WatchError: - # someone else acquired the lock - raise LockError("Cannot extend a lock that's no longer owned") - if not response[0]: - # pexpire returns False if the key doesn't exist - raise LockError("Cannot extend a lock that's no longer owned") - return True - - -class LuaLock(Lock): - """ - A lock implementation that uses Lua scripts rather than pipelines - and watches. - """ - lua_acquire = None - lua_release = None - lua_extend = None - - # KEYS[1] - lock name - # ARGV[1] - token - # ARGV[2] - timeout in milliseconds - # return 1 if lock was acquired, otherwise 0 - LUA_ACQUIRE_SCRIPT = """ - if redis.call('setnx', KEYS[1], ARGV[1]) == 1 then - if ARGV[2] ~= '' then - redis.call('pexpire', KEYS[1], ARGV[2]) - end - return 1 - end - return 0 - """ - - # KEYS[1] - lock name - # ARGS[1] - token - # return 1 if the lock was released, otherwise 0 - LUA_RELEASE_SCRIPT = """ - local token = redis.call('get', KEYS[1]) - if not token or token ~= ARGV[1] then - return 0 - end - redis.call('del', KEYS[1]) - return 1 - """ - - # KEYS[1] - lock name - # ARGS[1] - token - # ARGS[2] - additional milliseconds - # return 1 if the locks time was extended, otherwise 0 - LUA_EXTEND_SCRIPT = """ - local token = redis.call('get', KEYS[1]) - if not token or token ~= ARGV[1] then - return 0 - end - local expiration = redis.call('pttl', KEYS[1]) - if not expiration then - expiration = 0 - end - if expiration < 0 then - return 0 - end - redis.call('pexpire', KEYS[1], expiration + ARGV[2]) - return 1 - """ - - def __init__(self, *args, **kwargs): - super(LuaLock, self).__init__(*args, **kwargs) - LuaLock.register_scripts(self.redis) - - @classmethod - def register_scripts(cls, redis): - if cls.lua_acquire is None: - cls.lua_acquire = redis.register_script(cls.LUA_ACQUIRE_SCRIPT) - if cls.lua_release is None: - cls.lua_release = redis.register_script(cls.LUA_RELEASE_SCRIPT) - if cls.lua_extend is None: - cls.lua_extend = redis.register_script(cls.LUA_EXTEND_SCRIPT) - - def do_acquire(self, token): - timeout = self.timeout and int(self.timeout * 1000) or '' - return bool(self.lua_acquire(keys=[self.name], - args=[token, timeout], - client=self.redis)) - - def do_release(self, expected_token): - if not bool(self.lua_release(keys=[self.name], - args=[expected_token], - client=self.redis)): - raise LockError("Cannot release a lock that's no longer owned") - - def do_extend(self, additional_time): additional_time = int(additional_time * 1000) if not bool(self.lua_extend(keys=[self.name], args=[self.local.token, additional_time], diff --git a/redis/sentinel.py b/redis/sentinel.py index 3fb89ce..9df2997 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -2,9 +2,10 @@ import os import random import weakref -from redis.client import StrictRedis +from redis.client import Redis from redis.connection import ConnectionPool, Connection -from redis.exceptions import ConnectionError, ResponseError, ReadOnlyError +from redis.exceptions import (ConnectionError, ResponseError, ReadOnlyError, + TimeoutError) from redis._compat import iteritems, nativestr, xrange @@ -170,13 +171,14 @@ class Sentinel(object): # if sentinel_kwargs isn't defined, use the socket_* options from # connection_kwargs if sentinel_kwargs is None: - sentinel_kwargs = dict([(k, v) - for k, v in iteritems(connection_kwargs) - if k.startswith('socket_') - ]) + sentinel_kwargs = { + k: v + for k, v in iteritems(connection_kwargs) + if k.startswith('socket_') + } self.sentinel_kwargs = sentinel_kwargs - self.sentinels = [StrictRedis(hostname, port, **self.sentinel_kwargs) + self.sentinels = [Redis(hostname, port, **self.sentinel_kwargs) for hostname, port in sentinels] self.min_other_sentinels = min_other_sentinels self.connection_kwargs = connection_kwargs @@ -211,7 +213,7 @@ class Sentinel(object): for sentinel_no, sentinel in enumerate(self.sentinels): try: masters = sentinel.sentinel_masters() - except ConnectionError: + except (ConnectionError, TimeoutError): continue state = masters.get(service_name) if state and self.check_master_state(state, service_name): @@ -235,14 +237,14 @@ class Sentinel(object): for sentinel in self.sentinels: try: slaves = sentinel.sentinel_slaves(service_name) - except (ConnectionError, ResponseError): + except (ConnectionError, ResponseError, TimeoutError): continue slaves = self.filter_slaves(slaves) if slaves: return slaves return [] - def master_for(self, service_name, redis_class=StrictRedis, + def master_for(self, service_name, redis_class=Redis, connection_pool_class=SentinelConnectionPool, **kwargs): """ Returns a redis client instance for the ``service_name`` master. @@ -253,7 +255,7 @@ class Sentinel(object): NOTE: If the master's address has changed, any cached connections to the old master are closed. - By default clients will be a redis.StrictRedis instance. Specify a + By default clients will be a redis.Redis instance. Specify a different class to the ``redis_class`` argument if you desire something different. @@ -270,7 +272,7 @@ class Sentinel(object): return redis_class(connection_pool=connection_pool_class( service_name, self, **connection_kwargs)) - def slave_for(self, service_name, redis_class=StrictRedis, + def slave_for(self, service_name, redis_class=Redis, connection_pool_class=SentinelConnectionPool, **kwargs): """ Returns redis client instance for the ``service_name`` slave(s). @@ -278,7 +280,7 @@ class Sentinel(object): A SentinelConnectionPool class is used to retrive the slave's address before establishing a new connection. - By default clients will be a redis.StrictRedis instance. Specify a + By default clients will be a redis.Redis instance. Specify a different class to the ``redis_class`` argument if you desire something different. diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..089b37c --- /dev/null +++ b/setup.cfg @@ -0,0 +1,9 @@ +[pycodestyle] +show-source = 1 +exclude = .venv,.tox,dist,docs,build,*.egg + +[bdist_wheel] +universal = 1 + +[metadata] +license_file = LICENSE @@ -23,7 +23,9 @@ try: except ImportError: from distutils.core import setup - PyTest = lambda x: x + + def PyTest(x): + x f = open(os.path.join(os.path.dirname(__file__), 'README.rst')) long_description = f.read() @@ -34,7 +36,7 @@ setup( version=__version__, description='Python client for Redis key-value store', long_description=long_description, - url='http://github.com/andymccurdy/redis-py', + url='https://github.com/andymccurdy/redis-py', author='Andy McCurdy', author_email='sedrik@gmail.com', maintainer='Andy McCurdy', @@ -42,7 +44,16 @@ setup( keywords=['Redis', 'key-value store'], license='MIT', packages=['redis'], - tests_require=['pytest>=2.5.0'], + python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", + extras_require={ + 'hiredis': [ + "hiredis>=0.1.3", + ], + }, + tests_require=[ + 'mock', + 'pytest>=2.7.0', + ], cmdclass={'test': PyTest}, classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -51,11 +62,11 @@ setup( 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', 'Programming Language :: Python', - 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.2', - 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', ] ) diff --git a/tests/conftest.py b/tests/conftest.py index bd0116b..5a43968 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import pytest import redis +from mock import Mock from distutils.version import StrictVersion @@ -36,11 +37,77 @@ def skip_if_server_version_lt(min_version): return pytest.mark.skipif(check, reason="") +def skip_if_server_version_gte(min_version): + check = StrictVersion(get_version()) >= StrictVersion(min_version) + return pytest.mark.skipif(check, reason="") + + @pytest.fixture() def r(request, **kwargs): return _get_client(redis.Redis, request, **kwargs) +def _gen_cluster_mock_resp(r, response): + mock_connection_pool = Mock() + connection = Mock() + response = response + connection.read_response.return_value = response + mock_connection_pool.get_connection.return_value = connection + r.connection_pool = mock_connection_pool + return r + + +@pytest.fixture() +def mock_cluster_resp_ok(request, **kwargs): + r = _get_client(redis.Redis, request, **kwargs) + return _gen_cluster_mock_resp(r, 'OK') + + +@pytest.fixture() +def mock_cluster_resp_int(request, **kwargs): + r = _get_client(redis.Redis, request, **kwargs) + return _gen_cluster_mock_resp(r, '2') + + +@pytest.fixture() +def mock_cluster_resp_info(request, **kwargs): + r = _get_client(redis.Redis, request, **kwargs) + response = ('cluster_state:ok\r\ncluster_slots_assigned:16384\r\n' + 'cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n' + 'cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n' + 'cluster_size:3\r\ncluster_current_epoch:7\r\n' + 'cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n' + 'cluster_stats_messages_received:105653\r\n') + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture() +def mock_cluster_resp_nodes(request, **kwargs): + r = _get_client(redis.Redis, request, **kwargs) + response = ('c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 ' + 'slave aa90da731f673a99617dfe930306549a09f83a6b 0 ' + '1447836263059 5 connected\n' + '9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 ' + 'master - 0 1447836264065 0 connected\n' + 'aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 ' + 'myself,master - 0 0 2 connected 5461-10922\n' + '1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 ' + 'slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 ' + '1447836262556 3 connected\n' + '4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 ' + 'master - 0 1447836262555 7 connected 0-5460\n' + '19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 ' + 'master - 0 1447836263562 3 connected 10923-16383\n' + 'fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 ' + 'master,fail - 1447829446956 1447829444948 1 disconnected\n' + ) + return _gen_cluster_mock_resp(r, response) + + @pytest.fixture() -def sr(request, **kwargs): - return _get_client(redis.StrictRedis, request, **kwargs) +def mock_cluster_resp_slaves(request, **kwargs): + r = _get_client(redis.Redis, request, **kwargs) + response = ("['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836789290 3 connected']") + return _gen_cluster_mock_resp(r, response) diff --git a/tests/test_commands.py b/tests/test_commands.py index 286ea04..8a6be40 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,16 +1,17 @@ -from __future__ import with_statement +from __future__ import unicode_literals import binascii import datetime import pytest +import re import redis import time -from redis._compat import (unichr, u, b, ascii_letters, iteritems, iterkeys, - itervalues) +from redis._compat import (unichr, ascii_letters, iteritems, iterkeys, + itervalues, long) from redis.client import parse_info from redis import exceptions -from .conftest import skip_if_server_version_lt +from .conftest import skip_if_server_version_lt, skip_if_server_version_gte @pytest.fixture() @@ -34,6 +35,13 @@ def redis_server_time(client): return datetime.datetime.fromtimestamp(timestamp) +def get_stream_message(client, stream, message_id): + "Fetch a stream message and format it as a (message_id, fields) pair" + response = client.xrange(stream, min=message_id, max=message_id) + assert len(response) == 1 + return response[0] + + # RESPONSE CALLBACKS class TestResponseCallbacks(object): "Tests for the response callback system" @@ -59,6 +67,25 @@ class TestRedisCommands(object): assert isinstance(clients[0], dict) assert 'addr' in clients[0] + @skip_if_server_version_lt('5.0.0') + def test_client_list_type(self, r): + with pytest.raises(exceptions.RedisError): + r.client_list(_type='not a client type') + for client_type in ['normal', 'master', 'replica', 'pubsub']: + clients = r.client_list(_type=client_type) + assert isinstance(clients, list) + + @skip_if_server_version_lt('5.0.0') + def test_client_id(self, r): + assert r.client_id() > 0 + + @skip_if_server_version_lt('5.0.0') + def test_client_unblock(self, r): + myid = r.client_id() + assert not r.client_unblock(myid) + assert not r.client_unblock(myid, error=True) + assert not r.client_unblock(myid, error=False) + @skip_if_server_version_lt('2.6.9') def test_client_getname(self, r): assert r.client_getname() is None @@ -68,6 +95,20 @@ class TestRedisCommands(object): assert r.client_setname('redis_py_test') assert r.client_getname() == 'redis_py_test' + @skip_if_server_version_lt('2.6.9') + def test_client_list_after_client_setname(self, r): + r.client_setname('redis_py_test') + clients = r.client_list() + # we don't know which client ours will be + assert 'redis_py_test' in [c['name'] for c in clients] + + @skip_if_server_version_lt('2.9.50') + def test_client_pause(self, r): + assert r.client_pause(1) + assert r.client_pause(timeout=1) + with pytest.raises(exceptions.RedisError): + r.client_pause(timeout='not an integer') + def test_config_get(self, r): data = r.config_get() assert 'maxmemory' in data @@ -96,7 +137,7 @@ class TestRedisCommands(object): assert r.dbsize() == 2 def test_echo(self, r): - assert r.echo('foo bar') == b('foo bar') + assert r.echo('foo bar') == b'foo bar' def test_info(self, r): r['a'] = 'foo' @@ -112,7 +153,7 @@ class TestRedisCommands(object): r['a'] = 'foo' assert isinstance(r.object('refcount', 'a'), int) assert isinstance(r.object('idletime', 'a'), int) - assert r.object('encoding', 'a') == b('raw') + assert r.object('encoding', 'a') in (b'raw', b'embstr') assert r.object('idletime', 'invalid-key') is None def test_ping(self, r): @@ -120,20 +161,20 @@ class TestRedisCommands(object): def test_slowlog_get(self, r, slowlog): assert r.slowlog_reset() - unicode_string = unichr(3456) + u('abcd') + unichr(3421) + unicode_string = unichr(3456) + 'abcd' + unichr(3421) r.get(unicode_string) slowlog = r.slowlog_get() assert isinstance(slowlog, list) commands = [log['command'] for log in slowlog] - get_command = b(' ').join((b('GET'), unicode_string.encode('utf-8'))) + get_command = b' '.join((b'GET', unicode_string.encode('utf-8'))) assert get_command in commands - assert b('SLOWLOG RESET') in commands + assert b'SLOWLOG RESET' in commands # the order should be ['GET <uni string>', 'SLOWLOG RESET'], # but if other clients are executing commands at the same time, there # could be commands, before, between, or after, so just check that # the two we care about are in the appropriate order. - assert commands.index(get_command) < commands.index(b('SLOWLOG RESET')) + assert commands.index(get_command) < commands.index(b'SLOWLOG RESET') # make sure other attributes are typed correctly assert isinstance(slowlog[0]['start_time'], int) @@ -146,8 +187,8 @@ class TestRedisCommands(object): slowlog = r.slowlog_get(1) assert isinstance(slowlog, list) commands = [log['command'] for log in slowlog] - assert b('GET foo') not in commands - assert b('GET bar') in commands + assert b'GET foo' not in commands + assert b'GET bar' in commands def test_slowlog_length(self, r, slowlog): r.get('foo') @@ -163,9 +204,9 @@ class TestRedisCommands(object): # BASIC KEY COMMANDS def test_append(self, r): assert r.append('a', 'a1') == 2 - assert r['a'] == b('a1') + assert r['a'] == b'a1' assert r.append('a', 'a2') == 4 - assert r['a'] == b('a1a2') + assert r['a'] == b'a1a2' @skip_if_server_version_lt('2.6.0') def test_bitcount(self, r): @@ -194,7 +235,7 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.6.0') def test_bitop_not(self, r): - test_str = b('\xAA\x00\xFF\x55') + test_str = b'\xAA\x00\xFF\x55' correct = ~0xAA00FF55 & 0xFFFFFFFF r['a'] = test_str r.bitop('not', 'r', 'a') @@ -202,7 +243,7 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.6.0') def test_bitop_not_in_place(self, r): - test_str = b('\xAA\x00\xFF\x55') + test_str = b'\xAA\x00\xFF\x55' correct = ~0xAA00FF55 & 0xFFFFFFFF r['a'] = test_str r.bitop('not', 'a', 'a') @@ -210,7 +251,7 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.6.0') def test_bitop_single_string(self, r): - test_str = b('\x01\x02\xFF') + test_str = b'\x01\x02\xFF' r['a'] = test_str r.bitop('and', 'res1', 'a') r.bitop('or', 'res2', 'a') @@ -221,8 +262,8 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.6.0') def test_bitop_string_operands(self, r): - r['a'] = b('\x01\x02\xFF\xFF') - r['b'] = b('\x01\x02\xFF') + r['a'] = b'\x01\x02\xFF\xFF' + r['b'] = b'\x01\x02\xFF' r.bitop('and', 'res1', 'a', 'b') r.bitop('or', 'res2', 'a', 'b') r.bitop('xor', 'res3', 'a', 'b') @@ -233,20 +274,20 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.8.7') def test_bitpos(self, r): key = 'key:bitpos' - r.set(key, b('\xff\xf0\x00')) + r.set(key, b'\xff\xf0\x00') assert r.bitpos(key, 0) == 12 assert r.bitpos(key, 0, 2, -1) == 16 assert r.bitpos(key, 0, -2, -1) == 12 - r.set(key, b('\x00\xff\xf0')) + r.set(key, b'\x00\xff\xf0') assert r.bitpos(key, 1, 0) == 8 assert r.bitpos(key, 1, 1) == 8 - r.set(key, b('\x00\x00\x00')) + r.set(key, b'\x00\x00\x00') assert r.bitpos(key, 1) == -1 @skip_if_server_version_lt('2.8.7') def test_bitpos_wrong_arguments(self, r): key = 'key:bitpos:wrong:args' - r.set(key, b('\xff\xf0\x00')) + r.set(key, b'\xff\xf0\x00') with pytest.raises(exceptions.RedisError): r.bitpos(key, 0, end=1) == 12 with pytest.raises(exceptions.RedisError): @@ -254,11 +295,11 @@ class TestRedisCommands(object): def test_decr(self, r): assert r.decr('a') == -1 - assert r['a'] == b('-1') + assert r['a'] == b'-1' assert r.decr('a') == -2 - assert r['a'] == b('-2') + assert r['a'] == b'-2' assert r.decr('a', amount=5) == -7 - assert r['a'] == b('-7') + assert r['a'] == b'-7' def test_delete(self, r): assert r.delete('a') == 0 @@ -277,18 +318,45 @@ class TestRedisCommands(object): del r['a'] assert r.get('a') is None + @skip_if_server_version_lt('4.0.0') + def test_unlink(self, r): + assert r.unlink('a') == 0 + r['a'] = 'foo' + assert r.unlink('a') == 1 + assert r.get('a') is None + + @skip_if_server_version_lt('4.0.0') + def test_unlink_with_multiple_keys(self, r): + r['a'] = 'foo' + r['b'] = 'bar' + assert r.unlink('a', 'b') == 2 + assert r.get('a') is None + assert r.get('b') is None + @skip_if_server_version_lt('2.6.0') def test_dump_and_restore(self, r): r['a'] = 'foo' dumped = r.dump('a') del r['a'] r.restore('a', 0, dumped) - assert r['a'] == b('foo') + assert r['a'] == b'foo' + + @skip_if_server_version_lt('3.0.0') + def test_dump_and_restore_and_replace(self, r): + r['a'] = 'bar' + dumped = r.dump('a') + with pytest.raises(redis.ResponseError): + r.restore('a', 0, dumped) + + r.restore('a', 0, dumped, replace=True) + assert r['a'] == b'bar' def test_exists(self, r): - assert not r.exists('a') + assert r.exists('a') == 0 r['a'] = 'foo' - assert r.exists('a') + r['b'] = 'bar' + assert r.exists('a') == 1 + assert r.exists('a', 'b') == 2 def test_exists_contains(self, r): assert 'a' not in r @@ -301,7 +369,7 @@ class TestRedisCommands(object): assert r.expire('a', 10) assert 0 < r.ttl('a') <= 10 assert r.persist('a') - assert not r.ttl('a') + assert r.ttl('a') == -1 def test_expireat_datetime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) @@ -323,24 +391,28 @@ class TestRedisCommands(object): def test_get_and_set(self, r): # get and set can't be tested independently of each other assert r.get('a') is None - byte_string = b('value') + byte_string = b'value' integer = 5 - unicode_string = unichr(3456) + u('abcd') + unichr(3421) + unicode_string = unichr(3456) + 'abcd' + unichr(3421) assert r.set('byte_string', byte_string) assert r.set('integer', 5) assert r.set('unicode_string', unicode_string) assert r.get('byte_string') == byte_string - assert r.get('integer') == b(str(integer)) + assert r.get('integer') == str(integer).encode() assert r.get('unicode_string').decode('utf-8') == unicode_string def test_getitem_and_setitem(self, r): r['a'] = 'bar' - assert r['a'] == b('bar') + assert r['a'] == b'bar' def test_getitem_raises_keyerror_for_missing_key(self, r): with pytest.raises(KeyError): r['a'] + def test_getitem_does_not_raise_keyerror_for_empty_string(self, r): + r['a'] = b"" + assert r['a'] == b"" + def test_get_set_bit(self, r): # no value assert not r.getbit('a', 5) @@ -359,81 +431,67 @@ class TestRedisCommands(object): def test_getrange(self, r): r['a'] = 'foo' - assert r.getrange('a', 0, 0) == b('f') - assert r.getrange('a', 0, 2) == b('foo') - assert r.getrange('a', 3, 4) == b('') + assert r.getrange('a', 0, 0) == b'f' + assert r.getrange('a', 0, 2) == b'foo' + assert r.getrange('a', 3, 4) == b'' def test_getset(self, r): assert r.getset('a', 'foo') is None - assert r.getset('a', 'bar') == b('foo') - assert r.get('a') == b('bar') + assert r.getset('a', 'bar') == b'foo' + assert r.get('a') == b'bar' def test_incr(self, r): assert r.incr('a') == 1 - assert r['a'] == b('1') + assert r['a'] == b'1' assert r.incr('a') == 2 - assert r['a'] == b('2') + assert r['a'] == b'2' assert r.incr('a', amount=5) == 7 - assert r['a'] == b('7') + assert r['a'] == b'7' def test_incrby(self, r): assert r.incrby('a') == 1 assert r.incrby('a', 4) == 5 - assert r['a'] == b('5') + assert r['a'] == b'5' @skip_if_server_version_lt('2.6.0') def test_incrbyfloat(self, r): assert r.incrbyfloat('a') == 1.0 - assert r['a'] == b('1') + assert r['a'] == b'1' assert r.incrbyfloat('a', 1.1) == 2.1 assert float(r['a']) == float(2.1) def test_keys(self, r): assert r.keys() == [] - keys_with_underscores = set([b('test_a'), b('test_b')]) - keys = keys_with_underscores.union(set([b('testc')])) + keys_with_underscores = {b'test_a', b'test_b'} + keys = keys_with_underscores.union({b'testc'}) for key in keys: r[key] = 1 assert set(r.keys(pattern='test_*')) == keys_with_underscores assert set(r.keys(pattern='test*')) == keys def test_mget(self, r): + assert r.mget([]) == [] assert r.mget(['a', 'b']) == [None, None] r['a'] = '1' r['b'] = '2' r['c'] = '3' - assert r.mget('a', 'other', 'b', 'c') == [b('1'), None, b('2'), b('3')] + assert r.mget('a', 'other', 'b', 'c') == [b'1', None, b'2', b'3'] def test_mset(self, r): - d = {'a': b('1'), 'b': b('2'), 'c': b('3')} + d = {'a': b'1', 'b': b'2', 'c': b'3'} assert r.mset(d) for k, v in iteritems(d): assert r[k] == v - def test_mset_kwargs(self, r): - d = {'a': b('1'), 'b': b('2'), 'c': b('3')} - assert r.mset(**d) - for k, v in iteritems(d): - assert r[k] == v - def test_msetnx(self, r): - d = {'a': b('1'), 'b': b('2'), 'c': b('3')} + d = {'a': b'1', 'b': b'2', 'c': b'3'} assert r.msetnx(d) - d2 = {'a': b('x'), 'd': b('4')} + d2 = {'a': b'x', 'd': b'4'} assert not r.msetnx(d2) for k, v in iteritems(d): assert r[k] == v assert r.get('d') is None - def test_msetnx_kwargs(self, r): - d = {'a': b('1'), 'b': b('2'), 'c': b('3')} - assert r.msetnx(**d) - d2 = {'a': b('x'), 'd': b('4')} - assert not r.msetnx(**d2) - for k, v in iteritems(d): - assert r[k] == v - assert r.get('d') is None - @skip_if_server_version_lt('2.6.0') def test_pexpire(self, r): assert not r.pexpire('a', 60000) @@ -441,7 +499,7 @@ class TestRedisCommands(object): assert r.pexpire('a', 60000) assert 0 < r.pttl('a') <= 60000 assert r.persist('a') - assert r.pttl('a') is None + assert r.pttl('a') == -1 @skip_if_server_version_lt('2.6.0') def test_pexpireat_datetime(self, r): @@ -466,40 +524,54 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.6.0') def test_psetex(self, r): assert r.psetex('a', 1000, 'value') - assert r['a'] == b('value') + assert r['a'] == b'value' assert 0 < r.pttl('a') <= 1000 @skip_if_server_version_lt('2.6.0') def test_psetex_timedelta(self, r): expire_at = datetime.timedelta(milliseconds=1000) assert r.psetex('a', expire_at, 'value') - assert r['a'] == b('value') + assert r['a'] == b'value' assert 0 < r.pttl('a') <= 1000 + @skip_if_server_version_lt('2.6.0') + def test_pttl(self, r): + assert not r.pexpire('a', 10000) + r['a'] = '1' + assert r.pexpire('a', 10000) + assert 0 < r.pttl('a') <= 10000 + assert r.persist('a') + assert r.pttl('a') == -1 + + @skip_if_server_version_lt('2.8.0') + def test_pttl_no_key(self, r): + "PTTL on servers 2.8 and after return -2 when the key doesn't exist" + assert r.pttl('a') == -2 + def test_randomkey(self, r): assert r.randomkey() is None for key in ('a', 'b', 'c'): r[key] = 1 - assert r.randomkey() in (b('a'), b('b'), b('c')) + assert r.randomkey() in (b'a', b'b', b'c') def test_rename(self, r): r['a'] = '1' assert r.rename('a', 'b') assert r.get('a') is None - assert r['b'] == b('1') + assert r['b'] == b'1' def test_renamenx(self, r): r['a'] = '1' r['b'] = '2' assert not r.renamenx('a', 'b') - assert r['a'] == b('1') - assert r['b'] == b('2') + assert r['a'] == b'1' + assert r['b'] == b'2' @skip_if_server_version_lt('2.6.0') def test_set_nx(self, r): assert r.set('a', '1', nx=True) assert not r.set('a', '2', nx=True) - assert r['a'] == b('1') + assert r['a'] == b'1' @skip_if_server_version_lt('2.6.0') def test_set_xx(self, r): @@ -507,12 +579,12 @@ class TestRedisCommands(object): assert r.get('a') is None r['a'] = 'bar' assert r.set('a', '2', xx=True) - assert r.get('a') == b('2') + assert r.get('a') == b'2' @skip_if_server_version_lt('2.6.0') def test_set_px(self, r): assert r.set('a', '1', px=10000) - assert r['a'] == b('1') + assert r['a'] == b'1' assert 0 < r.pttl('a') <= 10000 assert 0 < r.ttl('a') <= 10 @@ -541,22 +613,22 @@ class TestRedisCommands(object): assert 0 < r.ttl('a') <= 10 def test_setex(self, r): - assert r.setex('a', '1', 60) - assert r['a'] == b('1') + assert r.setex('a', 60, '1') + assert r['a'] == b'1' assert 0 < r.ttl('a') <= 60 def test_setnx(self, r): assert r.setnx('a', '1') - assert r['a'] == b('1') + assert r['a'] == b'1' assert not r.setnx('a', '2') - assert r['a'] == b('1') + assert r['a'] == b'1' def test_setrange(self, r): assert r.setrange('a', 5, 'foo') == 8 - assert r['a'] == b('\0\0\0\0\0foo') + assert r['a'] == b'\0\0\0\0\0foo' r['a'] = 'abcdefghijh' assert r.setrange('a', 6, '12345') == 11 - assert r['a'] == b('abcdef12345') + assert r['a'] == b'abcdef12345' def test_strlen(self, r): r['a'] = 'foo' @@ -564,74 +636,86 @@ class TestRedisCommands(object): def test_substr(self, r): r['a'] = '0123456789' - assert r.substr('a', 0) == b('0123456789') - assert r.substr('a', 2) == b('23456789') - assert r.substr('a', 3, 5) == b('345') - assert r.substr('a', 3, -2) == b('345678') + assert r.substr('a', 0) == b'0123456789' + assert r.substr('a', 2) == b'23456789' + assert r.substr('a', 3, 5) == b'345' + assert r.substr('a', 3, -2) == b'345678' + + def test_ttl(self, r): + r['a'] = '1' + assert r.expire('a', 10) + assert 0 < r.ttl('a') <= 10 + assert r.persist('a') + assert r.ttl('a') == -1 + + @skip_if_server_version_lt('2.8.0') + def test_ttl_nokey(self, r): + "TTL on servers 2.8 and after return -2 when the key doesn't exist" + assert r.ttl('a') == -2 def test_type(self, r): - assert r.type('a') == b('none') + assert r.type('a') == b'none' r['a'] = '1' - assert r.type('a') == b('string') + assert r.type('a') == b'string' del r['a'] r.lpush('a', '1') - assert r.type('a') == b('list') + assert r.type('a') == b'list' del r['a'] r.sadd('a', '1') - assert r.type('a') == b('set') + assert r.type('a') == b'set' del r['a'] - r.zadd('a', **{'1': 1}) - assert r.type('a') == b('zset') + r.zadd('a', {'1': 1}) + assert r.type('a') == b'zset' # LIST COMMANDS def test_blpop(self, r): r.rpush('a', '1', '2') r.rpush('b', '3', '4') - assert r.blpop(['b', 'a'], timeout=1) == (b('b'), b('3')) - assert r.blpop(['b', 'a'], timeout=1) == (b('b'), b('4')) - assert r.blpop(['b', 'a'], timeout=1) == (b('a'), b('1')) - assert r.blpop(['b', 'a'], timeout=1) == (b('a'), b('2')) + assert r.blpop(['b', 'a'], timeout=1) == (b'b', b'3') + assert r.blpop(['b', 'a'], timeout=1) == (b'b', b'4') + assert r.blpop(['b', 'a'], timeout=1) == (b'a', b'1') + assert r.blpop(['b', 'a'], timeout=1) == (b'a', b'2') assert r.blpop(['b', 'a'], timeout=1) is None r.rpush('c', '1') - assert r.blpop('c', timeout=1) == (b('c'), b('1')) + assert r.blpop('c', timeout=1) == (b'c', b'1') def test_brpop(self, r): r.rpush('a', '1', '2') r.rpush('b', '3', '4') - assert r.brpop(['b', 'a'], timeout=1) == (b('b'), b('4')) - assert r.brpop(['b', 'a'], timeout=1) == (b('b'), b('3')) - assert r.brpop(['b', 'a'], timeout=1) == (b('a'), b('2')) - assert r.brpop(['b', 'a'], timeout=1) == (b('a'), b('1')) + assert r.brpop(['b', 'a'], timeout=1) == (b'b', b'4') + assert r.brpop(['b', 'a'], timeout=1) == (b'b', b'3') + assert r.brpop(['b', 'a'], timeout=1) == (b'a', b'2') + assert r.brpop(['b', 'a'], timeout=1) == (b'a', b'1') assert r.brpop(['b', 'a'], timeout=1) is None r.rpush('c', '1') - assert r.brpop('c', timeout=1) == (b('c'), b('1')) + assert r.brpop('c', timeout=1) == (b'c', b'1') def test_brpoplpush(self, r): r.rpush('a', '1', '2') r.rpush('b', '3', '4') - assert r.brpoplpush('a', 'b') == b('2') - assert r.brpoplpush('a', 'b') == b('1') + assert r.brpoplpush('a', 'b') == b'2' + assert r.brpoplpush('a', 'b') == b'1' assert r.brpoplpush('a', 'b', timeout=1) is None assert r.lrange('a', 0, -1) == [] - assert r.lrange('b', 0, -1) == [b('1'), b('2'), b('3'), b('4')] + assert r.lrange('b', 0, -1) == [b'1', b'2', b'3', b'4'] def test_brpoplpush_empty_string(self, r): r.rpush('a', '') - assert r.brpoplpush('a', 'b') == b('') + assert r.brpoplpush('a', 'b') == b'' def test_lindex(self, r): r.rpush('a', '1', '2', '3') - assert r.lindex('a', '0') == b('1') - assert r.lindex('a', '1') == b('2') - assert r.lindex('a', '2') == b('3') + assert r.lindex('a', '0') == b'1' + assert r.lindex('a', '1') == b'2' + assert r.lindex('a', '2') == b'3' def test_linsert(self, r): r.rpush('a', '1', '2', '3') assert r.linsert('a', 'after', '2', '2.5') == 4 - assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('2.5'), b('3')] + assert r.lrange('a', 0, -1) == [b'1', b'2', b'2.5', b'3'] assert r.linsert('a', 'before', '2', '1.5') == 5 assert r.lrange('a', 0, -1) == \ - [b('1'), b('1.5'), b('2'), b('2.5'), b('3')] + [b'1', b'1.5', b'2', b'2.5', b'3'] def test_llen(self, r): r.rpush('a', '1', '2', '3') @@ -639,74 +723,79 @@ class TestRedisCommands(object): def test_lpop(self, r): r.rpush('a', '1', '2', '3') - assert r.lpop('a') == b('1') - assert r.lpop('a') == b('2') - assert r.lpop('a') == b('3') + assert r.lpop('a') == b'1' + assert r.lpop('a') == b'2' + assert r.lpop('a') == b'3' assert r.lpop('a') is None def test_lpush(self, r): assert r.lpush('a', '1') == 1 assert r.lpush('a', '2') == 2 assert r.lpush('a', '3', '4') == 4 - assert r.lrange('a', 0, -1) == [b('4'), b('3'), b('2'), b('1')] + assert r.lrange('a', 0, -1) == [b'4', b'3', b'2', b'1'] def test_lpushx(self, r): assert r.lpushx('a', '1') == 0 assert r.lrange('a', 0, -1) == [] r.rpush('a', '1', '2', '3') assert r.lpushx('a', '4') == 4 - assert r.lrange('a', 0, -1) == [b('4'), b('1'), b('2'), b('3')] + assert r.lrange('a', 0, -1) == [b'4', b'1', b'2', b'3'] def test_lrange(self, r): r.rpush('a', '1', '2', '3', '4', '5') - assert r.lrange('a', 0, 2) == [b('1'), b('2'), b('3')] - assert r.lrange('a', 2, 10) == [b('3'), b('4'), b('5')] - assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3'), b('4'), b('5')] + assert r.lrange('a', 0, 2) == [b'1', b'2', b'3'] + assert r.lrange('a', 2, 10) == [b'3', b'4', b'5'] + assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4', b'5'] def test_lrem(self, r): - r.rpush('a', '1', '1', '1', '1') - assert r.lrem('a', '1', 1) == 1 - assert r.lrange('a', 0, -1) == [b('1'), b('1'), b('1')] - assert r.lrem('a', '1') == 3 - assert r.lrange('a', 0, -1) == [] + r.rpush('a', 'Z', 'b', 'Z', 'Z', 'c', 'Z', 'Z') + # remove the first 'Z' item + assert r.lrem('a', 1, 'Z') == 1 + assert r.lrange('a', 0, -1) == [b'b', b'Z', b'Z', b'c', b'Z', b'Z'] + # remove the last 2 'Z' items + assert r.lrem('a', -2, 'Z') == 2 + assert r.lrange('a', 0, -1) == [b'b', b'Z', b'Z', b'c'] + # remove all 'Z' items + assert r.lrem('a', 0, 'Z') == 2 + assert r.lrange('a', 0, -1) == [b'b', b'c'] def test_lset(self, r): r.rpush('a', '1', '2', '3') - assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3')] + assert r.lrange('a', 0, -1) == [b'1', b'2', b'3'] assert r.lset('a', 1, '4') - assert r.lrange('a', 0, 2) == [b('1'), b('4'), b('3')] + assert r.lrange('a', 0, 2) == [b'1', b'4', b'3'] def test_ltrim(self, r): r.rpush('a', '1', '2', '3') assert r.ltrim('a', 0, 1) - assert r.lrange('a', 0, -1) == [b('1'), b('2')] + assert r.lrange('a', 0, -1) == [b'1', b'2'] def test_rpop(self, r): r.rpush('a', '1', '2', '3') - assert r.rpop('a') == b('3') - assert r.rpop('a') == b('2') - assert r.rpop('a') == b('1') + assert r.rpop('a') == b'3' + assert r.rpop('a') == b'2' + assert r.rpop('a') == b'1' assert r.rpop('a') is None def test_rpoplpush(self, r): r.rpush('a', 'a1', 'a2', 'a3') r.rpush('b', 'b1', 'b2', 'b3') - assert r.rpoplpush('a', 'b') == b('a3') - assert r.lrange('a', 0, -1) == [b('a1'), b('a2')] - assert r.lrange('b', 0, -1) == [b('a3'), b('b1'), b('b2'), b('b3')] + assert r.rpoplpush('a', 'b') == b'a3' + assert r.lrange('a', 0, -1) == [b'a1', b'a2'] + assert r.lrange('b', 0, -1) == [b'a3', b'b1', b'b2', b'b3'] def test_rpush(self, r): assert r.rpush('a', '1') == 1 assert r.rpush('a', '2') == 2 assert r.rpush('a', '3', '4') == 4 - assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3'), b('4')] + assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4'] def test_rpushx(self, r): assert r.rpushx('a', 'b') == 0 assert r.lrange('a', 0, -1) == [] r.rpush('a', '1', '2', '3') assert r.rpushx('a', '4') == 4 - assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3'), b('4')] + assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4'] # SCAN COMMANDS @skip_if_server_version_lt('2.8.0') @@ -716,9 +805,9 @@ class TestRedisCommands(object): r.set('c', 3) cursor, keys = r.scan() assert cursor == 0 - assert set(keys) == set([b('a'), b('b'), b('c')]) + assert set(keys) == {b'a', b'b', b'c'} _, keys = r.scan(match='a') - assert set(keys) == set([b('a')]) + assert set(keys) == {b'a'} @skip_if_server_version_lt('2.8.0') def test_scan_iter(self, r): @@ -726,64 +815,64 @@ class TestRedisCommands(object): r.set('b', 2) r.set('c', 3) keys = list(r.scan_iter()) - assert set(keys) == set([b('a'), b('b'), b('c')]) + assert set(keys) == {b'a', b'b', b'c'} keys = list(r.scan_iter(match='a')) - assert set(keys) == set([b('a')]) + assert set(keys) == {b'a'} @skip_if_server_version_lt('2.8.0') def test_sscan(self, r): r.sadd('a', 1, 2, 3) cursor, members = r.sscan('a') assert cursor == 0 - assert set(members) == set([b('1'), b('2'), b('3')]) - _, members = r.sscan('a', match=b('1')) - assert set(members) == set([b('1')]) + assert set(members) == {b'1', b'2', b'3'} + _, members = r.sscan('a', match=b'1') + assert set(members) == {b'1'} @skip_if_server_version_lt('2.8.0') def test_sscan_iter(self, r): r.sadd('a', 1, 2, 3) members = list(r.sscan_iter('a')) - assert set(members) == set([b('1'), b('2'), b('3')]) - members = list(r.sscan_iter('a', match=b('1'))) - assert set(members) == set([b('1')]) + assert set(members) == {b'1', b'2', b'3'} + members = list(r.sscan_iter('a', match=b'1')) + assert set(members) == {b'1'} @skip_if_server_version_lt('2.8.0') def test_hscan(self, r): r.hmset('a', {'a': 1, 'b': 2, 'c': 3}) cursor, dic = r.hscan('a') assert cursor == 0 - assert dic == {b('a'): b('1'), b('b'): b('2'), b('c'): b('3')} + assert dic == {b'a': b'1', b'b': b'2', b'c': b'3'} _, dic = r.hscan('a', match='a') - assert dic == {b('a'): b('1')} + assert dic == {b'a': b'1'} @skip_if_server_version_lt('2.8.0') def test_hscan_iter(self, r): r.hmset('a', {'a': 1, 'b': 2, 'c': 3}) dic = dict(r.hscan_iter('a')) - assert dic == {b('a'): b('1'), b('b'): b('2'), b('c'): b('3')} + assert dic == {b'a': b'1', b'b': b'2', b'c': b'3'} dic = dict(r.hscan_iter('a', match='a')) - assert dic == {b('a'): b('1')} + assert dic == {b'a': b'1'} @skip_if_server_version_lt('2.8.0') def test_zscan(self, r): - r.zadd('a', 'a', 1, 'b', 2, 'c', 3) + r.zadd('a', {'a': 1, 'b': 2, 'c': 3}) cursor, pairs = r.zscan('a') assert cursor == 0 - assert set(pairs) == set([(b('a'), 1), (b('b'), 2), (b('c'), 3)]) + assert set(pairs) == {(b'a', 1), (b'b', 2), (b'c', 3)} _, pairs = r.zscan('a', match='a') - assert set(pairs) == set([(b('a'), 1)]) + assert set(pairs) == {(b'a', 1)} @skip_if_server_version_lt('2.8.0') def test_zscan_iter(self, r): - r.zadd('a', 'a', 1, 'b', 2, 'c', 3) + r.zadd('a', {'a': 1, 'b': 2, 'c': 3}) pairs = list(r.zscan_iter('a')) - assert set(pairs) == set([(b('a'), 1), (b('b'), 2), (b('c'), 3)]) + assert set(pairs) == {(b'a', 1), (b'b', 2), (b'c', 3)} pairs = list(r.zscan_iter('a', match='a')) - assert set(pairs) == set([(b('a'), 1)]) + assert set(pairs) == {(b'a', 1)} # SET COMMANDS def test_sadd(self, r): - members = set([b('1'), b('2'), b('3')]) + members = {b'1', b'2', b'3'} r.sadd('a', *members) assert r.smembers('a') == members @@ -793,23 +882,23 @@ class TestRedisCommands(object): def test_sdiff(self, r): r.sadd('a', '1', '2', '3') - assert r.sdiff('a', 'b') == set([b('1'), b('2'), b('3')]) + assert r.sdiff('a', 'b') == {b'1', b'2', b'3'} r.sadd('b', '2', '3') - assert r.sdiff('a', 'b') == set([b('1')]) + assert r.sdiff('a', 'b') == {b'1'} def test_sdiffstore(self, r): r.sadd('a', '1', '2', '3') assert r.sdiffstore('c', 'a', 'b') == 3 - assert r.smembers('c') == set([b('1'), b('2'), b('3')]) + assert r.smembers('c') == {b'1', b'2', b'3'} r.sadd('b', '2', '3') assert r.sdiffstore('c', 'a', 'b') == 1 - assert r.smembers('c') == set([b('1')]) + assert r.smembers('c') == {b'1'} def test_sinter(self, r): r.sadd('a', '1', '2', '3') assert r.sinter('a', 'b') == set() r.sadd('b', '2', '3') - assert r.sinter('a', 'b') == set([b('2'), b('3')]) + assert r.sinter('a', 'b') == {b'2', b'3'} def test_sinterstore(self, r): r.sadd('a', '1', '2', '3') @@ -817,7 +906,7 @@ class TestRedisCommands(object): assert r.smembers('c') == set() r.sadd('b', '2', '3') assert r.sinterstore('c', 'a', 'b') == 2 - assert r.smembers('c') == set([b('2'), b('3')]) + assert r.smembers('c') == {b'2', b'3'} def test_sismember(self, r): r.sadd('a', '1', '2', '3') @@ -828,30 +917,41 @@ class TestRedisCommands(object): def test_smembers(self, r): r.sadd('a', '1', '2', '3') - assert r.smembers('a') == set([b('1'), b('2'), b('3')]) + assert r.smembers('a') == {b'1', b'2', b'3'} def test_smove(self, r): r.sadd('a', 'a1', 'a2') r.sadd('b', 'b1', 'b2') assert r.smove('a', 'b', 'a1') - assert r.smembers('a') == set([b('a2')]) - assert r.smembers('b') == set([b('b1'), b('b2'), b('a1')]) + assert r.smembers('a') == {b'a2'} + assert r.smembers('b') == {b'b1', b'b2', b'a1'} def test_spop(self, r): - s = [b('1'), b('2'), b('3')] + s = [b'1', b'2', b'3'] r.sadd('a', *s) value = r.spop('a') assert value in s - assert r.smembers('a') == set(s) - set([value]) + assert r.smembers('a') == set(s) - {value} + + def test_spop_multi_value(self, r): + s = [b'1', b'2', b'3'] + r.sadd('a', *s) + values = r.spop('a', 2) + assert len(values) == 2 + + for value in values: + assert value in s + + assert r.spop('a', 1) == list(set(s) - set(values)) def test_srandmember(self, r): - s = [b('1'), b('2'), b('3')] + s = [b'1', b'2', b'3'] r.sadd('a', *s) assert r.srandmember('a') in s @skip_if_server_version_lt('2.6.0') def test_srandmember_multi_value(self, r): - s = [b('1'), b('2'), b('3')] + s = [b'1', b'2', b'3'] r.sadd('a', *s) randoms = r.srandmember('a', number=2) assert len(randoms) == 2 @@ -861,257 +961,352 @@ class TestRedisCommands(object): r.sadd('a', '1', '2', '3', '4') assert r.srem('a', '5') == 0 assert r.srem('a', '2', '4') == 2 - assert r.smembers('a') == set([b('1'), b('3')]) + assert r.smembers('a') == {b'1', b'3'} def test_sunion(self, r): r.sadd('a', '1', '2') r.sadd('b', '2', '3') - assert r.sunion('a', 'b') == set([b('1'), b('2'), b('3')]) + assert r.sunion('a', 'b') == {b'1', b'2', b'3'} def test_sunionstore(self, r): r.sadd('a', '1', '2') r.sadd('b', '2', '3') assert r.sunionstore('c', 'a', 'b') == 3 - assert r.smembers('c') == set([b('1'), b('2'), b('3')]) + assert r.smembers('c') == {b'1', b'2', b'3'} # SORTED SET COMMANDS def test_zadd(self, r): - r.zadd('a', a1=1, a2=2, a3=3) - assert r.zrange('a', 0, -1) == [b('a1'), b('a2'), b('a3')] + mapping = {'a1': 1.0, 'a2': 2.0, 'a3': 3.0} + r.zadd('a', mapping) + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a1', 1.0), (b'a2', 2.0), (b'a3', 3.0)] + + # error cases + with pytest.raises(exceptions.DataError): + r.zadd('a', {}) + + # cannot use both nx and xx options + with pytest.raises(exceptions.DataError): + r.zadd('a', mapping, nx=True, xx=True) + + # cannot use the incr options with more than one value + with pytest.raises(exceptions.DataError): + r.zadd('a', mapping, incr=True) + + def test_zadd_nx(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 99, 'a2': 2}, nx=True) == 1 + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a1', 1.0), (b'a2', 2.0)] + + def test_zadd_xx(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 99, 'a2': 2}, xx=True) == 0 + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a1', 99.0)] + + def test_zadd_ch(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 99, 'a2': 2}, ch=True) == 2 + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a2', 2.0), (b'a1', 99.0)] + + def test_zadd_incr(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 4.5}, incr=True) == 5.5 def test_zcard(self, r): - r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zcard('a') == 3 def test_zcount(self, r): - r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zcount('a', '-inf', '+inf') == 3 assert r.zcount('a', 1, 2) == 2 + assert r.zcount('a', '(' + str(1), 2) == 1 + assert r.zcount('a', 1, '(' + str(2)) == 1 assert r.zcount('a', 10, 20) == 0 def test_zincrby(self, r): - r.zadd('a', a1=1, a2=2, a3=3) - assert r.zincrby('a', 'a2') == 3.0 - assert r.zincrby('a', 'a3', amount=5) == 8.0 + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + assert r.zincrby('a', 1, 'a2') == 3.0 + assert r.zincrby('a', 5, 'a3') == 8.0 assert r.zscore('a', 'a2') == 3.0 assert r.zscore('a', 'a3') == 8.0 @skip_if_server_version_lt('2.8.9') def test_zlexcount(self, r): - r.zadd('a', a=0, b=0, c=0, d=0, e=0, f=0, g=0) + r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) assert r.zlexcount('a', '-', '+') == 7 assert r.zlexcount('a', '[b', '[f') == 5 def test_zinterstore_sum(self, r): - r.zadd('a', a1=1, a2=1, a3=1) - r.zadd('b', a1=2, a2=2, a3=2) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zinterstore('d', ['a', 'b', 'c']) == 2 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a3'), 8), (b('a1'), 9)] + [(b'a3', 8), (b'a1', 9)] def test_zinterstore_max(self, r): - r.zadd('a', a1=1, a2=1, a3=1) - r.zadd('b', a1=2, a2=2, a3=2) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zinterstore('d', ['a', 'b', 'c'], aggregate='MAX') == 2 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a3'), 5), (b('a1'), 6)] + [(b'a3', 5), (b'a1', 6)] def test_zinterstore_min(self, r): - r.zadd('a', a1=1, a2=2, a3=3) - r.zadd('b', a1=2, a2=3, a3=5) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + r.zadd('b', {'a1': 2, 'a2': 3, 'a3': 5}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zinterstore('d', ['a', 'b', 'c'], aggregate='MIN') == 2 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a1'), 1), (b('a3'), 3)] + [(b'a1', 1), (b'a3', 3)] def test_zinterstore_with_weight(self, r): - r.zadd('a', a1=1, a2=1, a3=1) - r.zadd('b', a1=2, a2=2, a3=2) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zinterstore('d', {'a': 1, 'b': 2, 'c': 3}) == 2 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a3'), 20), (b('a1'), 23)] + [(b'a3', 20), (b'a1', 23)] + + @skip_if_server_version_lt('4.9.0') + def test_zpopmax(self, r): + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + assert r.zpopmax('a') == [(b'a3', 3)] + + # with count + assert r.zpopmax('a', count=2) == \ + [(b'a2', 2), (b'a1', 1)] + + @skip_if_server_version_lt('4.9.0') + def test_zpopmin(self, r): + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + assert r.zpopmin('a') == [(b'a1', 1)] + + # with count + assert r.zpopmin('a', count=2) == \ + [(b'a2', 2), (b'a3', 3)] + + @skip_if_server_version_lt('4.9.0') + def test_bzpopmax(self, r): + r.zadd('a', {'a1': 1, 'a2': 2}) + r.zadd('b', {'b1': 10, 'b2': 20}) + assert r.bzpopmax(['b', 'a'], timeout=1) == (b'b', b'b2', 20) + assert r.bzpopmax(['b', 'a'], timeout=1) == (b'b', b'b1', 10) + assert r.bzpopmax(['b', 'a'], timeout=1) == (b'a', b'a2', 2) + assert r.bzpopmax(['b', 'a'], timeout=1) == (b'a', b'a1', 1) + assert r.bzpopmax(['b', 'a'], timeout=1) is None + r.zadd('c', {'c1': 100}) + assert r.bzpopmax('c', timeout=1) == (b'c', b'c1', 100) + + @skip_if_server_version_lt('4.9.0') + def test_bzpopmin(self, r): + r.zadd('a', {'a1': 1, 'a2': 2}) + r.zadd('b', {'b1': 10, 'b2': 20}) + assert r.bzpopmin(['b', 'a'], timeout=1) == (b'b', b'b1', 10) + assert r.bzpopmin(['b', 'a'], timeout=1) == (b'b', b'b2', 20) + assert r.bzpopmin(['b', 'a'], timeout=1) == (b'a', b'a1', 1) + assert r.bzpopmin(['b', 'a'], timeout=1) == (b'a', b'a2', 2) + assert r.bzpopmin(['b', 'a'], timeout=1) is None + r.zadd('c', {'c1': 100}) + assert r.bzpopmin('c', timeout=1) == (b'c', b'c1', 100) def test_zrange(self, r): - r.zadd('a', a1=1, a2=2, a3=3) - assert r.zrange('a', 0, 1) == [b('a1'), b('a2')] - assert r.zrange('a', 1, 2) == [b('a2'), b('a3')] + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + assert r.zrange('a', 0, 1) == [b'a1', b'a2'] + assert r.zrange('a', 1, 2) == [b'a2', b'a3'] # withscores assert r.zrange('a', 0, 1, withscores=True) == \ - [(b('a1'), 1.0), (b('a2'), 2.0)] + [(b'a1', 1.0), (b'a2', 2.0)] assert r.zrange('a', 1, 2, withscores=True) == \ - [(b('a2'), 2.0), (b('a3'), 3.0)] + [(b'a2', 2.0), (b'a3', 3.0)] # custom score function assert r.zrange('a', 0, 1, withscores=True, score_cast_func=int) == \ - [(b('a1'), 1), (b('a2'), 2)] + [(b'a1', 1), (b'a2', 2)] @skip_if_server_version_lt('2.8.9') def test_zrangebylex(self, r): - r.zadd('a', a=0, b=0, c=0, d=0, e=0, f=0, g=0) - assert r.zrangebylex('a', '-', '[c') == [b('a'), b('b'), b('c')] - assert r.zrangebylex('a', '-', '(c') == [b('a'), b('b')] + r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) + assert r.zrangebylex('a', '-', '[c') == [b'a', b'b', b'c'] + assert r.zrangebylex('a', '-', '(c') == [b'a', b'b'] assert r.zrangebylex('a', '[aaa', '(g') == \ - [b('b'), b('c'), b('d'), b('e'), b('f')] - assert r.zrangebylex('a', '[f', '+') == [b('f'), b('g')] - assert r.zrangebylex('a', '-', '+', start=3, num=2) == [b('d'), b('e')] + [b'b', b'c', b'd', b'e', b'f'] + assert r.zrangebylex('a', '[f', '+') == [b'f', b'g'] + assert r.zrangebylex('a', '-', '+', start=3, num=2) == [b'd', b'e'] + + @skip_if_server_version_lt('2.9.9') + def test_zrevrangebylex(self, r): + r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) + assert r.zrevrangebylex('a', '[c', '-') == [b'c', b'b', b'a'] + assert r.zrevrangebylex('a', '(c', '-') == [b'b', b'a'] + assert r.zrevrangebylex('a', '(g', '[aaa') == \ + [b'f', b'e', b'd', b'c', b'b'] + assert r.zrevrangebylex('a', '+', '[f') == [b'g', b'f'] + assert r.zrevrangebylex('a', '+', '-', start=3, num=2) == \ + [b'd', b'c'] def test_zrangebyscore(self, r): - r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) - assert r.zrangebyscore('a', 2, 4) == [b('a2'), b('a3'), b('a4')] + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) + assert r.zrangebyscore('a', 2, 4) == [b'a2', b'a3', b'a4'] # slicing with start/num assert r.zrangebyscore('a', 2, 4, start=1, num=2) == \ - [b('a3'), b('a4')] + [b'a3', b'a4'] # withscores assert r.zrangebyscore('a', 2, 4, withscores=True) == \ - [(b('a2'), 2.0), (b('a3'), 3.0), (b('a4'), 4.0)] + [(b'a2', 2.0), (b'a3', 3.0), (b'a4', 4.0)] # custom score function assert r.zrangebyscore('a', 2, 4, withscores=True, score_cast_func=int) == \ - [(b('a2'), 2), (b('a3'), 3), (b('a4'), 4)] + [(b'a2', 2), (b'a3', 3), (b'a4', 4)] def test_zrank(self, r): - r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) assert r.zrank('a', 'a1') == 0 assert r.zrank('a', 'a2') == 1 assert r.zrank('a', 'a6') is None def test_zrem(self, r): - r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zrem('a', 'a2') == 1 - assert r.zrange('a', 0, -1) == [b('a1'), b('a3')] + assert r.zrange('a', 0, -1) == [b'a1', b'a3'] assert r.zrem('a', 'b') == 0 - assert r.zrange('a', 0, -1) == [b('a1'), b('a3')] + assert r.zrange('a', 0, -1) == [b'a1', b'a3'] def test_zrem_multiple_keys(self, r): - r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zrem('a', 'a1', 'a2') == 2 - assert r.zrange('a', 0, 5) == [b('a3')] + assert r.zrange('a', 0, 5) == [b'a3'] @skip_if_server_version_lt('2.8.9') def test_zremrangebylex(self, r): - r.zadd('a', a=0, b=0, c=0, d=0, e=0, f=0, g=0) + r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) assert r.zremrangebylex('a', '-', '[c') == 3 - assert r.zrange('a', 0, -1) == [b('d'), b('e'), b('f'), b('g')] + assert r.zrange('a', 0, -1) == [b'd', b'e', b'f', b'g'] assert r.zremrangebylex('a', '[f', '+') == 2 - assert r.zrange('a', 0, -1) == [b('d'), b('e')] + assert r.zrange('a', 0, -1) == [b'd', b'e'] assert r.zremrangebylex('a', '[h', '+') == 0 - assert r.zrange('a', 0, -1) == [b('d'), b('e')] + assert r.zrange('a', 0, -1) == [b'd', b'e'] def test_zremrangebyrank(self, r): - r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) assert r.zremrangebyrank('a', 1, 3) == 3 - assert r.zrange('a', 0, 5) == [b('a1'), b('a5')] + assert r.zrange('a', 0, 5) == [b'a1', b'a5'] def test_zremrangebyscore(self, r): - r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) assert r.zremrangebyscore('a', 2, 4) == 3 - assert r.zrange('a', 0, -1) == [b('a1'), b('a5')] + assert r.zrange('a', 0, -1) == [b'a1', b'a5'] assert r.zremrangebyscore('a', 2, 4) == 0 - assert r.zrange('a', 0, -1) == [b('a1'), b('a5')] + assert r.zrange('a', 0, -1) == [b'a1', b'a5'] def test_zrevrange(self, r): - r.zadd('a', a1=1, a2=2, a3=3) - assert r.zrevrange('a', 0, 1) == [b('a3'), b('a2')] - assert r.zrevrange('a', 1, 2) == [b('a2'), b('a1')] + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + assert r.zrevrange('a', 0, 1) == [b'a3', b'a2'] + assert r.zrevrange('a', 1, 2) == [b'a2', b'a1'] # withscores assert r.zrevrange('a', 0, 1, withscores=True) == \ - [(b('a3'), 3.0), (b('a2'), 2.0)] + [(b'a3', 3.0), (b'a2', 2.0)] assert r.zrevrange('a', 1, 2, withscores=True) == \ - [(b('a2'), 2.0), (b('a1'), 1.0)] + [(b'a2', 2.0), (b'a1', 1.0)] # custom score function assert r.zrevrange('a', 0, 1, withscores=True, score_cast_func=int) == \ - [(b('a3'), 3.0), (b('a2'), 2.0)] + [(b'a3', 3.0), (b'a2', 2.0)] def test_zrevrangebyscore(self, r): - r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) - assert r.zrevrangebyscore('a', 4, 2) == [b('a4'), b('a3'), b('a2')] + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) + assert r.zrevrangebyscore('a', 4, 2) == [b'a4', b'a3', b'a2'] # slicing with start/num assert r.zrevrangebyscore('a', 4, 2, start=1, num=2) == \ - [b('a3'), b('a2')] + [b'a3', b'a2'] # withscores assert r.zrevrangebyscore('a', 4, 2, withscores=True) == \ - [(b('a4'), 4.0), (b('a3'), 3.0), (b('a2'), 2.0)] + [(b'a4', 4.0), (b'a3', 3.0), (b'a2', 2.0)] # custom score function assert r.zrevrangebyscore('a', 4, 2, withscores=True, score_cast_func=int) == \ - [(b('a4'), 4), (b('a3'), 3), (b('a2'), 2)] + [(b'a4', 4), (b'a3', 3), (b'a2', 2)] def test_zrevrank(self, r): - r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) assert r.zrevrank('a', 'a1') == 4 assert r.zrevrank('a', 'a2') == 3 assert r.zrevrank('a', 'a6') is None def test_zscore(self, r): - r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zscore('a', 'a1') == 1.0 assert r.zscore('a', 'a2') == 2.0 assert r.zscore('a', 'a4') is None def test_zunionstore_sum(self, r): - r.zadd('a', a1=1, a2=1, a3=1) - r.zadd('b', a1=2, a2=2, a3=2) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zunionstore('d', ['a', 'b', 'c']) == 4 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a2'), 3), (b('a4'), 4), (b('a3'), 8), (b('a1'), 9)] + [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] def test_zunionstore_max(self, r): - r.zadd('a', a1=1, a2=1, a3=1) - r.zadd('b', a1=2, a2=2, a3=2) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zunionstore('d', ['a', 'b', 'c'], aggregate='MAX') == 4 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a2'), 2), (b('a4'), 4), (b('a3'), 5), (b('a1'), 6)] + [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] def test_zunionstore_min(self, r): - r.zadd('a', a1=1, a2=2, a3=3) - r.zadd('b', a1=2, a2=2, a3=4) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 4}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zunionstore('d', ['a', 'b', 'c'], aggregate='MIN') == 4 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a1'), 1), (b('a2'), 2), (b('a3'), 3), (b('a4'), 4)] + [(b'a1', 1), (b'a2', 2), (b'a3', 3), (b'a4', 4)] def test_zunionstore_with_weight(self, r): - r.zadd('a', a1=1, a2=1, a3=1) - r.zadd('b', a1=2, a2=2, a3=2) - r.zadd('c', a1=6, a3=5, a4=4) + r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) assert r.zunionstore('d', {'a': 1, 'b': 2, 'c': 3}) == 4 assert r.zrange('d', 0, -1, withscores=True) == \ - [(b('a2'), 5), (b('a4'), 12), (b('a3'), 20), (b('a1'), 23)] + [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] # HYPERLOGLOG TESTS @skip_if_server_version_lt('2.8.9') def test_pfadd(self, r): - members = set([b('1'), b('2'), b('3')]) + members = {b'1', b'2', b'3'} assert r.pfadd('a', *members) == 1 assert r.pfadd('a', *members) == 0 assert r.pfcount('a') == len(members) @skip_if_server_version_lt('2.8.9') def test_pfcount(self, r): - members = set([b('1'), b('2'), b('3')]) + members = {b'1', b'2', b'3'} r.pfadd('a', *members) assert r.pfcount('a') == len(members) + members_b = {b'2', b'3', b'4'} + r.pfadd('b', *members_b) + assert r.pfcount('b') == len(members_b) + assert r.pfcount('a', 'b') == len(members_b.union(members)) @skip_if_server_version_lt('2.8.9') def test_pfmerge(self, r): - mema = set([b('1'), b('2'), b('3')]) - memb = set([b('2'), b('3'), b('4')]) - memc = set([b('5'), b('6'), b('7')]) + mema = {b'1', b'2', b'3'} + memb = {b'2', b'3', b'4'} + memc = {b'5', b'6', b'7'} r.pfadd('a', *mema) r.pfadd('b', *memb) r.pfadd('c', *memc) @@ -1123,17 +1318,17 @@ class TestRedisCommands(object): # HASH COMMANDS def test_hget_and_hset(self, r): r.hmset('a', {'1': 1, '2': 2, '3': 3}) - assert r.hget('a', '1') == b('1') - assert r.hget('a', '2') == b('2') - assert r.hget('a', '3') == b('3') + assert r.hget('a', '1') == b'1' + assert r.hget('a', '2') == b'2' + assert r.hget('a', '3') == b'3' # field was updated, redis returns 0 assert r.hset('a', '2', 5) == 0 - assert r.hget('a', '2') == b('5') + assert r.hget('a', '2') == b'5' # field is new, redis returns 1 assert r.hset('a', '4', 4) == 1 - assert r.hget('a', '4') == b('4') + assert r.hget('a', '4') == b'4' # key inside of hash that doesn't exist returns null value assert r.hget('a', 'b') is None @@ -1151,7 +1346,7 @@ class TestRedisCommands(object): assert not r.hexists('a', '4') def test_hgetall(self, r): - h = {b('a1'): b('1'), b('a2'): b('2'), b('a3'): b('3')} + h = {b'a1': b'1', b'a2': b'2', b'a3': b'3'} r.hmset('a', h) assert r.hgetall('a') == h @@ -1167,7 +1362,7 @@ class TestRedisCommands(object): assert r.hincrbyfloat('a', '1', 1.2) == 3.2 def test_hkeys(self, r): - h = {b('a1'): b('1'), b('a2'): b('2'), b('a3'): b('3')} + h = {b'a1': b'1', b'a2': b'2', b'a3': b'3'} r.hmset('a', h) local_keys = list(iterkeys(h)) remote_keys = r.hkeys('a') @@ -1179,49 +1374,55 @@ class TestRedisCommands(object): def test_hmget(self, r): assert r.hmset('a', {'a': 1, 'b': 2, 'c': 3}) - assert r.hmget('a', 'a', 'b', 'c') == [b('1'), b('2'), b('3')] + assert r.hmget('a', 'a', 'b', 'c') == [b'1', b'2', b'3'] def test_hmset(self, r): - h = {b('a'): b('1'), b('b'): b('2'), b('c'): b('3')} + h = {b'a': b'1', b'b': b'2', b'c': b'3'} assert r.hmset('a', h) assert r.hgetall('a') == h def test_hsetnx(self, r): # Initially set the hash field assert r.hsetnx('a', '1', 1) - assert r.hget('a', '1') == b('1') + assert r.hget('a', '1') == b'1' assert not r.hsetnx('a', '1', 2) - assert r.hget('a', '1') == b('1') + assert r.hget('a', '1') == b'1' def test_hvals(self, r): - h = {b('a1'): b('1'), b('a2'): b('2'), b('a3'): b('3')} + h = {b'a1': b'1', b'a2': b'2', b'a3': b'3'} r.hmset('a', h) local_vals = list(itervalues(h)) remote_vals = r.hvals('a') assert sorted(local_vals) == sorted(remote_vals) + @skip_if_server_version_lt('3.2.0') + def test_hstrlen(self, r): + r.hmset('a', {'1': '22', '2': '333'}) + assert r.hstrlen('a', '1') == 2 + assert r.hstrlen('a', '2') == 3 + # SORT def test_sort_basic(self, r): r.rpush('a', '3', '2', '1', '4') - assert r.sort('a') == [b('1'), b('2'), b('3'), b('4')] + assert r.sort('a') == [b'1', b'2', b'3', b'4'] def test_sort_limited(self, r): r.rpush('a', '3', '2', '1', '4') - assert r.sort('a', start=1, num=2) == [b('2'), b('3')] + assert r.sort('a', start=1, num=2) == [b'2', b'3'] def test_sort_by(self, r): r['score:1'] = 8 r['score:2'] = 3 r['score:3'] = 5 r.rpush('a', '3', '2', '1') - assert r.sort('a', by='score:*') == [b('2'), b('3'), b('1')] + assert r.sort('a', by='score:*') == [b'2', b'3', b'1'] def test_sort_get(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' r['user:3'] = 'u3' r.rpush('a', '2', '3', '1') - assert r.sort('a', get='user:*') == [b('u1'), b('u2'), b('u3')] + assert r.sort('a', get='user:*') == [b'u1', b'u2', b'u3'] def test_sort_get_multi(self, r): r['user:1'] = 'u1' @@ -1229,7 +1430,7 @@ class TestRedisCommands(object): r['user:3'] = 'u3' r.rpush('a', '2', '3', '1') assert r.sort('a', get=('user:*', '#')) == \ - [b('u1'), b('1'), b('u2'), b('2'), b('u3'), b('3')] + [b'u1', b'1', b'u2', b'2', b'u3', b'3'] def test_sort_get_groups_two(self, r): r['user:1'] = 'u1' @@ -1237,7 +1438,7 @@ class TestRedisCommands(object): r['user:3'] = 'u3' r.rpush('a', '2', '3', '1') assert r.sort('a', get=('user:*', '#'), groups=True) == \ - [(b('u1'), b('1')), (b('u2'), b('2')), (b('u3'), b('3'))] + [(b'u1', b'1'), (b'u2', b'2'), (b'u3', b'3')] def test_sort_groups_string_get(self, r): r['user:1'] = 'u1' @@ -1273,24 +1474,24 @@ class TestRedisCommands(object): r.rpush('a', '2', '3', '1') assert r.sort('a', get=('user:*', 'door:*', '#'), groups=True) == \ [ - (b('u1'), b('d1'), b('1')), - (b('u2'), b('d2'), b('2')), - (b('u3'), b('d3'), b('3')) - ] + (b'u1', b'd1', b'1'), + (b'u2', b'd2', b'2'), + (b'u3', b'd3', b'3') + ] def test_sort_desc(self, r): r.rpush('a', '2', '3', '1') - assert r.sort('a', desc=True) == [b('3'), b('2'), b('1')] + assert r.sort('a', desc=True) == [b'3', b'2', b'1'] def test_sort_alpha(self, r): r.rpush('a', 'e', 'c', 'b', 'd', 'a') assert r.sort('a', alpha=True) == \ - [b('a'), b('b'), b('c'), b('d'), b('e')] + [b'a', b'b', b'c', b'd', b'e'] def test_sort_store(self, r): r.rpush('a', '2', '3', '1') assert r.sort('a', store='sorted_values') == 3 - assert r.lrange('sorted_values', 0, -1) == [b('1'), b('2'), b('3')] + assert r.lrange('sorted_values', 0, -1) == [b'1', b'2', b'3'] def test_sort_all_options(self, r): r['user:1:username'] = 'zeus' @@ -1317,57 +1518,765 @@ class TestRedisCommands(object): store='sorted') assert num == 4 assert r.lrange('sorted', 0, 10) == \ - [b('vodka'), b('milk'), b('gin'), b('apple juice')] + [b'vodka', b'milk', b'gin', b'apple juice'] + def test_sort_issue_924(self, r): + # Tests for issue https://github.com/andymccurdy/redis-py/issues/924 + r.execute_command('SADD', 'issue#924', 1) + r.execute_command('SORT', 'issue#924') -class TestStrictCommands(object): + def test_cluster_addslots(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('ADDSLOTS', 1) is True - def test_strict_zadd(self, sr): - sr.zadd('a', 1.0, 'a1', 2.0, 'a2', a3=3.0) - assert sr.zrange('a', 0, -1, withscores=True) == \ - [(b('a1'), 1.0), (b('a2'), 2.0), (b('a3'), 3.0)] + def test_cluster_count_failure_reports(self, mock_cluster_resp_int): + assert isinstance(mock_cluster_resp_int.cluster( + 'COUNT-FAILURE-REPORTS', 'node'), int) - def test_strict_lrem(self, sr): - sr.rpush('a', 'a1', 'a2', 'a3', 'a1') - sr.lrem('a', 0, 'a1') - assert sr.lrange('a', 0, -1) == [b('a2'), b('a3')] + def test_cluster_countkeysinslot(self, mock_cluster_resp_int): + assert isinstance(mock_cluster_resp_int.cluster( + 'COUNTKEYSINSLOT', 2), int) - def test_strict_setex(self, sr): - assert sr.setex('a', 60, '1') - assert sr['a'] == b('1') - assert 0 < sr.ttl('a') <= 60 + def test_cluster_delslots(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('DELSLOTS', 1) is True - def test_strict_ttl(self, sr): - assert not sr.expire('a', 10) - sr['a'] = '1' - assert sr.expire('a', 10) - assert 0 < sr.ttl('a') <= 10 - assert sr.persist('a') - assert sr.ttl('a') == -1 + def test_cluster_failover(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('FAILOVER', 1) is True - @skip_if_server_version_lt('2.6.0') - def test_strict_pttl(self, sr): - assert not sr.pexpire('a', 10000) - sr['a'] = '1' - assert sr.pexpire('a', 10000) - assert 0 < sr.pttl('a') <= 10000 - assert sr.persist('a') - assert sr.pttl('a') == -1 + def test_cluster_forget(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('FORGET', 1) is True + + def test_cluster_info(self, mock_cluster_resp_info): + assert isinstance(mock_cluster_resp_info.cluster('info'), dict) + + def test_cluster_keyslot(self, mock_cluster_resp_int): + assert isinstance(mock_cluster_resp_int.cluster( + 'keyslot', 'asdf'), int) + + def test_cluster_meet(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('meet', 'ip', 'port', 1) is True + + def test_cluster_nodes(self, mock_cluster_resp_nodes): + assert isinstance(mock_cluster_resp_nodes.cluster('nodes'), dict) + + def test_cluster_replicate(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('replicate', 'nodeid') is True + + def test_cluster_reset(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('reset', 'hard') is True + + def test_cluster_saveconfig(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('saveconfig') is True + + def test_cluster_setslot(self, mock_cluster_resp_ok): + assert mock_cluster_resp_ok.cluster('setslot', 1, + 'IMPORTING', 'nodeid') is True + + def test_cluster_slaves(self, mock_cluster_resp_slaves): + assert isinstance(mock_cluster_resp_slaves.cluster( + 'slaves', 'nodeid'), dict) + + # GEO COMMANDS + @skip_if_server_version_lt('3.2.0') + def test_geoadd(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + assert r.geoadd('barcelona', *values) == 2 + assert r.zcard('barcelona') == 2 + + @skip_if_server_version_lt('3.2.0') + def test_geoadd_invalid_params(self, r): + with pytest.raises(exceptions.RedisError): + r.geoadd('barcelona', *(1, 2)) + + @skip_if_server_version_lt('3.2.0') + def test_geodist(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + assert r.geoadd('barcelona', *values) == 2 + assert r.geodist('barcelona', 'place1', 'place2') == 3067.4157 + + @skip_if_server_version_lt('3.2.0') + def test_geodist_units(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.geodist('barcelona', 'place1', 'place2', 'km') == 3.0674 + + @skip_if_server_version_lt('3.2.0') + def test_geodist_missing_one_member(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') + r.geoadd('barcelona', *values) + assert r.geodist('barcelona', 'place1', 'missing_member', 'km') is None + + @skip_if_server_version_lt('3.2.0') + def test_geodist_invalid_units(self, r): + with pytest.raises(exceptions.RedisError): + assert r.geodist('x', 'y', 'z', 'inches') + + @skip_if_server_version_lt('3.2.0') + def test_geohash(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.geohash('barcelona', 'place1', 'place2') ==\ + ['sp3e9yg3kd0', 'sp3e9cbc3t0'] + + @skip_if_server_version_lt('3.2.0') + def test_geopos(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + # redis uses 52 bits precision, hereby small errors may be introduced. + assert r.geopos('barcelona', 'place1', 'place2') ==\ + [(2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099)] + + @skip_if_server_version_lt('4.0.0') + def test_geopos_no_value(self, r): + assert r.geopos('barcelona', 'place1', 'place2') == [None, None] + + @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_gte('4.0.0') + def test_old_geopos_no_value(self, r): + assert r.geopos('barcelona', 'place1', 'place2') == [] + + @skip_if_server_version_lt('3.2.0') + def test_georadius(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.georadius('barcelona', 2.191, 41.433, 1000) == ['place1'] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_no_values(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.georadius('barcelona', 1, 2, 1000) == [] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_units(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km') ==\ + ['place1'] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_with(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + + # test a bunch of combinations to test the parse response + # function. + assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', + withdist=True, withcoord=True, withhash=True) ==\ + [['place1', 0.0881, 3471609698139488, + (2.19093829393386841, 41.43379028184083523)]] + + assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', + withdist=True, withcoord=True) ==\ + [['place1', 0.0881, + (2.19093829393386841, 41.43379028184083523)]] + + assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', + withhash=True, withcoord=True) ==\ + [['place1', 3471609698139488, + (2.19093829393386841, 41.43379028184083523)]] + + # test no values. + assert r.georadius('barcelona', 2, 1, 1, unit='km', + withdist=True, withcoord=True, withhash=True) == [] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_count(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.georadius('barcelona', 2.191, 41.433, 3000, count=1) ==\ + ['place1'] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_sort(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='ASC') ==\ + ['place1', 'place2'] + assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='DESC') ==\ + ['place2', 'place1'] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_store(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + r.georadius('barcelona', 2.191, 41.433, 1000, store='places_barcelona') + assert r.zrange('places_barcelona', 0, -1) == [b'place1'] + + @skip_if_server_version_lt('3.2.0') + def test_georadius_store_dist(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + r.georadius('barcelona', 2.191, 41.433, 1000, + store_dist='places_barcelona') + # instead of save the geo score, the distance is saved. + assert r.zscore('places_barcelona', 'place1') == 88.05060698409301 + + @skip_if_server_version_lt('3.2.0') + def test_georadiusmember(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') +\ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('barcelona', *values) + assert r.georadiusbymember('barcelona', 'place1', 4000) ==\ + ['place2', 'place1'] + assert r.georadiusbymember('barcelona', 'place1', 10) == ['place1'] + + assert r.georadiusbymember('barcelona', 'place1', 4000, + withdist=True, withcoord=True, + withhash=True) ==\ + [['place2', 3067.4157, 3471609625421029, + (2.187376320362091, 41.40634178640635)], + ['place1', 0.0, 3471609698139488, + (2.1909382939338684, 41.433790281840835)]] + + @skip_if_server_version_lt('5.0.0') + def test_xack(self, r): + stream = 'stream' + group = 'group' + consumer = 'consumer' + # xack on a stream that doesn't exist + assert r.xack(stream, group, '0-0') == 0 + + m1 = r.xadd(stream, {'one': 'one'}) + m2 = r.xadd(stream, {'two': 'two'}) + m3 = r.xadd(stream, {'three': 'three'}) + + # xack on a group that doesn't exist + assert r.xack(stream, group, m1) == 0 + + r.xgroup_create(stream, group, 0) + r.xreadgroup(group, consumer, streams={stream: 0}) + # xack returns the number of ack'd elements + assert r.xack(stream, group, m1) == 1 + assert r.xack(stream, group, m2, m3) == 2 + + @skip_if_server_version_lt('5.0.0') + def test_xadd(self, r): + stream = 'stream' + message_id = r.xadd(stream, {'foo': 'bar'}) + assert re.match(br'[0-9]+\-[0-9]+', message_id) + + # explicit message id + message_id = b'9999999999999999999-0' + assert message_id == r.xadd(stream, {'foo': 'bar'}, id=message_id) + + # with maxlen, the list evicts the first message + r.xadd(stream, {'foo': 'bar'}, maxlen=2, approximate=False) + assert r.xlen(stream) == 2 + + @skip_if_server_version_lt('5.0.0') + def test_xclaim(self, r): + stream = 'stream' + group = 'group' + consumer1 = 'consumer1' + consumer2 = 'consumer2' + + message_id = r.xadd(stream, {'john': 'wick'}) + message = get_stream_message(r, stream, message_id) + r.xgroup_create(stream, group, 0) + + # trying to claim a message that isn't already pending doesn't + # do anything + response = r.xclaim(stream, group, consumer2, + min_idle_time=0, message_ids=(message_id,)) + assert response == [] + + # read the group as consumer1 to initially claim the messages + r.xreadgroup(group, consumer1, streams={stream: 0}) + + # claim the message as consumer2 + response = r.xclaim(stream, group, consumer2, + min_idle_time=0, message_ids=(message_id,)) + assert response[0] == message + + # reclaim the message as consumer1, but use the justid argument + # which only returns message ids + assert r.xclaim(stream, group, consumer1, + min_idle_time=0, message_ids=(message_id,), + justid=True) == [message_id] + + @skip_if_server_version_lt('5.0.0') + def test_xdel(self, r): + stream = 'stream' + + # deleting from an empty stream doesn't do anything + assert r.xdel(stream, 1) == 0 + + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'foo': 'bar'}) + m3 = r.xadd(stream, {'foo': 'bar'}) + + # xdel returns the number of deleted elements + assert r.xdel(stream, m1) == 1 + assert r.xdel(stream, m2, m3) == 2 + + @skip_if_server_version_lt('5.0.0') + def test_xgroup_create(self, r): + # tests xgroup_create and xinfo_groups + stream = 'stream' + group = 'group' + r.xadd(stream, {'foo': 'bar'}) + + # no group is setup yet, no info to obtain + assert r.xinfo_groups(stream) == [] + + assert r.xgroup_create(stream, group, 0) + expected = [{ + 'name': group.encode(), + 'consumers': 0, + 'pending': 0, + 'last-delivered-id': b'0-0' + }] + assert r.xinfo_groups(stream) == expected + + @skip_if_server_version_lt('5.0.0') + def test_xgroup_create_mkstream(self, r): + # tests xgroup_create and xinfo_groups + stream = 'stream' + group = 'group' + + # an error is raised if a group is created on a stream that + # doesn't already exist + with pytest.raises(exceptions.ResponseError): + r.xgroup_create(stream, group, 0) + + # however, with mkstream=True, the underlying stream is created + # automatically + assert r.xgroup_create(stream, group, 0, mkstream=True) + expected = [{ + 'name': group.encode(), + 'consumers': 0, + 'pending': 0, + 'last-delivered-id': b'0-0' + }] + assert r.xinfo_groups(stream) == expected + + @skip_if_server_version_lt('5.0.0') + def test_xgroup_delconsumer(self, r): + stream = 'stream' + group = 'group' + consumer = 'consumer' + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) + r.xgroup_create(stream, group, 0) + + # a consumer that hasn't yet read any messages doesn't do anything + assert r.xgroup_delconsumer(stream, group, consumer) == 0 + + # read all messages from the group + r.xreadgroup(group, consumer, streams={stream: 0}) + + # deleting the consumer should return 2 pending messages + assert r.xgroup_delconsumer(stream, group, consumer) == 2 + + @skip_if_server_version_lt('5.0.0') + def test_xgroup_destroy(self, r): + stream = 'stream' + group = 'group' + r.xadd(stream, {'foo': 'bar'}) + + # destroying a nonexistent group returns False + assert not r.xgroup_destroy(stream, group) + + r.xgroup_create(stream, group, 0) + assert r.xgroup_destroy(stream, group) + + @skip_if_server_version_lt('5.0.0') + def test_xgroup_setid(self, r): + stream = 'stream' + group = 'group' + message_id = r.xadd(stream, {'foo': 'bar'}) + + r.xgroup_create(stream, group, 0) + # advance the last_delivered_id to the message_id + r.xgroup_setid(stream, group, message_id) + expected = [{ + 'name': group.encode(), + 'consumers': 0, + 'pending': 0, + 'last-delivered-id': message_id + }] + assert r.xinfo_groups(stream) == expected + + @skip_if_server_version_lt('5.0.0') + def test_xinfo_consumers(self, r): + stream = 'stream' + group = 'group' + consumer1 = 'consumer1' + consumer2 = 'consumer2' + r.xadd(stream, {'foo': 'bar'}) + + r.xgroup_create(stream, group, 0) + r.xreadgroup(group, consumer1, streams={stream: 0}) + r.xreadgroup(group, consumer2, streams={stream: 0}) + info = r.xinfo_consumers(stream, group) + assert len(info) == 2 + expected = [ + {'name': consumer1.encode(), 'pending': 1}, + {'name': consumer2.encode(), 'pending': 0}, + ] + + # we can't determine the idle time, so just make sure it's an int + assert isinstance(info[0].pop('idle'), (int, long)) + assert isinstance(info[1].pop('idle'), (int, long)) + assert info == expected + + @skip_if_server_version_lt('5.0.0') + def test_xinfo_stream(self, r): + stream = 'stream' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'foo': 'bar'}) + info = r.xinfo_stream(stream) + + assert info['length'] == 2 + assert info['first-entry'] == get_stream_message(r, stream, m1) + assert info['last-entry'] == get_stream_message(r, stream, m2) + + @skip_if_server_version_lt('5.0.0') + def test_xlen(self, r): + stream = 'stream' + assert r.xlen(stream) == 0 + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) + assert r.xlen(stream) == 2 + + @skip_if_server_version_lt('5.0.0') + def test_xpending(self, r): + stream = 'stream' + group = 'group' + consumer1 = 'consumer1' + consumer2 = 'consumer2' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'foo': 'bar'}) + r.xgroup_create(stream, group, 0) + + # xpending on a group that has no consumers yet + expected = { + 'pending': 0, + 'min': None, + 'max': None, + 'consumers': [] + } + assert r.xpending(stream, group) == expected + + # read 1 message from the group with each consumer + r.xreadgroup(group, consumer1, streams={stream: 0}, count=1) + r.xreadgroup(group, consumer2, streams={stream: m1}, count=1) + + expected = { + 'pending': 2, + 'min': m1, + 'max': m2, + 'consumers': [ + {'name': consumer1.encode(), 'pending': 1}, + {'name': consumer2.encode(), 'pending': 1}, + ] + } + assert r.xpending(stream, group) == expected + + @skip_if_server_version_lt('5.0.0') + def test_xpending_range(self, r): + stream = 'stream' + group = 'group' + consumer1 = 'consumer1' + consumer2 = 'consumer2' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'foo': 'bar'}) + r.xgroup_create(stream, group, 0) + + # xpending range on a group that has no consumers yet + assert r.xpending_range(stream, group) == [] + + # read 1 message from the group with each consumer + r.xreadgroup(group, consumer1, streams={stream: 0}, count=1) + r.xreadgroup(group, consumer2, streams={stream: m1}, count=1) + + response = r.xpending_range(stream, group) + assert len(response) == 2 + assert response[0]['message_id'] == m1 + assert response[0]['consumer'] == consumer1.encode() + assert response[1]['message_id'] == m2 + assert response[1]['consumer'] == consumer2.encode() + + @skip_if_server_version_lt('5.0.0') + def test_xrange(self, r): + stream = 'stream' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'foo': 'bar'}) + m3 = r.xadd(stream, {'foo': 'bar'}) + m4 = r.xadd(stream, {'foo': 'bar'}) + + def get_ids(results): + return [result[0] for result in results] + + results = r.xrange(stream, min=m1) + assert get_ids(results) == [m1, m2, m3, m4] + + results = r.xrange(stream, min=m2, max=m3) + assert get_ids(results) == [m2, m3] + + results = r.xrange(stream, max=m3) + assert get_ids(results) == [m1, m2, m3] + + results = r.xrange(stream, max=m2, count=1) + assert get_ids(results) == [m1] + + @skip_if_server_version_lt('5.0.0') + def test_xread(self, r): + stream = 'stream' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'bing': 'baz'}) + + expected = [ + [ + stream, + [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), + ] + ] + ] + # xread starting at 0 returns both messages + assert r.xread(streams={stream: 0}) == expected + + expected = [ + [ + stream, + [ + get_stream_message(r, stream, m1), + ] + ] + ] + # xread starting at 0 and count=1 returns only the first message + assert r.xread(streams={stream: 0}, count=1) == expected + + expected = [ + [ + stream, + [ + get_stream_message(r, stream, m2), + ] + ] + ] + # xread starting at m1 returns only the second message + assert r.xread(streams={stream: m1}) == expected + + # xread starting at the last message returns an empty list + assert r.xread(streams={stream: m2}) == [] + + @skip_if_server_version_lt('5.0.0') + def test_xreadgroup(self, r): + stream = 'stream' + group = 'group' + consumer = 'consumer' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'bing': 'baz'}) + r.xgroup_create(stream, group, 0) + + expected = [ + [ + stream, + [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), + ] + ] + ] + # xread starting at 0 returns both messages + assert r.xreadgroup(group, consumer, streams={stream: 0}) == expected + + r.xgroup_destroy(stream, group) + r.xgroup_create(stream, group, 0) + + expected = [ + [ + stream, + [ + get_stream_message(r, stream, m1), + ] + ] + ] + # xread starting at 0 and count=1 returns only the first message + assert r.xreadgroup(group, consumer, streams={stream: 0}, count=1) == \ + expected + + r.xgroup_destroy(stream, group) + r.xgroup_create(stream, group, 0) + + expected = [ + [ + stream, + [ + get_stream_message(r, stream, m2), + ] + ] + ] + # xread starting at m1 returns only the second message + assert r.xreadgroup(group, consumer, streams={stream: m1}) == expected + + r.xgroup_destroy(stream, group) + r.xgroup_create(stream, group, 0) + + # xread starting at the last message returns an empty message list + expected = [ + [ + stream, + [] + ] + ] + assert r.xreadgroup(group, consumer, streams={stream: m2}) == expected + + @skip_if_server_version_lt('5.0.0') + def test_xrevrange(self, r): + stream = 'stream' + m1 = r.xadd(stream, {'foo': 'bar'}) + m2 = r.xadd(stream, {'foo': 'bar'}) + m3 = r.xadd(stream, {'foo': 'bar'}) + m4 = r.xadd(stream, {'foo': 'bar'}) + + def get_ids(results): + return [result[0] for result in results] + + results = r.xrevrange(stream, max=m4) + assert get_ids(results) == [m4, m3, m2, m1] + + results = r.xrevrange(stream, max=m3, min=m2) + assert get_ids(results) == [m3, m2] + + results = r.xrevrange(stream, min=m3) + assert get_ids(results) == [m4, m3] + + results = r.xrevrange(stream, min=m2, count=1) + assert get_ids(results) == [m4] + + @skip_if_server_version_lt('5.0.0') + def test_xtrim(self, r): + stream = 'stream' + + # trimming an empty key doesn't do anything + assert r.xtrim(stream, 1000) == 0 + + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'foo': 'bar'}) + + # trimming an amount large than the number of messages + # doesn't do anything + assert r.xtrim(stream, 5, approximate=False) == 0 + + # 1 message is trimmed + assert r.xtrim(stream, 3, approximate=False) == 1 + + def test_bitfield_operations(self, r): + # comments show affected bits + bf = r.bitfield('a') + resp = (bf + .set('u8', 8, 255) # 00000000 11111111 + .get('u8', 0) # 00000000 + .get('u4', 8) # 1111 + .get('u4', 12) # 1111 + .get('u4', 13) # 111 0 + .execute()) + assert resp == [0, 0, 15, 15, 14] + + # .set() returns the previous value... + resp = (bf + .set('u8', 4, 1) # 0000 0001 + .get('u16', 0) # 00000000 00011111 + .set('u16', 0, 0) # 00000000 00000000 + .execute()) + assert resp == [15, 31, 31] + + # incrby adds to the value + resp = (bf + .incrby('u8', 8, 254) # 00000000 11111110 + .incrby('u8', 8, 1) # 00000000 11111111 + .get('u16', 0) # 00000000 11111111 + .execute()) + assert resp == [254, 255, 255] + + # Verify overflow protection works as a method: + r.delete('a') + resp = (bf + .set('u8', 8, 254) # 00000000 11111110 + .overflow('fail') + .incrby('u8', 8, 2) # incrby 2 would overflow, None returned + .incrby('u8', 8, 1) # 00000000 11111111 + .incrby('u8', 8, 1) # incrby 1 would overflow, None returned + .get('u16', 0) # 00000000 11111111 + .execute()) + assert resp == [0, None, 255, None, 255] + + # Verify overflow protection works as arg to incrby: + r.delete('a') + resp = (bf + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 00000000 wrap default + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1, 'FAIL') # 00000000 11111111 fail + .incrby('u8', 8, 1) # 00000000 11111111 still fail + .get('u16', 0) # 00000000 11111111 + .execute()) + assert resp == [0, 0, 0, None, None, 255] + + # test default default_overflow + r.delete('a') + bf = r.bitfield('a', default_overflow='FAIL') + resp = (bf + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 fail default + .get('u16', 0) # 00000000 11111111 + .execute()) + assert resp == [0, None, 255] + + @skip_if_server_version_lt('4.0.0') + def test_memory_usage(self, r): + r.set('foo', 'bar') + assert isinstance(r.memory_usage('foo'), int) class TestBinarySave(object): + def test_binary_get_set(self, r): assert r.set(' foo bar ', '123') - assert r.get(' foo bar ') == b('123') + assert r.get(' foo bar ') == b'123' assert r.set(' foo\r\nbar\r\n ', '456') - assert r.get(' foo\r\nbar\r\n ') == b('456') + assert r.get(' foo\r\nbar\r\n ') == b'456' assert r.set(' \r\n\t\x07\x13 ', '789') - assert r.get(' \r\n\t\x07\x13 ') == b('789') + assert r.get(' \r\n\t\x07\x13 ') == b'789' assert sorted(r.keys('*')) == \ - [b(' \r\n\t\x07\x13 '), b(' foo\r\nbar\r\n '), b(' foo bar ')] + [b' \r\n\t\x07\x13 ', b' foo\r\nbar\r\n ', b' foo bar '] assert r.delete(' foo bar ') assert r.delete(' foo\r\nbar\r\n ') @@ -1375,9 +2284,9 @@ class TestBinarySave(object): def test_binary_lists(self, r): mapping = { - b('foo bar'): [b('1'), b('2'), b('3')], - b('foo\r\nbar\r\n'): [b('4'), b('5'), b('6')], - b('foo\tbar\x07'): [b('7'), b('8'), b('9')], + b'foo bar': [b'1', b'2', b'3'], + b'foo\r\nbar\r\n': [b'4', b'5', b'6'], + b'foo\tbar\x07': [b'7', b'8', b'9'], } # fill in lists for key, value in iteritems(mapping): @@ -1429,7 +2338,7 @@ class TestBinarySave(object): # load up 5MB of data into a key data = ''.join([ascii_letters] * (5000000 // len(ascii_letters))) r['a'] = data - assert r['a'] == b(data) + assert r['a'] == data.encode() def test_floating_point_encoding(self, r): """ @@ -1437,5 +2346,5 @@ class TestBinarySave(object): precision. """ timestamp = 1349673917.939762 - r.zadd('a', 'a1', timestamp) + r.zadd('a', {'a1': timestamp}) assert r.zscore('a', 'a1') == timestamp diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 55ccce1..b0dec67 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import os import pytest import redis @@ -6,7 +5,7 @@ import time import re from threading import Thread -from redis.connection import ssl_available +from redis.connection import ssl_available, to_bool from .conftest import skip_if_server_version_lt @@ -163,6 +162,17 @@ class TestConnectionPoolURLParsing(object): 'password': None, } + def test_quoted_hostname(self): + pool = redis.ConnectionPool.from_url('redis://my %2F host %2B%3D+', + decode_components=True) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'my / host +=+', + 'port': 6379, + 'db': 0, + 'password': None, + } + def test_port(self): pool = redis.ConnectionPool.from_url('redis://localhost:6380') assert pool.connection_class == redis.Connection @@ -183,6 +193,18 @@ class TestConnectionPoolURLParsing(object): 'password': 'mypassword', } + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + 'redis://:%2Fmypass%2F%2B word%3D%24+@localhost', + decode_components=True) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': '/mypass/+ word=$+', + } + def test_db_as_argument(self): pool = redis.ConnectionPool.from_url('redis://localhost', db='1') assert pool.connection_class == redis.Connection @@ -214,6 +236,52 @@ class TestConnectionPoolURLParsing(object): 'password': None, } + def test_extra_typed_querystring_options(self): + pool = redis.ConnectionPool.from_url( + 'redis://localhost/2?socket_timeout=20&socket_connect_timeout=10' + '&socket_keepalive=&retry_on_timeout=Yes&max_connections=10' + ) + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 2, + 'socket_timeout': 20.0, + 'socket_connect_timeout': 10.0, + 'retry_on_timeout': True, + 'password': None, + } + assert pool.max_connections == 10 + + def test_boolean_parsing(self): + for expected, value in ( + (None, None), + (None, ''), + (False, 0), (False, '0'), + (False, 'f'), (False, 'F'), (False, 'False'), + (False, 'n'), (False, 'N'), (False, 'No'), + (True, 1), (True, '1'), + (True, 'y'), (True, 'Y'), (True, 'Yes'), + ): + assert expected is to_bool(value) + + def test_invalid_extra_typed_querystring_options(self): + import warnings + with warnings.catch_warnings(record=True) as warning_log: + redis.ConnectionPool.from_url( + 'redis://localhost/2?socket_timeout=_&' + 'socket_connect_timeout=abc' + ) + # Compare the message values + assert [ + str(m.message) for m in + sorted(warning_log, key=lambda l: str(l.message)) + ] == [ + 'Invalid value for `socket_connect_timeout` in connection URL.', + 'Invalid value for `socket_timeout` in connection URL.', + ] + def test_extra_querystring_options(self): pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2') assert pool.connection_class == redis.Connection @@ -231,7 +299,7 @@ class TestConnectionPoolURLParsing(object): assert isinstance(pool, redis.BlockingConnectionPool) def test_client_creates_connection_pool(self): - r = redis.StrictRedis.from_url('redis://myhost') + r = redis.Redis.from_url('redis://myhost') assert r.connection_pool.connection_class == redis.Connection assert r.connection_pool.connection_kwargs == { 'host': 'myhost', @@ -260,6 +328,28 @@ class TestConnectionPoolUnixSocketURLParsing(object): 'password': 'mypassword', } + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + 'unix://:%2Fmypass%2F%2B word%3D%24+@/socket', + decode_components=True) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/socket', + 'db': 0, + 'password': '/mypass/+ word=$+', + } + + def test_quoted_path(self): + pool = redis.ConnectionPool.from_url( + 'unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket', + decode_components=True) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/my/path/to/../+_+=$ocket', + 'db': 0, + 'password': 'mypassword', + } + def test_db_as_argument(self): pool = redis.ConnectionPool.from_url('unix:///socket', db=1) assert pool.connection_class == redis.UnixDomainSocketConnection diff --git a/tests/test_encoding.py b/tests/test_encoding.py index b1df0a5..283fc6e 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -1,24 +1,25 @@ -from __future__ import with_statement +from __future__ import unicode_literals import pytest +import redis -from redis._compat import unichr, u, unicode -from .conftest import r as _redis_client +from redis._compat import unichr, unicode +from .conftest import _get_client class TestEncoding(object): @pytest.fixture() def r(self, request): - return _redis_client(request=request, decode_responses=True) + return _get_client(redis.Redis, request=request, decode_responses=True) def test_simple_encoding(self, r): - unicode_string = unichr(3456) + u('abcd') + unichr(3421) + unicode_string = unichr(3456) + 'abcd' + unichr(3421) r['unicode-string'] = unicode_string cached_val = r['unicode-string'] assert isinstance(cached_val, unicode) assert unicode_string == cached_val def test_list_encoding(self, r): - unicode_string = unichr(3456) + u('abcd') + unichr(3421) + unicode_string = unichr(3456) + 'abcd' + unichr(3421) result = [unicode_string, unicode_string, unicode_string] r.rpush('a', *result) assert r.lrange('a', 0, -1) == result @@ -27,7 +28,28 @@ class TestEncoding(object): class TestCommandsAndTokensArentEncoded(object): @pytest.fixture() def r(self, request): - return _redis_client(request=request, charset='utf-16') + return _get_client(redis.Redis, request=request, encoding='utf-16') def test_basic_command(self, r): r.set('hello', 'world') + + +class TestInvalidUserInput(object): + def test_boolean_fails(self, r): + with pytest.raises(redis.DataError): + r.set('a', True) + + def test_none_fails(self, r): + with pytest.raises(redis.DataError): + r.set('a', None) + + def test_user_type_fails(self, r): + class Foo(object): + def __str__(self): + return 'Foo' + + def __unicode__(self): + return 'Foo' + + with pytest.raises(redis.DataError): + r.set('a', Foo()) diff --git a/tests/test_lock.py b/tests/test_lock.py index d732ae1..a6adbc2 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -1,29 +1,34 @@ -from __future__ import with_statement import pytest import time -from redis.exceptions import LockError, ResponseError -from redis.lock import Lock, LuaLock +from redis.exceptions import LockError +from redis.lock import Lock class TestLock(object): - lock_class = Lock - def get_lock(self, redis, *args, **kwargs): - kwargs['lock_class'] = self.lock_class + kwargs['lock_class'] = Lock return redis.lock(*args, **kwargs) - def test_lock(self, sr): - lock = self.get_lock(sr, 'foo') + def test_lock(self, r): + lock = self.get_lock(r, 'foo') assert lock.acquire(blocking=False) - assert sr.get('foo') == lock.local.token - assert sr.ttl('foo') == -1 + assert r.get('foo') == lock.local.token + assert r.ttl('foo') == -1 lock.release() - assert sr.get('foo') is None + assert r.get('foo') is None - def test_competing_locks(self, sr): - lock1 = self.get_lock(sr, 'foo') - lock2 = self.get_lock(sr, 'foo') + def test_locked(self, r): + lock = self.get_lock(r, 'foo') + assert lock.locked() is False + lock.acquire(blocking=False) + assert lock.locked() is True + lock.release() + assert lock.locked() is False + + def test_competing_locks(self, r): + lock1 = self.get_lock(r, 'foo') + lock2 = self.get_lock(r, 'foo') assert lock1.acquire(blocking=False) assert not lock2.acquire(blocking=False) lock1.release() @@ -31,137 +36,101 @@ class TestLock(object): assert not lock1.acquire(blocking=False) lock2.release() - def test_timeout(self, sr): - lock = self.get_lock(sr, 'foo', timeout=10) + def test_timeout(self, r): + lock = self.get_lock(r, 'foo', timeout=10) assert lock.acquire(blocking=False) - assert 8 < sr.ttl('foo') <= 10 + assert 8 < r.ttl('foo') <= 10 lock.release() - def test_float_timeout(self, sr): - lock = self.get_lock(sr, 'foo', timeout=9.5) + def test_float_timeout(self, r): + lock = self.get_lock(r, 'foo', timeout=9.5) assert lock.acquire(blocking=False) - assert 8 < sr.pttl('foo') <= 9500 + assert 8 < r.pttl('foo') <= 9500 lock.release() - def test_blocking_timeout(self, sr): - lock1 = self.get_lock(sr, 'foo') + def test_blocking_timeout(self, r): + lock1 = self.get_lock(r, 'foo') assert lock1.acquire(blocking=False) - lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2) + lock2 = self.get_lock(r, 'foo', blocking_timeout=0.2) start = time.time() assert not lock2.acquire() assert (time.time() - start) > 0.2 lock1.release() - def test_context_manager(self, sr): + def test_context_manager(self, r): # blocking_timeout prevents a deadlock if the lock can't be acquired # for some reason - with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock: - assert sr.get('foo') == lock.local.token - assert sr.get('foo') is None + with self.get_lock(r, 'foo', blocking_timeout=0.2) as lock: + assert r.get('foo') == lock.local.token + assert r.get('foo') is None - def test_high_sleep_raises_error(self, sr): + def test_context_manager_raises_when_locked_not_acquired(self, r): + r.set('foo', 'bar') + with pytest.raises(LockError): + with self.get_lock(r, 'foo', blocking_timeout=0.1): + pass + + def test_high_sleep_raises_error(self, r): "If sleep is higher than timeout, it should raise an error" with pytest.raises(LockError): - self.get_lock(sr, 'foo', timeout=1, sleep=2) + self.get_lock(r, 'foo', timeout=1, sleep=2) - def test_releasing_unlocked_lock_raises_error(self, sr): - lock = self.get_lock(sr, 'foo') + def test_releasing_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, 'foo') with pytest.raises(LockError): lock.release() - def test_releasing_lock_no_longer_owned_raises_error(self, sr): - lock = self.get_lock(sr, 'foo') + def test_releasing_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, 'foo') lock.acquire(blocking=False) # manually change the token - sr.set('foo', 'a') + r.set('foo', 'a') with pytest.raises(LockError): lock.release() # even though we errored, the token is still cleared assert lock.local.token is None - def test_extend_lock(self, sr): - lock = self.get_lock(sr, 'foo', timeout=10) + def test_extend_lock(self, r): + lock = self.get_lock(r, 'foo', timeout=10) assert lock.acquire(blocking=False) - assert 8000 < sr.pttl('foo') <= 10000 + assert 8000 < r.pttl('foo') <= 10000 assert lock.extend(10) - assert 16000 < sr.pttl('foo') <= 20000 + assert 16000 < r.pttl('foo') <= 20000 lock.release() - def test_extend_lock_float(self, sr): - lock = self.get_lock(sr, 'foo', timeout=10.0) + def test_extend_lock_float(self, r): + lock = self.get_lock(r, 'foo', timeout=10.0) assert lock.acquire(blocking=False) - assert 8000 < sr.pttl('foo') <= 10000 + assert 8000 < r.pttl('foo') <= 10000 assert lock.extend(10.0) - assert 16000 < sr.pttl('foo') <= 20000 + assert 16000 < r.pttl('foo') <= 20000 lock.release() - def test_extending_unlocked_lock_raises_error(self, sr): - lock = self.get_lock(sr, 'foo', timeout=10) + def test_extending_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, 'foo', timeout=10) with pytest.raises(LockError): lock.extend(10) - def test_extending_lock_with_no_timeout_raises_error(self, sr): - lock = self.get_lock(sr, 'foo') + def test_extending_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, 'foo') assert lock.acquire(blocking=False) with pytest.raises(LockError): lock.extend(10) lock.release() - def test_extending_lock_no_longer_owned_raises_error(self, sr): - lock = self.get_lock(sr, 'foo') + def test_extending_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, 'foo') assert lock.acquire(blocking=False) - sr.set('foo', 'a') + r.set('foo', 'a') with pytest.raises(LockError): lock.extend(10) -class TestLuaLock(TestLock): - lock_class = LuaLock - - class TestLockClassSelection(object): - def test_lock_class_argument(self, sr): - lock = sr.lock('foo', lock_class=Lock) - assert type(lock) == Lock - lock = sr.lock('foo', lock_class=LuaLock) - assert type(lock) == LuaLock - - def test_cached_lualock_flag(self, sr): - try: - sr._use_lua_lock = True - lock = sr.lock('foo') - assert type(lock) == LuaLock - finally: - sr._use_lua_lock = None - - def test_cached_lock_flag(self, sr): - try: - sr._use_lua_lock = False - lock = sr.lock('foo') - assert type(lock) == Lock - finally: - sr._use_lua_lock = None - - def test_lua_compatible_server(self, sr, monkeypatch): - @classmethod - def mock_register(cls, redis): - return - monkeypatch.setattr(LuaLock, 'register_scripts', mock_register) - try: - lock = sr.lock('foo') - assert type(lock) == LuaLock - assert sr._use_lua_lock is True - finally: - sr._use_lua_lock = None - - def test_lua_unavailable(self, sr, monkeypatch): - @classmethod - def mock_register(cls, redis): - raise ResponseError() - monkeypatch.setattr(LuaLock, 'register_scripts', mock_register) - try: - lock = sr.lock('foo') - assert type(lock) == Lock - assert sr._use_lua_lock is False - finally: - sr._use_lua_lock = None + def test_lock_class_argument(self, r): + class MyLock(object): + def __init__(self, *args, **kwargs): + + pass + lock = r.lock('foo', lock_class=MyLock) + assert type(lock) == MyLock diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 46fc994..2e2507a 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,23 +1,27 @@ -from __future__ import with_statement +from __future__ import unicode_literals import pytest import redis -from redis._compat import b, u, unichr, unicode +from redis._compat import unichr, unicode class TestPipeline(object): def test_pipeline(self, r): with r.pipeline() as pipe: - pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4) - pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True) + (pipe.set('a', 'a1') + .get('a') + .zadd('z', {'z1': 1}) + .zadd('z', {'z2': 4}) + .zincrby('z', 1, 'z1') + .zrange('z', 0, 5, withscores=True)) assert pipe.execute() == \ [ True, - b('a1'), + b'a1', True, True, 2.0, - [(b('z1'), 2.0), (b('z2'), 4)], + [(b'z1', 2.0), (b'z2', 4)], ] def test_pipeline_length(self, r): @@ -40,9 +44,9 @@ class TestPipeline(object): with r.pipeline(transaction=False) as pipe: pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1') assert pipe.execute() == [True, True, True] - assert r['a'] == b('a1') - assert r['b'] == b('b1') - assert r['c'] == b('c1') + assert r['a'] == b'a1' + assert r['b'] == b'b1' + assert r['c'] == b'c1' def test_pipeline_no_transaction_watch(self, r): r['a'] = 0 @@ -70,7 +74,7 @@ class TestPipeline(object): with pytest.raises(redis.WatchError): pipe.execute() - assert r['a'] == b('bad') + assert r['a'] == b'bad' def test_exec_error_in_response(self, r): """ @@ -83,23 +87,23 @@ class TestPipeline(object): result = pipe.execute(raise_on_error=False) assert result[0] - assert r['a'] == b('1') + assert r['a'] == b'1' assert result[1] - assert r['b'] == b('2') + assert r['b'] == b'2' # we can't lpush to a key that's a string value, so this should # be a ResponseError exception assert isinstance(result[2], redis.ResponseError) - assert r['c'] == b('a') + assert r['c'] == b'a' # since this isn't a transaction, the other commands after the # error are still executed assert result[3] - assert r['d'] == b('4') + assert r['d'] == b'4' # make sure the pipe was restored to a working state assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b('zzz') + assert r['z'] == b'zzz' def test_exec_error_raised(self, r): r['c'] = 'a' @@ -112,7 +116,35 @@ class TestPipeline(object): # make sure the pipe was restored to a working state assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b('zzz') + assert r['z'] == b'zzz' + + def test_transaction_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + with r.pipeline() as pipe: + pipe.set('a', 1).mget([]).set('c', 3) + result = pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + def test_pipeline_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + with r.pipeline(transaction=False) as pipe: + pipe.set('a', 1).mget([]).set('c', 3) + result = pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] def test_parse_error_raised(self, r): with r.pipeline() as pipe: @@ -126,7 +158,7 @@ class TestPipeline(object): # make sure the pipe was restored to a working state assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b('zzz') + assert r['z'] == b'zzz' def test_watch_succeed(self, r): r['a'] = 1 @@ -137,8 +169,8 @@ class TestPipeline(object): assert pipe.watching a_value = pipe.get('a') b_value = pipe.get('b') - assert a_value == b('1') - assert b_value == b('2') + assert a_value == b'1' + assert b_value == b'2' pipe.multi() pipe.set('c', 3) @@ -169,7 +201,7 @@ class TestPipeline(object): pipe.unwatch() assert not pipe.watching pipe.get('a') - assert pipe.execute() == [b('1')] + assert pipe.execute() == [b'1'] def test_transaction_callable(self, r): r['a'] = 1 @@ -178,9 +210,9 @@ class TestPipeline(object): def my_transaction(pipe): a_value = pipe.get('a') - assert a_value in (b('1'), b('2')) + assert a_value in (b'1', b'2') b_value = pipe.get('b') - assert b_value == b('2') + assert b_value == b'2' # silly run-once code... incr's "a" so WatchError should be raised # forcing this all to run again. this should incr "a" once to "2" @@ -193,7 +225,7 @@ class TestPipeline(object): result = r.transaction(my_transaction, 'a', 'b') assert result == [True] - assert r['c'] == b('4') + assert r['c'] == b'4' def test_exec_error_in_no_transaction_pipeline(self, r): r['a'] = 1 @@ -207,10 +239,10 @@ class TestPipeline(object): assert unicode(ex.value).startswith('Command # 1 (LLEN a) of ' 'pipeline caused error: ') - assert r['a'] == b('1') + assert r['a'] == b'1' def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): - key = unichr(3456) + u('abcd') + unichr(3421) + key = unichr(3456) + 'abcd' + unichr(3421) r[key] = 1 with r.pipeline(transaction=False) as pipe: pipe.llen(key) @@ -223,4 +255,21 @@ class TestPipeline(object): 'error: ') % key assert unicode(ex.value).startswith(expected) - assert r[key] == b('1') + assert r[key] == b'1' + + def test_pipeline_with_bitfield(self, r): + with r.pipeline() as pipe: + pipe.set('a', '1') + bf = pipe.bitfield('b') + pipe2 = (bf + .set('u8', 8, 255) + .get('u8', 0) + .get('u4', 8) # 1111 + .get('u4', 12) # 1111 + .get('u4', 13) # 1110 + .execute()) + pipe.get('a') + response = pipe.execute() + + assert pipe == pipe2 + assert response == [True, [0, 0, 15, 15, 14], b'1'] diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5486b75..91e9e48 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,12 +1,13 @@ -from __future__ import with_statement +from __future__ import unicode_literals import pytest import time import redis from redis.exceptions import ConnectionError -from redis._compat import basestring, u, unichr +from redis._compat import basestring, unichr -from .conftest import r as _redis_client +from .conftest import _get_client +from .conftest import skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -26,7 +27,7 @@ def make_message(type, channel, data, pattern=None): return { 'type': type, 'pattern': pattern and pattern.encode('utf-8') or None, - 'channel': channel.encode('utf-8'), + 'channel': channel and channel.encode('utf-8') or None, 'data': data.encode('utf-8') if isinstance(data, basestring) else data } @@ -39,7 +40,7 @@ def make_subscribe_test_data(pubsub, type): 'unsub_type': 'unsubscribe', 'sub_func': pubsub.subscribe, 'unsub_func': pubsub.unsubscribe, - 'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')] + 'keys': ['foo', 'bar', 'uni' + unichr(4456) + 'code'] } elif type == 'pattern': return { @@ -48,7 +49,7 @@ def make_subscribe_test_data(pubsub, type): 'unsub_type': 'punsubscribe', 'sub_func': pubsub.psubscribe, 'unsub_func': pubsub.punsubscribe, - 'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')] + 'keys': ['f*', 'b*', 'uni' + unichr(4456) + '*'] } assert False, 'invalid subscribe type: %s' % type @@ -266,7 +267,7 @@ class TestPubSubMessages(object): def test_unicode_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) - channel = u('uni') + unichr(4456) + u('code') + channel = 'uni' + unichr(4456) + 'code' channels = {channel: self.message_handler} p.subscribe(**channels) assert r.publish(channel, 'test message') == 1 @@ -275,21 +276,29 @@ class TestPubSubMessages(object): def test_unicode_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) - pattern = u('uni') + unichr(4456) + u('*') - channel = u('uni') + unichr(4456) + u('code') + pattern = 'uni' + unichr(4456) + '*' + channel = 'uni' + unichr(4456) + 'code' p.psubscribe(**{pattern: self.message_handler}) assert r.publish(channel, 'test message') == 1 assert wait_for_message(p) is None assert self.message == make_message('pmessage', channel, 'test message', pattern=pattern) + def test_get_message_without_subscribe(self, r): + p = r.pubsub() + with pytest.raises(RuntimeError) as info: + p.get_message() + expect = ('connection not set: ' + 'did you forget to call subscribe() or psubscribe()?') + assert expect in info.exconly() + class TestPubSubAutoDecoding(object): "These tests only validate that we get unicode values back" - channel = u('uni') + unichr(4456) + u('code') - pattern = u('uni') + unichr(4456) + u('*') - data = u('abc') + unichr(4458) + u('123') + channel = 'uni' + unichr(4456) + 'code' + pattern = 'uni' + unichr(4456) + '*' + data = 'abc' + unichr(4458) + '123' def make_message(self, type, channel, data, pattern=None): return { @@ -307,7 +316,7 @@ class TestPubSubAutoDecoding(object): @pytest.fixture() def r(self, request): - return _redis_client(request=request, decode_responses=True) + return _get_client(redis.Redis, request=request, decode_responses=True) def test_channel_subscribe_unsubscribe(self, r): p = r.pubsub() @@ -357,7 +366,7 @@ class TestPubSubAutoDecoding(object): # test that we reconnected to the correct channel p.connection.disconnect() assert wait_for_message(p) is None # should reconnect - new_data = self.data + u('new data') + new_data = self.data + 'new data' r.publish(self.channel, new_data) assert wait_for_message(p) is None assert self.message == self.make_message('message', self.channel, @@ -375,7 +384,7 @@ class TestPubSubAutoDecoding(object): # test that we reconnected to the correct pattern p.connection.disconnect() assert wait_for_message(p) is None # should reconnect - new_data = self.data + u('new data') + new_data = self.data + 'new data' r.publish(self.channel, new_data) assert wait_for_message(p) is None assert self.message == self.make_message('pmessage', self.channel, @@ -390,3 +399,52 @@ class TestPubSubRedisDown(object): p = r.pubsub() with pytest.raises(ConnectionError): p.subscribe('foo') + + +class TestPubSubPubSubSubcommands(object): + + @skip_if_server_version_lt('2.8.0') + def test_pubsub_channels(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo', 'bar', 'baz', 'quux') + channels = sorted(r.pubsub_channels()) + assert channels == [b'bar', b'baz', b'foo', b'quux'] + + @skip_if_server_version_lt('2.8.0') + def test_pubsub_numsub(self, r): + p1 = r.pubsub(ignore_subscribe_messages=True) + p1.subscribe('foo', 'bar', 'baz') + p2 = r.pubsub(ignore_subscribe_messages=True) + p2.subscribe('bar', 'baz') + p3 = r.pubsub(ignore_subscribe_messages=True) + p3.subscribe('baz') + + channels = [(b'foo', 1), (b'bar', 2), (b'baz', 3)] + assert channels == r.pubsub_numsub('foo', 'bar', 'baz') + + @skip_if_server_version_lt('2.8.0') + def test_pubsub_numpat(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe('*oo', '*ar', 'b*z') + assert r.pubsub_numpat() == 3 + + +class TestPubSubPings(object): + + @skip_if_server_version_lt('3.0.0') + def test_send_pubsub_ping(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo') + p.ping() + assert wait_for_message(p) == make_message(type='pong', channel=None, + data='', + pattern=None) + + @skip_if_server_version_lt('3.0.0') + def test_send_pubsub_ping_message(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo') + p.ping(message='hello world') + assert wait_for_message(p) == make_message(type='pong', channel=None, + data='hello world', + pattern=None) diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 2213ec6..b3d52a5 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,8 +1,7 @@ -from __future__ import with_statement +from __future__ import unicode_literals import pytest from redis import exceptions -from redis._compat import b multiply_script = """ @@ -57,44 +56,49 @@ class TestScripting(object): def test_script_object(self, r): r.set('a', 2) multiply = r.register_script(multiply_script) - assert not multiply.sha - # test evalsha fail -> script load + retry + precalculated_sha = multiply.sha + assert precalculated_sha + assert r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) assert multiply(keys=['a'], args=[3]) == 6 - assert multiply.sha + # At this point, the script should be loaded assert r.script_exists(multiply.sha) == [True] - # test first evalsha + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block assert multiply(keys=['a'], args=[3]) == 6 def test_script_object_in_pipeline(self, r): multiply = r.register_script(multiply_script) - assert not multiply.sha + precalculated_sha = multiply.sha + assert precalculated_sha pipe = r.pipeline() pipe.set('a', 2) pipe.get('a') multiply(keys=['a'], args=[3], client=pipe) - # even though the pipeline wasn't executed yet, we made sure the - # script was loaded and got a valid sha - assert multiply.sha - assert r.script_exists(multiply.sha) == [True] + assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] - assert pipe.execute() == [True, b('2'), 6] + assert pipe.execute() == [True, b'2', 6] + # The script should have been loaded by pipe.execute() + assert r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha # purge the script from redis's cache and re-run the pipeline - # the multiply script object knows it's sha, so it shouldn't get - # reloaded until pipe.execute() + # the multiply script should be reloaded by pipe.execute() r.script_flush() pipe = r.pipeline() pipe.set('a', 2) pipe.get('a') - assert multiply.sha multiply(keys=['a'], args=[3], client=pipe) assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] - assert pipe.execute() == [True, b('2'), 6] + assert pipe.execute() == [True, b'2', 6] + assert r.script_exists(multiply.sha) == [True] def test_eval_msgpack_pipeline_error_in_lua(self, r): msgpack_hello = r.register_script(msgpack_hello_script) - assert not msgpack_hello.sha + assert msgpack_hello.sha pipe = r.pipeline() @@ -104,8 +108,9 @@ class TestScripting(object): msgpack_hello(args=[msgpack_message_1], client=pipe) - assert r.script_exists(msgpack_hello.sha) == [True] + assert r.script_exists(msgpack_hello.sha) == [False] assert pipe.execute()[0] == b'hello Joe' + assert r.script_exists(msgpack_hello.sha) == [True] msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 0a6e98b..1081e2b 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import pytest from redis import exceptions @@ -15,10 +14,12 @@ class SentinelTestClient(object): def sentinel_masters(self): self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) return {self.cluster.service_name: self.cluster.master} def sentinel_slaves(self, master_name): self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) if master_name != self.cluster.service_name: return [] return self.cluster.slaves @@ -38,11 +39,16 @@ class SentinelTestCluster(object): self.service_name = service_name self.slaves = [] self.nodes_down = set() + self.nodes_timeout = set() def connection_error_if_down(self, node): if node.id in self.nodes_down: raise exceptions.ConnectionError + def timeout_if_down(self, node): + if node.id in self.nodes_timeout: + raise exceptions.TimeoutError + def client(self, host, port, **kwargs): return SentinelTestClient(self, (host, port)) @@ -50,10 +56,10 @@ class SentinelTestCluster(object): @pytest.fixture() def cluster(request): def teardown(): - redis.sentinel.StrictRedis = saved_StrictRedis + redis.sentinel.Redis = saved_Redis cluster = SentinelTestCluster() - saved_StrictRedis = redis.sentinel.StrictRedis - redis.sentinel.StrictRedis = cluster.client + saved_Redis = redis.sentinel.Redis + redis.sentinel.Redis = cluster.client request.addfinalizer(teardown) return cluster @@ -82,6 +88,15 @@ def test_discover_master_sentinel_down(cluster, sentinel): assert sentinel.sentinels[0].id == ('bar', 26379) +def test_discover_master_sentinel_timeout(cluster, sentinel): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == ('127.0.0.1', 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + def test_master_min_other_sentinels(cluster): sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) # min_other_sentinels @@ -130,6 +145,12 @@ def test_discover_slaves(cluster, sentinel): cluster.nodes_down.add(('foo', 26379)) assert sentinel.discover_slaves('mymaster') == [ ('slave0', 1234), ('slave1', 1234)] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] def test_master_for(cluster, sentinel): @@ -1,45 +1,17 @@ [tox] -envlist = py26, py27, py32, py33, py34, hi26, hi27, hi32, hi33, hi34, pep8 +minversion = 1.8 +envlist = {py27,py34,py35,py36}-{plain,hiredis}, pycodestyle [testenv] -deps=pytest>=2.5.0 -commands = py.test [] - -[testenv:hi26] -basepython = python2.6 -deps = - hiredis>=0.1.3 - pytest>=2.5.0 -commands = py.test [] - -[testenv:hi27] -basepython = python2.7 -deps = - hiredis>=0.1.3 - pytest>=2.5.0 -commands = py.test [] - -[testenv:hi32] -basepython = python3.2 -deps = - hiredis>=0.1.3 - pytest>=2.5.0 -commands = py.test [] - -[testenv:hi33] -basepython = python3.3 -deps = - hiredis>=0.1.3 - pytest>=2.5.0 -commands = py.test [] - -[testenv:hi34] -basepython = python3.4 deps = - hiredis>=0.1.3 - pytest>=2.5.0 -commands = py.test [] + mock + pytest >= 2.7.0 + hiredis: hiredis >= 0.1.3 +commands = py.test {posargs} -[testenv:pep8] -deps = pep8 -commands = pep8 --repeat --show-source --exclude=.venv,.tox,dist,docs,build,*.egg . +[testenv:pycodestyle] +basepython = python3.6 +deps = pycodestyle +commands = pycodestyle +skipsdist = true +skip_install = true diff --git a/vagrant/.bash_profile b/vagrant/.bash_profile deleted file mode 100644 index e3d9bca..0000000 --- a/vagrant/.bash_profile +++ /dev/null @@ -1 +0,0 @@ -PATH=$PATH:/home/vagrant/redis/bin diff --git a/vagrant/Vagrantfile b/vagrant/Vagrantfile index 7465ccd..3ee7aee 100644 --- a/vagrant/Vagrantfile +++ b/vagrant/Vagrantfile @@ -12,11 +12,11 @@ Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| config.vm.synced_folder "../", "/home/vagrant/redis-py" # install the redis server - config.vm.provision :shell, :path => "bootstrap.sh" - config.vm.provision :shell, :path => "build_redis.sh" - config.vm.provision :shell, :path => "install_redis.sh" - config.vm.provision :shell, :path => "install_sentinel.sh" - config.vm.provision :file, :source => ".bash_profile", :destination => "/home/vagrant/.bash_profile" + config.vm.provision :shell, :path => "../build_tools/bootstrap.sh" + config.vm.provision :shell, :path => "../build_tools/build_redis.sh" + config.vm.provision :shell, :path => "../build_tools/install_redis.sh" + config.vm.provision :shell, :path => "../build_tools/install_sentinel.sh" + config.vm.provision :file, :source => "../build_tools/.bash_profile", :destination => "/home/vagrant/.bash_profile" # setup forwarded ports config.vm.network "forwarded_port", guest: 6379, host: 6379 |