diff options
author | Sean Reifschneider <sean@realgo.com> | 2023-04-18 05:57:59 -0600 |
---|---|---|
committer | Sean Reifschneider <sean@realgo.com> | 2023-04-18 05:57:59 -0600 |
commit | 3c8465f96df7823ccb8b5872871d21c41a024fc0 (patch) | |
tree | d5a82c4bb197d4d8814b0b60c644550d00ffb815 | |
parent | 88b83c6f8bfe056735cf026ad869b933ff8fb892 (diff) | |
download | python-memcached-3c8465f96df7823ccb8b5872871d21c41a024fc0.tar.gz |
Allow keys to be encoded before use.
Ported patch in #52 from @harlowja to current branch. Added tests.
For the cases where the user wants to transparently
encode keys (say using urllib) before they are used
further allow a encoding function to be passed in that
will perform these types of activities (by default it
is the identity function).
-rw-r--r-- | memcache.py | 45 | ||||
-rw-r--r-- | tests/test_memcache.py | 27 |
2 files changed, 54 insertions, 18 deletions
diff --git a/memcache.py b/memcache.py index 11da6c1..e3e1d85 100644 --- a/memcache.py +++ b/memcache.py @@ -162,7 +162,8 @@ class Client(threading.local): pload=None, pid=None, server_max_key_length=None, server_max_value_length=None, dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT, - cache_cas=False, flush_on_reconnect=0, check_keys=True): + cache_cas=False, flush_on_reconnect=0, check_keys=True, + key_encoder=None): """Create a new Client object with the given list of servers. @param servers: C{servers} is passed to L{set_servers}. @@ -205,6 +206,10 @@ class Client(threading.local): @param check_keys: (default True) If True, the key is checked to ensure it is the correct length and composed of the right characters. + @param key_encoder: (default None) If provided a functor that will + be called to encode keys before they are checked and used. It will + be expected to take one parameter (the key) and return a new encoded + key as a result. """ super(Client, self).__init__() self.debug = debug @@ -226,6 +231,10 @@ class Client(threading.local): self.persistent_load = pload self.persistent_id = pid self.server_max_key_length = server_max_key_length + if key_encoder is None: + def key_encoder(key): + return key + self.key_encoder = key_encoder if self.server_max_key_length is None: self.server_max_key_length = SERVER_MAX_KEY_LENGTH self.server_max_value_length = server_max_value_length @@ -494,7 +503,7 @@ class Client(threading.local): else: headers = None for key in server_keys[server]: # These are mangled keys - cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n') + cmd = self._encode_cmd('delete', self.key_encoder(key), headers, noreply, b'\r\n') write(cmd) try: server.send_cmds(b''.join(bigcmd)) @@ -532,7 +541,7 @@ class Client(threading.local): reply. @rtype: int ''' - key = self._encode_key(key) + key = self._encode_key(self.key_encoder(key)) if self.do_check_key: self.check_key(key) server, key = self._get_server(key) @@ -568,7 +577,7 @@ class Client(threading.local): reply. @rtype: int ''' - key = self._encode_key(key) + key = self._encode_key(self.key_encoder(key)) if self.do_check_key: self.check_key(key) server, key = self._get_server(key) @@ -622,7 +631,7 @@ class Client(threading.local): @return: New value after incrementing, no None for noreply or error. @rtype: int """ - return self._incrdecr("incr", key, delta, noreply) + return self._incrdecr("incr", self.key_encoder(key), delta, noreply) def decr(self, key, delta=1, noreply=False): """Decrement value for C{key} by C{delta} @@ -640,7 +649,7 @@ class Client(threading.local): @return: New value after decrementing, or None for noreply or error. @rtype: int """ - return self._incrdecr("decr", key, delta, noreply) + return self._incrdecr("decr", self.key_encoder(key), delta, noreply) def _incrdecr(self, cmd, key, delta, noreply=False): key = self._encode_key(key) @@ -674,7 +683,7 @@ class Client(threading.local): @return: Nonzero on success. @rtype: int ''' - return self._set("add", key, val, time, min_compress_len, noreply) + return self._set("add", self.key_encoder(key), val, time, min_compress_len, noreply) def append(self, key, val, time=0, min_compress_len=0, noreply=False): '''Append the value to the end of the existing key's value. @@ -685,7 +694,7 @@ class Client(threading.local): @return: Nonzero on success. @rtype: int ''' - return self._set("append", key, val, time, min_compress_len, noreply) + return self._set("append", self.key_encoder(key), val, time, min_compress_len, noreply) def prepend(self, key, val, time=0, min_compress_len=0, noreply=False): '''Prepend the value to the beginning of the existing key's value. @@ -696,7 +705,7 @@ class Client(threading.local): @return: Nonzero on success. @rtype: int ''' - return self._set("prepend", key, val, time, min_compress_len, noreply) + return self._set("prepend", self.key_encoder(key), val, time, min_compress_len, noreply) def replace(self, key, val, time=0, min_compress_len=0, noreply=False): '''Replace existing key with value. @@ -707,7 +716,7 @@ class Client(threading.local): @return: Nonzero on success. @rtype: int ''' - return self._set("replace", key, val, time, min_compress_len, noreply) + return self._set("replace", self.key_encoder(key), val, time, min_compress_len, noreply) def set(self, key, val, time=0, min_compress_len=0, noreply=False): '''Unconditionally sets a key to a given value in the memcache. @@ -743,7 +752,7 @@ class Client(threading.local): ''' if isinstance(time, timedelta): time = int(time.total_seconds()) - return self._set("set", key, val, time, min_compress_len, noreply) + return self._set("set", self.key_encoder(key), val, time, min_compress_len, noreply) def cas(self, key, val, time=0, min_compress_len=0, noreply=False): '''Check and set (CAS) @@ -780,7 +789,7 @@ class Client(threading.local): @param noreply: optional parameter instructs the server to not send the reply. ''' - return self._set("cas", key, val, time, min_compress_len, noreply) + return self._set("cas", self.key_encoder(key), val, time, min_compress_len, noreply) def _map_and_prefix_keys(self, key_iterable, key_prefix): """Map keys to the servers they will reside on. @@ -807,7 +816,7 @@ class Client(threading.local): # Ensure call to _get_server gets a Tuple as well. serverhash, key = orig_key - key = self._encode_key(key) + key = self._encode_key(self.key_encoder(key)) if not isinstance(key, six.binary_type): # set_multi supports int / long keys. key = str(key).encode('utf8') @@ -818,7 +827,7 @@ class Client(threading.local): server, key = self._get_server( (serverhash, key_prefix + key)) else: - key = self._encode_key(orig_key) + key = self._encode_key(self.key_encoder(orig_key)) if not isinstance(key, six.binary_type): # set_multi supports int / long keys. key = str(key).encode('utf8') @@ -923,7 +932,7 @@ class Client(threading.local): if store_info: flags, len_val, val = store_info headers = "%d %d %d" % (flags, time, len_val) - fullcmd = self._encode_cmd('set', key, headers, + fullcmd = self._encode_cmd('set', self.key_encoder(key), headers, noreply, b'\r\n', val, b'\r\n') write(fullcmd) @@ -1121,14 +1130,14 @@ class Client(threading.local): @return: The value or None. ''' - return self._get('get', key, default) + return self._get('get', self.key_encoder(key), default) def gets(self, key): '''Retrieves a key from the memcache. Used in conjunction with 'cas'. @return: The value or None. ''' - return self._get('gets', key) + return self._get('gets', self.key_encoder(key)) def get_multi(self, keys, key_prefix=''): '''Retrieves multiple keys from the memcache doing just one query. @@ -1188,7 +1197,7 @@ class Client(threading.local): self._statlog('get_multi') server_keys, prefixed_to_orig_key = self._map_and_prefix_keys( - keys, key_prefix) + [self.key_encoder(k) for k in keys], key_prefix) # send out all requests on each server before reading anything dead_servers = [] diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 3593e03..2258d5a 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -252,5 +252,32 @@ class TestMemcache(unittest.TestCase): ) +class TestMemcacheEncoder(unittest.TestCase): + def setUp(self): + # TODO(): unix socket server stuff + servers = ["127.0.0.1:11211"] + self.mc = Client(servers, debug=1, key_encoder=self.encoder) + + def tearDown(self): + self.mc.flush_all() + self.mc.disconnect_all() + + def encoder(self, key): + return key.lower() + + def check_setget(self, key, val, noreply=False): + self.mc.set(key, val, noreply=noreply) + newval = self.mc.get(key) + self.assertEqual(newval, val) + + def test_setget(self): + self.check_setget("a_string", "some random string") + self.check_setget("A_String2", "some random string") + self.check_setget("an_integer", 42) + self.assertEqual("some random string", self.mc.get("A_String")) + self.assertEqual("some random string", self.mc.get("a_sTRing2")) + self.assertEqual(42, self.mc.get("An_Integer")) + + if __name__ == '__main__': unittest.main() |