diff options
Diffstat (limited to 'tests/test_async_keyword.py')
-rwxr-xr-x | tests/test_async_keyword.py | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/tests/test_async_keyword.py b/tests/test_async_keyword.py new file mode 100755 index 0000000..3b20e8d --- /dev/null +++ b/tests/test_async_keyword.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# test_async_keyword.py - test for objects using 'async' as attribute/param +# +# Copyright (C) 2017 Daniele Varrazzo <daniele.varrazzo@gmail.com> +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +import psycopg2 +from psycopg2 import extras + +from testconfig import dsn +from testutils import (ConnectingTestCase, unittest, skip_before_postgres, + assertDsnEqual) +from test_replication import ReplicationTestCase, skip_repl_if_green +from psycopg2.extras import LogicalReplicationConnection, StopReplication + + +class AsyncTests(ConnectingTestCase): + def setUp(self): + ConnectingTestCase.setUp(self) + + self.sync_conn = self.conn + self.conn = self.connect(async=True) + + self.wait(self.conn) + + curs = self.conn.cursor() + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.wait(curs) + + def test_connection_setup(self): + cur = self.conn.cursor() + sync_cur = self.sync_conn.cursor() + del cur, sync_cur + + self.assert_(self.conn.async) + self.assert_(not self.sync_conn.async) + + # the async connection should be in isolevel 0 + self.assertEquals(self.conn.isolation_level, 0) + + # check other properties to be found on the connection + self.assert_(self.conn.server_version) + self.assert_(self.conn.protocol_version in (2, 3)) + self.assert_(self.conn.encoding in psycopg2.extensions.encodings) + + def test_async_subclass(self): + class MyConn(psycopg2.extensions.connection): + def __init__(self, dsn, async=0): + psycopg2.extensions.connection.__init__(self, dsn, async=async) + + conn = self.connect(connection_factory=MyConn, async=True) + self.assert_(isinstance(conn, MyConn)) + self.assert_(conn.async) + conn.close() + + def test_async_connection_error_message(self): + try: + cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async=True) + self.wait(cnn) + except psycopg2.Error, e: + self.assertNotEqual(str(e), "asynchronous connection failed", + "connection error reason lost") + else: + self.fail("no exception raised") + + +class CancelTests(ConnectingTestCase): + def setUp(self): + ConnectingTestCase.setUp(self) + + cur = self.conn.cursor() + cur.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.conn.commit() + + @skip_before_postgres(8, 2) + def test_async_cancel(self): + async_conn = psycopg2.connect(dsn, async=True) + self.assertRaises(psycopg2.OperationalError, async_conn.cancel) + extras.wait_select(async_conn) + cur = async_conn.cursor() + cur.execute("select pg_sleep(10000)") + self.assertTrue(async_conn.isexecuting()) + async_conn.cancel() + self.assertRaises(psycopg2.extensions.QueryCanceledError, + extras.wait_select, async_conn) + cur.execute("select 1") + extras.wait_select(async_conn) + self.assertEqual(cur.fetchall(), [(1, )]) + + def test_async_connection_cancel(self): + async_conn = psycopg2.connect(dsn, async=True) + async_conn.close() + self.assertTrue(async_conn.closed) + + +class ConnectTestCase(unittest.TestCase): + def setUp(self): + self.args = None + + def connect_stub(dsn, connection_factory=None, async=False): + self.args = (dsn, connection_factory, async) + + self._connect_orig = psycopg2._connect + psycopg2._connect = connect_stub + + def tearDown(self): + psycopg2._connect = self._connect_orig + + def test_there_has_to_be_something(self): + self.assertRaises(TypeError, psycopg2.connect) + self.assertRaises(TypeError, psycopg2.connect, + connection_factory=lambda dsn, async=False: None) + self.assertRaises(TypeError, psycopg2.connect, + async=True) + + def test_factory(self): + def f(dsn, async=False): + pass + + psycopg2.connect(database='foo', host='baz', connection_factory=f) + assertDsnEqual(self, self.args[0], 'dbname=foo host=baz') + self.assertEqual(self.args[1], f) + self.assertEqual(self.args[2], False) + + psycopg2.connect("dbname=foo host=baz", connection_factory=f) + assertDsnEqual(self, self.args[0], 'dbname=foo host=baz') + self.assertEqual(self.args[1], f) + self.assertEqual(self.args[2], False) + + def test_async(self): + psycopg2.connect(database='foo', host='baz', async=1) + assertDsnEqual(self, self.args[0], 'dbname=foo host=baz') + self.assertEqual(self.args[1], None) + self.assert_(self.args[2]) + + psycopg2.connect("dbname=foo host=baz", async=True) + assertDsnEqual(self, self.args[0], 'dbname=foo host=baz') + self.assertEqual(self.args[1], None) + self.assert_(self.args[2]) + + +class AsyncReplicationTest(ReplicationTestCase): + @skip_before_postgres(9, 4) # slots require 9.4 + @skip_repl_if_green + def test_async_replication(self): + conn = self.repl_connect( + connection_factory=LogicalReplicationConnection, async=1) + if conn is None: + return + + cur = conn.cursor() + + self.create_replication_slot(cur, output_plugin='test_decoding') + self.wait(cur) + + cur.start_replication(self.slot) + self.wait(cur) + + self.make_replication_events() + + self.msg_count = 0 + + def consume(msg): + # just check the methods + "%s: %s" % (cur.io_timestamp, repr(msg)) + + self.msg_count += 1 + if self.msg_count > 3: + cur.send_feedback(reply=True) + raise StopReplication() + + cur.send_feedback(flush_lsn=msg.data_start) + + # cannot be used in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) + + def process_stream(): + from select import select + while True: + msg = cur.read_message() + if msg: + consume(msg) + else: + select([cur], [], []) + self.assertRaises(StopReplication, process_stream) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main() |