summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES30
-rw-r--r--README.rst2
-rw-r--r--redis/__init__.py2
-rwxr-xr-xredis/client.py123
-rwxr-xr-xredis/connection.py84
-rw-r--r--setup.py5
-rw-r--r--tests/conftest.py5
-rw-r--r--tests/test_commands.py27
-rw-r--r--tests/test_scripting.py34
9 files changed, 200 insertions, 112 deletions
diff --git a/CHANGES b/CHANGES
index 2a26c26..d048ea5 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,4 +1,32 @@
-* 2.10.6 (in development)
+* 2.10.6
+ * Various performance improvements. Thanks cjsimpson
+ * Fixed a bug with SRANDMEMBER where
+ * Added HSTRLEN command. Thanks Alexander Putilin
+ * Added the TOUCH command. Thanks Anis Jonischkeit
+ * Remove unnecessary calls to the server when registering Lua scripts.
+ Thanks Ben Greenberg
+ * SET's EX and PX arguments now allow values of zero. Thanks huangqiyin
+ * Added PUBSUB {CHANNELS, NUMPAT, NUMSUB} commands. Thanks Angus Pearson
+ * PubSub connections that that encounter `InterruptedError`s now
+ retry automatically. Thanks Carlton Gibson and Seth M. Larson
+ * LPUSH and RPUSH commands run on PyPy now correctly returns the number
+ of items of the list. Thanks Jeong YunWon
+ * Added support to automatically retry socket EINTR errors. Thanks
+ Thomas Steinacher
+ * PubSubWorker threads started with `run_in_thread` are now daemonized
+ so the thread shuts down when the running process goes away. Thanks
+ Keith Ainsworth
+ * Added support for GEO commands. Thanks Pau Freixes, Alex DeBrie and
+ Abraham Toriz
+ * Made client construction from URLs smarter. Thanks Tim Savage
+ * Added support for CLUSTER * commands. Thanks Andy Huang
+ * The RESTORE command now accepts an optional `replace` boolean.
+ Thanks Yoshinari Takaoka
+ * Attempt to connect to a new Sentinel if a TimeoutError occurs. Thanks
+ Bo Lopker
+ * Fixed a bug in the client's `__getitem__` where a KeyError would be
+ raised if the value returned by the server is an empty string.
+ Thanks Javier Candeira.
* Socket timeouts when connecting to a server are now properly raised
as TimeoutErrors.
* 2.10.5
diff --git a/README.rst b/README.rst
index 307224a..97604f8 100644
--- a/README.rst
+++ b/README.rst
@@ -183,7 +183,7 @@ set_response_callback method. This method accepts two arguments: a command
name and the callback. Callbacks added in this manner are only valid on the
instance the callback is added to. If you want to define or override a callback
globally, you should make a subclass of the Redis client and add your callback
-to its REDIS_CALLBACKS class dictionary.
+to its RESPONSE_CALLBACKS class dictionary.
Response callbacks take at least one parameter: the response from the Redis
server. Keyword arguments may also be accepted in order to further control
diff --git a/redis/__init__.py b/redis/__init__.py
index 11281ba..6607155 100644
--- a/redis/__init__.py
+++ b/redis/__init__.py
@@ -22,7 +22,7 @@ from redis.exceptions import (
)
-__version__ = '2.10.5'
+__version__ = '2.10.6'
VERSION = tuple(map(int, __version__.split('.')))
__all__ = [
diff --git a/redis/client.py b/redis/client.py
index 92095de..e3b5ed7 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -6,6 +6,7 @@ import warnings
import time
import threading
import time as mod_time
+import hashlib
from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys,
itervalues, izip, long, nativestr, unicode,
safe_unicode)
@@ -244,7 +245,8 @@ def bool_ok(response):
def parse_client_list(response, **options):
clients = []
for c in nativestr(response).splitlines():
- clients.append(dict([pair.split('=') for pair in c.split(' ')]))
+ # Values might contain '='
+ clients.append(dict([pair.split('=', 1) for pair in c.split(' ')]))
return clients
@@ -360,9 +362,9 @@ class StrictRedis(object):
bool
),
string_keys_to_dict(
- 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN '
- 'LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE SETBIT '
- 'SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD '
+ 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY '
+ 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE '
+ 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD '
'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE '
'GEOADD',
int
@@ -447,7 +449,8 @@ class StrictRedis(object):
'CLUSTER SETSLOT': bool_ok,
'CLUSTER SLAVES': parse_cluster_nodes,
'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]),
- float(ll[1])), r)),
+ float(ll[1]))
+ if ll is not None else None, r)),
'GEOHASH': lambda r: list(map(nativestr, r)),
'GEORADIUS': parse_georadius_generic,
'GEORADIUSBYMEMBER': parse_georadius_generic,
@@ -1143,11 +1146,11 @@ class StrictRedis(object):
``px`` sets an expire flag on key ``name`` for ``px`` milliseconds.
- ``nx`` if set to True, set the value at key ``name`` to ``value`` if it
- does not already exist.
+ ``nx`` if set to True, set the value at key ``name`` to ``value`` only
+ if it does not exist.
- ``xx`` if set to True, set the value at key ``name`` to ``value`` if it
- already exists.
+ ``xx`` if set to True, set the value at key ``name`` to ``value`` only
+ if it already exists.
"""
pieces = [name, value]
if ex is not None:
@@ -1217,6 +1220,13 @@ class StrictRedis(object):
"""
return self.execute_command('SUBSTR', name, start, end)
+ def touch(self, *args):
+ """
+ Alters the last access time of a key(s) ``*args``. A key is ignored
+ if it does not exist.
+ """
+ return self.execute_command('TOUCH', *args)
+
def ttl(self, name):
"Returns the number of seconds until the key ``name`` will expire"
return self.execute_command('TTL', name)
@@ -1264,7 +1274,7 @@ class StrictRedis(object):
RPOP a value off of the first non-empty list
named in the ``keys`` list.
- If none of the lists in ``keys`` has a value to LPOP, then block
+ If none of the lists in ``keys`` has a value to RPOP, then block
for ``timeout`` seconds, or until a value gets pushed on to one
of the lists.
@@ -2010,6 +2020,13 @@ class StrictRedis(object):
"Return the list of values within hash ``name``"
return self.execute_command('HVALS', name)
+ def hstrlen(self, name, key):
+ """
+ Return the number of bytes stored in the value of ``key``
+ within hash ``name``
+ """
+ return self.execute_command('HSTRLEN', name, key)
+
def publish(self, channel, message):
"""
Publish ``message`` on ``channel``.
@@ -2100,7 +2117,7 @@ class StrictRedis(object):
the triad longitude, latitude and name.
"""
if len(values) % 3 != 0:
- raise RedisError("GEOADD requires places with lat, lon and name"
+ raise RedisError("GEOADD requires places with lon, lat and name"
" values")
return self.execute_command('GEOADD', name, *values)
@@ -2129,7 +2146,7 @@ class StrictRedis(object):
"""
Return the positions of each item of ``values`` as members of
the specified key identified by the ``name``argument. Each position
- is represented by the pairs lat and lon.
+ is represented by the pairs lon and lat.
"""
return self.execute_command('GEOPOS', name, *values)
@@ -2324,13 +2341,7 @@ class PubSub(object):
self.connection = None
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
- conn = connection_pool.get_connection('pubsub', shard_hint)
- try:
- self.encoding = conn.encoding
- self.encoding_errors = conn.encoding_errors
- self.decode_responses = conn.decode_responses
- finally:
- connection_pool.release(conn)
+ self.encoder = self.connection_pool.get_encoder()
self.reset()
def __del__(self):
@@ -2362,29 +2373,14 @@ class PubSub(object):
if self.channels:
channels = {}
for k, v in iteritems(self.channels):
- if not self.decode_responses:
- k = k.decode(self.encoding, self.encoding_errors)
- channels[k] = v
+ channels[self.encoder.decode(k, force=True)] = v
self.subscribe(**channels)
if self.patterns:
patterns = {}
for k, v in iteritems(self.patterns):
- if not self.decode_responses:
- k = k.decode(self.encoding, self.encoding_errors)
- patterns[k] = v
+ patterns[self.encoder.decode(k, force=True)] = v
self.psubscribe(**patterns)
- def encode(self, value):
- """
- Encode the value so that it's identical to what we'll
- read off the connection
- """
- if self.decode_responses and isinstance(value, bytes):
- value = value.decode(self.encoding, self.encoding_errors)
- elif not self.decode_responses and isinstance(value, unicode):
- value = value.encode(self.encoding, self.encoding_errors)
- return value
-
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
@@ -2434,6 +2430,16 @@ class PubSub(object):
return None
return self._execute(connection, connection.read_response)
+ def _normalize_keys(self, data):
+ """
+ normalize channel/pattern names to be either bytes or strings
+ based on whether responses are automatically decoded. this saves us
+ from coercing the value for each message coming in.
+ """
+ encode = self.encoder.encode
+ decode = self.encoder.decode
+ return dict([(decode(encode(k)), v) for k, v in iteritems(data)])
+
def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
@@ -2444,15 +2450,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- new_patterns = {}
- new_patterns.update(dict.fromkeys(imap(self.encode, args)))
- for pattern, handler in iteritems(kwargs):
- new_patterns[self.encode(pattern)] = handler
+ new_patterns = dict.fromkeys(args)
+ new_patterns.update(kwargs)
ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
# update the patterns dict AFTER we send the command. we don't want to
# subscribe twice to these patterns, once for the command and again
# for the reconnection.
- self.patterns.update(new_patterns)
+ self.patterns.update(self._normalize_keys(new_patterns))
return ret_val
def punsubscribe(self, *args):
@@ -2474,15 +2478,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- new_channels = {}
- new_channels.update(dict.fromkeys(imap(self.encode, args)))
- for channel, handler in iteritems(kwargs):
- new_channels[self.encode(channel)] = handler
+ new_channels = dict.fromkeys(args)
+ new_channels.update(kwargs)
ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
# update the channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
- self.channels.update(new_channels)
+ self.channels.update(self._normalize_keys(new_channels))
return ret_val
def unsubscribe(self, *args):
@@ -2848,12 +2850,11 @@ class BasePipeline(object):
shas = [s.sha for s in scripts]
# we can't use the normal script_* methods because they would just
# get buffered in the pipeline.
- exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'})
+ exists = immediate('SCRIPT EXISTS', *shas)
if not all(exists):
for s, exist in izip(scripts, exists):
if not exist:
- s.sha = immediate('SCRIPT', 'LOAD', s.script,
- **{'parse': 'LOAD'})
+ s.sha = immediate('SCRIPT LOAD', s.script)
def execute(self, raise_on_error=True):
"Execute all the commands in the current pipeline"
@@ -2905,16 +2906,6 @@ class BasePipeline(object):
"Unwatches all previously specified keys"
return self.watching and self.execute_command('UNWATCH') or True
- def script_load_for_pipeline(self, script):
- "Make sure scripts are loaded prior to pipeline execution"
- # we need the sha now so that Script.__call__ can use it to run
- # evalsha.
- if not script.sha:
- script.sha = self.immediate_execute_command('SCRIPT', 'LOAD',
- script.script,
- **{'parse': 'LOAD'})
- self.scripts.add(script)
-
class StrictPipeline(BasePipeline, StrictRedis):
"Pipeline for the StrictRedis class"
@@ -2932,7 +2923,14 @@ class Script(object):
def __init__(self, registered_client, script):
self.registered_client = registered_client
self.script = script
- self.sha = ''
+ # Precalculate and store the SHA1 hex digest of the script.
+
+ if isinstance(script, basestring):
+ # We need the encoding from the client in order to generate an
+ # accurate byte representation of the script
+ encoder = registered_client.connection_pool.get_encoder()
+ script = encoder.encode(script)
+ self.sha = hashlib.sha1(script).hexdigest()
def __call__(self, keys=[], args=[], client=None):
"Execute the script, passing any required ``args``"
@@ -2941,12 +2939,13 @@ class Script(object):
args = tuple(keys) + tuple(args)
# make sure the Redis server knows about the script
if isinstance(client, BasePipeline):
- # make sure this script is good to go on pipeline
- client.script_load_for_pipeline(self)
+ # Make sure the pipeline can register the script before executing.
+ client.scripts.add(self)
try:
return client.evalsha(self.sha, len(keys), *args)
except NoScriptError:
# Maybe the client is pointed to a differnet server than the client
# that created this instance?
+ # Overwrite the sha just in case there was a discrepancy.
self.sha = client.script_load(self.script)
return client.evalsha(self.sha, len(keys), *args)
diff --git a/redis/connection.py b/redis/connection.py
index 728a627..c8f0513 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -94,6 +94,38 @@ class Token(object):
return self.value
+class Encoder(object):
+ "Encode strings to bytes and decode bytes to strings"
+
+ def __init__(self, encoding, encoding_errors, decode_responses):
+ self.encoding = encoding
+ self.encoding_errors = encoding_errors
+ self.decode_responses = decode_responses
+
+ def encode(self, value):
+ "Return a bytestring representation of the value"
+ if isinstance(value, Token):
+ return value.encoded_value
+ elif isinstance(value, bytes):
+ return value
+ elif isinstance(value, (int, long)):
+ value = b(str(value))
+ elif isinstance(value, float):
+ value = b(repr(value))
+ elif not isinstance(value, basestring):
+ # an object we don't know how to deal with. default to unicode()
+ value = unicode(value)
+ if isinstance(value, unicode):
+ value = value.encode(self.encoding, self.encoding_errors)
+ return value
+
+ def decode(self, value, force=False):
+ "Return a unicode string from the byte representation"
+ if (self.decode_responses or force) and isinstance(value, bytes):
+ value = value.decode(self.encoding, self.encoding_errors)
+ return value
+
+
class BaseParser(object):
EXCEPTION_CLASSES = {
'ERR': {
@@ -217,10 +249,9 @@ class SocketBuffer(object):
class PythonParser(BaseParser):
"Plain Python parsing class"
- encoding = None
-
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
+ self.encoder = None
self._sock = None
self._buffer = None
@@ -234,8 +265,7 @@ class PythonParser(BaseParser):
"Called when the socket connects"
self._sock = connection._sock
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
- if connection.decode_responses:
- self.encoding = connection.encoding
+ self.encoder = connection.encoder
def on_disconnect(self):
"Called when the socket disconnects"
@@ -245,7 +275,7 @@ class PythonParser(BaseParser):
if self._buffer is not None:
self._buffer.close()
self._buffer = None
- self.encoding = None
+ self.encoder = None
def can_read(self):
return self._buffer and bool(self._buffer.length)
@@ -292,8 +322,8 @@ class PythonParser(BaseParser):
if length == -1:
return None
response = [self.read_response() for i in xrange(length)]
- if isinstance(response, bytes) and self.encoding:
- response = response.decode(self.encoding)
+ if isinstance(response, bytes):
+ response = self.encoder.decode(response)
return response
@@ -324,8 +354,8 @@ class HiredisParser(BaseParser):
if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
kwargs['replyError'] = ResponseError
- if connection.decode_responses:
- kwargs['encoding'] = connection.encoding
+ if connection.encoder.decode_responses:
+ kwargs['encoding'] = connection.encoder.encoding
self._reader = hiredis.Reader(**kwargs)
self._next_response = False
@@ -394,6 +424,7 @@ class HiredisParser(BaseParser):
raise response[0]
return response
+
if HIREDIS_AVAILABLE:
DefaultParser = HiredisParser
else:
@@ -420,9 +451,7 @@ class Connection(object):
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self.retry_on_timeout = retry_on_timeout
- self.encoding = encoding
- self.encoding_errors = encoding_errors
- self.decode_responses = decode_responses
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
@@ -600,22 +629,6 @@ class Connection(object):
raise response
return response
- def encode(self, value):
- "Return a bytestring representation of the value"
- if isinstance(value, Token):
- return value.encoded_value
- elif isinstance(value, bytes):
- return value
- elif isinstance(value, (int, long)):
- value = b(str(value))
- elif isinstance(value, float):
- value = b(repr(value))
- elif not isinstance(value, basestring):
- value = unicode(value)
- if isinstance(value, unicode):
- value = value.encode(self.encoding, self.encoding_errors)
- return value
-
def pack_command(self, *args):
"Pack a series of arguments into the Redis protocol"
output = []
@@ -634,7 +647,7 @@ class Connection(object):
buff = SYM_EMPTY.join(
(SYM_STAR, b(str(len(args))), SYM_CRLF))
- for arg in imap(self.encode, args):
+ for arg in imap(self.encoder.encode, args):
# to avoid large string mallocs, chunk the command into the
# output list if we're sending large values
if len(buff) > 6000 or len(arg) > 6000:
@@ -723,9 +736,7 @@ class UnixDomainSocketConnection(Connection):
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
- self.encoding = encoding
- self.encoding_errors = encoding_errors
- self.decode_responses = decode_responses
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
@@ -955,6 +966,15 @@ class ConnectionPool(object):
self._in_use_connections.add(connection)
return connection
+ def get_encoder(self):
+ "Return an encoder based on encoding settings"
+ kwargs = self.connection_kwargs
+ return Encoder(
+ encoding=kwargs.get('encoding', 'utf-8'),
+ encoding_errors=kwargs.get('encoding_errors', 'strict'),
+ decode_responses=kwargs.get('decode_responses', False)
+ )
+
def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
diff --git a/setup.py b/setup.py
index 9c51d5a..0e8f07d 100644
--- a/setup.py
+++ b/setup.py
@@ -44,6 +44,11 @@ setup(
keywords=['Redis', 'key-value store'],
license='MIT',
packages=['redis'],
+ extras_require={
+ 'hiredis': [
+ "hiredis>=0.1.3",
+ ],
+ },
tests_require=[
'mock',
'pytest>=2.5.0',
diff --git a/tests/conftest.py b/tests/conftest.py
index d7b2b14..006697f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,6 +37,11 @@ def skip_if_server_version_lt(min_version):
return pytest.mark.skipif(check, reason="")
+def skip_if_server_version_gte(min_version):
+ check = StrictVersion(get_version()) >= StrictVersion(min_version)
+ return pytest.mark.skipif(check, reason="")
+
+
@pytest.fixture()
def r(request, **kwargs):
return _get_client(redis.Redis, request, **kwargs)
diff --git a/tests/test_commands.py b/tests/test_commands.py
index 246ef65..ed46298 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -10,7 +10,7 @@ from redis._compat import (unichr, u, b, ascii_letters, iteritems, iterkeys,
from redis.client import parse_info
from redis import exceptions
-from .conftest import skip_if_server_version_lt
+from .conftest import skip_if_server_version_lt, skip_if_server_version_gte
@pytest.fixture()
@@ -68,6 +68,14 @@ class TestRedisCommands(object):
assert r.client_setname('redis_py_test')
assert r.client_getname() == 'redis_py_test'
+ @skip_if_server_version_lt('2.6.9')
+ def test_client_list_after_client_setname(self, r):
+ r.client_setname('cl=i=ent')
+ clients = r.client_list()
+ assert isinstance(clients[0], dict)
+ assert 'name' in clients[0]
+ assert clients[0]['name'] == 'cl=i=ent'
+
def test_config_get(self, r):
data = r.config_get()
assert 'maxmemory' in data
@@ -901,6 +909,8 @@ class TestRedisCommands(object):
r.zadd('a', a1=1, a2=2, a3=3)
assert r.zcount('a', '-inf', '+inf') == 3
assert r.zcount('a', 1, 2) == 2
+ assert r.zcount('a', '(' + str(1), 2) == 1
+ assert r.zcount('a', 1, '(' + str(2)) == 1
assert r.zcount('a', 10, 20) == 0
def test_zincrby(self, r):
@@ -1229,6 +1239,12 @@ class TestRedisCommands(object):
remote_vals = r.hvals('a')
assert sorted(local_vals) == sorted(remote_vals)
+ @skip_if_server_version_lt('3.2.0')
+ def test_hstrlen(self, r):
+ r.hmset('a', {'1': '22', '2': '333'})
+ assert r.hstrlen('a', '1') == 2
+ assert r.hstrlen('a', '2') == 3
+
# SORT
def test_sort_basic(self, r):
r.rpush('a', '3', '2', '1', '4')
@@ -1453,6 +1469,15 @@ class TestRedisCommands(object):
[(2.19093829393386841, 41.43379028184083523),
(2.18737632036209106, 41.40634178640635099)]
+ @skip_if_server_version_lt('4.0.0')
+ def test_geopos_no_value(self, r):
+ assert r.geopos('barcelona', 'place1', 'place2') == [None, None]
+
+ @skip_if_server_version_lt('3.2.0')
+ @skip_if_server_version_gte('4.0.0')
+ def test_old_geopos_no_value(self, r):
+ assert r.geopos('barcelona', 'place1', 'place2') == []
+
@skip_if_server_version_lt('3.2.0')
def test_georadius(self, r):
values = (2.1909389952632, 41.433791470673, 'place1') +\
diff --git a/tests/test_scripting.py b/tests/test_scripting.py
index 2213ec6..9131de8 100644
--- a/tests/test_scripting.py
+++ b/tests/test_scripting.py
@@ -57,44 +57,49 @@ class TestScripting(object):
def test_script_object(self, r):
r.set('a', 2)
multiply = r.register_script(multiply_script)
- assert not multiply.sha
- # test evalsha fail -> script load + retry
+ precalculated_sha = multiply.sha
+ assert precalculated_sha
+ assert r.script_exists(multiply.sha) == [False]
+ # Test second evalsha block (after NoScriptError)
assert multiply(keys=['a'], args=[3]) == 6
- assert multiply.sha
+ # At this point, the script should be loaded
assert r.script_exists(multiply.sha) == [True]
- # test first evalsha
+ # Test that the precalculated sha matches the one from redis
+ assert multiply.sha == precalculated_sha
+ # Test first evalsha block
assert multiply(keys=['a'], args=[3]) == 6
def test_script_object_in_pipeline(self, r):
multiply = r.register_script(multiply_script)
- assert not multiply.sha
+ precalculated_sha = multiply.sha
+ assert precalculated_sha
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
multiply(keys=['a'], args=[3], client=pipe)
- # even though the pipeline wasn't executed yet, we made sure the
- # script was loaded and got a valid sha
- assert multiply.sha
- assert r.script_exists(multiply.sha) == [True]
+ assert r.script_exists(multiply.sha) == [False]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]
+ # The script should have been loaded by pipe.execute()
+ assert r.script_exists(multiply.sha) == [True]
+ # The precalculated sha should have been the correct one
+ assert multiply.sha == precalculated_sha
# purge the script from redis's cache and re-run the pipeline
- # the multiply script object knows it's sha, so it shouldn't get
- # reloaded until pipe.execute()
+ # the multiply script should be reloaded by pipe.execute()
r.script_flush()
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
- assert multiply.sha
multiply(keys=['a'], args=[3], client=pipe)
assert r.script_exists(multiply.sha) == [False]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]
+ assert r.script_exists(multiply.sha) == [True]
def test_eval_msgpack_pipeline_error_in_lua(self, r):
msgpack_hello = r.register_script(msgpack_hello_script)
- assert not msgpack_hello.sha
+ assert msgpack_hello.sha
pipe = r.pipeline()
@@ -104,8 +109,9 @@ class TestScripting(object):
msgpack_hello(args=[msgpack_message_1], client=pipe)
- assert r.script_exists(msgpack_hello.sha) == [True]
+ assert r.script_exists(msgpack_hello.sha) == [False]
assert pipe.execute()[0] == b'hello Joe'
+ assert r.script_exists(msgpack_hello.sha) == [True]
msgpack_hello_broken = r.register_script(msgpack_hello_script_broken)