From d0c180d27242a1176f98549761825210958d8a8a Mon Sep 17 00:00:00 2001 From: Martin J Date: Thu, 9 Jun 2022 21:33:43 +0200 Subject: Expand Client with a method for sending arbitrary commands. (#395) --- pymemcache/client/base.py | 74 ++++++++++++++++++++++++++++++++++++++++-- pymemcache/test/test_client.py | 57 ++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) (limited to 'pymemcache') diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 2654661..3d4234c 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -13,6 +13,7 @@ # limitations under the License. import errno +from functools import partial import platform import socket from typing import Tuple, Union @@ -867,6 +868,27 @@ class Client: raise MemcacheUnknownError("Received unexpected response: %s" % results[0]) return after + def raw_command(self, command, end_tokens="\r\n"): + """ + Sends an arbitrary command to the server and parses the response until a + specified token is encountered. + + Args: + command: str|bytes: The command to send. + end_tokens: str|bytes: The token expected at the end of the + response. If the `end_token` is not found, the client will wait + until the timeout specified in the constructor. + + Returns: + The response from the server, with the `end_token` removed. + """ + encoding = "utf8" if self.allow_unicode_keys else "ascii" + command = command.encode(encoding) if isinstance(command, str) else command + end_tokens = ( + end_tokens.encode(encoding) if isinstance(end_tokens, str) else end_tokens + ) + return self._misc_cmd([b"" + command + b"\r\n"], command, False, end_tokens)[0] + def flush_all(self, delay=0, noreply=None): """ The memcached "flush_all" command. @@ -1126,7 +1148,15 @@ class Client: self.close() raise - def _misc_cmd(self, cmds, cmd_name, noreply): + def _misc_cmd(self, cmds, cmd_name, noreply, end_tokens=None): + + # If no end_tokens have been given, just assume standard memcached + # operations, which end in "\r\n", use regular code for that. + if end_tokens: + _reader = partial(_readsegment, end_tokens=end_tokens) + else: + _reader = _readline + if self.sock is None: self._connect() @@ -1141,7 +1171,7 @@ class Client: line = None for cmd in cmds: try: - buf, line = _readline(self.sock, buf) + buf, line = _reader(self.sock, buf) except MemcacheUnexpectedCloseError: self.close() raise @@ -1396,6 +1426,10 @@ class PooledClient: with self.client_pool.get_and_release(destroy_on_fail=True) as client: client.shutdown(graceful) + def raw_command(self, command, end_tokens=b"\r\n"): + with self.client_pool.get_and_release(destroy_on_fail=True) as client: + return client.raw_command(command, end_tokens) + def __setitem__(self, key, value): self.set(key, value, noreply=True) @@ -1505,6 +1539,42 @@ def _readvalue(sock, buf, size): return buf[rlen:], b"".join(chunks) +def _readsegment(sock, buf, end_tokens): + """Read a segment from the socket. + + Read a segment from the socket, up to the first end_token sub-string/bytes, + and return that segment. + + Args: + sock: Socket object, should be connected. + buf: bytes, zero or more bytes, returned from an earlier + call to _readline, _readsegment or _readvalue (pass an empty + byte-string on the first call). + end_tokens: bytes, indicates the end of the segment, generally this is + b"\\r\\n" for memcached. + + Returns: + A tuple of (buf, line) where line is the full line read from the + socket (minus the end_tokens bytes) and buf is any trailing + characters read after the end_tokens was found (which may be an empty + bytes object). + + """ + result = bytes() + + while True: + + tokens_pos = buf.find(end_tokens) + if tokens_pos != -1: + before, after = buf[:tokens_pos], buf[tokens_pos + len(end_tokens) :] + result += before + return after, result + + buf = _recv(sock, RECV_SIZE) + if not buf: + raise MemcacheUnexpectedCloseError() + + def _recv(sock, size): """sock.recv() with retry on EINTR""" while True: diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 40bcbf6..22f8387 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1149,6 +1149,63 @@ class TestClient(ClientTestMixin, unittest.TestCase): with pytest.raises(MemcacheUnknownError): client.version() + def test_raw_command_default_end_tokens(self): + client = self.make_client([b"REPLY\r\n", b"REPLY\r\nLEFTOVER"]) + result = client.raw_command(b"misc") + assert result == b"REPLY" + result = client.raw_command(b"misc") + assert result == b"REPLY" + + def test_raw_command_custom_end_tokens(self): + client = self.make_client( + [ + b"REPLY\r\nEND\r\n", + b"REPLY\r\nEND\r\nLEFTOVER", + b"REPLYEND\r\nLEFTOVER", + b"REPLY\nLEFTOVER", + ] + ) + end_tokens = b"END\r\n" + result = client.raw_command(b"misc", end_tokens) + assert result == b"REPLY\r\n" + result = client.raw_command(b"misc", end_tokens) + assert result == b"REPLY\r\n" + result = client.raw_command(b"misc", end_tokens) + assert result == b"REPLY" + result = client.raw_command(b"misc", b"\n") + assert result == b"REPLY" + + def test_raw_command_missing_end_tokens(self): + client = self.make_client([b"REPLY", b"REPLY"]) + with pytest.raises(IndexError): + client.raw_command(b"misc") + with pytest.raises(IndexError): + client.raw_command(b"misc", b"END\r\n") + + def test_raw_command_empty_end_tokens(self): + client = self.make_client([b"REPLY"]) + + with pytest.raises(IndexError): + client.raw_command(b"misc", b"") + + def test_raw_command_types(self): + client = self.make_client( + [b"REPLY\r\n", b"REPLY\r\n", b"REPLY\r\nLEFTOVER", b"REPLY\r\nLEFTOVER"] + ) + assert client.raw_command("key") == b"REPLY" + assert client.raw_command(b"key") == b"REPLY" + assert client.raw_command("key") == b"REPLY" + assert client.raw_command(b"key") == b"REPLY" + + def test_send_end_token_types(self): + client = self.make_client( + [b"REPLY\r\n", b"REPLY\r\n", b"REPLY\r\nLEFTOVER", b"REPLY\r\nLEFTOVER"] + ) + assert client.raw_command("key", "\r\n") == b"REPLY" + assert client.raw_command(b"key", b"\r\n") == b"REPLY" + assert client.raw_command("key", "\r\n") == b"REPLY" + assert client.raw_command(b"key", b"\r\n") == b"REPLY" + @pytest.mark.unit() class TestClientSocketConnect(unittest.TestCase): -- cgit v1.2.1