summaryrefslogtreecommitdiff
path: root/tests/test_green.py
blob: e56ce586ff19d95670ec2ebf69fafe7bacc1a721 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#!/usr/bin/env python

# test_green.py - unit test for async wait callback
#
# Copyright (C) 2010-2019 Daniele Varrazzo  <daniele.varrazzo@gmail.com>
# Copyright (C) 2020 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
# License for more details.

import select
import unittest
import warnings

import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE

from .testutils import ConnectingTestCase, skip_before_postgres, slow


class ConnectionStub(object):
    """A `connection` wrapper allowing analysis of the `poll()` calls."""
    def __init__(self, conn):
        self.conn = conn
        self.polls = []

    def fileno(self):
        return self.conn.fileno()

    def poll(self):
        rv = self.conn.poll()
        self.polls.append(rv)
        return rv


class GreenTestCase(ConnectingTestCase):
    def setUp(self):
        self._cb = psycopg2.extensions.get_wait_callback()
        psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
        ConnectingTestCase.setUp(self)

    def tearDown(self):
        ConnectingTestCase.tearDown(self)
        psycopg2.extensions.set_wait_callback(self._cb)

    def set_stub_wait_callback(self, conn, cb=None):
        stub = ConnectionStub(conn)
        psycopg2.extensions.set_wait_callback(
            lambda conn: (cb or psycopg2.extras.wait_select)(stub))
        return stub

    @slow
    def test_flush_on_write(self):
        # a very large query requires a flush loop to be sent to the backend
        conn = self.conn
        stub = self.set_stub_wait_callback(conn)
        curs = conn.cursor()
        for mb in 1, 5, 10, 20, 50:
            size = mb * 1024 * 1024
            del stub.polls[:]
            curs.execute("select %s;", ('x' * size,))
            self.assertEqual(size, len(curs.fetchone()[0]))
            if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1:
                return

        # This is more a testing glitch than an error: it happens
        # on high load on linux: probably because the kernel has more
        # buffers ready. A warning may be useful during development,
        # but an error is bad during regression testing.
        warnings.warn("sending a large query didn't trigger block on write.")

    def test_error_in_callback(self):
        # behaviour changed after issue #113: if there is an error in the
        # callback for the moment we don't have a way to reset the connection
        # without blocking (ticket #113) so just close it.
        conn = self.conn
        curs = conn.cursor()
        curs.execute("select 1")  # have a BEGIN
        curs.fetchone()

        # now try to do something that will fail in the callback
        psycopg2.extensions.set_wait_callback(lambda conn: 1 // 0)
        self.assertRaises(ZeroDivisionError, curs.execute, "select 2")

        self.assert_(conn.closed)

    def test_dont_freak_out(self):
        # if there is an error in a green query, don't freak out and close
        # the connection
        conn = self.conn
        curs = conn.cursor()
        self.assertRaises(psycopg2.ProgrammingError,
            curs.execute, "select the unselectable")

        # check that the connection is left in an usable state
        self.assert_(not conn.closed)
        conn.rollback()
        curs.execute("select 1")
        self.assertEqual(curs.fetchone()[0], 1)

    @skip_before_postgres(8, 2)
    def test_copy_no_hang(self):
        cur = self.conn.cursor()
        self.assertRaises(psycopg2.ProgrammingError,
            cur.execute, "copy (select 1) to stdout")

    @slow
    @skip_before_postgres(9, 0)
    def test_non_block_after_notification(self):
        def wait(conn):
            while 1:
                state = conn.poll()
                if state == POLL_OK:
                    break
                elif state == POLL_READ:
                    select.select([conn.fileno()], [], [], 0.1)
                elif state == POLL_WRITE:
                    select.select([], [conn.fileno()], [], 0.1)
                else:
                    raise conn.OperationalError("bad state from poll: %s" % state)

        stub = self.set_stub_wait_callback(self.conn, wait)
        cur = self.conn.cursor()
        cur.execute("""
            select 1;
            do $$
                begin
                    raise notice 'hello';
                end
            $$ language plpgsql;
            select pg_sleep(1);
            """)

        polls = stub.polls.count(POLL_READ)
        self.assert_(polls > 8, polls)


class CallbackErrorTestCase(ConnectingTestCase):
    def setUp(self):
        self._cb = psycopg2.extensions.get_wait_callback()
        psycopg2.extensions.set_wait_callback(self.crappy_callback)
        ConnectingTestCase.setUp(self)
        self.to_error = None

    def tearDown(self):
        ConnectingTestCase.tearDown(self)
        psycopg2.extensions.set_wait_callback(self._cb)

    def crappy_callback(self, conn):
        """green callback failing after `self.to_error` time it is called"""
        while True:
            if self.to_error is not None:
                self.to_error -= 1
                if self.to_error <= 0:
                    raise ZeroDivisionError("I accidentally the connection")
            try:
                state = conn.poll()
                if state == POLL_OK:
                    break
                elif state == POLL_READ:
                    select.select([conn.fileno()], [], [])
                elif state == POLL_WRITE:
                    select.select([], [conn.fileno()], [])
                else:
                    raise conn.OperationalError("bad state from poll: %s" % state)
            except KeyboardInterrupt:
                conn.cancel()
                # the loop will be broken by a server error
                continue

    def test_errors_on_connection(self):
        # Test error propagation in the different stages of the connection
        for i in range(100):
            self.to_error = i
            try:
                self.connect()
            except ZeroDivisionError:
                pass
            else:
                # We managed to connect
                return

        self.fail("you should have had a success or an error by now")

    def test_errors_on_query(self):
        for i in range(100):
            self.to_error = None
            cnn = self.connect()
            cur = cnn.cursor()
            self.to_error = i
            try:
                cur.execute("select 1")
                cur.fetchone()
            except ZeroDivisionError:
                pass
            else:
                # The query completed
                return

        self.fail("you should have had a success or an error by now")

    def test_errors_named_cursor(self):
        for i in range(100):
            self.to_error = None
            cnn = self.connect()
            cur = cnn.cursor('foo')
            self.to_error = i
            try:
                cur.execute("select 1")
                cur.fetchone()
            except ZeroDivisionError:
                pass
            else:
                # The query completed
                return

        self.fail("you should have had a success or an error by now")


def test_suite():
    return unittest.TestLoader().loadTestsFromName(__name__)


if __name__ == "__main__":
    unittest.main()