diff options
Diffstat (limited to 'tests/test_async.py')
-rwxr-xr-x | tests/test_async.py | 42 |
1 files changed, 30 insertions, 12 deletions
diff --git a/tests/test_async.py b/tests/test_async.py index 91a5392..e1b50cb 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -4,6 +4,7 @@ import unittest import psycopg2 from psycopg2 import extensions +import time import select import StringIO @@ -48,7 +49,10 @@ class AsyncTests(unittest.TestCase): self.sync_conn.close() self.conn.close() - def wait(self, pollable): + def wait(self, cur_or_conn): + pollable = cur_or_conn + if not hasattr(pollable, 'poll'): + pollable = cur_or_conn.connection while True: state = pollable.poll() if state == psycopg2.extensions.POLL_OK: @@ -301,7 +305,7 @@ class AsyncTests(unittest.TestCase): curs = self.conn.cursor() for mb in 1, 5, 10, 20, 50: size = mb * 1024 * 1024 - stub = PollableStub(curs) + stub = PollableStub(self.conn) curs.execute("select %s;", ('x' * size,)) self.wait(stub) self.assertEqual(size, len(curs.fetchone()[0])) @@ -312,19 +316,33 @@ class AsyncTests(unittest.TestCase): def test_sync_poll(self): cur = self.sync_conn.cursor() - # polling a sync cursor works - cur.poll() + cur.execute("select 1") + # polling with a sync query works + cur.connection.poll() + self.assertEquals(cur.fetchone()[0], 1) - def test_async_poll_wrong_cursor(self): - cur1 = self.conn.cursor() - cur2 = self.conn.cursor() - cur1.execute("select 1") + def test_notify(self): + cur = self.conn.cursor() + sync_cur = self.sync_conn.cursor() - # polling a cursor that's not currently executing is an error - self.assertRaises(psycopg2.ProgrammingError, cur2.poll) + sync_cur.execute("listen test_notify") + self.sync_conn.commit() + cur.execute("notify test_notify") + self.wait(cur) - self.wait(cur1) - self.assertEquals(cur1.fetchone()[0], 1) + self.assertEquals(self.sync_conn.notifies, []) + + pid = self.conn.get_backend_pid() + for _ in range(5): + self.wait(self.sync_conn) + if not self.sync_conn.notifies: + time.sleep(0.5) + continue + self.assertEquals(len(self.sync_conn.notifies), 1) + self.assertEquals(self.sync_conn.notifies.pop(), + (pid, "test_notify")) + return + self.fail("No NOTIFY in 2.5 seconds") def test_async_fetch_wrong_cursor(self): cur1 = self.conn.cursor() |