diff options
Diffstat (limited to 'tests/test_asyncio')
-rw-r--r-- | tests/test_asyncio/mocks.py | 51 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection.py | 48 |
2 files changed, 92 insertions, 7 deletions
diff --git a/tests/test_asyncio/mocks.py b/tests/test_asyncio/mocks.py new file mode 100644 index 0000000..89bd9c0 --- /dev/null +++ b/tests/test_asyncio/mocks.py @@ -0,0 +1,51 @@ +import asyncio + +# Helper Mocking classes for the tests. + + +class MockStream: + """ + A class simulating an asyncio input buffer, optionally raising a + special exception every other read. + """ + + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): + self.data = data + self.counter = 0 + self.pos = 0 + self.interrupt_every = interrupt_every + + def tick(self): + self.counter += 1 + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() + + async def read(self, want): + self.tick() + want = 5 + result = self.data[self.pos : self.pos + want] + self.pos += len(result) + return result + + async def readline(self): + self.tick() + find = self.data.find(b"\n", self.pos) + if find >= 0: + result = self.data[self.pos : find + 1] + else: + result = self.data[self.pos :] + self.pos += len(result) + return result + + async def readexactly(self, length): + self.tick() + result = self.data[self.pos : self.pos + length] + if len(result) < length: + raise asyncio.IncompleteReadError(result, None) + self.pos += len(result) + return result diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6bf0034..bf59dbe 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -5,7 +5,9 @@ from unittest.mock import patch import pytest +import redis from redis.asyncio.connection import ( + BaseParser, Connection, PythonParser, UnixDomainSocketConnection, @@ -16,6 +18,7 @@ from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from tests.conftest import skip_if_server_version_lt from .compat import mock +from .mocks import MockStream @pytest.mark.onlynoncluster @@ -23,16 +26,19 @@ async def test_invalid_response(create_redis): r = await create_redis(single_connection_client=True) raw = b"x" + fake_stream = MockStream(raw + b"\r\n") - parser: "PythonParser" = r.connection._parser - if not isinstance(parser, PythonParser): - pytest.skip("PythonParser only") - stream_mock = mock.Mock(parser._stream) - stream_mock.readline.return_value = raw + b"\r\n" - with mock.patch.object(parser, "_stream", stream_mock): + parser: BaseParser = r.connection._parser + with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - assert str(cm.value) == f"Protocol Error: {raw!r}" + if isinstance(parser, PythonParser): + assert str(cm.value) == f"Protocol Error: {raw!r}" + else: + assert ( + str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte' + ) + await r.connection.disconnect() @skip_if_server_version_lt("4.0.0") @@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry(): await conn.connect() assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" + + +@pytest.mark.onlynoncluster +async def test_connection_parse_response_resume(r: redis.Redis): + """ + This test verifies that the Connection parser, + be that PythonParser or HiredisParser, + can be interrupted at IO time and then resume parsing. + """ + conn = Connection(**r.connection_pool.connection_kwargs) + await conn.connect() + message = ( + b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n" + b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" + ) + + conn._parser._stream = MockStream(message, interrupt_every=2) + for i in range(100): + try: + response = await conn.read_response() + break + except MockStream.TestError: + pass + + else: + pytest.fail("didn't receive a response") + assert response + assert i > 0 |