summaryrefslogtreecommitdiff
path: root/tests/test_asyncio/test_scripting.py
blob: 406ab208e28495644032cf5865b63cbf7f56cea9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import sys

import pytest

if sys.version_info[0:2] == (3, 6):
    import pytest as pytest_asyncio

    pytestmark = pytest.mark.asyncio
else:
    import pytest_asyncio

from redis import exceptions
from tests.conftest import skip_if_server_version_lt

multiply_script = """
local value = redis.call('GET', KEYS[1])
value = tonumber(value)
return value * ARGV[1]"""

msgpack_hello_script = """
local message = cmsgpack.unpack(ARGV[1])
local name = message['name']
return "hello " .. name
"""
msgpack_hello_script_broken = """
local message = cmsgpack.unpack(ARGV[1])
local names = message['name']
return "hello " .. name
"""


@pytest.mark.onlynoncluster
class TestScripting:
    @pytest_asyncio.fixture
    async def r(self, create_redis):
        redis = await create_redis()
        yield redis
        await redis.script_flush()

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_eval(self, r):
        await r.flushdb()
        await r.set("a", 2)
        # 2 * 3 == 6
        assert await r.eval(multiply_script, 1, "a", 3) == 6

    @pytest.mark.asyncio(forbid_global_loop=True)
    @skip_if_server_version_lt("6.2.0")
    async def test_script_flush(self, r):
        await r.set("a", 2)
        await r.script_load(multiply_script)
        await r.script_flush("ASYNC")

        await r.set("a", 2)
        await r.script_load(multiply_script)
        await r.script_flush("SYNC")

        await r.set("a", 2)
        await r.script_load(multiply_script)
        await r.script_flush()

        with pytest.raises(exceptions.DataError):
            await r.set("a", 2)
            await r.script_load(multiply_script)
            await r.script_flush("NOTREAL")

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_evalsha(self, r):
        await r.set("a", 2)
        sha = await r.script_load(multiply_script)
        # 2 * 3 == 6
        assert await r.evalsha(sha, 1, "a", 3) == 6

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_evalsha_script_not_loaded(self, r):
        await r.set("a", 2)
        sha = await r.script_load(multiply_script)
        # remove the script from Redis's cache
        await r.script_flush()
        with pytest.raises(exceptions.NoScriptError):
            await r.evalsha(sha, 1, "a", 3)

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_script_loading(self, r):
        # get the sha, then clear the cache
        sha = await r.script_load(multiply_script)
        await r.script_flush()
        assert await r.script_exists(sha) == [False]
        await r.script_load(multiply_script)
        assert await r.script_exists(sha) == [True]

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_script_object(self, r):
        await r.script_flush()
        await r.set("a", 2)
        multiply = r.register_script(multiply_script)
        precalculated_sha = multiply.sha
        assert precalculated_sha
        assert await r.script_exists(multiply.sha) == [False]
        # Test second evalsha block (after NoScriptError)
        assert await multiply(keys=["a"], args=[3]) == 6
        # At this point, the script should be loaded
        assert await r.script_exists(multiply.sha) == [True]
        # Test that the precalculated sha matches the one from redis
        assert multiply.sha == precalculated_sha
        # Test first evalsha block
        assert await multiply(keys=["a"], args=[3]) == 6

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_script_object_in_pipeline(self, r):
        await r.script_flush()
        multiply = r.register_script(multiply_script)
        precalculated_sha = multiply.sha
        assert precalculated_sha
        pipe = r.pipeline()
        pipe.set("a", 2)
        pipe.get("a")
        await multiply(keys=["a"], args=[3], client=pipe)
        assert await r.script_exists(multiply.sha) == [False]
        # [SET worked, GET 'a', result of multiple script]
        assert await pipe.execute() == [True, b"2", 6]
        # The script should have been loaded by pipe.execute()
        assert await r.script_exists(multiply.sha) == [True]
        # The precalculated sha should have been the correct one
        assert multiply.sha == precalculated_sha

        # purge the script from redis's cache and re-run the pipeline
        # the multiply script should be reloaded by pipe.execute()
        await r.script_flush()
        pipe = r.pipeline()
        pipe.set("a", 2)
        pipe.get("a")
        await multiply(keys=["a"], args=[3], client=pipe)
        assert await r.script_exists(multiply.sha) == [False]
        # [SET worked, GET 'a', result of multiple script]
        assert await pipe.execute() == [True, b"2", 6]
        assert await r.script_exists(multiply.sha) == [True]

    @pytest.mark.asyncio(forbid_global_loop=True)
    async def test_eval_msgpack_pipeline_error_in_lua(self, r):
        msgpack_hello = r.register_script(msgpack_hello_script)
        assert msgpack_hello.sha

        pipe = r.pipeline()

        # avoiding a dependency to msgpack, this is the output of
        # msgpack.dumps({"name": "joe"})
        msgpack_message_1 = b"\x81\xa4name\xa3Joe"

        await msgpack_hello(args=[msgpack_message_1], client=pipe)

        assert await r.script_exists(msgpack_hello.sha) == [False]
        assert (await pipe.execute())[0] == b"hello Joe"
        assert await r.script_exists(msgpack_hello.sha) == [True]

        msgpack_hello_broken = r.register_script(msgpack_hello_script_broken)

        await msgpack_hello_broken(args=[msgpack_message_1], client=pipe)
        with pytest.raises(exceptions.ResponseError) as excinfo:
            await pipe.execute()
        assert excinfo.type == exceptions.ResponseError