summaryrefslogtreecommitdiff
path: root/tests/test_asyncio
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_asyncio')
-rw-r--r--tests/test_asyncio/mocks.py51
-rw-r--r--tests/test_asyncio/test_connection.py48
2 files changed, 92 insertions, 7 deletions
diff --git a/tests/test_asyncio/mocks.py b/tests/test_asyncio/mocks.py
new file mode 100644
index 0000000..89bd9c0
--- /dev/null
+++ b/tests/test_asyncio/mocks.py
@@ -0,0 +1,51 @@
+import asyncio
+
+# Helper Mocking classes for the tests.
+
+
+class MockStream:
+ """
+ A class simulating an asyncio input buffer, optionally raising a
+ special exception every other read.
+ """
+
+ class TestError(BaseException):
+ pass
+
+ def __init__(self, data, interrupt_every=0):
+ self.data = data
+ self.counter = 0
+ self.pos = 0
+ self.interrupt_every = interrupt_every
+
+ def tick(self):
+ self.counter += 1
+ if not self.interrupt_every:
+ return
+ if (self.counter % self.interrupt_every) == 0:
+ raise self.TestError()
+
+ async def read(self, want):
+ self.tick()
+ want = 5
+ result = self.data[self.pos : self.pos + want]
+ self.pos += len(result)
+ return result
+
+ async def readline(self):
+ self.tick()
+ find = self.data.find(b"\n", self.pos)
+ if find >= 0:
+ result = self.data[self.pos : find + 1]
+ else:
+ result = self.data[self.pos :]
+ self.pos += len(result)
+ return result
+
+ async def readexactly(self, length):
+ self.tick()
+ result = self.data[self.pos : self.pos + length]
+ if len(result) < length:
+ raise asyncio.IncompleteReadError(result, None)
+ self.pos += len(result)
+ return result
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index 6bf0034..bf59dbe 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
+import redis
from redis.asyncio.connection import (
+ BaseParser,
Connection,
PythonParser,
UnixDomainSocketConnection,
@@ -16,6 +18,7 @@ from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
from tests.conftest import skip_if_server_version_lt
from .compat import mock
+from .mocks import MockStream
@pytest.mark.onlynoncluster
@@ -23,16 +26,19 @@ async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)
raw = b"x"
+ fake_stream = MockStream(raw + b"\r\n")
- parser: "PythonParser" = r.connection._parser
- if not isinstance(parser, PythonParser):
- pytest.skip("PythonParser only")
- stream_mock = mock.Mock(parser._stream)
- stream_mock.readline.return_value = raw + b"\r\n"
- with mock.patch.object(parser, "_stream", stream_mock):
+ parser: BaseParser = r.connection._parser
+ with mock.patch.object(parser, "_stream", fake_stream):
with pytest.raises(InvalidResponse) as cm:
await parser.read_response()
- assert str(cm.value) == f"Protocol Error: {raw!r}"
+ if isinstance(parser, PythonParser):
+ assert str(cm.value) == f"Protocol Error: {raw!r}"
+ else:
+ assert (
+ str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte'
+ )
+ await r.connection.disconnect()
@skip_if_server_version_lt("4.0.0")
@@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry():
await conn.connect()
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"
+
+
+@pytest.mark.onlynoncluster
+async def test_connection_parse_response_resume(r: redis.Redis):
+ """
+ This test verifies that the Connection parser,
+ be that PythonParser or HiredisParser,
+ can be interrupted at IO time and then resume parsing.
+ """
+ conn = Connection(**r.connection_pool.connection_kwargs)
+ await conn.connect()
+ message = (
+ b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
+ b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
+ )
+
+ conn._parser._stream = MockStream(message, interrupt_every=2)
+ for i in range(100):
+ try:
+ response = await conn.read_response()
+ break
+ except MockStream.TestError:
+ pass
+
+ else:
+ pytest.fail("didn't receive a response")
+ assert response
+ assert i > 0