diff options
author | JianGuoPinterest <51248389+JianGuoPinterest@users.noreply.github.com> | 2019-06-12 15:03:53 -0700 |
---|---|---|
committer | Jon Parise <jon@pinterest.com> | 2019-06-12 15:03:53 -0700 |
commit | 7b391dd5923558542b97bc0ac5053f51cd3025b7 (patch) | |
tree | 9cc0b64f849be6c8434144712728ed526fc23968 | |
parent | 26f7c1b1371e86b0aab7f640af980a59df1ec098 (diff) | |
download | pymemcache-7b391dd5923558542b97bc0ac5053f51cd3025b7.tar.gz |
extracted _extract_value() from _fetch_cmd() (#237)
extract _extract_value() from _fetch_cmd() to support customized value extraction logic.
-rw-r--r-- | pymemcache/client/base.py | 48 | ||||
-rw-r--r-- | pymemcache/test/test_client.py | 41 |
2 files changed, 68 insertions, 21 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index e44151b..ff5a8d0 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -762,6 +762,31 @@ class Client(object): error = line[line.find(b' ') + 1:] raise MemcacheServerError(error) + def _extract_value(self, expect_cas, line, buf, remapped_keys, + prefixed_keys): + """ + This function is abstracted from _fetch_cmd to support different ways + of value extraction. In order to use this feature, _extract_value needs + to be overriden in the subclass. + """ + if expect_cas: + _, key, flags, size, cas = line.split() + else: + try: + _, key, flags, size = line.split() + except Exception as e: + raise ValueError("Unable to parse line %s: %s" % (line, e)) + + buf, value = _readvalue(self.sock, buf, int(size)) + key = remapped_keys[key] + if self.deserializer: + value = self.deserializer(key, value, int(flags)) + + if expect_cas: + return key, (value, cas), buf + else: + return key, value, buf + def _fetch_cmd(self, name, keys, expect_cas): prefixed_keys = [self.check_key(k) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) @@ -783,25 +808,10 @@ class Client(object): if line == b'END' or line == b'OK': return result elif line.startswith(b'VALUE'): - if expect_cas: - _, key, flags, size, cas = line.split() - else: - try: - _, key, flags, size = line.split() - except Exception as e: - raise ValueError("Unable to parse line %s: %s" - % (line, str(e))) - - buf, value = _readvalue(self.sock, buf, int(size)) - key = remapped_keys[key] - - if self.deserializer: - value = self.deserializer(key, value, int(flags)) - - if expect_cas: - result[key] = (value, cas) - else: - result[key] = value + key, value, buf = self._extract_value(expect_cas, line, buf, + remapped_keys, + prefixed_keys) + result[key] = value elif name == b'stats' and line.startswith(b'STAT'): key_value = line.split() result[key_value[1]] = key_value[2] diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 937c486..773bac9 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -23,6 +23,7 @@ import os import mock import socket import unittest + import pytest from pymemcache.client.base import PooledClient, Client @@ -114,6 +115,13 @@ class MockSocketModule(object): return getattr(socket, name) +class CustomizedClient(Client): + + def _extract_value(self, expect_cas, line, buf, remapped_keys, + prefixed_keys): + return b'key', b'value', b'END\r\n' + + @pytest.mark.unit() class ClientTestMixin(object): def make_client(self, mock_socket_values, **kwargs): @@ -126,6 +134,16 @@ class ClientTestMixin(object): setattr, client, "sock", sock)) return client + def make_customized_client(self, mock_socket_values, **kwargs): + client = CustomizedClient(None, **kwargs) + # mock out client._connect() rather than hard-settting client.sock to + # ensure methods are checking whether self.sock is None before + # attempting to use it + sock = MockSocket(list(mock_socket_values)) + client._connect = mock.Mock(side_effect=functools.partial( + setattr, client, "sock", sock)) + return client + def test_set_success(self): client = self.make_client([b'STORED\r\n']) result = client.set(b'key', b'value', noreply=False) @@ -273,11 +291,21 @@ class ClientTestMixin(object): result = client.get(b'key') assert result is None + # Unit test for customized client (override _extract_value) + client = self.make_customized_client([b'END\r\n']) + result = client.get(b'key') + assert result is None + def test_get_not_found_default(self): client = self.make_client([b'END\r\n']) result = client.get(b'key', default='foobar') assert result == 'foobar' + # Unit test for customized client (override _extract_value) + client = self.make_customized_client([b'END\r\n']) + result = client.get(b'key', default='foobar') + assert result == 'foobar' + def test_get_found(self): client = self.make_client([ b'STORED\r\n', @@ -287,6 +315,15 @@ class ClientTestMixin(object): result = client.get(b'key') assert result == b'value' + # Unit test for customized client (override _extract_value) + client = self.make_customized_client([ + b'STORED\r\n', + b'VALUE key 0 5\r\nvalue\r\nEND\r\n', + ]) + client.set(b'key', b'value', noreply=False) + result = client.get(b'key') + assert result == b'value' + def test_get_many_none_found(self): client = self.make_client([b'END\r\n']) result = client.get_many([b'key1', b'key2']) @@ -313,8 +350,8 @@ class ClientTestMixin(object): b'VALUE key1 0 6\r\nvalue1\r\n', b'VALUE key2 0 6\r\nvalue2\r\nEND\r\n', ]) - result = client.set(b'key1', b'value1', noreply=False) - result = client.set(b'key2', b'value2', noreply=False) + client.set(b'key1', b'value1', noreply=False) + client.set(b'key2', b'value2', noreply=False) result = client.get_many([b'key1', b'key2']) assert result == {b'key1': b'value1', b'key2': b'value2'} |