diff options
author | Tony Garnock-Jones <tonyg@kcbbs.gen.nz> | 2009-04-25 19:41:56 +0100 |
---|---|---|
committer | Tony Garnock-Jones <tonyg@kcbbs.gen.nz> | 2009-04-25 19:41:56 +0100 |
commit | 67970c9c56ebd49b57e61d50255b04fa1ac7d27d (patch) | |
tree | ec02d63a0a7b8f31b7a0c0426eeaab25f5884e1f /librabbitmq/codegen.py | |
parent | 8bf174bc0a3d682ff9c8c11435008f1adf3c288f (diff) | |
download | rabbitmq-c-github-ask-67970c9c56ebd49b57e61d50255b04fa1ac7d27d.tar.gz |
Codegen, codec
Diffstat (limited to 'librabbitmq/codegen.py')
-rw-r--r-- | librabbitmq/codegen.py | 409 |
1 files changed, 250 insertions, 159 deletions
diff --git a/librabbitmq/codegen.py b/librabbitmq/codegen.py index d73db73..00159ae 100644 --- a/librabbitmq/codegen.py +++ b/librabbitmq/codegen.py @@ -47,23 +47,6 @@ cTypeMap = { 'timestamp': 'uint64_t', } -def convertTable(d): - if len(d) == 0: - return "[]" - else: raise 'Non-empty table defaults not supported', d - -def convertBytes(x): - return "(amqp_bytes_t) { .length = %d, .bytes = \"%s\" }" % (len(x), x) - -defaultValueTypeConvMap = { - bool : lambda x: x and "1" or "0", - str : convertBytes, - int : lambda x: str(x), - float : lambda x: str(x), - dict: convertTable, - unicode: lambda x: convertBytes(x.encode("utf-8")) -} - def c_ize(s): s = s.replace('-', '_') s = s.replace(' ', '_') @@ -72,34 +55,18 @@ def c_ize(s): AmqpMethod.defName = lambda m: cConstantName(c_ize(m.klass.name) + '_' + c_ize(m.name) + "_method") AmqpMethod.structName = lambda m: "amqp_" + c_ize(m.klass.name) + '_' + c_ize(m.name) + "_t" +AmqpClass.structName = lambda c: "amqp_" + c_ize(c.name) + "_properties_t" + def cConstantName(s): return 'AMQP_' + '_'.join(re.split('[- ]', s.upper())) -class PackedMethodBitField: - def __init__(self, index): - self.index = index - self.domain = 'bit' - self.contents = [] - - def extend(self, f): - self.contents.append(f) +def cFlagName(c, f): + return cConstantName(c.name + '_' + f.name) + '_FLAG' - def count(self): - return len(self.contents) - - def full(self): - return self.count() == 8 - def genErl(spec): def cType(domain): return cTypeMap[spec.resolveDomain(domain)] - def fieldTypeList(fields): - return '[' + ', '.join([cType(f.domain) for f in fields]) + ']' - - def fieldNameList(fields): - return '[' + ', '.join([c_ize(f.name) for f in fields]) + ']' - def fieldTempList(fields): return '[' + ', '.join(['F' + str(f.index) for f in fields]) + ']' @@ -109,171 +76,297 @@ def genErl(spec): def genLookupMethodName(m): print ' case %s: return "%s";' % (m.defName(), m.defName()) - def packMethodFields(fields): - packed = [] - bitfield = None - for f in fields: - if cType(f.domain) == 'bit': - if not(bitfield) or bitfield.full(): - bitfield = PackedMethodBitField(f.index) - packed.append(bitfield) - bitfield.extend(f) - else: - bitfield = None - packed.append(f) - return packed + def genSingleDecode(prefix, cLvalue, unresolved_domain): + type = spec.resolveDomain(unresolved_domain) + if type == 'shortstr': + print prefix + "%s.len = D_8(encoded, offset);" % (cLvalue,) + print prefix + "offset++;" + print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue) + print prefix + "offset += %s.len;" % (cLvalue,) + elif type == 'longstr': + print prefix + "%s.len = D_32(encoded, offset);" % (cLvalue,) + print prefix + "offset += 4;" + print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue) + print prefix + "offset += %s.len;" % (cLvalue,) + elif type == 'octet': + print prefix + "%s = D_8(encoded, offset);" % (cLvalue,) + print prefix + "offset++;" + elif type == 'short': + print prefix + "%s = D_16(encoded, offset);" % (cLvalue,) + print prefix + "offset += 2;" + elif type == 'long': + print prefix + "%s = D_32(encoded, offset);" % (cLvalue,) + print prefix + "offset += 4;" + elif type == 'longlong': + print prefix + "%s = D_64(encoded, offset);" % (cLvalue,) + print prefix + "offset += 8;" + elif type == 'timestamp': + print prefix + "%s = D_64(encoded, offset);" % (cLvalue,) + print prefix + "offset += 8;" + elif type == 'bit': + raise "Can't decode bit in genSingleDecode" + elif type == 'table': + print prefix + "table_result = amqp_decode_table(encoded, pool, &(%s), &offset);" % \ + (cLvalue,) + print prefix + "if (table_result != 0) return table_result;" + else: + raise "Illegal domain in genSingleDecode", type - def methodFieldFragment(f): - type = cType(f.domain) - p = 'F' + str(f.index) + def genSingleEncode(prefix, cValue, unresolved_domain): + type = spec.resolveDomain(unresolved_domain) if type == 'shortstr': - return p+'Len:8/unsigned, '+p+':'+p+'Len/binary' + print prefix + "E_8(encoded, offset, %s.len);" % (cValue,) + print prefix + "offset++;" + print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue) + print prefix + "offset += %s.len;" % (cValue,) elif type == 'longstr': - return p+'Len:32/unsigned, '+p+':'+p+'Len/binary' + print prefix + "E_32(encoded, offset, %s.len);" % (cValue,) + print prefix + "offset += 4;" + print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue) + print prefix + "offset += %s.len;" % (cValue,) elif type == 'octet': - return p+':8/unsigned' - elif type == 'shortint': - return p+':16/unsigned' - elif type == 'longint': - return p+':32/unsigned' - elif type == 'longlongint': - return p+':64/unsigned' + print prefix + "E_8(encoded, offset, %s);" % (cValue,) + print prefix + "offset++;" + elif type == 'short': + print prefix + "E_16(encoded, offset, %s);" % (cValue,) + print prefix + "offset += 2;" + elif type == 'long': + print prefix + "E_32(encoded, offset, %s);" % (cValue,) + print prefix + "offset += 4;" + elif type == 'longlong': + print prefix + "E_64(encoded, offset, %s);" % (cValue,) + print prefix + "offset += 8;" elif type == 'timestamp': - return p+':64/unsigned' + print prefix + "E_64(encoded, offset, %s);" % (cValue,) + print prefix + "offset += 8;" elif type == 'bit': - return p+'Bits:8' + raise "Can't encode bit in genSingleDecode" elif type == 'table': - return p+'Len:32/unsigned, '+p+'Tab:'+p+'Len/binary' + print prefix + "table_result = amqp_encode_table(encoded, &(%s), &offset);" % \ + (cValue,) + print prefix + "if (table_result != 0) return table_result;" else: - return 'UNIMPLEMENTED' - - def genFieldPostprocessing(packed): - for f in packed: - type = cType(f.domain) - if type == 'bit': - for index in range(f.count()): - print " F%d = ((F%dBits band %d) /= 0)," % \ - (f.index + index, - f.index, - 1 << index) - elif type == 'table': - print " F%d = rabbit_binary_parser:parse_table(F%dTab)," % \ - (f.index, f.index) - else: - pass + raise "Illegal domain in genSingleEncode", type def genDecodeMethodFields(m): - packedFields = packMethodFields(m.arguments) - binaryPattern = ', '.join([methodFieldFragment(f) for f in packedFields]) - if binaryPattern: - restSeparator = ', ' - else: - restSeparator = '' - recordConstructorExpr = '#%s{%s}' % (m.structName(), fieldMapList(m.arguments)) - print "decode_method_fields(%s, <<%s>>) ->" % (m.defName(), binaryPattern) - genFieldPostprocessing(packedFields) - print " %s;" % (recordConstructorExpr,) + print " case %s: {" % (m.defName(),) + print " %s *m = (%s *) amqp_pool_alloc(pool, sizeof(%s));" % \ + (m.structName(), m.structName(), m.structName()) + bitindex = None + for f in m.arguments: + if spec.resolveDomain(f.domain) == 'bit': + if bitindex is None: + bitindex = 0 + if bitindex >= 8: + bitindex = 0 + if bitindex == 0: + print " bit_buffer = D_8(encoded, offset);" + print " offset++;" + print " m->%s = (bit_buffer & (1 << %d)) ? 1 : 0;" % \ + (c_ize(f.name), bitindex) + bitindex = bitindex + 1 + else: + bitindex = None + genSingleDecode(" ", "m->%s" % (c_ize(f.name),), f.domain) + print " *decoded = m;" + print " return 0;" + print " }" def genDecodeProperties(c): - print "decode_properties(%d, PropBin) ->" % (c.index) - print " %s = rabbit_binary_parser:parse_properties(%s, PropBin)," % \ - (fieldTempList(c.fields), fieldTypeList(c.fields)) - print " #'P_%s'{%s};" % (c_ize(c.name), fieldMapList(c.fields)) - - def genFieldPreprocessing(packed): - for f in packed: - type = cType(f.domain) - if type == 'bit': - print " F%dBits = (%s)," % \ - (f.index, - ' bor '.join(['(bitvalue(F%d) bsl %d)' % (x.index, x.index - f.index) - for x in f.contents])) - elif type == 'table': - print " F%dTab = rabbit_binary_generator:generate_table(F%d)," % (f.index, f.index) - print " F%dLen = size(F%dTab)," % (f.index, f.index) - elif type in ['shortstr', 'longstr']: - print " F%dLen = size(F%d)," % (f.index, f.index) - else: + print " case %d: {" % (c.index,) + print " %s *p = (%s *) amqp_pool_alloc(pool, sizeof(%s));" % \ + (c.structName(), c.structName(), c.structName()) + print " p->_flags = flags;" + for f in c.fields: + if spec.resolveDomain(f.domain) == 'bit': pass + else: + print " if (flags & %s) {" % (cFlagName(c, f),) + genSingleDecode(" ", "p->%s" % (c_ize(f.name),), f.domain) + print " }" + print " *decoded = p;" + print " return 0;" + print " }" def genEncodeMethodFields(m): - packedFields = packMethodFields(m.arguments) - print "encode_method_fields(#%s{%s}) ->" % (m.structName(), fieldMapList(m.arguments)) - genFieldPreprocessing(packedFields) - print " <<%s>>;" % (', '.join([methodFieldFragment(f) for f in packedFields])) + print " case %s: {" % (m.defName(),) + if m.arguments: + print " %s *m = (%s *) decoded;" % (m.structName(), m.structName()) + bitindex = None + def finishBits(): + if bitindex is not None: + print " E_8(encoded, offset, bit_buffer);" + print " offset++;" + for f in m.arguments: + if spec.resolveDomain(f.domain) == 'bit': + if bitindex is None: + bitindex = 0 + print " bit_buffer = 0;" + if bitindex >= 8: + finishBits() + print " bit_buffer = 0;" + bitindex = 0 + print " if (m->%s) { bit_buffer |= (1 << %d); }" % \ + (c_ize(f.name), bitindex) + bitindex = bitindex + 1 + else: + finishBits() + bitindex = None + genSingleEncode(" ", "m->%s" % (c_ize(f.name),), f.domain) + finishBits() + print " return offset;" + print " }" def genEncodeProperties(c): - print "encode_properties(#'P_%s'{%s}) ->" % (c_ize(c.name), fieldMapList(c.fields)) - print " rabbit_binary_generator:encode_properties(%s, %s);" % \ - (fieldTypeList(c.fields), fieldTempList(c.fields)) + print " case %d: {" % (c.index,) + if c.fields: + print " %s *p = (%s *) decoded;" % (c.structName(), c.structName()) + for f in c.fields: + if spec.resolveDomain(f.domain) == 'bit': + pass + else: + print " if (flags & %s) {" % (cFlagName(c, f),) + genSingleEncode(" ", "p->%s" % (c_ize(f.name),), f.domain) + print " }" + print " return offset;" + print " }" def genLookupException(c,v,cls): # We do this because 0.8 uses "soft error" and 8.1 uses "soft-error". mCls = c_ize(cls).upper() - if mCls == 'SOFT_ERROR': genLookupException1(c,'false') - elif mCls == 'HARD_ERROR': genLookupException1(c, 'true') + if mCls == 'SOFT_ERROR': genLookupException1(c,'AMQP_EXCEPTION_CATEGORY_CHANNEL') + elif mCls == 'HARD_ERROR': genLookupException1(c, 'AMQP_EXCEPTION_CATEGORY_CONNECTION') elif mCls == '': pass else: raise 'Unknown constant class', cls - def genLookupException1(c,hardErrorBoolStr): - n = cConstantName(c) - print 'lookup_amqp_exception(%s) -> {%s, ?%s, <<"%s">>};' % \ - (n.lower(), hardErrorBoolStr, n, n) + def genLookupException1(c, cCategory): + print ' case %s: return %s;' % (cConstantName(c), cCategory) methods = spec.allMethods() print '#include <stdlib.h>' print '#include <string.h>' print '#include <stdio.h>' + print '#include <errno.h>' + print '#include <arpa/inet.h> /* ntohl, htonl, ntohs, htons */' print + print '#include "amqp.h"' print '#include "amqp_framing.h"' + print '#include "amqp_private.h"' print """ -char const *amqp_method_name(uint32_t methodNumber) { +char const *amqp_method_name(amqp_method_number_t methodNumber) { switch (methodNumber) {""" for m in methods: genLookupMethodName(m) print """ default: return NULL; } -} -""" +}""" print """ -amqp_boolean_t amqp_method_has_content(uint32_t methodNumber) { +amqp_boolean_t amqp_method_has_content(amqp_method_number_t methodNumber) { switch (methodNumber) {""" for m in methods: if m.hasContent: print ' case %s: return 1;' % (m.defName()) print """ default: return 0; } -} -""" +}""" + print """ +int amqp_decode_method(amqp_method_number_t methodNumber, + amqp_pool_t *pool, + amqp_bytes_t encoded, + void **decoded) +{ + int offset = 0; + int table_result; + uint8_t bit_buffer; + + switch (methodNumber) {""" for m in methods: genDecodeMethodFields(m) - print "decode_method_fields(Name, BinaryFields) ->" - print " rabbit_misc:frame_error(Name, BinaryFields)." + print """ default: return -ENOENT; + } +}""" + print """ +int amqp_decode_properties(uint16_t class_id, + amqp_pool_t *pool, + amqp_bytes_t encoded, + void **decoded) +{ + int offset = 0; + int table_result; + + amqp_flags_t flags = 0; + int flagword_index = 0; + amqp_flags_t partial_flags; + + do { + partial_flags = D_16(encoded, offset); + offset += 2; + flags |= (partial_flags << (flagword_index * 16)); + } while (partial_flags & 1); + + switch (class_id) {""" for c in spec.allClasses(): genDecodeProperties(c) - print "decode_properties(ClassId, _BinaryFields) -> exit({unknown_class_id, ClassId})." + print """ default: return -ENOENT; + } +}""" + + print """ +int amqp_encode_method(amqp_method_number_t methodNumber, + void *decoded, + amqp_bytes_t encoded) +{ + int offset = 0; + int table_result; + uint8_t bit_buffer; + switch (methodNumber) {""" for m in methods: genEncodeMethodFields(m) - print "encode_method_fields(Record) -> exit({unknown_method_name, element(1, Record)})." + print """ default: return -ENOENT; + } +}""" + + print """ +int amqp_encode_properties(uint16_t class_id, + void *decoded, + amqp_bytes_t encoded) +{ + int offset = 0; + int table_result; + + /* Cheat, and get the flags out generically, relying on the + similarity of structure between classes */ + amqp_flags_t flags = * (amqp_flags_t *) decoded; /* cheating! */ + + while (flags != 0) { + amqp_flags_t remainder = flags >> 16; + uint16_t partial_flags = flags & 0xFFFE; + if (remainder != 0) { partial_flags |= 1; } + E_16(encoded, offset, partial_flags); + offset += 2; + flags = remainder; + } + switch (class_id) {""" for c in spec.allClasses(): genEncodeProperties(c) - print "encode_properties(Record) -> exit({unknown_properties_record, Record})." + print """ default: return -ENOENT; + } +}""" + print """ +int amqp_exception_category(uint16_t code) { + switch (code) {""" for (c,v,cls) in spec.constants: genLookupException(c,v,cls) - print "lookup_amqp_exception(Code) ->" - print " rabbit_log:warning(\"Unknown AMQP error code '~p'~n\", [Code])," - print " {true, ?INTERNAL_ERROR, <<\"INTERNAL_ERROR\">>}." - + print """ default: return 0; + } +}""" def genHrl(spec): def cType(domain): return cTypeMap[spec.resolveDomain(domain)] - def fieldNameList(fields): - return ', '.join([c_ize(f.name) for f in fields]) - def fieldDeclList(fields): return ''.join([" %s %s;\n" % (cType(f.domain), c_ize(f.name)) for f in fields]) @@ -282,20 +375,15 @@ def genHrl(spec): for f in fields if spec.resolveDomain(f.domain) != 'bit']) - def fieldNameListDefaults(fields): - def fillField(field): - result = c_ize(f.name) - if field.defaultvalue != None: - conv_fn = defaultValueTypeConvMap[type(field.defaultvalue)] - result += ' = ' + conv_fn(field.defaultvalue) - return result - return ', '.join([fillField(f) for f in fields]) - methods = spec.allMethods() - print "#ifndef LIBRABBITMQ_AMQP_FRAMING_H" - print "#define LIBRABBITMQ_AMQP_FRAMING_H" - print + print """#ifndef librabbitmq_amqp_framing_h +#define librabbitmq_amqp_framing_h + +#ifdef __cplusplus +extern "C" { +#endif +""" print "#define AMQP_PROTOCOL_VERSION_MAJOR %d" % (spec.major) print "#define AMQP_PROTOCOL_VERSION_MINOR %d" % (spec.minor) print "#define AMQP_PROTOCOL_PORT %d" % (spec.port) @@ -307,7 +395,7 @@ def genHrl(spec): print "/* Method field records. */" for m in methods: methodid = m.klass.index << 16 | m.index - print "#define %s ((uint32_t) 0x%.08X) /* %d, %d; %d */" % \ + print "#define %s ((amqp_method_number_t) 0x%.08X) /* %d, %d; %d */" % \ (m.defName(), methodid, m.klass.index, @@ -324,13 +412,16 @@ def genHrl(spec): shortnum = index / 16 partialindex = 15 - (index % 16) bitindex = shortnum * 16 + partialindex - print '#define %s_FLAG (1 << %d)' % (cConstantName(c.name + '_' + f.name), bitindex) + print '#define %s (1 << %d)' % (cFlagName(c, f), bitindex) index = index + 1 - print "typedef struct {\n uint32_t _flags;\n%s} %s;\n" % \ - (fieldDeclList(c.fields), \ - 'amqp_%s_properties_t' % (c_ize(c.name),)) + print "typedef struct {\n amqp_flags_t _flags;\n%s} %s;\n" % \ + (fieldDeclList(c.fields), c.structName()) + + print """#ifdef __cplusplus +} +#endif - print "#endif" +#endif""" def generateErl(specPath): genErl(AmqpSpec(specPath)) |