summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Reifschneider <sean@realgo.com>2023-04-18 05:57:59 -0600
committerSean Reifschneider <sean@realgo.com>2023-04-18 05:57:59 -0600
commit3c8465f96df7823ccb8b5872871d21c41a024fc0 (patch)
treed5a82c4bb197d4d8814b0b60c644550d00ffb815
parent88b83c6f8bfe056735cf026ad869b933ff8fb892 (diff)
downloadpython-memcached-3c8465f96df7823ccb8b5872871d21c41a024fc0.tar.gz
Allow keys to be encoded before use.
Ported patch in #52 from @harlowja to current branch. Added tests. For the cases where the user wants to transparently encode keys (say using urllib) before they are used further allow a encoding function to be passed in that will perform these types of activities (by default it is the identity function).
-rw-r--r--memcache.py45
-rw-r--r--tests/test_memcache.py27
2 files changed, 54 insertions, 18 deletions
diff --git a/memcache.py b/memcache.py
index 11da6c1..e3e1d85 100644
--- a/memcache.py
+++ b/memcache.py
@@ -162,7 +162,8 @@ class Client(threading.local):
pload=None, pid=None,
server_max_key_length=None, server_max_value_length=None,
dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT,
- cache_cas=False, flush_on_reconnect=0, check_keys=True):
+ cache_cas=False, flush_on_reconnect=0, check_keys=True,
+ key_encoder=None):
"""Create a new Client object with the given list of servers.
@param servers: C{servers} is passed to L{set_servers}.
@@ -205,6 +206,10 @@ class Client(threading.local):
@param check_keys: (default True) If True, the key is checked
to ensure it is the correct length and composed of the right
characters.
+ @param key_encoder: (default None) If provided a functor that will
+ be called to encode keys before they are checked and used. It will
+ be expected to take one parameter (the key) and return a new encoded
+ key as a result.
"""
super(Client, self).__init__()
self.debug = debug
@@ -226,6 +231,10 @@ class Client(threading.local):
self.persistent_load = pload
self.persistent_id = pid
self.server_max_key_length = server_max_key_length
+ if key_encoder is None:
+ def key_encoder(key):
+ return key
+ self.key_encoder = key_encoder
if self.server_max_key_length is None:
self.server_max_key_length = SERVER_MAX_KEY_LENGTH
self.server_max_value_length = server_max_value_length
@@ -494,7 +503,7 @@ class Client(threading.local):
else:
headers = None
for key in server_keys[server]: # These are mangled keys
- cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n')
+ cmd = self._encode_cmd('delete', self.key_encoder(key), headers, noreply, b'\r\n')
write(cmd)
try:
server.send_cmds(b''.join(bigcmd))
@@ -532,7 +541,7 @@ class Client(threading.local):
reply.
@rtype: int
'''
- key = self._encode_key(key)
+ key = self._encode_key(self.key_encoder(key))
if self.do_check_key:
self.check_key(key)
server, key = self._get_server(key)
@@ -568,7 +577,7 @@ class Client(threading.local):
reply.
@rtype: int
'''
- key = self._encode_key(key)
+ key = self._encode_key(self.key_encoder(key))
if self.do_check_key:
self.check_key(key)
server, key = self._get_server(key)
@@ -622,7 +631,7 @@ class Client(threading.local):
@return: New value after incrementing, no None for noreply or error.
@rtype: int
"""
- return self._incrdecr("incr", key, delta, noreply)
+ return self._incrdecr("incr", self.key_encoder(key), delta, noreply)
def decr(self, key, delta=1, noreply=False):
"""Decrement value for C{key} by C{delta}
@@ -640,7 +649,7 @@ class Client(threading.local):
@return: New value after decrementing, or None for noreply or error.
@rtype: int
"""
- return self._incrdecr("decr", key, delta, noreply)
+ return self._incrdecr("decr", self.key_encoder(key), delta, noreply)
def _incrdecr(self, cmd, key, delta, noreply=False):
key = self._encode_key(key)
@@ -674,7 +683,7 @@ class Client(threading.local):
@return: Nonzero on success.
@rtype: int
'''
- return self._set("add", key, val, time, min_compress_len, noreply)
+ return self._set("add", self.key_encoder(key), val, time, min_compress_len, noreply)
def append(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Append the value to the end of the existing key's value.
@@ -685,7 +694,7 @@ class Client(threading.local):
@return: Nonzero on success.
@rtype: int
'''
- return self._set("append", key, val, time, min_compress_len, noreply)
+ return self._set("append", self.key_encoder(key), val, time, min_compress_len, noreply)
def prepend(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Prepend the value to the beginning of the existing key's value.
@@ -696,7 +705,7 @@ class Client(threading.local):
@return: Nonzero on success.
@rtype: int
'''
- return self._set("prepend", key, val, time, min_compress_len, noreply)
+ return self._set("prepend", self.key_encoder(key), val, time, min_compress_len, noreply)
def replace(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Replace existing key with value.
@@ -707,7 +716,7 @@ class Client(threading.local):
@return: Nonzero on success.
@rtype: int
'''
- return self._set("replace", key, val, time, min_compress_len, noreply)
+ return self._set("replace", self.key_encoder(key), val, time, min_compress_len, noreply)
def set(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Unconditionally sets a key to a given value in the memcache.
@@ -743,7 +752,7 @@ class Client(threading.local):
'''
if isinstance(time, timedelta):
time = int(time.total_seconds())
- return self._set("set", key, val, time, min_compress_len, noreply)
+ return self._set("set", self.key_encoder(key), val, time, min_compress_len, noreply)
def cas(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Check and set (CAS)
@@ -780,7 +789,7 @@ class Client(threading.local):
@param noreply: optional parameter instructs the server to not
send the reply.
'''
- return self._set("cas", key, val, time, min_compress_len, noreply)
+ return self._set("cas", self.key_encoder(key), val, time, min_compress_len, noreply)
def _map_and_prefix_keys(self, key_iterable, key_prefix):
"""Map keys to the servers they will reside on.
@@ -807,7 +816,7 @@ class Client(threading.local):
# Ensure call to _get_server gets a Tuple as well.
serverhash, key = orig_key
- key = self._encode_key(key)
+ key = self._encode_key(self.key_encoder(key))
if not isinstance(key, six.binary_type):
# set_multi supports int / long keys.
key = str(key).encode('utf8')
@@ -818,7 +827,7 @@ class Client(threading.local):
server, key = self._get_server(
(serverhash, key_prefix + key))
else:
- key = self._encode_key(orig_key)
+ key = self._encode_key(self.key_encoder(orig_key))
if not isinstance(key, six.binary_type):
# set_multi supports int / long keys.
key = str(key).encode('utf8')
@@ -923,7 +932,7 @@ class Client(threading.local):
if store_info:
flags, len_val, val = store_info
headers = "%d %d %d" % (flags, time, len_val)
- fullcmd = self._encode_cmd('set', key, headers,
+ fullcmd = self._encode_cmd('set', self.key_encoder(key), headers,
noreply,
b'\r\n', val, b'\r\n')
write(fullcmd)
@@ -1121,14 +1130,14 @@ class Client(threading.local):
@return: The value or None.
'''
- return self._get('get', key, default)
+ return self._get('get', self.key_encoder(key), default)
def gets(self, key):
'''Retrieves a key from the memcache. Used in conjunction with 'cas'.
@return: The value or None.
'''
- return self._get('gets', key)
+ return self._get('gets', self.key_encoder(key))
def get_multi(self, keys, key_prefix=''):
'''Retrieves multiple keys from the memcache doing just one query.
@@ -1188,7 +1197,7 @@ class Client(threading.local):
self._statlog('get_multi')
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(
- keys, key_prefix)
+ [self.key_encoder(k) for k in keys], key_prefix)
# send out all requests on each server before reading anything
dead_servers = []
diff --git a/tests/test_memcache.py b/tests/test_memcache.py
index 3593e03..2258d5a 100644
--- a/tests/test_memcache.py
+++ b/tests/test_memcache.py
@@ -252,5 +252,32 @@ class TestMemcache(unittest.TestCase):
)
+class TestMemcacheEncoder(unittest.TestCase):
+ def setUp(self):
+ # TODO(): unix socket server stuff
+ servers = ["127.0.0.1:11211"]
+ self.mc = Client(servers, debug=1, key_encoder=self.encoder)
+
+ def tearDown(self):
+ self.mc.flush_all()
+ self.mc.disconnect_all()
+
+ def encoder(self, key):
+ return key.lower()
+
+ def check_setget(self, key, val, noreply=False):
+ self.mc.set(key, val, noreply=noreply)
+ newval = self.mc.get(key)
+ self.assertEqual(newval, val)
+
+ def test_setget(self):
+ self.check_setget("a_string", "some random string")
+ self.check_setget("A_String2", "some random string")
+ self.check_setget("an_integer", 42)
+ self.assertEqual("some random string", self.mc.get("A_String"))
+ self.assertEqual("some random string", self.mc.get("a_sTRing2"))
+ self.assertEqual(42, self.mc.get("An_Integer"))
+
+
if __name__ == '__main__':
unittest.main()