diff options
Diffstat (limited to 'python/qpid/connection.py')
-rw-r--r-- | python/qpid/connection.py | 566 |
1 files changed, 132 insertions, 434 deletions
diff --git a/python/qpid/connection.py b/python/qpid/connection.py index eafad7067a..dc72cd9cb8 100644 --- a/python/qpid/connection.py +++ b/python/qpid/connection.py @@ -17,467 +17,165 @@ # under the License. # -""" -A Connection class containing socket code that uses the spec metadata -to read and write Frame objects. This could be used by a client, -server, or even a proxy implementation. -""" +import datatypes, session +from threading import Thread, Condition, RLock +from util import wait +from framer import Closed +from assembler import Assembler, Segment +from codec010 import StringCodec +from session import Session +from invoker import Invoker +from spec010 import Control, Command +from exceptions import * +from logging import getLogger +import delegates -import socket, codec, logging, qpid -from cStringIO import StringIO -from spec import load -from codec import EOF +class ChannelBusy(Exception): pass -class SockIO: +class ChannelsBusy(Exception): pass - def __init__(self, sock): - self.sock = sock +class SessionBusy(Exception): pass - def write(self, buf): -# print "OUT: %r" % buf - self.sock.sendall(buf) +def client(*args): + return delegates.Client(*args) - def read(self, n): - data = "" - while len(data) < n: - try: - s = self.sock.recv(n - len(data)) - except socket.error: - break - if len(s) == 0: - break -# print "IN: %r" % s - data += s - return data - - def flush(self): - pass - - def close(self): - self.sock.shutdown(socket.SHUT_RDWR) +def server(*args): + return delegates.Server(*args) -def connect(host, port): - sock = socket.socket() - sock.connect((host, port)) - sock.setblocking(1) - return SockIO(sock) +class Connection(Assembler): -def listen(host, port, predicate = lambda: True): - sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((host, port)) - sock.listen(5) - while predicate(): - s, a = sock.accept() - yield SockIO(s) - -class Connection: - - def __init__(self, io, spec): - self.codec = codec.Codec(io, spec) + def __init__(self, sock, spec, delegate=client): + Assembler.__init__(self, sock) self.spec = spec - self.FRAME_END = self.spec.constants.byname["frame_end"].id - self.write = getattr(self, "write_%s_%s" % (self.spec.major, self.spec.minor)) - self.read = getattr(self, "read_%s_%s" % (self.spec.major, self.spec.minor)) - - def flush(self): - self.codec.flush() - - INIT="!4s4B" + self.track = self.spec["track"] - def init(self): - self.codec.pack(Connection.INIT, "AMQP", 1, 1, self.spec.major, - self.spec.minor) + self.lock = RLock() + self.attached = {} + self.sessions = {} - def tini(self): - self.codec.unpack(Connection.INIT) + self.condition = Condition() + self.opened = False - def write_8_0(self, frame): - c = self.codec - c.encode_octet(self.spec.constants.byname[frame.type].id) - c.encode_short(frame.channel) - body = StringIO() - enc = codec.Codec(body, self.spec) - frame.encode(enc) - enc.flush() - c.encode_longstr(body.getvalue()) - c.encode_octet(self.FRAME_END) + self.thread = Thread(target=self.run) + self.thread.setDaemon(True) - def read_8_0(self): - c = self.codec - type = self.spec.constants.byid[c.decode_octet()].name - channel = c.decode_short() - body = c.decode_longstr() - dec = codec.Codec(StringIO(body), self.spec) - frame = Frame.DECODERS[type].decode(self.spec, dec, len(body)) - frame.channel = channel - end = c.decode_octet() - if end != self.FRAME_END: - garbage = "" - while end != self.FRAME_END: - garbage += chr(end) - end = c.decode_octet() - raise "frame error: expected %r, got %r" % (self.FRAME_END, garbage) - return frame + self.channel_max = 65535 - def write_0_10(self, frame): - c = self.codec - flags = 0 - if frame.bof: flags |= 0x08 - if frame.eof: flags |= 0x04 - if frame.bos: flags |= 0x02 - if frame.eos: flags |= 0x01 + self.delegate = delegate(self) - c.encode_octet(flags) # TODO: currently fixed at ver=0, B=E=b=e=1 - c.encode_octet(self.spec.constants.byname[frame.type].id) - body = StringIO() - enc = codec.Codec(body, self.spec) - frame.encode(enc) - enc.flush() - frame_size = len(body.getvalue()) + 12 # TODO: Magic number (frame header size) - c.encode_short(frame_size) - c.encode_octet(0) # Reserved - c.encode_octet(frame.subchannel & 0x0f) - c.encode_short(frame.channel) - c.encode_long(0) # Reserved - c.write(body.getvalue()) - c.encode_octet(self.FRAME_END) - - def read_0_10(self): - c = self.codec - flags = c.decode_octet() # TODO: currently ignoring flags - framing_version = (flags & 0xc0) >> 6 - if framing_version != 0: - raise "frame error: unknown framing version" - type = self.spec.constants.byid[c.decode_octet()].name - frame_size = c.decode_short() - if frame_size < 12: # TODO: Magic number (frame header size) - raise "frame error: frame size too small" - reserved1 = c.decode_octet() - field = c.decode_octet() - subchannel = field & 0x0f - channel = c.decode_short() - reserved2 = c.decode_long() # TODO: reserved maybe need to ensure 0 - if (flags & 0x30) != 0 or reserved1 != 0 or (field & 0xf0) != 0: - raise "frame error: reserved bits not all zero" - body_size = frame_size - 12 # TODO: Magic number (frame header size) - body = c.read(body_size) - dec = codec.Codec(StringIO(body), self.spec) + def attach(self, name, ch, delegate, force=False): + self.lock.acquire() try: - frame = Frame.DECODERS[type].decode(self.spec, dec, len(body)) - except EOF: - raise "truncated frame body: %r" % body - frame.channel = channel - frame.subchannel = subchannel - end = c.decode_octet() - if end != self.FRAME_END: - garbage = "" - while end != self.FRAME_END: - garbage += chr(end) - end = c.decode_octet() - raise "frame error: expected %r, got %r" % (self.FRAME_END, garbage) - return frame - - def write_99_0(self, frame): - self.write_0_10(frame) - - def read_99_0(self): - return self.read_0_10() - -class Frame: - - DECODERS = {} - - class __metaclass__(type): - - def __new__(cls, name, bases, dict): - for attr in ("encode", "decode", "type"): - if not dict.has_key(attr): - raise TypeError("%s must define %s" % (name, attr)) - dict["decode"] = staticmethod(dict["decode"]) - if dict.has_key("__init__"): - __init__ = dict["__init__"] - def init(self, *args, **kwargs): - args = list(args) - self.init(args, kwargs) - __init__(self, *args, **kwargs) - dict["__init__"] = init - t = type.__new__(cls, name, bases, dict) - if t.type != None: - Frame.DECODERS[t.type] = t - return t - - type = None - - def init(self, args, kwargs): - self.channel = kwargs.pop("channel", 0) - self.subchannel = kwargs.pop("subchannel", 0) - self.bos = True - self.eos = True - self.bof = True - self.eof = True - - def encode(self, enc): abstract - - def decode(spec, dec, size): abstract - -class Method(Frame): - - type = "frame_method" - - def __init__(self, method, args): - if len(args) != len(method.fields): - argspec = ["%s: %s" % (f.name, f.type) - for f in method.fields] - raise TypeError("%s.%s expecting (%s), got %s" % - (method.klass.name, method.name, ", ".join(argspec), - args)) - self.method = method - self.method_type = method - self.args = args - self.eof = not method.content - - def encode(self, c): - version = (c.spec.major, c.spec.minor) - if version == (0, 10) or version == (99, 0): - c.encode_octet(self.method.klass.id) - c.encode_octet(self.method.id) - else: - c.encode_short(self.method.klass.id) - c.encode_short(self.method.id) - for field, arg in zip(self.method.fields, self.args): - c.encode(field.type, arg) - - def decode(spec, c, size): - version = (c.spec.major, c.spec.minor) - if version == (0, 10) or version == (99, 0): - klass = spec.classes.byid[c.decode_octet()] - meth = klass.methods.byid[c.decode_octet()] - else: - klass = spec.classes.byid[c.decode_short()] - meth = klass.methods.byid[c.decode_short()] - args = tuple([c.decode(f.type) for f in meth.fields]) - return Method(meth, args) - - def __str__(self): - return "[%s] %s %s" % (self.channel, self.method, - ", ".join([str(a) for a in self.args])) - -class Request(Frame): - - type = "frame_request" - - def __init__(self, id, response_mark, method): - self.id = id - self.response_mark = response_mark - self.method = method - self.method_type = method.method_type - self.args = method.args - - def encode(self, enc): - enc.encode_longlong(self.id) - enc.encode_longlong(self.response_mark) - # reserved - enc.encode_long(0) - self.method.encode(enc) - - def decode(spec, dec, size): - id = dec.decode_longlong() - mark = dec.decode_longlong() - # reserved - dec.decode_long() - method = Method.decode(spec, dec, size - 20) - return Request(id, mark, method) - - def __str__(self): - return "[%s] Request(%s) %s" % (self.channel, self.id, self.method) - -class Response(Frame): - - type = "frame_response" - - def __init__(self, id, request_id, batch_offset, method): - self.id = id - self.request_id = request_id - self.batch_offset = batch_offset - self.method = method - self.method_type = method.method_type - self.args = method.args - - def encode(self, enc): - enc.encode_longlong(self.id) - enc.encode_longlong(self.request_id) - enc.encode_long(self.batch_offset) - self.method.encode(enc) - - def decode(spec, dec, size): - id = dec.decode_longlong() - request_id = dec.decode_longlong() - batch_offset = dec.decode_long() - method = Method.decode(spec, dec, size - 20) - return Response(id, request_id, batch_offset, method) - - def __str__(self): - return "[%s] Response(%s,%s,%s) %s" % (self.channel, self.id, self.request_id, self.batch_offset, self.method) - -def uses_struct_encoding(spec): - return (spec.major == 0 and spec.minor == 10) or (spec.major == 99 and spec.minor == 0) - -class Header(Frame): - - type = "frame_header" - - def __init__(self, klass, weight, size, properties): - self.klass = klass - self.weight = weight - self.size = size - self.properties = properties - self.eof = size == 0 - self.bof = False - - def __getitem__(self, name): - return self.properties[name] - - def __setitem__(self, name, value): - self.properties[name] = value - - def __delitem__(self, name): - del self.properties[name] - - def encode(self, c): - if uses_struct_encoding(c.spec): - self.encode_structs(c) - else: - self.encode_legacy(c) - - def encode_structs(self, c): - # XXX - structs = [qpid.Struct(c.spec.domains.byname["delivery_properties"].type), - qpid.Struct(c.spec.domains.byname["message_properties"].type)] - - # XXX - props = self.properties.copy() - for k in self.properties: - for s in structs: - if s.exists(k): - s.set(k, props.pop(k)) - if props: - raise TypeError("no such property: %s" % (", ".join(props))) - - # message properties store the content-length now, and weight is - # deprecated - if self.size != None: - structs[1].content_length = self.size - - for s in structs: - c.encode_long_struct(s) - - def encode_legacy(self, c): - c.encode_short(self.klass.id) - c.encode_short(self.weight) - c.encode_longlong(self.size) - - # property flags - nprops = len(self.klass.fields) - flags = 0 - for i in range(nprops): - f = self.klass.fields.items[i] - flags <<= 1 - if self.properties.get(f.name) != None: - flags |= 1 - # the last bit indicates more flags - if i > 0 and (i % 15) == 0: - flags <<= 1 - if nprops > (i + 1): - flags |= 1 - c.encode_short(flags) - flags = 0 - flags <<= ((16 - (nprops % 15)) % 16) - c.encode_short(flags) - - # properties - for f in self.klass.fields: - v = self.properties.get(f.name) - if v != None: - c.encode(f.type, v) - - def decode(spec, c, size): - if uses_struct_encoding(spec): - return Header.decode_structs(spec, c, size) + ssn = self.attached.get(ch.id) + if ssn is not None: + if ssn.name != name: + raise ChannelBusy(ch, ssn) + else: + ssn = self.sessions.get(name) + if ssn is None: + ssn = Session(name, self.spec, delegate=delegate) + self.sessions[name] = ssn + elif ssn.channel is not None: + if force: + del self.attached[ssn.channel.id] + ssn.channel = None + else: + raise SessionBusy(ssn) + self.attached[ch.id] = ssn + ssn.channel = ch + ch.session = ssn + return ssn + finally: + self.lock.release() + + def detach(self, name, ch): + self.lock.acquire() + try: + self.attached.pop(ch.id, None) + ssn = self.sessions.pop(name, None) + if ssn is not None: + ssn.channel = None + return ssn + finally: + self.lock.release() + + def __channel(self): + # XXX: ch 0? + for i in xrange(self.channel_max): + if not self.attached.has_key(i): + return i else: - return Header.decode_legacy(spec, c, size) + raise ChannelsBusy() - @staticmethod - def decode_structs(spec, c, size): - structs = [] - start = c.nread - while c.nread - start < size: - structs.append(c.decode_long_struct()) - - # XXX - props = {} - length = None - for s in structs: - for f in s.type.fields: - if s.has(f.name): - props[f.name] = s.get(f.name) - if f.name == "content_length": - length = s.get(f.name) - return Header(None, 0, length, props) - - @staticmethod - def decode_legacy(spec, c, size): - klass = spec.classes.byid[c.decode_short()] - weight = c.decode_short() - size = c.decode_longlong() - - # property flags - bits = [] + def session(self, name, timeout=None, delegate=session.client): + self.lock.acquire() + try: + ch = Channel(self, self.__channel()) + ssn = self.attach(name, ch, delegate) + ssn.channel.session_attach(name) + if wait(ssn.condition, lambda: ssn.channel is not None, timeout): + return ssn + else: + self.detach(name, ch) + raise Timeout() + finally: + self.lock.release() + + def start(self, timeout=None): + self.delegate.start() + self.thread.start() + if not wait(self.condition, lambda: self.opened, timeout): + raise Timeout() + + def run(self): + # XXX: we don't really have a good way to exit this loop without + # getting the other end to kill the socket while True: - flags = c.decode_short() - for i in range(15, 0, -1): - if flags >> i & 0x1 != 0: - bits.append(True) - else: - bits.append(False) - if flags & 0x1 == 0: + try: + seg = self.read_segment() + except Closed: break + self.delegate.received(seg) - # properties - properties = {} - for b, f in zip(bits, klass.fields): - if b: - # Note: decode returns a unicode u'' string but only - # plain '' strings can be used as keywords so we need to - # stringify the names. - properties[str(f.name)] = c.decode(f.type) - return Header(klass, weight, size, properties) + def close(self, timeout=None): + if not self.opened: return + Channel(self, 0).connection_close(200) + if not wait(self.condition, lambda: not self.opened, timeout): + raise Timeout() + self.thread.join(timeout=timeout) def __str__(self): - return "%s %s %s %s" % (self.klass, self.weight, self.size, - self.properties) + return "%s:%s" % self.sock.getsockname() + + def __repr__(self): + return str(self) -class Body(Frame): +log = getLogger("qpid.io.ctl") - type = "frame_body" +class Channel(Invoker): - def __init__(self, content): - self.content = content - self.eof = True - self.bof = False + def __init__(self, connection, id): + self.connection = connection + self.id = id + self.session = None - def encode(self, enc): - enc.write(self.content) + def resolve_method(self, name): + inst = self.connection.spec.instructions.get(name) + if inst is not None and isinstance(inst, Control): + return inst + else: + return None - def decode(spec, dec, size): - return Body(dec.read(size)) + def invoke(self, type, args, kwargs): + ctl = type.new(args, kwargs) + sc = StringCodec(self.connection.spec) + sc.write_control(ctl) + self.connection.write_segment(Segment(True, True, type.segment_type, + type.track, self.id, sc.encoded)) + log.debug("SENT %s", ctl) def __str__(self): - return "Body(%r)" % self.content + return "%s[%s]" % (self.connection, self.id) -# TODO: -# OOB_METHOD = "frame_oob_method" -# OOB_HEADER = "frame_oob_header" -# OOB_BODY = "frame_oob_body" -# TRACE = "frame_trace" -# HEARTBEAT = "frame_heartbeat" + def __repr__(self): + return str(self) |