diff options
Diffstat (limited to 'tests/test_replication.py')
-rw-r--r-- | tests/test_replication.py | 60 |
1 files changed, 39 insertions, 21 deletions
diff --git a/tests/test_replication.py b/tests/test_replication.py index f527edd..ca99038 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -24,8 +24,8 @@ import psycopg2 import psycopg2.extensions -from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection -from psycopg2.extras import StopReplication +from psycopg2.extras import ( + PhysicalReplicationConnection, LogicalReplicationConnection, StopReplication) import testconfig from testutils import unittest @@ -70,14 +70,16 @@ class ReplicationTestCase(ConnectingTestCase): # generate some events for our replication stream def make_replication_events(self): conn = self.connect() - if conn is None: return + if conn is None: + return cur = conn.cursor() try: cur.execute("DROP TABLE dummy1") except psycopg2.ProgrammingError: conn.rollback() - cur.execute("CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") + cur.execute( + "CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") conn.commit() @@ -85,7 +87,8 @@ class ReplicationTest(ReplicationTestCase): @skip_before_postgres(9, 0) def test_physical_replication_connection(self): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) - if conn is None: return + if conn is None: + return cur = conn.cursor() cur.execute("IDENTIFY_SYSTEM") cur.fetchall() @@ -93,41 +96,49 @@ class ReplicationTest(ReplicationTestCase): @skip_before_postgres(9, 4) def test_logical_replication_connection(self): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) - if conn is None: return + if conn is None: + return cur = conn.cursor() cur.execute("IDENTIFY_SYSTEM") cur.fetchall() - @skip_before_postgres(9, 4) # slots require 9.4 + @skip_before_postgres(9, 4) # slots require 9.4 def test_create_replication_slot(self): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) - if conn is None: return + if conn is None: + return cur = conn.cursor() self.create_replication_slot(cur) - self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur) + self.assertRaises( + psycopg2.ProgrammingError, self.create_replication_slot, cur) - @skip_before_postgres(9, 4) # slots require 9.4 + @skip_before_postgres(9, 4) # slots require 9.4 def test_start_on_missing_replication_slot(self): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) - if conn is None: return + if conn is None: + return cur = conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, self.slot) + self.assertRaises(psycopg2.ProgrammingError, + cur.start_replication, self.slot) self.create_replication_slot(cur) cur.start_replication(self.slot) - @skip_before_postgres(9, 4) # slots require 9.4 + @skip_before_postgres(9, 4) # slots require 9.4 def test_start_and_recover_from_error(self): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) - if conn is None: return + if conn is None: + return cur = conn.cursor() self.create_replication_slot(cur, output_plugin='test_decoding') # try with invalid options - cur.start_replication(slot_name=self.slot, options={'invalid_param': 'value'}) + cur.start_replication( + slot_name=self.slot, options={'invalid_param': 'value'}) + def consume(msg): pass # we don't see the error from the server before we try to read the data @@ -136,10 +147,11 @@ class ReplicationTest(ReplicationTestCase): # try with correct command cur.start_replication(slot_name=self.slot) - @skip_before_postgres(9, 4) # slots require 9.4 + @skip_before_postgres(9, 4) # slots require 9.4 def test_stop_replication(self): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) - if conn is None: return + if conn is None: + return cur = conn.cursor() self.create_replication_slot(cur, output_plugin='test_decoding') @@ -147,16 +159,19 @@ class ReplicationTest(ReplicationTestCase): self.make_replication_events() cur.start_replication(self.slot) + def consume(msg): raise StopReplication() self.assertRaises(StopReplication, cur.consume_stream, consume) class AsyncReplicationTest(ReplicationTestCase): - @skip_before_postgres(9, 4) # slots require 9.4 + @skip_before_postgres(9, 4) # slots require 9.4 def test_async_replication(self): - conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1) - if conn is None: return + conn = self.repl_connect( + connection_factory=LogicalReplicationConnection, async=1) + if conn is None: + return self.wait(conn) cur = conn.cursor() @@ -169,9 +184,10 @@ class AsyncReplicationTest(ReplicationTestCase): self.make_replication_events() self.msg_count = 0 + def consume(msg): # just check the methods - log = "%s: %s" % (cur.io_timestamp, repr(msg)) + "%s: %s" % (cur.io_timestamp, repr(msg)) self.msg_count += 1 if self.msg_count > 3: @@ -193,8 +209,10 @@ class AsyncReplicationTest(ReplicationTestCase): select([cur], [], []) self.assertRaises(StopReplication, process_stream) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) + if __name__ == "__main__": unittest.main() |