summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianGuoPinterest <51248389+JianGuoPinterest@users.noreply.github.com>2019-06-12 15:03:53 -0700
committerJon Parise <jon@pinterest.com>2019-06-12 15:03:53 -0700
commit7b391dd5923558542b97bc0ac5053f51cd3025b7 (patch)
tree9cc0b64f849be6c8434144712728ed526fc23968
parent26f7c1b1371e86b0aab7f640af980a59df1ec098 (diff)
downloadpymemcache-7b391dd5923558542b97bc0ac5053f51cd3025b7.tar.gz
extracted _extract_value() from _fetch_cmd() (#237)
extract _extract_value() from _fetch_cmd() to support customized value extraction logic.
-rw-r--r--pymemcache/client/base.py48
-rw-r--r--pymemcache/test/test_client.py41
2 files changed, 68 insertions, 21 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index e44151b..ff5a8d0 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -762,6 +762,31 @@ class Client(object):
error = line[line.find(b' ') + 1:]
raise MemcacheServerError(error)
+ def _extract_value(self, expect_cas, line, buf, remapped_keys,
+ prefixed_keys):
+ """
+ This function is abstracted from _fetch_cmd to support different ways
+ of value extraction. In order to use this feature, _extract_value needs
+ to be overriden in the subclass.
+ """
+ if expect_cas:
+ _, key, flags, size, cas = line.split()
+ else:
+ try:
+ _, key, flags, size = line.split()
+ except Exception as e:
+ raise ValueError("Unable to parse line %s: %s" % (line, e))
+
+ buf, value = _readvalue(self.sock, buf, int(size))
+ key = remapped_keys[key]
+ if self.deserializer:
+ value = self.deserializer(key, value, int(flags))
+
+ if expect_cas:
+ return key, (value, cas), buf
+ else:
+ return key, value, buf
+
def _fetch_cmd(self, name, keys, expect_cas):
prefixed_keys = [self.check_key(k) for k in keys]
remapped_keys = dict(zip(prefixed_keys, keys))
@@ -783,25 +808,10 @@ class Client(object):
if line == b'END' or line == b'OK':
return result
elif line.startswith(b'VALUE'):
- if expect_cas:
- _, key, flags, size, cas = line.split()
- else:
- try:
- _, key, flags, size = line.split()
- except Exception as e:
- raise ValueError("Unable to parse line %s: %s"
- % (line, str(e)))
-
- buf, value = _readvalue(self.sock, buf, int(size))
- key = remapped_keys[key]
-
- if self.deserializer:
- value = self.deserializer(key, value, int(flags))
-
- if expect_cas:
- result[key] = (value, cas)
- else:
- result[key] = value
+ key, value, buf = self._extract_value(expect_cas, line, buf,
+ remapped_keys,
+ prefixed_keys)
+ result[key] = value
elif name == b'stats' and line.startswith(b'STAT'):
key_value = line.split()
result[key_value[1]] = key_value[2]
diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py
index 937c486..773bac9 100644
--- a/pymemcache/test/test_client.py
+++ b/pymemcache/test/test_client.py
@@ -23,6 +23,7 @@ import os
import mock
import socket
import unittest
+
import pytest
from pymemcache.client.base import PooledClient, Client
@@ -114,6 +115,13 @@ class MockSocketModule(object):
return getattr(socket, name)
+class CustomizedClient(Client):
+
+ def _extract_value(self, expect_cas, line, buf, remapped_keys,
+ prefixed_keys):
+ return b'key', b'value', b'END\r\n'
+
+
@pytest.mark.unit()
class ClientTestMixin(object):
def make_client(self, mock_socket_values, **kwargs):
@@ -126,6 +134,16 @@ class ClientTestMixin(object):
setattr, client, "sock", sock))
return client
+ def make_customized_client(self, mock_socket_values, **kwargs):
+ client = CustomizedClient(None, **kwargs)
+ # mock out client._connect() rather than hard-settting client.sock to
+ # ensure methods are checking whether self.sock is None before
+ # attempting to use it
+ sock = MockSocket(list(mock_socket_values))
+ client._connect = mock.Mock(side_effect=functools.partial(
+ setattr, client, "sock", sock))
+ return client
+
def test_set_success(self):
client = self.make_client([b'STORED\r\n'])
result = client.set(b'key', b'value', noreply=False)
@@ -273,11 +291,21 @@ class ClientTestMixin(object):
result = client.get(b'key')
assert result is None
+ # Unit test for customized client (override _extract_value)
+ client = self.make_customized_client([b'END\r\n'])
+ result = client.get(b'key')
+ assert result is None
+
def test_get_not_found_default(self):
client = self.make_client([b'END\r\n'])
result = client.get(b'key', default='foobar')
assert result == 'foobar'
+ # Unit test for customized client (override _extract_value)
+ client = self.make_customized_client([b'END\r\n'])
+ result = client.get(b'key', default='foobar')
+ assert result == 'foobar'
+
def test_get_found(self):
client = self.make_client([
b'STORED\r\n',
@@ -287,6 +315,15 @@ class ClientTestMixin(object):
result = client.get(b'key')
assert result == b'value'
+ # Unit test for customized client (override _extract_value)
+ client = self.make_customized_client([
+ b'STORED\r\n',
+ b'VALUE key 0 5\r\nvalue\r\nEND\r\n',
+ ])
+ client.set(b'key', b'value', noreply=False)
+ result = client.get(b'key')
+ assert result == b'value'
+
def test_get_many_none_found(self):
client = self.make_client([b'END\r\n'])
result = client.get_many([b'key1', b'key2'])
@@ -313,8 +350,8 @@ class ClientTestMixin(object):
b'VALUE key1 0 6\r\nvalue1\r\n',
b'VALUE key2 0 6\r\nvalue2\r\nEND\r\n',
])
- result = client.set(b'key1', b'value1', noreply=False)
- result = client.set(b'key2', b'value2', noreply=False)
+ client.set(b'key1', b'value1', noreply=False)
+ client.set(b'key2', b'value2', noreply=False)
result = client.get_many([b'key1', b'key2'])
assert result == {b'key1': b'value1', b'key2': b'value2'}