diff options
-rw-r--r-- | CHANGES | 5 | ||||
-rwxr-xr-x | redis/connection.py | 25 | ||||
-rw-r--r-- | tests/test_encoding.py | 47 | ||||
-rw-r--r-- | tests/test_pipeline.py | 10 |
4 files changed, 75 insertions, 12 deletions
@@ -10,7 +10,10 @@ * Optimized Lock's blocking_timeout and sleep. If the lock cannot be acquired and the sleep value would cause the loop to sleep beyond blocking_timeout, fail immediately. Thanks @clslgrnc. #1263 - + * Added support for passing Python memoryviews to Redis command args that + expect strings or bytes. The memoryview instance is sent directly to + the socket such that there are zero copies made of the underlying data + during command packing. Thanks @Cody-G. #1265, #1285 * 3.4.1 * Move the username argument in the Redis and Connection classes to the end of the argument list. This helps those poor souls that specify all diff --git a/redis/connection.py b/redis/connection.py index f54aace..c61517e 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -94,7 +94,7 @@ SENTINEL = object() class Encoder(object): - "Encode strings to bytes and decode bytes to strings" + "Encode strings to bytes-like and decode bytes-like to strings" def __init__(self, encoding, encoding_errors, decode_responses): self.encoding = encoding @@ -102,8 +102,8 @@ class Encoder(object): self.decode_responses = decode_responses def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, bytes): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): return value elif isinstance(value, bool): # special case bool since it is a subclass of int @@ -124,9 +124,12 @@ class Encoder(object): return value def decode(self, value, force=False): - "Return a unicode string from the byte representation" - if (self.decode_responses or force) and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) return value @@ -767,9 +770,10 @@ class Connection(object): buffer_cutoff = self._buffer_cutoff for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the - # output list if we're sending large values + # output list if we're sending large values or memoryviews arg_length = len(arg) - if len(buff) > buffer_cutoff or arg_length > buffer_cutoff: + if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff + or isinstance(arg, memoryview)): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)) output.append(buff) @@ -792,12 +796,13 @@ class Connection(object): for cmd in commands: for chunk in self.pack_command(*cmd): chunklen = len(chunk) - if buffer_length > buffer_cutoff or chunklen > buffer_cutoff: + if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff + or isinstance(chunk, memoryview)): output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] - if chunklen > self._buffer_cutoff: + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): output.append(chunk) else: pieces.append(chunk) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 3f43006..ea7db7f 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -3,6 +3,7 @@ import pytest import redis from redis._compat import unichr, unicode +from redis.connection import Connection from .conftest import _get_client @@ -11,13 +12,45 @@ class TestEncoding(object): def r(self, request): return _get_client(redis.Redis, request=request, decode_responses=True) - def test_simple_encoding(self, r): + @pytest.fixture() + def r_no_decode(self, request): + return _get_client( + redis.Redis, + request=request, + decode_responses=False, + ) + + def test_simple_encoding(self, r_no_decode): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + r_no_decode['unicode-string'] = unicode_string.encode('utf-8') + cached_val = r_no_decode['unicode-string'] + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode('utf-8') + + def test_simple_encoding_and_decoding(self, r): unicode_string = unichr(3456) + 'abcd' + unichr(3421) r['unicode-string'] = unicode_string cached_val = r['unicode-string'] assert isinstance(cached_val, unicode) assert unicode_string == cached_val + def test_memoryview_encoding(self, r_no_decode): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + unicode_string_view = memoryview(unicode_string.encode('utf-8')) + r_no_decode['unicode-string-memoryview'] = unicode_string_view + cached_val = r_no_decode['unicode-string-memoryview'] + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode('utf-8') + + def test_memoryview_encoding_and_decoding(self, r): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + unicode_string_view = memoryview(unicode_string.encode('utf-8')) + r['unicode-string-memoryview'] = unicode_string_view + cached_val = r['unicode-string-memoryview'] + assert isinstance(cached_val, unicode) + assert unicode_string == cached_val + def test_list_encoding(self, r): unicode_string = unichr(3456) + 'abcd' + unichr(3421) result = [unicode_string, unicode_string, unicode_string] @@ -39,6 +72,18 @@ class TestEncodingErrors(object): assert r.get('a') == 'foo\ufffd' +class TestMemoryviewsAreNotPacked(object): + def test_memoryviews_are_not_packed(self): + c = Connection() + arg = memoryview(b'some_arg') + arg_list = ['SOME_COMMAND', arg] + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + class TestCommandsAreNotEncoded(object): @pytest.fixture() def r(self, request): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 867666b..828b989 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -29,6 +29,16 @@ class TestPipeline(object): [(b'z1', 2.0), (b'z2', 4)], ] + def test_pipeline_memoryview(self, r): + with r.pipeline() as pipe: + (pipe.set('a', memoryview(b'a1')) + .get('a')) + assert pipe.execute() == \ + [ + True, + b'a1', + ] + def test_pipeline_length(self, r): with r.pipeline() as pipe: # Initially empty. |