#!/usr/bin/env python # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # """ Utility code to translate between python objects and AMQP encoded data fields. The unit test for this module is located in tests/codec.py """ import re, qpid, spec from cStringIO import StringIO from struct import * from reference import ReferenceId class EOF(Exception): pass TYPE_ALIASES = { "long_string": "longstr", "unsigned_int": "long" } class Codec: """ class that handles encoding/decoding of AMQP primitives """ def __init__(self, stream, spec): """ initializing the stream/fields used """ self.stream = stream self.spec = spec self.nwrote = 0 self.nread = 0 self.incoming_bits = [] self.outgoing_bits = [] self.types = {} self.codes = {} self.encodings = { basestring: "longstr", int: "long", long: "long", None.__class__:"void", list: "sequence", tuple: "sequence", dict: "table" } for constant in self.spec.constants: if constant.klass == "field-table-type": type = constant.name.replace("field_table_", "") self.typecode(constant.id, TYPE_ALIASES.get(type, type)) if not self.types: self.typecode(ord('S'), "longstr") self.typecode(ord('I'), "long") def typecode(self, code, type): self.types[code] = type self.codes[type] = code def resolve(self, klass): if self.encodings.has_key(klass): return self.encodings[klass] for base in klass.__bases__: result = self.resolve(base) if result != None: return result def read(self, n): """ reads in 'n' bytes from the stream. Can raise EOF exception """ self.clearbits() data = self.stream.read(n) if n > 0 and len(data) == 0: raise EOF() self.nread += len(data) return data def write(self, s): """ writes data 's' to the stream """ self.flushbits() self.stream.write(s) self.nwrote += len(s) def flush(self): """ flushes the bits and data present in the stream """ self.flushbits() self.stream.flush() def flushbits(self): """ flushes the bits(compressed into octets) onto the stream """ if len(self.outgoing_bits) > 0: bytes = [] index = 0 for b in self.outgoing_bits: if index == 0: bytes.append(0) if b: bytes[-1] |= 1 << index index = (index + 1) % 8 del self.outgoing_bits[:] for byte in bytes: self.encode_octet(byte) def clearbits(self): if self.incoming_bits: self.incoming_bits = [] def pack(self, fmt, *args): """ packs the data 'args' as per the format 'fmt' and writes it to the stream """ self.write(pack(fmt, *args)) def unpack(self, fmt): """ reads data from the stream and unpacks it as per the format 'fmt' """ size = calcsize(fmt) data = self.read(size) values = unpack(fmt, data) if len(values) == 1: return values[0] else: return values def encode(self, type, value): """ calls the appropriate encode function e.g. encode_octet, encode_short etc. """ if isinstance(type, spec.Struct): self.encode_struct(type, value) else: getattr(self, "encode_" + type)(value) def decode(self, type): """ calls the appropriate decode function e.g. decode_octet, decode_short etc. """ if isinstance(type, spec.Struct): return self.decode_struct(type) else: return getattr(self, "decode_" + type)() def encode_bit(self, o): """ encodes a bit """ if o: self.outgoing_bits.append(True) else: self.outgoing_bits.append(False) def decode_bit(self): """ decodes a bit """ if len(self.incoming_bits) == 0: bits = self.decode_octet() for i in range(8): self.incoming_bits.append(bits >> i & 1 != 0) return self.incoming_bits.pop(0) def encode_octet(self, o): """ encodes octet (8 bits) data 'o' in network byte order """ # octet's valid range is [0,255] if (o < 0 or o > 255): raise ValueError('Valid range of octet is [0,255]') self.pack("!B", o) def decode_octet(self): """ decodes a octet (8 bits) encoded in network byte order """ return self.unpack("!B") def encode_short(self, o): """ encodes short (16 bits) data 'o' in network byte order """ # short int's valid range is [0,65535] if (o < 0 or o > 65535): raise ValueError('Valid range of short int is [0,65535]: %s' % o) self.pack("!H", o) def decode_short(self): """ decodes a short (16 bits) in network byte order """ return self.unpack("!H") def encode_long(self, o): """ encodes long (32 bits) data 'o' in network byte order """ # we need to check both bounds because on 64 bit platforms # struct.pack won't raise an error if o is too large if (o < 0 or o > 4294967295): raise ValueError('Valid range of long int is [0,4294967295]') self.pack("!L", o) def decode_long(self): """ decodes a long (32 bits) in network byte order """ return self.unpack("!L") def encode_signed_long(self, o): self.pack("!q", o) def decode_signed_long(self): return self.unpack("!q") def encode_longlong(self, o): """ encodes long long (64 bits) data 'o' in network byte order """ self.pack("!Q", o) def decode_longlong(self): """ decodes a long long (64 bits) in network byte order """ return self.unpack("!Q") def enc_str(self, fmt, s): """ encodes a string 's' in network byte order as per format 'fmt' """ size = len(s) self.pack(fmt, size) self.write(s) def dec_str(self, fmt): """ decodes a string in network byte order as per format 'fmt' """ size = self.unpack(fmt) return self.read(size) def encode_shortstr(self, s): """ encodes a short string 's' in network byte order """ # short strings are limited to 255 octets if len(s) > 255: raise ValueError('Short strings are limited to 255 octets') self.enc_str("!B", s) def decode_shortstr(self): """ decodes a short string in network byte order """ return self.dec_str("!B") def encode_longstr(self, s): """ encodes a long string 's' in network byte order """ if isinstance(s, dict): self.encode_table(s) else: self.enc_str("!L", s) def decode_longstr(self): """ decodes a long string 's' in network byte order """ return self.dec_str("!L") def encode_table(self, tbl): """ encodes a table data structure in network byte order """ enc = StringIO() codec = Codec(enc, self.spec) if tbl: for key, value in tbl.items(): if self.spec.major == 8 and self.spec.minor == 0 and len(key) > 128: raise ValueError("field table key too long: '%s'" % key) type = self.resolve(value.__class__) if type == None: raise ValueError("no encoding for: " + value.__class__) codec.encode_shortstr(key) codec.encode_octet(self.codes[type]) codec.encode(type, value) s = enc.getvalue() self.encode_long(len(s)) self.write(s) def decode_table(self): """ decodes a table data structure in network byte order """ size = self.decode_long() start = self.nread result = {} while self.nread - start < size: key = self.decode_shortstr() code = self.decode_octet() if self.types.has_key(code): value = self.decode(self.types[code]) else: w = width(code) if fixed(code): value = self.read(w) else: value = self.read(self.dec_num(w)) result[key] = value return result def encode_timestamp(self, t): """ encodes a timestamp data structure in network byte order """ self.encode_longlong(t) def decode_timestamp(self): """ decodes a timestamp data structure in network byte order """ return self.decode_longlong() def encode_content(self, s): """ encodes a content data structure in network byte order content can be passed as a string in which case it is assumed to be inline data, or as an instance of ReferenceId indicating it is a reference id """ if isinstance(s, ReferenceId): self.encode_octet(1) self.encode_longstr(s.id) else: self.encode_octet(0) self.encode_longstr(s) def decode_content(self): """ decodes a content data structure in network byte order return a string for inline data and a ReferenceId instance for references """ type = self.decode_octet() if type == 0: return self.decode_longstr() else: return ReferenceId(self.decode_longstr()) # new domains for 0-10: def encode_rfc1982_long(self, s): self.encode_long(s) def decode_rfc1982_long(self): return self.decode_long() def encode_rfc1982_long_set(self, s): self.encode_short(len(s) * 4) for i in s: self.encode_long(i) def decode_rfc1982_long_set(self): count = self.decode_short() / 4 set = [] for i in range(0, count): set.append(self.decode_long()) return set; def encode_uuid(self, s): self.pack("16s", s) def decode_uuid(self): return self.unpack("16s") def enc_num(self, width, n): if width == 1: self.encode_octet(n) elif width == 2: self.encode_short(n) elif width == 3: self.encode_long(n) else: raise ValueError("invalid width: %s" % width) def dec_num(self, width): if width == 1: return self.decode_octet() elif width == 2: return self.decode_short() elif width == 4: return self.decode_long() else: raise ValueError("invalid width: %s" % width) def encode_struct(self, type, s): if False and type.size: enc = StringIO() codec = Codec(enc, self.spec) codec.encode_struct_body(type, s) codec.flush() body = enc.getvalue() self.enc_num(type.size, len(body)) self.write(body) else: self.encode_struct_body(type, s) def decode_struct(self, type): if False and type.size: size = self.dec_num(type.size) if size == 0: return None return self.decode_struct_body(type) def encode_struct_body(self, type, s): reserved = 8*type.pack - len(type.fields) assert reserved >= 0 for f in type.fields: if s == None: self.encode_bit(False) elif f.type == "bit": self.encode_bit(s.get(f.name)) else: self.encode_bit(s.has(f.name)) for i in range(reserved): self.encode_bit(False) for f in type.fields: if f.type != "bit" and s != None and s.has(f.name): self.encode(f.type, s.get(f.name)) self.flush() def decode_struct_body(self, type): reserved = 8*type.pack - len(type.fields) assert reserved >= 0 s = qpid.Struct(type) for f in type.fields: if f.type == "bit": s.set(f.name, self.decode_bit()) elif self.decode_bit(): s.set(f.name, None) for i in range(reserved): if self.decode_bit(): raise ValueError("expecting reserved flag") for f in type.fields: if f.type != "bit" and s.has(f.name): s.set(f.name, self.decode(f.type)) self.clearbits() return s def encode_long_struct(self, s): enc = StringIO() codec = Codec(enc, self.spec) type = s.type codec.encode_short(type.type) codec.encode_struct_body(type, s) self.encode_longstr(enc.getvalue()) def decode_long_struct(self): codec = Codec(StringIO(self.decode_longstr()), self.spec) type = self.spec.structs[codec.decode_short()] return codec.decode_struct_body(type) def decode_array(self): size = self.decode_long() code = self.decode_octet() count = self.decode_long() result = [] for i in range(0, count): if self.types.has_key(code): value = self.decode(self.types[code]) else: w = width(code) if fixed(code): value = self.read(w) else: value = self.read(self.dec_num(w)) result.append(value) return result def fixed(code): return (code >> 6) != 2 def width(code): # decimal if code >= 192: decsel = (code >> 4) & 3 if decsel == 0: return 5 elif decsel == 1: return 9 elif decsel == 3: return 0 else: raise ValueError(code) # variable width elif code < 192 and code >= 128: lenlen = (code >> 4) & 3 if lenlen == 3: raise ValueError(code) return 2 ** lenlen # fixed width else: return (code >> 4) & 7