summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianGuoPinterest <51248389+JianGuoPinterest@users.noreply.github.com>2019-06-07 15:29:14 -0700
committerJon Parise <jon@pinterest.com>2019-06-07 15:29:14 -0700
commitf23602c176f72cda367f618f20e0c89940431130 (patch)
treeac1a27ad3d600027f139e2123f18d3872844dd55
parent708fb55e2af6d5a58046f1ff5597f0024b5865e2 (diff)
downloadpymemcache-f23602c176f72cda367f618f20e0c89940431130.tar.gz
add encoding in client functions (#232)
-rw-r--r--pymemcache/client/base.py17
-rw-r--r--pymemcache/client/hash.py5
-rw-r--r--pymemcache/test/test_client.py105
-rw-r--r--pymemcache/test/utils.py10
4 files changed, 120 insertions, 17 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index 7d67bdf..eabef47 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -206,7 +206,8 @@ class Client(object):
socket_module=socket,
key_prefix=b'',
default_noreply=True,
- allow_unicode_keys=False):
+ allow_unicode_keys=False,
+ encoding='ascii'):
"""
Constructor.
@@ -233,6 +234,7 @@ class Client(object):
store commands (except from cas, incr, and decr, which default to
False).
allow_unicode_keys: bool, support unicode (utf8) keys
+ encoding: optional str, controls data encoding (defaults to 'ascii').
Notes:
The constructor does not make a connection to memcached. The first
@@ -254,6 +256,7 @@ class Client(object):
self.key_prefix = key_prefix
self.default_noreply = default_noreply
self.allow_unicode_keys = allow_unicode_keys
+ self.encoding = encoding
def check_key(self, key):
"""Checks key and add key_prefix."""
@@ -803,7 +806,7 @@ class Client(object):
extra += b' ' + cas
if noreply:
extra += b' noreply'
- expire = six.text_type(expire).encode('ascii')
+ expire = six.text_type(expire).encode(self.encoding)
for key, data in six.iteritems(values):
# must be able to reliably map responses back to the original order
@@ -817,15 +820,15 @@ class Client(object):
if not isinstance(data, six.binary_type):
try:
- data = six.text_type(data).encode('ascii')
+ data = six.text_type(data).encode(self.encoding)
except UnicodeEncodeError as e:
raise MemcacheIllegalInputError(
"Data values must be binary-safe: %s" % e)
cmds.append(name + b' ' + key + b' ' +
- six.text_type(flags).encode('ascii') +
+ six.text_type(flags).encode(self.encoding) +
b' ' + expire +
- b' ' + six.text_type(len(data)).encode('ascii') +
+ b' ' + six.text_type(len(data)).encode(self.encoding) +
extra + b'\r\n' + data + b'\r\n')
if self.sock is None:
@@ -921,7 +924,8 @@ class PooledClient(object):
max_pool_size=None,
lock_generator=None,
default_noreply=True,
- allow_unicode_keys=False):
+ allow_unicode_keys=False,
+ encoding='ascii'):
self.server = server
self.serializer = serializer
self.deserializer = deserializer
@@ -942,6 +946,7 @@ class PooledClient(object):
after_remove=lambda client: client.close(),
max_size=max_pool_size,
lock_generator=lock_generator)
+ self.encoding = encoding
def check_key(self, key):
"""Checks key and add key_prefix."""
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index 1be4882..cd147b2 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -34,7 +34,8 @@ class HashClient(object):
use_pooling=False,
ignore_exc=False,
allow_unicode_keys=False,
- default_noreply=True
+ default_noreply=True,
+ encoding='ascii'
):
"""
Constructor.
@@ -55,6 +56,7 @@ class HashClient(object):
attempts.
dead_timeout (float): Time in seconds before attempting to add a node
back in the pool.
+ encoding: optional str, controls data encoding (defaults to 'ascii').
Further arguments are interpreted as for :py:class:`.Client`
constructor.
@@ -93,6 +95,7 @@ class HashClient(object):
for server, port in servers:
self.add_server(server, port)
+ self.encoding = encoding
def add_server(self, server, port):
key = '%s:%s' % (server, port)
diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py
index cf2cfd0..cb53c2c 100644
--- a/pymemcache/test/test_client.py
+++ b/pymemcache/test/test_client.py
@@ -131,11 +131,21 @@ class ClientTestMixin(object):
result = client.set(b'key', b'value', noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.set(b'key', b'value', noreply=False)
+ assert result is True
+
def test_set_future(self):
client = self.make_client([b'STORED\r\n'])
result = client.set(newbytes(b'key'), newbytes(b'value'), noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.set(newbytes(b'key'), newbytes(b'value'), noreply=False)
+ assert result is True
+
def test_set_unicode_key(self):
client = self.make_client([b''])
@@ -196,26 +206,54 @@ class ClientTestMixin(object):
result = client.set(b'key', b'value', noreply=True)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([], encoding='utf-8')
+ result = client.set(b'key', b'value', noreply=True)
+ assert result is True
+
def test_set_many_success(self):
client = self.make_client([b'STORED\r\n'])
result = client.set_many({b'key': b'value'}, noreply=False)
assert result == []
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.set_many({b'key': b'value'}, noreply=False)
+ assert result == []
+
def test_set_multi_success(self):
# Should just map to set_many
client = self.make_client([b'STORED\r\n'])
result = client.set_multi({b'key': b'value'}, noreply=False)
assert result == []
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.set_multi({b'key': b'value'}, noreply=False)
+ assert result == []
+
def test_add_stored(self):
client = self.make_client([b'STORED\r', b'\n'])
result = client.add(b'key', b'value', noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r', b'\n'], encoding='utf-8')
+ result = client.add(b'key', b'value', noreply=False)
+ assert result is True
+
def test_add_not_stored(self):
client = self.make_client([b'STORED\r', b'\n',
b'NOT_', b'STOR', b'ED', b'\r\n'])
+ client.add(b'key', b'value', noreply=False)
result = client.add(b'key', b'value', noreply=False)
+ assert result is False
+
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r', b'\n',
+ b'NOT_', b'STOR', b'ED', b'\r\n'],
+ encoding='utf-8')
+ client.add(b'key', b'value', noreply=False)
result = client.add(b'key', b'value', noreply=False)
assert result is False
@@ -234,7 +272,7 @@ class ClientTestMixin(object):
b'STORED\r\n',
b'VALUE key 0 5\r\nvalue\r\nEND\r\n',
])
- result = client.set(b'key', b'value', noreply=False)
+ client.set(b'key', b'value', noreply=False)
result = client.get(b'key')
assert result == b'value'
@@ -253,7 +291,7 @@ class ClientTestMixin(object):
b'STORED\r\n',
b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n',
])
- result = client.set(b'key1', b'value1', noreply=False)
+ client.set(b'key1', b'value1', noreply=False)
result = client.get_many([b'key1', b'key2'])
assert result == {b'key1': b'value1'}
@@ -285,7 +323,7 @@ class ClientTestMixin(object):
def test_delete_found(self):
client = self.make_client([b'STORED\r', b'\n', b'DELETED\r\n'])
- result = client.add(b'key', b'value', noreply=False)
+ client.add(b'key', b'value', noreply=False)
result = client.delete(b'key', noreply=False)
assert result is True
@@ -306,7 +344,7 @@ class ClientTestMixin(object):
def test_delete_many_found(self):
client = self.make_client([b'STORED\r', b'\n', b'DELETED\r\n'])
- result = client.add(b'key', b'value', noreply=False)
+ client.add(b'key', b'value', noreply=False)
result = client.delete_many([b'key'], noreply=False)
assert result is True
@@ -316,7 +354,7 @@ class ClientTestMixin(object):
b'DELETED\r\n',
b'NOT_FOUND\r\n'
])
- result = client.add(b'key', b'value', noreply=False)
+ client.add(b'key', b'value', noreply=False)
result = client.delete_many([b'key', b'key2'], noreply=False)
assert result is True
@@ -326,7 +364,7 @@ class ClientTestMixin(object):
b'DELETED\r\n',
b'NOT_FOUND\r\n'
])
- result = client.add(b'key', b'value', noreply=False)
+ client.add(b'key', b'value', noreply=False)
result = client.delete_multi([b'key', b'key2'], noreply=False)
assert result is True
@@ -370,26 +408,51 @@ class TestClient(ClientTestMixin, unittest.TestCase):
result = client.append(b'key', b'value', noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.append(b'key', b'value', noreply=False)
+ assert result is True
+
def test_prepend_stored(self):
client = self.make_client([b'STORED\r\n'])
result = client.prepend(b'key', b'value', noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.prepend(b'key', b'value', noreply=False)
+ assert result is True
+
def test_cas_stored(self):
client = self.make_client([b'STORED\r\n'])
result = client.cas(b'key', b'value', b'cas', noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.cas(b'key', b'value', b'cas', noreply=False)
+ assert result is True
+
def test_cas_exists(self):
client = self.make_client([b'EXISTS\r\n'])
result = client.cas(b'key', b'value', b'cas', noreply=False)
assert result is False
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'EXISTS\r\n'], encoding='utf-8')
+ result = client.cas(b'key', b'value', b'cas', noreply=False)
+ assert result is False
+
def test_cas_not_found(self):
client = self.make_client([b'NOT_FOUND\r\n'])
result = client.cas(b'key', b'value', b'cas', noreply=False)
assert result is None
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'NOT_FOUND\r\n'], encoding='utf-8')
+ result = client.cas(b'key', b'value', b'cas', noreply=False)
+ assert result is None
+
def test_cr_nl_boundaries(self):
client = self.make_client([b'VALUE key1 0 6\r',
b'\nvalue1\r\n'
@@ -538,11 +601,21 @@ class TestClient(ClientTestMixin, unittest.TestCase):
result = client.replace(b'key', b'value', noreply=False)
assert result is True
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.replace(b'key', b'value', noreply=False)
+ assert result is True
+
def test_replace_not_stored(self):
client = self.make_client([b'NOT_STORED\r\n'])
result = client.replace(b'key', b'value', noreply=False)
assert result is False
+ # unit test for encoding passed in __init__
+ client = self.make_client([b'NOT_STORED\r\n'], encoding='utf-8')
+ result = client.replace(b'key', b'value', noreply=False)
+ assert result is False
+
def test_serialization(self):
def _ser(key, value):
return json.dumps(value), 0
@@ -649,6 +722,13 @@ class TestClient(ClientTestMixin, unittest.TestCase):
assert client.sock.closed is False
assert len(client.sock.send_bufs) == 1
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n'], encoding='utf-8')
+ result = client.set_many({b'key': b'value'}, noreply=False)
+ assert result == []
+ assert client.sock.closed is False
+ assert len(client.sock.send_bufs) == 1
+
def test_set_many_exception(self):
client = self.make_client([b'STORED\r\n', Exception('fail')])
@@ -661,6 +741,19 @@ class TestClient(ClientTestMixin, unittest.TestCase):
assert client.sock is None
+ # unit test for encoding passed in __init__()
+ client = self.make_client([b'STORED\r\n', Exception('fail')],
+ encoding='utf-8')
+
+ def _set():
+ client.set_many({b'key': b'value', b'other': b'value'},
+ noreply=False)
+
+ with pytest.raises(Exception):
+ _set()
+
+ assert client.sock is None
+
def test_stats(self):
client = self.make_client([b'STAT fake_stats 1\r\n', b'END\r\n'])
result = client.stats()
diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py
index a30e98c..826e052 100644
--- a/pymemcache/test/utils.py
+++ b/pymemcache/test/utils.py
@@ -27,7 +27,8 @@ class MockMemcacheClient(object):
no_delay=False,
ignore_exc=False,
default_noreply=True,
- allow_unicode_keys=False):
+ allow_unicode_keys=False,
+ encoding='ascii'):
self._contents = {}
@@ -41,6 +42,7 @@ class MockMemcacheClient(object):
self.timeout = timeout
self.no_delay = no_delay
self.ignore_exc = ignore_exc
+ self.encoding = encoding
def get(self, key, default=None):
if not self.allow_unicode_keys:
@@ -80,15 +82,15 @@ class MockMemcacheClient(object):
if isinstance(key, six.string_types):
try:
if isinstance(key, bytes):
- key = key.decode().encode('ascii')
+ key = key.decode().encode()
else:
- key = key.encode('ascii')
+ key = key.encode(self.encoding)
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if (isinstance(value, six.string_types) and
not isinstance(value, six.binary_type)):
try:
- value = value.encode('ascii')
+ value = value.encode(self.encoding)
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError