From 3310fe415a4c3a6ec9a9fcaf9a9fa24808219429 Mon Sep 17 00:00:00 2001 From: Cody-G Date: Sat, 8 Feb 2020 13:19:17 -0800 Subject: Support memoryview encoding/decoding as a no-op This allows memoryview instances to be passed to Redis command args that expect strings or bytes. The memoryview instance is sent directly to the socket such that there are zero copies made of the underlying data during command packing. Fixes #1265 Fixes #1285 --- CHANGES | 5 ++++- redis/connection.py | 25 +++++++++++++++---------- tests/test_encoding.py | 47 ++++++++++++++++++++++++++++++++++++++++++++++- tests/test_pipeline.py | 10 ++++++++++ 4 files changed, 75 insertions(+), 12 deletions(-) diff --git a/CHANGES b/CHANGES index 5cbc7ca..97a53ed 100644 --- a/CHANGES +++ b/CHANGES @@ -10,7 +10,10 @@ * Optimized Lock's blocking_timeout and sleep. If the lock cannot be acquired and the sleep value would cause the loop to sleep beyond blocking_timeout, fail immediately. Thanks @clslgrnc. #1263 - + * Added support for passing Python memoryviews to Redis command args that + expect strings or bytes. The memoryview instance is sent directly to + the socket such that there are zero copies made of the underlying data + during command packing. Thanks @Cody-G. #1265, #1285 * 3.4.1 * Move the username argument in the Redis and Connection classes to the end of the argument list. This helps those poor souls that specify all diff --git a/redis/connection.py b/redis/connection.py index f54aace..c61517e 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -94,7 +94,7 @@ SENTINEL = object() class Encoder(object): - "Encode strings to bytes and decode bytes to strings" + "Encode strings to bytes-like and decode bytes-like to strings" def __init__(self, encoding, encoding_errors, decode_responses): self.encoding = encoding @@ -102,8 +102,8 @@ class Encoder(object): self.decode_responses = decode_responses def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, bytes): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): return value elif isinstance(value, bool): # special case bool since it is a subclass of int @@ -124,9 +124,12 @@ class Encoder(object): return value def decode(self, value, force=False): - "Return a unicode string from the byte representation" - if (self.decode_responses or force) and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) return value @@ -767,9 +770,10 @@ class Connection(object): buffer_cutoff = self._buffer_cutoff for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the - # output list if we're sending large values + # output list if we're sending large values or memoryviews arg_length = len(arg) - if len(buff) > buffer_cutoff or arg_length > buffer_cutoff: + if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff + or isinstance(arg, memoryview)): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)) output.append(buff) @@ -792,12 +796,13 @@ class Connection(object): for cmd in commands: for chunk in self.pack_command(*cmd): chunklen = len(chunk) - if buffer_length > buffer_cutoff or chunklen > buffer_cutoff: + if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff + or isinstance(chunk, memoryview)): output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] - if chunklen > self._buffer_cutoff: + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): output.append(chunk) else: pieces.append(chunk) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 3f43006..ea7db7f 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -3,6 +3,7 @@ import pytest import redis from redis._compat import unichr, unicode +from redis.connection import Connection from .conftest import _get_client @@ -11,13 +12,45 @@ class TestEncoding(object): def r(self, request): return _get_client(redis.Redis, request=request, decode_responses=True) - def test_simple_encoding(self, r): + @pytest.fixture() + def r_no_decode(self, request): + return _get_client( + redis.Redis, + request=request, + decode_responses=False, + ) + + def test_simple_encoding(self, r_no_decode): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + r_no_decode['unicode-string'] = unicode_string.encode('utf-8') + cached_val = r_no_decode['unicode-string'] + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode('utf-8') + + def test_simple_encoding_and_decoding(self, r): unicode_string = unichr(3456) + 'abcd' + unichr(3421) r['unicode-string'] = unicode_string cached_val = r['unicode-string'] assert isinstance(cached_val, unicode) assert unicode_string == cached_val + def test_memoryview_encoding(self, r_no_decode): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + unicode_string_view = memoryview(unicode_string.encode('utf-8')) + r_no_decode['unicode-string-memoryview'] = unicode_string_view + cached_val = r_no_decode['unicode-string-memoryview'] + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode('utf-8') + + def test_memoryview_encoding_and_decoding(self, r): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + unicode_string_view = memoryview(unicode_string.encode('utf-8')) + r['unicode-string-memoryview'] = unicode_string_view + cached_val = r['unicode-string-memoryview'] + assert isinstance(cached_val, unicode) + assert unicode_string == cached_val + def test_list_encoding(self, r): unicode_string = unichr(3456) + 'abcd' + unichr(3421) result = [unicode_string, unicode_string, unicode_string] @@ -39,6 +72,18 @@ class TestEncodingErrors(object): assert r.get('a') == 'foo\ufffd' +class TestMemoryviewsAreNotPacked(object): + def test_memoryviews_are_not_packed(self): + c = Connection() + arg = memoryview(b'some_arg') + arg_list = ['SOME_COMMAND', arg] + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + class TestCommandsAreNotEncoded(object): @pytest.fixture() def r(self, request): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 867666b..828b989 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -29,6 +29,16 @@ class TestPipeline(object): [(b'z1', 2.0), (b'z2', 4)], ] + def test_pipeline_memoryview(self, r): + with r.pipeline() as pipe: + (pipe.set('a', memoryview(b'a1')) + .get('a')) + assert pipe.execute() == \ + [ + True, + b'a1', + ] + def test_pipeline_length(self, r): with r.pipeline() as pipe: # Initially empty. -- cgit v1.2.1