summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRafael H. Schloming <rhs@apache.org>2009-06-04 15:57:48 +0000
committerRafael H. Schloming <rhs@apache.org>2009-06-04 15:57:48 +0000
commita1756f92f1412431bb90632b39dfd80fe41e7452 (patch)
treea5a7a4e10bf32c53904a4c4e25c1c580169c992a
parent69da96714d2ae6393c5cfc0a2a8c3b4b3f8a7b62 (diff)
downloadqpid-python-a1756f92f1412431bb90632b39dfd80fe41e7452.tar.gz
Added commit and rollback to the Session API and streamlined some test utilities.
git-svn-id: https://svn.apache.org/repos/asf/qpid/trunk@781786 13f79535-47bb-0310-9956-ffa450edef68
-rw-r--r--qpid/python/qpid/messaging.py64
-rw-r--r--qpid/python/qpid/tests/messaging.py109
2 files changed, 146 insertions, 27 deletions
diff --git a/qpid/python/qpid/messaging.py b/qpid/python/qpid/messaging.py
index 84331134a9..05e2f7c51f 100644
--- a/qpid/python/qpid/messaging.py
+++ b/qpid/python/qpid/messaging.py
@@ -147,7 +147,7 @@ class Connection(Lockable):
self._condition = Condition(self._lock)
@synchronized
- def session(self, name=None):
+ def session(self, name=None, transactional=False):
"""
Creates or retrieves the named session. If the name is omitted or
None, then a unique name is chosen based on a randomly generated
@@ -168,7 +168,7 @@ class Connection(Lockable):
if self.sessions.has_key(name):
return self.sessions[name]
else:
- ssn = Session(self, name, self.started)
+ ssn = Session(self, name, self.started, transactional=transactional)
self.sessions[name] = ssn
if self._conn is not None:
ssn._attach()
@@ -268,10 +268,11 @@ class Session(Lockable):
messages, and manage various Senders and Receivers.
"""
- def __init__(self, connection, name, started):
+ def __init__(self, connection, name, started, transactional):
self.connection = connection
self.name = name
self.started = started
+ self.transactional = transactional
self._ssn = None
self.senders = []
self.receivers = []
@@ -279,6 +280,8 @@ class Session(Lockable):
self.incoming = []
self.closed = False
self.unacked = []
+ if self.transactional:
+ self.acked = []
self._lock = RLock()
self._condition = Condition(self._lock)
self.thread = Thread(target = self.run)
@@ -294,6 +297,8 @@ class Session(Lockable):
self._ssn.invoke_lock = self._lock
self._ssn.lock = self._lock
self._ssn.condition = self._condition
+ if self.transactional:
+ self._ssn.tx_select()
for link in self.senders + self.receivers:
link._link()
@@ -414,17 +419,17 @@ class Session(Lockable):
def acknowledge(self, message=None):
"""
Acknowledge the given L{Message}. If message is None, then all
- unackednowledged messages on the session are acknowledged.
+ unacknowledged messages on the session are acknowledged.
@type message: Message
@param message: the message to acknowledge or None
"""
if message is None:
- messages = self.unacked
+ messages = self.unacked[:]
else:
messages = [message]
- ids = RangedSet(*[m._transfer_id for m in self.unacked])
+ ids = RangedSet(*[m._transfer_id for m in messages])
for range in ids:
self._ssn.receiver._completed.add_range(range)
self._ssn.channel.session_completed(self._ssn.receiver._completed)
@@ -436,6 +441,46 @@ class Session(Lockable):
self.unacked.remove(m)
except ValueError:
pass
+ if self.transactional:
+ self.acked.append(m)
+
+ @synchronized
+ def commit(self):
+ """
+ Commit outstanding transactional work. This consists of all
+ message sends and receives since the prior commit or rollback.
+ """
+ if not self.transactional:
+ raise NontransactionalSession()
+ if self._ssn is None:
+ raise Disconnected()
+ self._ssn.tx_commit(sync=True)
+ del self.acked[:]
+ self._ssn.sync()
+
+ @synchronized
+ def rollback(self):
+ """
+ Rollback outstanding transactional work. This consists of all
+ message sends and receives since the prior commit or rollback.
+ """
+ if not self.transactional:
+ raise NontransactionalSession()
+ if self._ssn is None:
+ raise Disconnected()
+
+ ids = RangedSet(*[m._transfer_id for m in self.acked + self.unacked + self.incoming])
+ for range in ids:
+ self._ssn.receiver._completed.add_range(range)
+ self._ssn.channel.session_completed(self._ssn.receiver._completed)
+ self._ssn.message_release(ids)
+ self._ssn.tx_rollback(sync=True)
+
+ del self.incoming[:]
+ del self.unacked[:]
+ del self.acked[:]
+
+ self._ssn.sync()
@synchronized
def start(self):
@@ -515,6 +560,13 @@ class Disconnected(Exception):
"""
pass
+class NontransactionalSession(Exception):
+ """
+ Exception raised when commit or rollback is attempted on a non
+ transactional session.
+ """
+ pass
+
class Sender(Lockable):
"""
diff --git a/qpid/python/qpid/tests/messaging.py b/qpid/python/qpid/tests/messaging.py
index ed02d7a27a..2346dbb050 100644
--- a/qpid/python/qpid/tests/messaging.py
+++ b/qpid/python/qpid/tests/messaging.py
@@ -40,6 +40,7 @@ class Base(Test):
return None
def setup(self):
+ self.test_id = uuid4()
self.broker = self.config.broker
self.conn = self.setup_connection()
self.ssn = self.setup_session()
@@ -50,10 +51,16 @@ class Base(Test):
if self.conn is not None and self.conn.connected():
self.conn.close()
+ def content(self, base, count = None):
+ if count is None:
+ return "%s[%s]" % (base, self.test_id)
+ else:
+ return "%s[%s, %s]" % (base, count, self.test_id)
+
def ping(self, ssn):
# send a message
sender = ssn.sender("ping-queue")
- content = "ping[%s]" % uuid4()
+ content = self.content("ping")
sender.send(content)
receiver = ssn.receiver("ping-queue")
msg = receiver.fetch(timeout=0)
@@ -61,13 +68,17 @@ class Base(Test):
assert msg.content == content
def drain(self, rcv, limit=None):
- msgs = []
+ contents = []
try:
- while limit is None or len(msgs) < limit:
- msgs.append(rcv.fetch(0))
+ while limit is None or len(contents) < limit:
+ contents.append(rcv.fetch(0).content)
except Empty:
pass
- return msgs
+ return contents
+
+ def assertEmpty(self, rcv):
+ contents = self.drain(rcv)
+ assert len(contents) == 0, "%s is supposed to be empty: %s" % (rcv, contents)
def delay(self):
d = float(self.config.defines.get("delay", "2"))
@@ -156,7 +167,7 @@ class SessionTests(Base):
assert snd is not snd2
snd2.close()
- content = "testSender[%s]" % uuid4()
+ content = self.content("testSender")
snd.send(content)
rcv = self.ssn.receiver(snd.target)
msg = rcv.fetch(0)
@@ -169,7 +180,7 @@ class SessionTests(Base):
assert rcv is not rcv2
rcv2.close()
- content = "testReceiver[%s]" % uuid4()
+ content = self.content("testReceiver")
snd = self.ssn.sender(rcv.source)
snd.send(content)
msg = rcv.fetch(0)
@@ -206,16 +217,14 @@ class SessionTests(Base):
# drain the queue, verify the messages are there and then close
# without acking
rcv = self.ssn.receiver(snd.target)
- msgs = self.drain(rcv)
- assert contents == [m.content for m in msgs]
+ assert contents == self.drain(rcv)
self.ssn.close()
# drain the queue again, verify that they are all the messages
# were requeued, and ack this time before closing
self.ssn = self.conn.session()
rcv = self.ssn.receiver("test-ack-queue")
- msgs = self.drain(rcv)
- assert contents == [m.content for m in msgs]
+ assert contents == self.drain(rcv)
self.ssn.acknowledge()
self.ssn.close()
@@ -223,8 +232,69 @@ class SessionTests(Base):
# dequeued
self.ssn = self.conn.session()
rcv = self.ssn.receiver("test-ack-queue")
- msgs = self.drain(rcv)
- assert len(msgs) == 0
+ self.assertEmpty(rcv)
+
+ def send(self, ssn, queue, base, count=1):
+ snd = ssn.sender(queue)
+ contents = []
+ for i in range(count):
+ c = self.content(base, i)
+ snd.send(c)
+ contents.append(c)
+ snd.close()
+ return contents
+
+ def testCommitSend(self):
+ txssn = self.conn.session(transactional=True)
+ contents = self.send(txssn, "test-commit-send-queue", "testCommitSend", 3)
+ rcv = self.ssn.receiver("test-commit-send-queue")
+ self.assertEmpty(rcv)
+ txssn.commit()
+ assert contents == self.drain(rcv)
+ self.ssn.acknowledge()
+
+ def testCommitAck(self):
+ txssn = self.conn.session(transactional=True)
+ txrcv = txssn.receiver("test-commit-ack-queue")
+ self.assertEmpty(txrcv)
+ contents = self.send(self.ssn, "test-commit-ack-queue", "testCommitAck", 3)
+ assert contents == self.drain(txrcv)
+ txssn.acknowledge()
+ txssn.close()
+
+ txssn = self.conn.session(transactional=True)
+ txrcv = txssn.receiver("test-commit-ack-queue")
+ assert contents == self.drain(txrcv)
+ txssn.acknowledge()
+ txssn.commit()
+ rcv = self.ssn.receiver("test-commit-ack-queue")
+ self.assertEmpty(rcv)
+ txssn.close()
+ self.assertEmpty(rcv)
+
+ def testRollbackAck(self):
+ txssn = self.conn.session(transactional=True)
+ txrcv = txssn.receiver("test-rollback-ack-queue")
+ self.assertEmpty(txrcv)
+ contents = self.send(self.ssn, "test-rollback-ack-queue", "testRollbackAck", 3)
+ assert contents == self.drain(txrcv)
+ txssn.rollback()
+ assert contents == self.drain(txrcv)
+ txssn.acknowledge()
+ txssn.rollback()
+ assert contents == self.drain(txrcv)
+ txssn.commit() # commit without ack
+ self.assertEmpty(txrcv)
+ txssn.close()
+ txssn = self.conn.session(transactional=True)
+ txrcv = txssn.receiver("test-rollback-ack-queue")
+ assert contents == self.drain(txrcv)
+ txssn.acknowledge()
+ txssn.commit()
+ rcv = self.ssn.receiver("test-rollback-ack-queue")
+ self.assertEmpty(rcv)
+ txssn.close()
+ self.assertEmpty(rcv)
def testClose(self):
self.ssn.close()
@@ -249,10 +319,7 @@ class ReceiverTests(Base):
return self.ssn.receiver("test-receiver-queue")
def send(self, base, count = None):
- if count is None:
- content = "%s[%s]" % (base, uuid4())
- else:
- content = "%s[%s, %s]" % (base, count, uuid4())
+ content = self.content(base, count)
self.snd.send(content)
return content
@@ -412,13 +479,13 @@ class SenderTests(Base):
self.ssn.acknowledge()
def testSendString(self):
- self.checkContent("testSendString[%s]" % uuid4())
+ self.checkContent(self.content("testSendString"))
def testSendList(self):
- self.checkContent(["testSendList", 1, 3.14, uuid4()])
+ self.checkContent(["testSendList", 1, 3.14, self.test_id])
def testSendMap(self):
- self.checkContent({"testSendMap": uuid4(), "pie": "blueberry", "pi": 3.14})
+ self.checkContent({"testSendMap": self.test_id, "pie": "blueberry", "pi": 3.14})
class MessageTests(Base):
@@ -506,7 +573,7 @@ class MessageEchoTests(Base):
msg = Message()
msg.to = "to-address"
msg.subject = "subject"
- msg.correlation_id = str(uuid4())
+ msg.correlation_id = str(self.test_id)
msg.properties = MessageEchoTests.TEST_MAP
msg.reply_to = "reply-address"
self.check(msg)