summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES5
-rwxr-xr-xredis/connection.py25
-rw-r--r--tests/test_encoding.py47
-rw-r--r--tests/test_pipeline.py10
4 files changed, 75 insertions, 12 deletions
diff --git a/CHANGES b/CHANGES
index 5cbc7ca..97a53ed 100644
--- a/CHANGES
+++ b/CHANGES
@@ -10,7 +10,10 @@
* Optimized Lock's blocking_timeout and sleep. If the lock cannot be
acquired and the sleep value would cause the loop to sleep beyond
blocking_timeout, fail immediately. Thanks @clslgrnc. #1263
-
+ * Added support for passing Python memoryviews to Redis command args that
+ expect strings or bytes. The memoryview instance is sent directly to
+ the socket such that there are zero copies made of the underlying data
+ during command packing. Thanks @Cody-G. #1265, #1285
* 3.4.1
* Move the username argument in the Redis and Connection classes to the
end of the argument list. This helps those poor souls that specify all
diff --git a/redis/connection.py b/redis/connection.py
index f54aace..c61517e 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -94,7 +94,7 @@ SENTINEL = object()
class Encoder(object):
- "Encode strings to bytes and decode bytes to strings"
+ "Encode strings to bytes-like and decode bytes-like to strings"
def __init__(self, encoding, encoding_errors, decode_responses):
self.encoding = encoding
@@ -102,8 +102,8 @@ class Encoder(object):
self.decode_responses = decode_responses
def encode(self, value):
- "Return a bytestring representation of the value"
- if isinstance(value, bytes):
+ "Return a bytestring or bytes-like representation of the value"
+ if isinstance(value, (bytes, memoryview)):
return value
elif isinstance(value, bool):
# special case bool since it is a subclass of int
@@ -124,9 +124,12 @@ class Encoder(object):
return value
def decode(self, value, force=False):
- "Return a unicode string from the byte representation"
- if (self.decode_responses or force) and isinstance(value, bytes):
- value = value.decode(self.encoding, self.encoding_errors)
+ "Return a unicode string from the bytes-like representation"
+ if self.decode_responses or force:
+ if isinstance(value, memoryview):
+ value = value.tobytes()
+ if isinstance(value, bytes):
+ value = value.decode(self.encoding, self.encoding_errors)
return value
@@ -767,9 +770,10 @@ class Connection(object):
buffer_cutoff = self._buffer_cutoff
for arg in imap(self.encoder.encode, args):
# to avoid large string mallocs, chunk the command into the
- # output list if we're sending large values
+ # output list if we're sending large values or memoryviews
arg_length = len(arg)
- if len(buff) > buffer_cutoff or arg_length > buffer_cutoff:
+ if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff
+ or isinstance(arg, memoryview)):
buff = SYM_EMPTY.join(
(buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF))
output.append(buff)
@@ -792,12 +796,13 @@ class Connection(object):
for cmd in commands:
for chunk in self.pack_command(*cmd):
chunklen = len(chunk)
- if buffer_length > buffer_cutoff or chunklen > buffer_cutoff:
+ if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff
+ or isinstance(chunk, memoryview)):
output.append(SYM_EMPTY.join(pieces))
buffer_length = 0
pieces = []
- if chunklen > self._buffer_cutoff:
+ if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
output.append(chunk)
else:
pieces.append(chunk)
diff --git a/tests/test_encoding.py b/tests/test_encoding.py
index 3f43006..ea7db7f 100644
--- a/tests/test_encoding.py
+++ b/tests/test_encoding.py
@@ -3,6 +3,7 @@ import pytest
import redis
from redis._compat import unichr, unicode
+from redis.connection import Connection
from .conftest import _get_client
@@ -11,13 +12,45 @@ class TestEncoding(object):
def r(self, request):
return _get_client(redis.Redis, request=request, decode_responses=True)
- def test_simple_encoding(self, r):
+ @pytest.fixture()
+ def r_no_decode(self, request):
+ return _get_client(
+ redis.Redis,
+ request=request,
+ decode_responses=False,
+ )
+
+ def test_simple_encoding(self, r_no_decode):
+ unicode_string = unichr(3456) + 'abcd' + unichr(3421)
+ r_no_decode['unicode-string'] = unicode_string.encode('utf-8')
+ cached_val = r_no_decode['unicode-string']
+ assert isinstance(cached_val, bytes)
+ assert unicode_string == cached_val.decode('utf-8')
+
+ def test_simple_encoding_and_decoding(self, r):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
r['unicode-string'] = unicode_string
cached_val = r['unicode-string']
assert isinstance(cached_val, unicode)
assert unicode_string == cached_val
+ def test_memoryview_encoding(self, r_no_decode):
+ unicode_string = unichr(3456) + 'abcd' + unichr(3421)
+ unicode_string_view = memoryview(unicode_string.encode('utf-8'))
+ r_no_decode['unicode-string-memoryview'] = unicode_string_view
+ cached_val = r_no_decode['unicode-string-memoryview']
+ # The cached value won't be a memoryview because it's a copy from Redis
+ assert isinstance(cached_val, bytes)
+ assert unicode_string == cached_val.decode('utf-8')
+
+ def test_memoryview_encoding_and_decoding(self, r):
+ unicode_string = unichr(3456) + 'abcd' + unichr(3421)
+ unicode_string_view = memoryview(unicode_string.encode('utf-8'))
+ r['unicode-string-memoryview'] = unicode_string_view
+ cached_val = r['unicode-string-memoryview']
+ assert isinstance(cached_val, unicode)
+ assert unicode_string == cached_val
+
def test_list_encoding(self, r):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
result = [unicode_string, unicode_string, unicode_string]
@@ -39,6 +72,18 @@ class TestEncodingErrors(object):
assert r.get('a') == 'foo\ufffd'
+class TestMemoryviewsAreNotPacked(object):
+ def test_memoryviews_are_not_packed(self):
+ c = Connection()
+ arg = memoryview(b'some_arg')
+ arg_list = ['SOME_COMMAND', arg]
+ cmd = c.pack_command(*arg_list)
+ assert cmd[1] is arg
+ cmds = c.pack_commands([arg_list, arg_list])
+ assert cmds[1] is arg
+ assert cmds[3] is arg
+
+
class TestCommandsAreNotEncoded(object):
@pytest.fixture()
def r(self, request):
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
index 867666b..828b989 100644
--- a/tests/test_pipeline.py
+++ b/tests/test_pipeline.py
@@ -29,6 +29,16 @@ class TestPipeline(object):
[(b'z1', 2.0), (b'z2', 4)],
]
+ def test_pipeline_memoryview(self, r):
+ with r.pipeline() as pipe:
+ (pipe.set('a', memoryview(b'a1'))
+ .get('a'))
+ assert pipe.execute() == \
+ [
+ True,
+ b'a1',
+ ]
+
def test_pipeline_length(self, r):
with r.pipeline() as pipe:
# Initially empty.