summaryrefslogtreecommitdiff
path: root/qpid/python/qpid/messaging/driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'qpid/python/qpid/messaging/driver.py')
-rw-r--r--qpid/python/qpid/messaging/driver.py1330
1 files changed, 1330 insertions, 0 deletions
diff --git a/qpid/python/qpid/messaging/driver.py b/qpid/python/qpid/messaging/driver.py
new file mode 100644
index 0000000000..7c21388213
--- /dev/null
+++ b/qpid/python/qpid/messaging/driver.py
@@ -0,0 +1,1330 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import socket, struct, sys, time
+from logging import getLogger, DEBUG
+from qpid import compat
+from qpid import sasl
+from qpid.concurrency import synchronized
+from qpid.datatypes import RangedSet, Serial
+from qpid.framing import OpEncoder, SegmentEncoder, FrameEncoder, \
+ FrameDecoder, SegmentDecoder, OpDecoder
+from qpid.messaging import address, transports
+from qpid.messaging.constants import UNLIMITED, REJECTED, RELEASED
+from qpid.messaging.exceptions import *
+from qpid.messaging.message import get_codec, Disposition, Message
+from qpid.ops import *
+from qpid.selector import Selector
+from qpid.util import URL, default
+from qpid.validator import And, Context, List, Map, Types, Values
+from threading import Condition, Thread
+
+log = getLogger("qpid.messaging")
+rawlog = getLogger("qpid.messaging.io.raw")
+opslog = getLogger("qpid.messaging.io.ops")
+
+def addr2reply_to(addr):
+ name, subject, options = address.parse(addr)
+ if options:
+ type = options.get("node", {}).get("type")
+ else:
+ type = None
+
+ if type == "topic":
+ return ReplyTo(name, subject)
+ else:
+ return ReplyTo(None, name)
+
+def reply_to2addr(reply_to):
+ if reply_to.exchange in (None, ""):
+ return reply_to.routing_key
+ elif reply_to.routing_key is None:
+ return "%s; {node: {type: topic}}" % reply_to.exchange
+ else:
+ return "%s/%s; {node: {type: topic}}" % (reply_to.exchange, reply_to.routing_key)
+
+class Attachment:
+
+ def __init__(self, target):
+ self.target = target
+
+# XXX
+
+DURABLE_DEFAULT=False
+
+# XXX
+
+class Pattern:
+ """
+ The pattern filter matches the supplied wildcard pattern against a
+ message subject.
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+ # XXX: this should become part of the driver
+ def _bind(self, sst, exchange, queue):
+ from qpid.ops import ExchangeBind
+
+ sst.write_cmd(ExchangeBind(exchange=exchange, queue=queue,
+ binding_key=self.value.replace("*", "#")))
+
+SUBJECT_DEFAULTS = {
+ "topic": "#"
+ }
+
+# XXX
+ppid = 0
+try:
+ ppid = os.getppid()
+except:
+ pass
+
+CLIENT_PROPERTIES = {"product": "qpid python client",
+ "version": "development",
+ "platform": os.name,
+ "qpid.client_process": os.path.basename(sys.argv[0]),
+ "qpid.client_pid": os.getpid(),
+ "qpid.client_ppid": ppid}
+
+def noop(): pass
+def sync_noop(): pass
+
+class SessionState:
+
+ def __init__(self, driver, session, name, channel):
+ self.driver = driver
+ self.session = session
+ self.name = name
+ self.channel = channel
+ self.detached = False
+ self.committing = False
+ self.aborting = False
+
+ # sender state
+ self.sent = Serial(0)
+ self.acknowledged = RangedSet()
+ self.actions = {}
+ self.min_completion = self.sent
+ self.max_completion = self.sent
+ self.results = {}
+ self.need_sync = False
+
+ # receiver state
+ self.received = None
+ self.executed = RangedSet()
+
+ # XXX: need to periodically exchange completion/known_completion
+
+ self.destinations = {}
+
+ def write_query(self, query, handler):
+ id = self.sent
+ self.write_cmd(query, lambda: handler(self.results.pop(id)))
+
+ def apply_overrides(self, cmd, overrides):
+ for k, v in overrides.items():
+ cmd[k.replace('-', '_')] = v
+
+ def write_cmd(self, cmd, action=noop, overrides=None, sync=True):
+ if overrides:
+ self.apply_overrides(cmd, overrides)
+
+ if action != noop:
+ cmd.sync = sync
+ if self.detached:
+ raise Exception("detached")
+ cmd.id = self.sent
+ self.sent += 1
+ self.actions[cmd.id] = action
+ self.max_completion = cmd.id
+ self.write_op(cmd)
+ self.need_sync = not cmd.sync
+
+ def write_cmds(self, cmds, action=noop):
+ if cmds:
+ for cmd in cmds[:-1]:
+ self.write_cmd(cmd)
+ self.write_cmd(cmds[-1], action)
+ else:
+ action()
+
+ def write_op(self, op):
+ op.channel = self.channel
+ self.driver.write_op(op)
+
+POLICIES = Values("always", "sender", "receiver", "never")
+RELIABILITY = Values("unreliable", "at-most-once", "at-least-once",
+ "exactly-once")
+
+DECLARE = Map({}, restricted=False)
+BINDINGS = List(Map({
+ "exchange": Types(basestring),
+ "queue": Types(basestring),
+ "key": Types(basestring),
+ "arguments": Map({}, restricted=False)
+ }))
+
+COMMON_OPTS = {
+ "create": POLICIES,
+ "delete": POLICIES,
+ "assert": POLICIES,
+ "node": Map({
+ "type": Values("queue", "topic"),
+ "durable": Types(bool),
+ "x-declare": DECLARE,
+ "x-bindings": BINDINGS
+ }),
+ "link": Map({
+ "name": Types(basestring),
+ "durable": Types(bool),
+ "reliability": RELIABILITY,
+ "x-declare": DECLARE,
+ "x-bindings": BINDINGS,
+ "x-subscribe": Map({}, restricted=False)
+ })
+ }
+
+RECEIVE_MODES = Values("browse", "consume")
+
+SOURCE_OPTS = COMMON_OPTS.copy()
+SOURCE_OPTS.update({
+ "mode": RECEIVE_MODES
+ })
+
+TARGET_OPTS = COMMON_OPTS.copy()
+
+class LinkIn:
+
+ ADDR_NAME = "source"
+ DIR_NAME = "receiver"
+ VALIDATOR = Map(SOURCE_OPTS)
+
+ def init_link(self, sst, rcv, _rcv):
+ _rcv.destination = str(rcv.id)
+ sst.destinations[_rcv.destination] = _rcv
+ _rcv.draining = False
+ _rcv.bytes_open = False
+ _rcv.on_unlink = []
+
+ def do_link(self, sst, rcv, _rcv, type, subtype, action):
+ link_opts = _rcv.options.get("link", {})
+ reliability = link_opts.get("reliability", "at-least-once")
+ declare = link_opts.get("x-declare", {})
+ subscribe = link_opts.get("x-subscribe", {})
+ acq_mode = acquire_mode.pre_acquired
+ if reliability in ("unreliable", "at-most-once"):
+ rcv._accept_mode = accept_mode.none
+ else:
+ rcv._accept_mode = accept_mode.explicit
+
+ if type == "topic":
+ default_name = "%s.%s" % (rcv.session.name, _rcv.destination)
+ _rcv._queue = link_opts.get("name", default_name)
+ sst.write_cmd(QueueDeclare(queue=_rcv._queue,
+ durable=link_opts.get("durable", False),
+ exclusive=True,
+ auto_delete=(reliability == "unreliable")),
+ overrides=declare)
+ _rcv.on_unlink = [QueueDelete(_rcv._queue)]
+ subject = _rcv.subject or SUBJECT_DEFAULTS.get(subtype)
+ bindings = get_bindings(link_opts, _rcv._queue, _rcv.name, subject)
+ if not bindings:
+ sst.write_cmd(ExchangeBind(_rcv._queue, _rcv.name, subject))
+
+ elif type == "queue":
+ _rcv._queue = _rcv.name
+ if _rcv.options.get("mode", "consume") == "browse":
+ acq_mode = acquire_mode.not_acquired
+ bindings = get_bindings(link_opts, queue=_rcv._queue)
+
+
+ sst.write_cmds(bindings)
+ sst.write_cmd(MessageSubscribe(queue=_rcv._queue,
+ destination=_rcv.destination,
+ acquire_mode = acq_mode,
+ accept_mode = rcv._accept_mode),
+ overrides=subscribe)
+ sst.write_cmd(MessageSetFlowMode(_rcv.destination, flow_mode.credit), action)
+
+ def do_unlink(self, sst, rcv, _rcv, action=noop):
+ link_opts = _rcv.options.get("link", {})
+ reliability = link_opts.get("reliability")
+ cmds = [MessageCancel(_rcv.destination)]
+ cmds.extend(_rcv.on_unlink)
+ sst.write_cmds(cmds, action)
+
+ def del_link(self, sst, rcv, _rcv):
+ del sst.destinations[_rcv.destination]
+
+class LinkOut:
+
+ ADDR_NAME = "target"
+ DIR_NAME = "sender"
+ VALIDATOR = Map(TARGET_OPTS)
+
+ def init_link(self, sst, snd, _snd):
+ _snd.closing = False
+ _snd.pre_ack = False
+
+ def do_link(self, sst, snd, _snd, type, subtype, action):
+ link_opts = _snd.options.get("link", {})
+ reliability = link_opts.get("reliability", "at-least-once")
+ _snd.pre_ack = reliability in ("unreliable", "at-most-once")
+ if type == "topic":
+ _snd._exchange = _snd.name
+ _snd._routing_key = _snd.subject
+ bindings = get_bindings(link_opts, exchange=_snd.name, key=_snd.subject)
+ elif type == "queue":
+ _snd._exchange = ""
+ _snd._routing_key = _snd.name
+ bindings = get_bindings(link_opts, queue=_snd.name)
+ sst.write_cmds(bindings, action)
+
+ def do_unlink(self, sst, snd, _snd, action=noop):
+ action()
+
+ def del_link(self, sst, snd, _snd):
+ pass
+
+class Cache:
+
+ def __init__(self, ttl):
+ self.ttl = ttl
+ self.entries = {}
+
+ def __setitem__(self, key, value):
+ self.entries[key] = time.time(), value
+
+ def __getitem__(self, key):
+ tstamp, value = self.entries[key]
+ if time.time() - tstamp >= self.ttl:
+ del self.entries[key]
+ raise KeyError(key)
+ else:
+ return value
+
+ def __delitem__(self, key):
+ del self.entries[key]
+
+# XXX
+HEADER="!4s4B"
+
+EMPTY_DP = DeliveryProperties()
+EMPTY_MP = MessageProperties()
+
+SUBJECT = "qpid.subject"
+
+CLOSED = "CLOSED"
+READ_ONLY = "READ_ONLY"
+WRITE_ONLY = "WRITE_ONLY"
+OPEN = "OPEN"
+
+class Driver:
+
+ def __init__(self, connection):
+ self.connection = connection
+ self.log_id = "%x" % id(self.connection)
+ self._lock = self.connection._lock
+
+ self._selector = Selector.default()
+ self._attempts = 0
+ self._delay = self.connection.reconnect_interval_min
+ self._reconnect_log = self.connection.reconnect_log
+ self._host = 0
+ self._retrying = False
+ self._next_retry = None
+ self._transport = None
+
+ self._timeout = None
+
+ self.engine = None
+
+ def _next_host(self):
+ urls = [URL(u) for u in self.connection.reconnect_urls]
+ hosts = [(self.connection.host, default(self.connection.port, 5672))] + \
+ [(u.host, default(u.port, 5672)) for u in urls]
+ if self._host >= len(hosts):
+ self._host = 0
+ result = hosts[self._host]
+ if self._host == 0:
+ self._attempts += 1
+ self._host = self._host + 1
+ return result
+
+ def _num_hosts(self):
+ return len(self.connection.reconnect_urls) + 1
+
+ @synchronized
+ def wakeup(self):
+ self.dispatch()
+ self._selector.wakeup()
+
+ def start(self):
+ self._selector.register(self)
+
+ def stop(self):
+ self._selector.unregister(self)
+ if self._transport:
+ self.st_closed()
+
+ def fileno(self):
+ return self._transport.fileno()
+
+ @synchronized
+ def reading(self):
+ return self._transport is not None and \
+ self._transport.reading(True)
+
+ @synchronized
+ def writing(self):
+ return self._transport is not None and \
+ self._transport.writing(self.engine.pending())
+
+ @synchronized
+ def timing(self):
+ return self._timeout
+
+ @synchronized
+ def readable(self):
+ try:
+ data = self._transport.recv(64*1024)
+ if data is None:
+ return
+ elif data:
+ rawlog.debug("READ[%s]: %r", self.log_id, data)
+ self.engine.write(data)
+ else:
+ self.close_engine()
+ except socket.error, e:
+ self.close_engine(ConnectionError(text=str(e)))
+
+ self.update_status()
+
+ self._notify()
+
+ def _notify(self):
+ if self.connection.error:
+ self.connection._condition.gc()
+ self.connection._waiter.notifyAll()
+
+ def close_engine(self, e=None):
+ if e is None:
+ e = ConnectionError(text="connection aborted")
+
+ if (self.connection.reconnect and
+ (self.connection.reconnect_limit is None or
+ self.connection.reconnect_limit <= 0 or
+ self._attempts <= self.connection.reconnect_limit)):
+ if self._host < self._num_hosts():
+ delay = 0
+ else:
+ delay = self._delay
+ self._delay = min(2*self._delay,
+ self.connection.reconnect_interval_max)
+ self._next_retry = time.time() + delay
+ if self._reconnect_log:
+ log.warn("recoverable error[attempt %s]: %s" % (self._attempts, e))
+ if delay > 0:
+ log.warn("sleeping %s seconds" % delay)
+ self._retrying = True
+ self.engine.close()
+ else:
+ self.engine.close(e)
+
+ self.schedule()
+
+ def update_status(self):
+ status = self.engine.status()
+ return getattr(self, "st_%s" % status.lower())()
+
+ def st_closed(self):
+ # XXX: this log statement seems to sometimes hit when the socket is not connected
+ # XXX: rawlog.debug("CLOSE[%s]: %s", self.log_id, self._socket.getpeername())
+ self._transport.close()
+ self._transport = None
+ self.engine = None
+ return True
+
+ def st_open(self):
+ return False
+
+ @synchronized
+ def writeable(self):
+ notify = False
+ try:
+ n = self._transport.send(self.engine.peek())
+ if n == 0: return
+ sent = self.engine.read(n)
+ rawlog.debug("SENT[%s]: %r", self.log_id, sent)
+ except socket.error, e:
+ self.close_engine(e)
+ notify = True
+
+ if self.update_status() or notify:
+ self._notify()
+
+ @synchronized
+ def timeout(self):
+ self.dispatch()
+ self._notify()
+ self.schedule()
+
+ def schedule(self):
+ times = []
+ if self.connection.heartbeat:
+ times.append(time.time() + self.connection.heartbeat)
+ if self._next_retry:
+ times.append(self._next_retry)
+ if times:
+ self._timeout = min(times)
+ else:
+ self._timeout = None
+
+ def dispatch(self):
+ try:
+ if self._transport is None:
+ if self.connection._connected and not self.connection.error:
+ self.connect()
+ else:
+ self.engine.dispatch()
+ except HeartbeatTimeout, e:
+ self.close_engine(e)
+ except:
+ # XXX: Does socket get leaked if this occurs?
+ msg = compat.format_exc()
+ self.connection.error = InternalError(text=msg)
+
+ def connect(self):
+ if self._retrying and time.time() < self._next_retry:
+ return
+
+ try:
+ # XXX: should make this non blocking
+ host, port = self._next_host()
+ if self._retrying and self._reconnect_log:
+ log.warn("trying: %s:%s", host, port)
+ self.engine = Engine(self.connection)
+ self.engine.open()
+ rawlog.debug("OPEN[%s]: %s:%s", self.log_id, host, port)
+ trans = transports.TRANSPORTS.get(self.connection.transport)
+ if trans:
+ self._transport = trans(self.connection, host, port)
+ else:
+ raise ConnectError("no such transport: %s" % self.connection.transport)
+ if self._retrying and self._reconnect_log:
+ log.warn("reconnect succeeded: %s:%s", host, port)
+ self._next_retry = None
+ self._attempts = 0
+ self._host = 0
+ self._delay = self.connection.reconnect_interval_min
+ self._retrying = False
+ self.schedule()
+ except socket.error, e:
+ self.close_engine(ConnectError(text=str(e)))
+
+DEFAULT_DISPOSITION = Disposition(None)
+
+def get_bindings(opts, queue=None, exchange=None, key=None):
+ bindings = opts.get("x-bindings", [])
+ cmds = []
+ for b in bindings:
+ exchange = b.get("exchange", exchange)
+ queue = b.get("queue", queue)
+ key = b.get("key", key)
+ args = b.get("arguments", {})
+ cmds.append(ExchangeBind(queue, exchange, key, args))
+ return cmds
+
+CONNECTION_ERRS = {
+ # anythong not here (i.e. everything right now) will default to
+ # connection error
+ }
+
+SESSION_ERRS = {
+ # anything not here will default to session error
+ error_code.unauthorized_access: UnauthorizedAccess,
+ error_code.not_found: NotFound,
+ error_code.resource_locked: ReceiverError,
+ error_code.resource_limit_exceeded: TargetCapacityExceeded,
+ error_code.internal_error: ServerError
+ }
+
+class Engine:
+
+ def __init__(self, connection):
+ self.connection = connection
+ self.log_id = "%x" % id(self.connection)
+ self._closing = False
+ self._connected = False
+ self._attachments = {}
+
+ self._in = LinkIn()
+ self._out = LinkOut()
+
+ self._channel_max = 65536
+ self._channels = 0
+ self._sessions = {}
+
+ self.address_cache = Cache(self.connection.address_ttl)
+
+ self._status = CLOSED
+ self._buf = ""
+ self._hdr = ""
+ self._last_in = None
+ self._last_out = None
+ self._op_enc = OpEncoder()
+ self._seg_enc = SegmentEncoder()
+ self._frame_enc = FrameEncoder()
+ self._frame_dec = FrameDecoder()
+ self._seg_dec = SegmentDecoder()
+ self._op_dec = OpDecoder()
+
+ self._sasl = sasl.Client()
+ if self.connection.username:
+ self._sasl.setAttr("username", self.connection.username)
+ if self.connection.password:
+ self._sasl.setAttr("password", self.connection.password)
+ if self.connection.host:
+ self._sasl.setAttr("host", self.connection.host)
+ self._sasl.setAttr("service", self.connection.sasl_service)
+ if self.connection.sasl_min_ssf is not None:
+ self._sasl.setAttr("minssf", self.connection.sasl_min_ssf)
+ if self.connection.sasl_max_ssf is not None:
+ self._sasl.setAttr("maxssf", self.connection.sasl_max_ssf)
+ self._sasl.init()
+ self._sasl_encode = False
+ self._sasl_decode = False
+
+ def _reset(self):
+ self.connection._transport_connected = False
+
+ for ssn in self.connection.sessions.values():
+ for m in ssn.acked + ssn.unacked + ssn.incoming:
+ m._transfer_id = None
+ for snd in ssn.senders:
+ snd.linked = False
+ for rcv in ssn.receivers:
+ rcv.impending = rcv.received
+ rcv.linked = False
+
+ def status(self):
+ return self._status
+
+ def write(self, data):
+ self._last_in = time.time()
+ try:
+ if self._sasl_decode:
+ data = self._sasl.decode(data)
+
+ if len(self._hdr) < 8:
+ r = 8 - len(self._hdr)
+ self._hdr += data[:r]
+ data = data[r:]
+
+ if len(self._hdr) == 8:
+ self.do_header(self._hdr)
+
+ self._frame_dec.write(data)
+ self._seg_dec.write(*self._frame_dec.read())
+ self._op_dec.write(*self._seg_dec.read())
+ for op in self._op_dec.read():
+ self.assign_id(op)
+ opslog.debug("RCVD[%s]: %r", self.log_id, op)
+ op.dispatch(self)
+ self.dispatch()
+ except MessagingError, e:
+ self.close(e)
+ except:
+ self.close(InternalError(text=compat.format_exc()))
+
+ def close(self, e=None):
+ self._reset()
+ if e:
+ self.connection.error = e
+ self._status = CLOSED
+
+ def assign_id(self, op):
+ if isinstance(op, Command):
+ sst = self.get_sst(op)
+ op.id = sst.received
+ sst.received += 1
+
+ def pending(self):
+ return len(self._buf)
+
+ def read(self, n):
+ result = self._buf[:n]
+ self._buf = self._buf[n:]
+ return result
+
+ def peek(self):
+ return self._buf
+
+ def write_op(self, op):
+ opslog.debug("SENT[%s]: %r", self.log_id, op)
+ self._op_enc.write(op)
+ self._seg_enc.write(*self._op_enc.read())
+ self._frame_enc.write(*self._seg_enc.read())
+ bytes = self._frame_enc.read()
+ if self._sasl_encode:
+ bytes = self._sasl.encode(bytes)
+ self._buf += bytes
+ self._last_out = time.time()
+
+ def do_header(self, hdr):
+ cli_major = 0; cli_minor = 10
+ magic, _, _, major, minor = struct.unpack(HEADER, hdr)
+ if major != cli_major or minor != cli_minor:
+ raise VersionError(text="client: %s-%s, server: %s-%s" %
+ (cli_major, cli_minor, major, minor))
+
+ def do_connection_start(self, start):
+ if self.connection.sasl_mechanisms:
+ permitted = self.connection.sasl_mechanisms.split()
+ mechs = [m for m in start.mechanisms if m in permitted]
+ else:
+ mechs = start.mechanisms
+ try:
+ mech, initial = self._sasl.start(" ".join(mechs))
+ except sasl.SASLError, e:
+ raise AuthenticationFailure(text=str(e))
+ self.write_op(ConnectionStartOk(client_properties=CLIENT_PROPERTIES,
+ mechanism=mech, response=initial))
+
+ def do_connection_secure(self, secure):
+ resp = self._sasl.step(secure.challenge)
+ self.write_op(ConnectionSecureOk(response=resp))
+
+ def do_connection_tune(self, tune):
+ # XXX: is heartbeat protocol specific?
+ if tune.channel_max is not None:
+ self.channel_max = tune.channel_max
+ self.write_op(ConnectionTuneOk(heartbeat=self.connection.heartbeat,
+ channel_max=self.channel_max))
+ self.write_op(ConnectionOpen())
+ self._sasl_encode = True
+
+ def do_connection_open_ok(self, open_ok):
+ self.connection.auth_username = self._sasl.auth_username()
+ self._connected = True
+ self._sasl_decode = True
+ self.connection._transport_connected = True
+
+ def do_connection_heartbeat(self, hrt):
+ pass
+
+ def do_connection_close(self, close):
+ self.write_op(ConnectionCloseOk())
+ if close.reply_code != close_code.normal:
+ exc = CONNECTION_ERRS.get(close.reply_code, ConnectionError)
+ self.connection.error = exc(close.reply_code, close.reply_text)
+ # XXX: should we do a half shutdown on the socket here?
+ # XXX: we really need to test this, we may end up reporting a
+ # connection abort after this, if we were to do a shutdown on read
+ # and stop reading, then we wouldn't report the abort, that's
+ # probably the right thing to do
+
+ def do_connection_close_ok(self, close_ok):
+ self.close()
+
+ def do_session_attached(self, atc):
+ pass
+
+ def do_session_command_point(self, cp):
+ sst = self.get_sst(cp)
+ sst.received = cp.command_id
+
+ def do_session_completed(self, sc):
+ sst = self.get_sst(sc)
+ for r in sc.commands:
+ sst.acknowledged.add(r.lower, r.upper)
+
+ if not sc.commands.empty():
+ while sst.min_completion in sc.commands:
+ if sst.actions.has_key(sst.min_completion):
+ sst.actions.pop(sst.min_completion)()
+ sst.min_completion += 1
+
+ def session_known_completed(self, kcmp):
+ sst = self.get_sst(kcmp)
+ executed = RangedSet()
+ for e in sst.executed.ranges:
+ for ke in kcmp.ranges:
+ if e.lower in ke and e.upper in ke:
+ break
+ else:
+ executed.add_range(e)
+ sst.executed = completed
+
+ def do_session_flush(self, sf):
+ sst = self.get_sst(sf)
+ if sf.expected:
+ if sst.received is None:
+ exp = None
+ else:
+ exp = RangedSet(sst.received)
+ sst.write_op(SessionExpected(exp))
+ if sf.confirmed:
+ sst.write_op(SessionConfirmed(sst.executed))
+ if sf.completed:
+ sst.write_op(SessionCompleted(sst.executed))
+
+ def do_session_request_timeout(self, rt):
+ sst = self.get_sst(rt)
+ sst.write_op(SessionTimeout(timeout=0))
+
+ def do_execution_result(self, er):
+ sst = self.get_sst(er)
+ sst.results[er.command_id] = er.value
+ sst.executed.add(er.id)
+
+ def do_execution_exception(self, ex):
+ sst = self.get_sst(ex)
+ exc = SESSION_ERRS.get(ex.error_code, SessionError)
+ sst.session.error = exc(ex.error_code, ex.description)
+
+ def dispatch(self):
+ if not self.connection._connected and not self._closing and self._status != CLOSED:
+ self.disconnect()
+
+ if self._connected and not self._closing:
+ for ssn in self.connection.sessions.values():
+ self.attach(ssn)
+ self.process(ssn)
+
+ if self.connection.heartbeat and self._status != CLOSED:
+ now = time.time()
+ if self._last_in is not None and \
+ now - self._last_in > 2*self.connection.heartbeat:
+ raise HeartbeatTimeout(text="heartbeat timeout")
+ if self._last_out is None or now - self._last_out >= self.connection.heartbeat/2.0:
+ self.write_op(ConnectionHeartbeat())
+
+ def open(self):
+ self._reset()
+ self._status = OPEN
+ self._buf += struct.pack(HEADER, "AMQP", 1, 1, 0, 10)
+
+ def disconnect(self):
+ self.write_op(ConnectionClose(close_code.normal))
+ self._closing = True
+
+ def attach(self, ssn):
+ if ssn.closed: return
+ sst = self._attachments.get(ssn)
+ if sst is None:
+ for i in xrange(0, self.channel_max):
+ if not self._sessions.has_key(i):
+ ch = i
+ break
+ else:
+ raise RuntimeError("all channels used")
+ sst = SessionState(self, ssn, ssn.name, ch)
+ sst.write_op(SessionAttach(name=ssn.name))
+ sst.write_op(SessionCommandPoint(sst.sent, 0))
+ sst.outgoing_idx = 0
+ sst.acked = []
+ sst.acked_idx = 0
+ if ssn.transactional:
+ sst.write_cmd(TxSelect())
+ self._attachments[ssn] = sst
+ self._sessions[sst.channel] = sst
+
+ for snd in ssn.senders:
+ self.link(snd, self._out, snd.target)
+ for rcv in ssn.receivers:
+ self.link(rcv, self._in, rcv.source)
+
+ if sst is not None and ssn.closing and not sst.detached:
+ sst.detached = True
+ sst.write_op(SessionDetach(name=ssn.name))
+
+ def get_sst(self, op):
+ return self._sessions[op.channel]
+
+ def do_session_detached(self, dtc):
+ sst = self._sessions.pop(dtc.channel)
+ ssn = sst.session
+ del self._attachments[ssn]
+ ssn.closed = True
+
+ def do_session_detach(self, dtc):
+ sst = self.get_sst(dtc)
+ sst.write_op(SessionDetached(name=dtc.name))
+ self.do_session_detached(dtc)
+
+ def link(self, lnk, dir, addr):
+ sst = self._attachments.get(lnk.session)
+ _lnk = self._attachments.get(lnk)
+
+ if _lnk is None and not lnk.closed:
+ _lnk = Attachment(lnk)
+ _lnk.closing = False
+ dir.init_link(sst, lnk, _lnk)
+
+ err = self.parse_address(_lnk, dir, addr) or self.validate_options(_lnk, dir)
+ if err:
+ lnk.error = err
+ lnk.closed = True
+ return
+
+ def linked():
+ lnk.linked = True
+
+ def resolved(type, subtype):
+ dir.do_link(sst, lnk, _lnk, type, subtype, linked)
+
+ self.resolve_declare(sst, _lnk, dir.DIR_NAME, resolved)
+ self._attachments[lnk] = _lnk
+
+ if lnk.linked and lnk.closing and not lnk.closed:
+ if not _lnk.closing:
+ def unlinked():
+ dir.del_link(sst, lnk, _lnk)
+ del self._attachments[lnk]
+ lnk.closed = True
+ if _lnk.options.get("delete") in ("always", dir.DIR_NAME):
+ dir.do_unlink(sst, lnk, _lnk)
+ self.delete(sst, _lnk.name, unlinked)
+ else:
+ dir.do_unlink(sst, lnk, _lnk, unlinked)
+ _lnk.closing = True
+ elif not lnk.linked and lnk.closing and not lnk.closed:
+ if lnk.error: lnk.closed = True
+
+ def parse_address(self, lnk, dir, addr):
+ if addr is None:
+ return MalformedAddress(text="%s is None" % dir.ADDR_NAME)
+ else:
+ try:
+ lnk.name, lnk.subject, lnk.options = address.parse(addr)
+ # XXX: subject
+ if lnk.options is None:
+ lnk.options = {}
+ except address.LexError, e:
+ return MalformedAddress(text=str(e))
+ except address.ParseError, e:
+ return MalformedAddress(text=str(e))
+
+ def validate_options(self, lnk, dir):
+ ctx = Context()
+ err = dir.VALIDATOR.validate(lnk.options, ctx)
+ if err: return InvalidOption(text="error in options: %s" % err)
+
+ def resolve_declare(self, sst, lnk, dir, action):
+ declare = lnk.options.get("create") in ("always", dir)
+ assrt = lnk.options.get("assert") in ("always", dir)
+ def do_resolved(type, subtype):
+ err = None
+ if type is None:
+ if declare:
+ err = self.declare(sst, lnk, action)
+ else:
+ err = NotFound(text="no such queue: %s" % lnk.name)
+ else:
+ if assrt:
+ expected = lnk.options.get("node", {}).get("type")
+ if expected and type != expected:
+ err = AssertionFailed(text="expected %s, got %s" % (expected, type))
+ if err is None:
+ action(type, subtype)
+
+ if err:
+ tgt = lnk.target
+ tgt.error = err
+ del self._attachments[tgt]
+ tgt.closed = True
+ return
+ self.resolve(sst, lnk.name, do_resolved, force=declare)
+
+ def resolve(self, sst, name, action, force=False):
+ if not force:
+ try:
+ type, subtype = self.address_cache[name]
+ action(type, subtype)
+ return
+ except KeyError:
+ pass
+
+ args = []
+ def do_result(r):
+ args.append(r)
+ def do_action(r):
+ do_result(r)
+ er, qr = args
+ if er.not_found and not qr.queue:
+ type, subtype = None, None
+ elif qr.queue:
+ type, subtype = "queue", None
+ else:
+ type, subtype = "topic", er.type
+ if type is not None:
+ self.address_cache[name] = (type, subtype)
+ action(type, subtype)
+ sst.write_query(ExchangeQuery(name), do_result)
+ sst.write_query(QueueQuery(name), do_action)
+
+ def declare(self, sst, lnk, action):
+ name = lnk.name
+ props = lnk.options.get("node", {})
+ durable = props.get("durable", DURABLE_DEFAULT)
+ type = props.get("type", "queue")
+ declare = props.get("x-declare", {})
+
+ if type == "topic":
+ cmd = ExchangeDeclare(exchange=name, durable=durable)
+ bindings = get_bindings(props, exchange=name)
+ elif type == "queue":
+ cmd = QueueDeclare(queue=name, durable=durable)
+ bindings = get_bindings(props, queue=name)
+ else:
+ raise ValueError(type)
+
+ sst.apply_overrides(cmd, declare)
+
+ if type == "topic":
+ if cmd.type is None:
+ cmd.type = "topic"
+ subtype = cmd.type
+ else:
+ subtype = None
+
+ cmds = [cmd]
+ cmds.extend(bindings)
+
+ def declared():
+ self.address_cache[name] = (type, subtype)
+ action(type, subtype)
+
+ sst.write_cmds(cmds, declared)
+
+ def delete(self, sst, name, action):
+ def deleted():
+ del self.address_cache[name]
+ action()
+
+ def do_delete(type, subtype):
+ if type == "topic":
+ sst.write_cmd(ExchangeDelete(name), deleted)
+ elif type == "queue":
+ sst.write_cmd(QueueDelete(name), deleted)
+ elif type is None:
+ action()
+ else:
+ raise ValueError(type)
+ self.resolve(sst, name, do_delete, force=True)
+
+ def process(self, ssn):
+ if ssn.closed or ssn.closing: return
+
+ sst = self._attachments[ssn]
+
+ while sst.outgoing_idx < len(ssn.outgoing):
+ msg = ssn.outgoing[sst.outgoing_idx]
+ snd = msg._sender
+ # XXX: should check for sender error here
+ _snd = self._attachments.get(snd)
+ if _snd and snd.linked:
+ self.send(snd, msg)
+ sst.outgoing_idx += 1
+ else:
+ break
+
+ for snd in ssn.senders:
+ # XXX: should included snd.acked in this
+ if snd.synced >= snd.queued and sst.need_sync:
+ sst.write_cmd(ExecutionSync(), sync_noop)
+
+ for rcv in ssn.receivers:
+ self.process_receiver(rcv)
+
+ if ssn.acked:
+ messages = ssn.acked[sst.acked_idx:]
+ if messages:
+ ids = RangedSet()
+
+ disposed = [(DEFAULT_DISPOSITION, [])]
+ acked = []
+ for m in messages:
+ # XXX: we're ignoring acks that get lost when disconnected,
+ # could we deal this via some message-id based purge?
+ if m._transfer_id is None:
+ acked.append(m)
+ continue
+ ids.add(m._transfer_id)
+ if m._receiver._accept_mode is accept_mode.explicit:
+ disp = m._disposition or DEFAULT_DISPOSITION
+ last, msgs = disposed[-1]
+ if disp.type is last.type and disp.options == last.options:
+ msgs.append(m)
+ else:
+ disposed.append((disp, [m]))
+ else:
+ acked.append(m)
+
+ for range in ids:
+ sst.executed.add_range(range)
+ sst.write_op(SessionCompleted(sst.executed))
+
+ def ack_acker(msgs):
+ def ack_ack():
+ for m in msgs:
+ ssn.acked.remove(m)
+ sst.acked_idx -= 1
+ # XXX: should this check accept_mode too?
+ if not ssn.transactional:
+ sst.acked.remove(m)
+ return ack_ack
+
+ for disp, msgs in disposed:
+ if not msgs: continue
+ if disp.type is None:
+ op = MessageAccept
+ elif disp.type is RELEASED:
+ op = MessageRelease
+ elif disp.type is REJECTED:
+ op = MessageReject
+ sst.write_cmd(op(RangedSet(*[m._transfer_id for m in msgs]),
+ **disp.options),
+ ack_acker(msgs))
+ if log.isEnabledFor(DEBUG):
+ for m in msgs:
+ log.debug("SACK[%s]: %s, %s", ssn.log_id, m, m._disposition)
+
+ sst.acked.extend(messages)
+ sst.acked_idx += len(messages)
+ ack_acker(acked)()
+
+ if ssn.committing and not sst.committing:
+ def commit_ok():
+ del sst.acked[:]
+ ssn.committing = False
+ ssn.committed = True
+ ssn.aborting = False
+ ssn.aborted = False
+ sst.committing = False
+ sst.write_cmd(TxCommit(), commit_ok)
+ sst.committing = True
+
+ if ssn.aborting and not sst.aborting:
+ sst.aborting = True
+ def do_rb():
+ messages = sst.acked + ssn.unacked + ssn.incoming
+ ids = RangedSet(*[m._transfer_id for m in messages])
+ for range in ids:
+ sst.executed.add_range(range)
+ sst.write_op(SessionCompleted(sst.executed))
+ sst.write_cmd(MessageRelease(ids, True))
+ sst.write_cmd(TxRollback(), do_rb_ok)
+
+ def do_rb_ok():
+ del ssn.incoming[:]
+ del ssn.unacked[:]
+ del sst.acked[:]
+
+ for rcv in ssn.receivers:
+ rcv.impending = rcv.received
+ rcv.returned = rcv.received
+ # XXX: do we need to update granted here as well?
+
+ for rcv in ssn.receivers:
+ self.process_receiver(rcv)
+
+ ssn.aborting = False
+ ssn.aborted = True
+ ssn.committing = False
+ ssn.committed = False
+ sst.aborting = False
+
+ for rcv in ssn.receivers:
+ _rcv = self._attachments[rcv]
+ sst.write_cmd(MessageStop(_rcv.destination))
+ sst.write_cmd(ExecutionSync(), do_rb)
+
+ def grant(self, rcv):
+ sst = self._attachments[rcv.session]
+ _rcv = self._attachments.get(rcv)
+ if _rcv is None or not rcv.linked or _rcv.closing or _rcv.draining:
+ return
+
+ if rcv.granted is UNLIMITED:
+ if rcv.impending is UNLIMITED:
+ delta = 0
+ else:
+ delta = UNLIMITED
+ elif rcv.impending is UNLIMITED:
+ delta = -1
+ else:
+ delta = max(rcv.granted, rcv.received) - rcv.impending
+
+ if delta is UNLIMITED:
+ if not _rcv.bytes_open:
+ sst.write_cmd(MessageFlow(_rcv.destination, credit_unit.byte, UNLIMITED.value))
+ _rcv.bytes_open = True
+ sst.write_cmd(MessageFlow(_rcv.destination, credit_unit.message, UNLIMITED.value))
+ rcv.impending = UNLIMITED
+ elif delta > 0:
+ if not _rcv.bytes_open:
+ sst.write_cmd(MessageFlow(_rcv.destination, credit_unit.byte, UNLIMITED.value))
+ _rcv.bytes_open = True
+ sst.write_cmd(MessageFlow(_rcv.destination, credit_unit.message, delta))
+ rcv.impending += delta
+ elif delta < 0 and not rcv.draining:
+ _rcv.draining = True
+ def do_stop():
+ rcv.impending = rcv.received
+ _rcv.draining = False
+ _rcv.bytes_open = False
+ self.grant(rcv)
+ sst.write_cmd(MessageStop(_rcv.destination), do_stop)
+
+ if rcv.draining:
+ _rcv.draining = True
+ def do_flush():
+ rcv.impending = rcv.received
+ rcv.granted = rcv.impending
+ _rcv.draining = False
+ _rcv.bytes_open = False
+ rcv.draining = False
+ sst.write_cmd(MessageFlush(_rcv.destination), do_flush)
+
+
+ def process_receiver(self, rcv):
+ if rcv.closed: return
+ self.grant(rcv)
+
+ def send(self, snd, msg):
+ sst = self._attachments[snd.session]
+ _snd = self._attachments[snd]
+
+ if msg.subject is None or _snd._exchange == "":
+ rk = _snd._routing_key
+ else:
+ rk = msg.subject
+
+ if msg.subject is None:
+ subject = _snd.subject
+ else:
+ subject = msg.subject
+
+ # XXX: do we need to query to figure out how to create the reply-to interoperably?
+ if msg.reply_to:
+ rt = addr2reply_to(msg.reply_to)
+ else:
+ rt = None
+ content_encoding = msg.properties.get("x-amqp-0-10.content-encoding")
+ dp = DeliveryProperties(routing_key=rk)
+ mp = MessageProperties(message_id=msg.id,
+ user_id=msg.user_id,
+ reply_to=rt,
+ correlation_id=msg.correlation_id,
+ app_id = msg.properties.get("x-amqp-0-10.app-id"),
+ content_type=msg.content_type,
+ content_encoding=content_encoding,
+ application_headers=msg.properties)
+ if subject is not None:
+ if mp.application_headers is None:
+ mp.application_headers = {}
+ mp.application_headers[SUBJECT] = subject
+ if msg.durable is not None:
+ if msg.durable:
+ dp.delivery_mode = delivery_mode.persistent
+ else:
+ dp.delivery_mode = delivery_mode.non_persistent
+ if msg.priority is not None:
+ dp.priority = msg.priority
+ if msg.ttl is not None:
+ dp.ttl = long(msg.ttl*1000)
+ enc, dec = get_codec(msg.content_type)
+ body = enc(msg.content)
+
+ # XXX: this is not safe for out of order, can this be triggered by pre_ack?
+ def msg_acked():
+ # XXX: should we log the ack somehow too?
+ snd.acked += 1
+ m = snd.session.outgoing.pop(0)
+ sst.outgoing_idx -= 1
+ log.debug("RACK[%s]: %s", sst.session.log_id, msg)
+ assert msg == m
+
+ xfr = MessageTransfer(destination=_snd._exchange, headers=(dp, mp),
+ payload=body)
+
+ if _snd.pre_ack:
+ sst.write_cmd(xfr)
+ else:
+ sst.write_cmd(xfr, msg_acked, sync=msg._sync)
+
+ log.debug("SENT[%s]: %s", sst.session.log_id, msg)
+
+ if _snd.pre_ack:
+ msg_acked()
+
+ def do_message_transfer(self, xfr):
+ sst = self.get_sst(xfr)
+ ssn = sst.session
+
+ msg = self._decode(xfr)
+ rcv = sst.destinations[xfr.destination].target
+ msg._receiver = rcv
+ if rcv.impending is not UNLIMITED:
+ assert rcv.received < rcv.impending, "%s, %s" % (rcv.received, rcv.impending)
+ rcv.received += 1
+ log.debug("RCVD[%s]: %s", ssn.log_id, msg)
+ ssn.incoming.append(msg)
+
+ def _decode(self, xfr):
+ dp = EMPTY_DP
+ mp = EMPTY_MP
+
+ for h in xfr.headers:
+ if isinstance(h, DeliveryProperties):
+ dp = h
+ elif isinstance(h, MessageProperties):
+ mp = h
+
+ ap = mp.application_headers
+ enc, dec = get_codec(mp.content_type)
+ content = dec(xfr.payload)
+ msg = Message(content)
+ msg.id = mp.message_id
+ if ap is not None:
+ msg.subject = ap.get(SUBJECT)
+ msg.user_id = mp.user_id
+ if mp.reply_to is not None:
+ msg.reply_to = reply_to2addr(mp.reply_to)
+ msg.correlation_id = mp.correlation_id
+ if dp.delivery_mode is not None:
+ msg.durable = dp.delivery_mode == delivery_mode.persistent
+ msg.priority = dp.priority
+ if dp.ttl is not None:
+ msg.ttl = dp.ttl/1000.0
+ msg.redelivered = dp.redelivered
+ msg.properties = mp.application_headers or {}
+ if mp.app_id is not None:
+ msg.properties["x-amqp-0-10.app-id"] = mp.app_id
+ if mp.content_encoding is not None:
+ msg.properties["x-amqp-0-10.content-encoding"] = mp.content_encoding
+ if dp.routing_key is not None:
+ msg.properties["x-amqp-0-10.routing-key"] = dp.routing_key
+ msg.content_type = mp.content_type
+ msg._transfer_id = xfr.id
+ return msg