diff options
author | JianGuoPinterest <51248389+JianGuoPinterest@users.noreply.github.com> | 2019-06-07 15:29:14 -0700 |
---|---|---|
committer | Jon Parise <jon@pinterest.com> | 2019-06-07 15:29:14 -0700 |
commit | f23602c176f72cda367f618f20e0c89940431130 (patch) | |
tree | ac1a27ad3d600027f139e2123f18d3872844dd55 | |
parent | 708fb55e2af6d5a58046f1ff5597f0024b5865e2 (diff) | |
download | pymemcache-f23602c176f72cda367f618f20e0c89940431130.tar.gz |
add encoding in client functions (#232)
-rw-r--r-- | pymemcache/client/base.py | 17 | ||||
-rw-r--r-- | pymemcache/client/hash.py | 5 | ||||
-rw-r--r-- | pymemcache/test/test_client.py | 105 | ||||
-rw-r--r-- | pymemcache/test/utils.py | 10 |
4 files changed, 120 insertions, 17 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 7d67bdf..eabef47 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -206,7 +206,8 @@ class Client(object): socket_module=socket, key_prefix=b'', default_noreply=True, - allow_unicode_keys=False): + allow_unicode_keys=False, + encoding='ascii'): """ Constructor. @@ -233,6 +234,7 @@ class Client(object): store commands (except from cas, incr, and decr, which default to False). allow_unicode_keys: bool, support unicode (utf8) keys + encoding: optional str, controls data encoding (defaults to 'ascii'). Notes: The constructor does not make a connection to memcached. The first @@ -254,6 +256,7 @@ class Client(object): self.key_prefix = key_prefix self.default_noreply = default_noreply self.allow_unicode_keys = allow_unicode_keys + self.encoding = encoding def check_key(self, key): """Checks key and add key_prefix.""" @@ -803,7 +806,7 @@ class Client(object): extra += b' ' + cas if noreply: extra += b' noreply' - expire = six.text_type(expire).encode('ascii') + expire = six.text_type(expire).encode(self.encoding) for key, data in six.iteritems(values): # must be able to reliably map responses back to the original order @@ -817,15 +820,15 @@ class Client(object): if not isinstance(data, six.binary_type): try: - data = six.text_type(data).encode('ascii') + data = six.text_type(data).encode(self.encoding) except UnicodeEncodeError as e: raise MemcacheIllegalInputError( "Data values must be binary-safe: %s" % e) cmds.append(name + b' ' + key + b' ' + - six.text_type(flags).encode('ascii') + + six.text_type(flags).encode(self.encoding) + b' ' + expire + - b' ' + six.text_type(len(data)).encode('ascii') + + b' ' + six.text_type(len(data)).encode(self.encoding) + extra + b'\r\n' + data + b'\r\n') if self.sock is None: @@ -921,7 +924,8 @@ class PooledClient(object): max_pool_size=None, lock_generator=None, default_noreply=True, - allow_unicode_keys=False): + allow_unicode_keys=False, + encoding='ascii'): self.server = server self.serializer = serializer self.deserializer = deserializer @@ -942,6 +946,7 @@ class PooledClient(object): after_remove=lambda client: client.close(), max_size=max_pool_size, lock_generator=lock_generator) + self.encoding = encoding def check_key(self, key): """Checks key and add key_prefix.""" diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 1be4882..cd147b2 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -34,7 +34,8 @@ class HashClient(object): use_pooling=False, ignore_exc=False, allow_unicode_keys=False, - default_noreply=True + default_noreply=True, + encoding='ascii' ): """ Constructor. @@ -55,6 +56,7 @@ class HashClient(object): attempts. dead_timeout (float): Time in seconds before attempting to add a node back in the pool. + encoding: optional str, controls data encoding (defaults to 'ascii'). Further arguments are interpreted as for :py:class:`.Client` constructor. @@ -93,6 +95,7 @@ class HashClient(object): for server, port in servers: self.add_server(server, port) + self.encoding = encoding def add_server(self, server, port): key = '%s:%s' % (server, port) diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index cf2cfd0..cb53c2c 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -131,11 +131,21 @@ class ClientTestMixin(object): result = client.set(b'key', b'value', noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.set(b'key', b'value', noreply=False) + assert result is True + def test_set_future(self): client = self.make_client([b'STORED\r\n']) result = client.set(newbytes(b'key'), newbytes(b'value'), noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.set(newbytes(b'key'), newbytes(b'value'), noreply=False) + assert result is True + def test_set_unicode_key(self): client = self.make_client([b'']) @@ -196,26 +206,54 @@ class ClientTestMixin(object): result = client.set(b'key', b'value', noreply=True) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([], encoding='utf-8') + result = client.set(b'key', b'value', noreply=True) + assert result is True + def test_set_many_success(self): client = self.make_client([b'STORED\r\n']) result = client.set_many({b'key': b'value'}, noreply=False) assert result == [] + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.set_many({b'key': b'value'}, noreply=False) + assert result == [] + def test_set_multi_success(self): # Should just map to set_many client = self.make_client([b'STORED\r\n']) result = client.set_multi({b'key': b'value'}, noreply=False) assert result == [] + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.set_multi({b'key': b'value'}, noreply=False) + assert result == [] + def test_add_stored(self): client = self.make_client([b'STORED\r', b'\n']) result = client.add(b'key', b'value', noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r', b'\n'], encoding='utf-8') + result = client.add(b'key', b'value', noreply=False) + assert result is True + def test_add_not_stored(self): client = self.make_client([b'STORED\r', b'\n', b'NOT_', b'STOR', b'ED', b'\r\n']) + client.add(b'key', b'value', noreply=False) result = client.add(b'key', b'value', noreply=False) + assert result is False + + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r', b'\n', + b'NOT_', b'STOR', b'ED', b'\r\n'], + encoding='utf-8') + client.add(b'key', b'value', noreply=False) result = client.add(b'key', b'value', noreply=False) assert result is False @@ -234,7 +272,7 @@ class ClientTestMixin(object): b'STORED\r\n', b'VALUE key 0 5\r\nvalue\r\nEND\r\n', ]) - result = client.set(b'key', b'value', noreply=False) + client.set(b'key', b'value', noreply=False) result = client.get(b'key') assert result == b'value' @@ -253,7 +291,7 @@ class ClientTestMixin(object): b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ]) - result = client.set(b'key1', b'value1', noreply=False) + client.set(b'key1', b'value1', noreply=False) result = client.get_many([b'key1', b'key2']) assert result == {b'key1': b'value1'} @@ -285,7 +323,7 @@ class ClientTestMixin(object): def test_delete_found(self): client = self.make_client([b'STORED\r', b'\n', b'DELETED\r\n']) - result = client.add(b'key', b'value', noreply=False) + client.add(b'key', b'value', noreply=False) result = client.delete(b'key', noreply=False) assert result is True @@ -306,7 +344,7 @@ class ClientTestMixin(object): def test_delete_many_found(self): client = self.make_client([b'STORED\r', b'\n', b'DELETED\r\n']) - result = client.add(b'key', b'value', noreply=False) + client.add(b'key', b'value', noreply=False) result = client.delete_many([b'key'], noreply=False) assert result is True @@ -316,7 +354,7 @@ class ClientTestMixin(object): b'DELETED\r\n', b'NOT_FOUND\r\n' ]) - result = client.add(b'key', b'value', noreply=False) + client.add(b'key', b'value', noreply=False) result = client.delete_many([b'key', b'key2'], noreply=False) assert result is True @@ -326,7 +364,7 @@ class ClientTestMixin(object): b'DELETED\r\n', b'NOT_FOUND\r\n' ]) - result = client.add(b'key', b'value', noreply=False) + client.add(b'key', b'value', noreply=False) result = client.delete_multi([b'key', b'key2'], noreply=False) assert result is True @@ -370,26 +408,51 @@ class TestClient(ClientTestMixin, unittest.TestCase): result = client.append(b'key', b'value', noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.append(b'key', b'value', noreply=False) + assert result is True + def test_prepend_stored(self): client = self.make_client([b'STORED\r\n']) result = client.prepend(b'key', b'value', noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.prepend(b'key', b'value', noreply=False) + assert result is True + def test_cas_stored(self): client = self.make_client([b'STORED\r\n']) result = client.cas(b'key', b'value', b'cas', noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.cas(b'key', b'value', b'cas', noreply=False) + assert result is True + def test_cas_exists(self): client = self.make_client([b'EXISTS\r\n']) result = client.cas(b'key', b'value', b'cas', noreply=False) assert result is False + # unit test for encoding passed in __init__() + client = self.make_client([b'EXISTS\r\n'], encoding='utf-8') + result = client.cas(b'key', b'value', b'cas', noreply=False) + assert result is False + def test_cas_not_found(self): client = self.make_client([b'NOT_FOUND\r\n']) result = client.cas(b'key', b'value', b'cas', noreply=False) assert result is None + # unit test for encoding passed in __init__() + client = self.make_client([b'NOT_FOUND\r\n'], encoding='utf-8') + result = client.cas(b'key', b'value', b'cas', noreply=False) + assert result is None + def test_cr_nl_boundaries(self): client = self.make_client([b'VALUE key1 0 6\r', b'\nvalue1\r\n' @@ -538,11 +601,21 @@ class TestClient(ClientTestMixin, unittest.TestCase): result = client.replace(b'key', b'value', noreply=False) assert result is True + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.replace(b'key', b'value', noreply=False) + assert result is True + def test_replace_not_stored(self): client = self.make_client([b'NOT_STORED\r\n']) result = client.replace(b'key', b'value', noreply=False) assert result is False + # unit test for encoding passed in __init__ + client = self.make_client([b'NOT_STORED\r\n'], encoding='utf-8') + result = client.replace(b'key', b'value', noreply=False) + assert result is False + def test_serialization(self): def _ser(key, value): return json.dumps(value), 0 @@ -649,6 +722,13 @@ class TestClient(ClientTestMixin, unittest.TestCase): assert client.sock.closed is False assert len(client.sock.send_bufs) == 1 + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n'], encoding='utf-8') + result = client.set_many({b'key': b'value'}, noreply=False) + assert result == [] + assert client.sock.closed is False + assert len(client.sock.send_bufs) == 1 + def test_set_many_exception(self): client = self.make_client([b'STORED\r\n', Exception('fail')]) @@ -661,6 +741,19 @@ class TestClient(ClientTestMixin, unittest.TestCase): assert client.sock is None + # unit test for encoding passed in __init__() + client = self.make_client([b'STORED\r\n', Exception('fail')], + encoding='utf-8') + + def _set(): + client.set_many({b'key': b'value', b'other': b'value'}, + noreply=False) + + with pytest.raises(Exception): + _set() + + assert client.sock is None + def test_stats(self): client = self.make_client([b'STAT fake_stats 1\r\n', b'END\r\n']) result = client.stats() diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index a30e98c..826e052 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -27,7 +27,8 @@ class MockMemcacheClient(object): no_delay=False, ignore_exc=False, default_noreply=True, - allow_unicode_keys=False): + allow_unicode_keys=False, + encoding='ascii'): self._contents = {} @@ -41,6 +42,7 @@ class MockMemcacheClient(object): self.timeout = timeout self.no_delay = no_delay self.ignore_exc = ignore_exc + self.encoding = encoding def get(self, key, default=None): if not self.allow_unicode_keys: @@ -80,15 +82,15 @@ class MockMemcacheClient(object): if isinstance(key, six.string_types): try: if isinstance(key, bytes): - key = key.decode().encode('ascii') + key = key.decode().encode() else: - key = key.encode('ascii') + key = key.encode(self.encoding) except (UnicodeEncodeError, UnicodeDecodeError): raise MemcacheIllegalInputError if (isinstance(value, six.string_types) and not isinstance(value, six.binary_type)): try: - value = value.encode('ascii') + value = value.encode(self.encoding) except (UnicodeEncodeError, UnicodeDecodeError): raise MemcacheIllegalInputError |