diff options
Diffstat (limited to 'cpp/src/tests/testlib.py')
-rw-r--r-- | cpp/src/tests/testlib.py | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/cpp/src/tests/testlib.py b/cpp/src/tests/testlib.py index 07c4794767..398eaf96cc 100644 --- a/cpp/src/tests/testlib.py +++ b/cpp/src/tests/testlib.py @@ -21,7 +21,7 @@ # Support library for qpid python tests. # -import os, re, signal, subprocess, unittest +import os, re, signal, subprocess, time, unittest class TestBase(unittest.TestCase): """ @@ -210,6 +210,12 @@ class TestBaseCluster(TestBase): for n in range(0, numberNodes): self.createClusterNode(n, clusterName) + def waitForNodes(self, clusterName): + """Wait for all nodes to become active (ie finish cluster sync)""" + # TODO - connect to each known node in cluster + # Until this is done, wait a bit (hack) + time.sleep(1) + # --- Cluster and node status --- def getTupleList(self, clusterName = None): @@ -246,13 +252,15 @@ class TestBaseCluster(TestBase): """Get the (pid, port) tuple for the given cluster node""" return self._clusterDict[clusterName][nodeNumber] - def checkNumClusterBrokers(self, clusterName, expected = None, checkPids = True): + def checkNumClusterBrokers(self, clusterName, expected = None, checkPids = True, waitForNodes = True): """Check that the total number of brokers in the named cluster is the expected value""" if expected != None and self.getNumClusterBrokers(clusterName) != expected: raise Exception("Unexpected number of brokers in cluster %s: expected %d, found %d" % \ (clusterName, expected, self.getNumClusterBrokers(clusterName))) if checkPids: self._checkPids(clusterName) + if waitForNodes: + self.waitForNodes(clusterName) def clusterExists(self, clusterName): """ Return True if clusterName exists, False otherwise""" @@ -330,7 +338,7 @@ class TestBaseCluster(TestBase): if self.clusterExists(clusterName): raise Exception("Unable to kill cluster %s; %d nodes still exist" % (clusterName, self.getNumClusterBrokers(clusterName))) - def stopCheckAll(self, ignoreFailures = False): + def stopAllCheck(self, ignoreFailures = False): """Kill all known clusters and check that the cluster dictionary is empty""" self.stopAllClusters() self.checkNumBrokers(0) @@ -580,6 +588,7 @@ class TestBaseCluster(TestBase): self._testBaseCluster.createClusterNode(nodeNumber, self._clusterName) self._nodes.append(nodeNumber) self._testBaseCluster.checkNumClusterBrokers(self._clusterName, len(self._nodes)) + self._testBaseCluster.waitForNodes(self._clusterName) def restoreNode(self, nodeNumber): """Restore a cluster node that has been previously killed""" @@ -598,6 +607,7 @@ class TestBaseCluster(TestBase): self.restoreNode(lastNode) while len(self._deadNodes) > 0: self.restoreNode(self._deadNodes[0]) + self._testBaseCluster.waitForNodes(self._clusterName) def killNode(self, nodeNumber): """Kill a cluster node (if it is in the _nodes list).""" @@ -635,15 +645,20 @@ class TestBaseCluster(TestBase): if nm == None: nm = len(self._txMsgs[qn]) - len(self._rxMsgs[qn]) # get all remaining messages if nm > 0: - receiver = self._testBaseCluster.createReciever(nodeNumber, self._clusterName, qn, nm) - cnt = 0 - while cnt < nm: - rx = receiver.stdout.readline().strip() - if rx == "" and receiver.poll() != None: break - self._rxMsgs[qn].append(rx) - cnt = cnt + 1 + while nm > 0: + receiver = self._testBaseCluster.createReciever(nodeNumber, self._clusterName, qn, nm) + cnt = 0 + while cnt < nm: + rx = receiver.stdout.readline().strip() + if rx == "": + if receiver.poll() != None: break + elif rx not in self._rxMsgs[qn]: + self._rxMsgs[qn].append(rx) + cnt = cnt + 1 + nm = nm - cnt if wait: receiver.wait() + self._rxMsgs[qn].sort() self._lastNode = nodeNumber def receiveRemainingMsgs(self, nodeNumber = None, queueNameList = None, wait = True): @@ -670,10 +685,10 @@ class TestBaseCluster(TestBase): def finalizeTest(self): """Recover all the remaining messages on all queues, then check that all expected messages were received.""" self.receiveRemainingMsgs() - self._testBaseCluster.stopCheckAll() + self._testBaseCluster.stopAllCheck() if not self.checkMsgs(): - self._testBaseCluster.fail("Send - receive message mismatch") self.printMsgs() + self._testBaseCluster.fail("Send - receive message mismatch") def printMsgs(self, txMsgs = True, rxMsgs = True): """Print all messages transmitted and received.""" |