summaryrefslogtreecommitdiff
path: root/librabbitmq/codegen.py
diff options
context:
space:
mode:
authorTony Garnock-Jones <tonyg@kcbbs.gen.nz>2009-04-25 19:41:56 +0100
committerTony Garnock-Jones <tonyg@kcbbs.gen.nz>2009-04-25 19:41:56 +0100
commit67970c9c56ebd49b57e61d50255b04fa1ac7d27d (patch)
treeec02d63a0a7b8f31b7a0c0426eeaab25f5884e1f /librabbitmq/codegen.py
parent8bf174bc0a3d682ff9c8c11435008f1adf3c288f (diff)
downloadrabbitmq-c-github-ask-67970c9c56ebd49b57e61d50255b04fa1ac7d27d.tar.gz
Codegen, codec
Diffstat (limited to 'librabbitmq/codegen.py')
-rw-r--r--librabbitmq/codegen.py409
1 files changed, 250 insertions, 159 deletions
diff --git a/librabbitmq/codegen.py b/librabbitmq/codegen.py
index d73db73..00159ae 100644
--- a/librabbitmq/codegen.py
+++ b/librabbitmq/codegen.py
@@ -47,23 +47,6 @@ cTypeMap = {
'timestamp': 'uint64_t',
}
-def convertTable(d):
- if len(d) == 0:
- return "[]"
- else: raise 'Non-empty table defaults not supported', d
-
-def convertBytes(x):
- return "(amqp_bytes_t) { .length = %d, .bytes = \"%s\" }" % (len(x), x)
-
-defaultValueTypeConvMap = {
- bool : lambda x: x and "1" or "0",
- str : convertBytes,
- int : lambda x: str(x),
- float : lambda x: str(x),
- dict: convertTable,
- unicode: lambda x: convertBytes(x.encode("utf-8"))
-}
-
def c_ize(s):
s = s.replace('-', '_')
s = s.replace(' ', '_')
@@ -72,34 +55,18 @@ def c_ize(s):
AmqpMethod.defName = lambda m: cConstantName(c_ize(m.klass.name) + '_' + c_ize(m.name) + "_method")
AmqpMethod.structName = lambda m: "amqp_" + c_ize(m.klass.name) + '_' + c_ize(m.name) + "_t"
+AmqpClass.structName = lambda c: "amqp_" + c_ize(c.name) + "_properties_t"
+
def cConstantName(s):
return 'AMQP_' + '_'.join(re.split('[- ]', s.upper()))
-class PackedMethodBitField:
- def __init__(self, index):
- self.index = index
- self.domain = 'bit'
- self.contents = []
-
- def extend(self, f):
- self.contents.append(f)
+def cFlagName(c, f):
+ return cConstantName(c.name + '_' + f.name) + '_FLAG'
- def count(self):
- return len(self.contents)
-
- def full(self):
- return self.count() == 8
-
def genErl(spec):
def cType(domain):
return cTypeMap[spec.resolveDomain(domain)]
- def fieldTypeList(fields):
- return '[' + ', '.join([cType(f.domain) for f in fields]) + ']'
-
- def fieldNameList(fields):
- return '[' + ', '.join([c_ize(f.name) for f in fields]) + ']'
-
def fieldTempList(fields):
return '[' + ', '.join(['F' + str(f.index) for f in fields]) + ']'
@@ -109,171 +76,297 @@ def genErl(spec):
def genLookupMethodName(m):
print ' case %s: return "%s";' % (m.defName(), m.defName())
- def packMethodFields(fields):
- packed = []
- bitfield = None
- for f in fields:
- if cType(f.domain) == 'bit':
- if not(bitfield) or bitfield.full():
- bitfield = PackedMethodBitField(f.index)
- packed.append(bitfield)
- bitfield.extend(f)
- else:
- bitfield = None
- packed.append(f)
- return packed
+ def genSingleDecode(prefix, cLvalue, unresolved_domain):
+ type = spec.resolveDomain(unresolved_domain)
+ if type == 'shortstr':
+ print prefix + "%s.len = D_8(encoded, offset);" % (cLvalue,)
+ print prefix + "offset++;"
+ print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue)
+ print prefix + "offset += %s.len;" % (cLvalue,)
+ elif type == 'longstr':
+ print prefix + "%s.len = D_32(encoded, offset);" % (cLvalue,)
+ print prefix + "offset += 4;"
+ print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue)
+ print prefix + "offset += %s.len;" % (cLvalue,)
+ elif type == 'octet':
+ print prefix + "%s = D_8(encoded, offset);" % (cLvalue,)
+ print prefix + "offset++;"
+ elif type == 'short':
+ print prefix + "%s = D_16(encoded, offset);" % (cLvalue,)
+ print prefix + "offset += 2;"
+ elif type == 'long':
+ print prefix + "%s = D_32(encoded, offset);" % (cLvalue,)
+ print prefix + "offset += 4;"
+ elif type == 'longlong':
+ print prefix + "%s = D_64(encoded, offset);" % (cLvalue,)
+ print prefix + "offset += 8;"
+ elif type == 'timestamp':
+ print prefix + "%s = D_64(encoded, offset);" % (cLvalue,)
+ print prefix + "offset += 8;"
+ elif type == 'bit':
+ raise "Can't decode bit in genSingleDecode"
+ elif type == 'table':
+ print prefix + "table_result = amqp_decode_table(encoded, pool, &(%s), &offset);" % \
+ (cLvalue,)
+ print prefix + "if (table_result != 0) return table_result;"
+ else:
+ raise "Illegal domain in genSingleDecode", type
- def methodFieldFragment(f):
- type = cType(f.domain)
- p = 'F' + str(f.index)
+ def genSingleEncode(prefix, cValue, unresolved_domain):
+ type = spec.resolveDomain(unresolved_domain)
if type == 'shortstr':
- return p+'Len:8/unsigned, '+p+':'+p+'Len/binary'
+ print prefix + "E_8(encoded, offset, %s.len);" % (cValue,)
+ print prefix + "offset++;"
+ print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue)
+ print prefix + "offset += %s.len;" % (cValue,)
elif type == 'longstr':
- return p+'Len:32/unsigned, '+p+':'+p+'Len/binary'
+ print prefix + "E_32(encoded, offset, %s.len);" % (cValue,)
+ print prefix + "offset += 4;"
+ print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue)
+ print prefix + "offset += %s.len;" % (cValue,)
elif type == 'octet':
- return p+':8/unsigned'
- elif type == 'shortint':
- return p+':16/unsigned'
- elif type == 'longint':
- return p+':32/unsigned'
- elif type == 'longlongint':
- return p+':64/unsigned'
+ print prefix + "E_8(encoded, offset, %s);" % (cValue,)
+ print prefix + "offset++;"
+ elif type == 'short':
+ print prefix + "E_16(encoded, offset, %s);" % (cValue,)
+ print prefix + "offset += 2;"
+ elif type == 'long':
+ print prefix + "E_32(encoded, offset, %s);" % (cValue,)
+ print prefix + "offset += 4;"
+ elif type == 'longlong':
+ print prefix + "E_64(encoded, offset, %s);" % (cValue,)
+ print prefix + "offset += 8;"
elif type == 'timestamp':
- return p+':64/unsigned'
+ print prefix + "E_64(encoded, offset, %s);" % (cValue,)
+ print prefix + "offset += 8;"
elif type == 'bit':
- return p+'Bits:8'
+ raise "Can't encode bit in genSingleDecode"
elif type == 'table':
- return p+'Len:32/unsigned, '+p+'Tab:'+p+'Len/binary'
+ print prefix + "table_result = amqp_encode_table(encoded, &(%s), &offset);" % \
+ (cValue,)
+ print prefix + "if (table_result != 0) return table_result;"
else:
- return 'UNIMPLEMENTED'
-
- def genFieldPostprocessing(packed):
- for f in packed:
- type = cType(f.domain)
- if type == 'bit':
- for index in range(f.count()):
- print " F%d = ((F%dBits band %d) /= 0)," % \
- (f.index + index,
- f.index,
- 1 << index)
- elif type == 'table':
- print " F%d = rabbit_binary_parser:parse_table(F%dTab)," % \
- (f.index, f.index)
- else:
- pass
+ raise "Illegal domain in genSingleEncode", type
def genDecodeMethodFields(m):
- packedFields = packMethodFields(m.arguments)
- binaryPattern = ', '.join([methodFieldFragment(f) for f in packedFields])
- if binaryPattern:
- restSeparator = ', '
- else:
- restSeparator = ''
- recordConstructorExpr = '#%s{%s}' % (m.structName(), fieldMapList(m.arguments))
- print "decode_method_fields(%s, <<%s>>) ->" % (m.defName(), binaryPattern)
- genFieldPostprocessing(packedFields)
- print " %s;" % (recordConstructorExpr,)
+ print " case %s: {" % (m.defName(),)
+ print " %s *m = (%s *) amqp_pool_alloc(pool, sizeof(%s));" % \
+ (m.structName(), m.structName(), m.structName())
+ bitindex = None
+ for f in m.arguments:
+ if spec.resolveDomain(f.domain) == 'bit':
+ if bitindex is None:
+ bitindex = 0
+ if bitindex >= 8:
+ bitindex = 0
+ if bitindex == 0:
+ print " bit_buffer = D_8(encoded, offset);"
+ print " offset++;"
+ print " m->%s = (bit_buffer & (1 << %d)) ? 1 : 0;" % \
+ (c_ize(f.name), bitindex)
+ bitindex = bitindex + 1
+ else:
+ bitindex = None
+ genSingleDecode(" ", "m->%s" % (c_ize(f.name),), f.domain)
+ print " *decoded = m;"
+ print " return 0;"
+ print " }"
def genDecodeProperties(c):
- print "decode_properties(%d, PropBin) ->" % (c.index)
- print " %s = rabbit_binary_parser:parse_properties(%s, PropBin)," % \
- (fieldTempList(c.fields), fieldTypeList(c.fields))
- print " #'P_%s'{%s};" % (c_ize(c.name), fieldMapList(c.fields))
-
- def genFieldPreprocessing(packed):
- for f in packed:
- type = cType(f.domain)
- if type == 'bit':
- print " F%dBits = (%s)," % \
- (f.index,
- ' bor '.join(['(bitvalue(F%d) bsl %d)' % (x.index, x.index - f.index)
- for x in f.contents]))
- elif type == 'table':
- print " F%dTab = rabbit_binary_generator:generate_table(F%d)," % (f.index, f.index)
- print " F%dLen = size(F%dTab)," % (f.index, f.index)
- elif type in ['shortstr', 'longstr']:
- print " F%dLen = size(F%d)," % (f.index, f.index)
- else:
+ print " case %d: {" % (c.index,)
+ print " %s *p = (%s *) amqp_pool_alloc(pool, sizeof(%s));" % \
+ (c.structName(), c.structName(), c.structName())
+ print " p->_flags = flags;"
+ for f in c.fields:
+ if spec.resolveDomain(f.domain) == 'bit':
pass
+ else:
+ print " if (flags & %s) {" % (cFlagName(c, f),)
+ genSingleDecode(" ", "p->%s" % (c_ize(f.name),), f.domain)
+ print " }"
+ print " *decoded = p;"
+ print " return 0;"
+ print " }"
def genEncodeMethodFields(m):
- packedFields = packMethodFields(m.arguments)
- print "encode_method_fields(#%s{%s}) ->" % (m.structName(), fieldMapList(m.arguments))
- genFieldPreprocessing(packedFields)
- print " <<%s>>;" % (', '.join([methodFieldFragment(f) for f in packedFields]))
+ print " case %s: {" % (m.defName(),)
+ if m.arguments:
+ print " %s *m = (%s *) decoded;" % (m.structName(), m.structName())
+ bitindex = None
+ def finishBits():
+ if bitindex is not None:
+ print " E_8(encoded, offset, bit_buffer);"
+ print " offset++;"
+ for f in m.arguments:
+ if spec.resolveDomain(f.domain) == 'bit':
+ if bitindex is None:
+ bitindex = 0
+ print " bit_buffer = 0;"
+ if bitindex >= 8:
+ finishBits()
+ print " bit_buffer = 0;"
+ bitindex = 0
+ print " if (m->%s) { bit_buffer |= (1 << %d); }" % \
+ (c_ize(f.name), bitindex)
+ bitindex = bitindex + 1
+ else:
+ finishBits()
+ bitindex = None
+ genSingleEncode(" ", "m->%s" % (c_ize(f.name),), f.domain)
+ finishBits()
+ print " return offset;"
+ print " }"
def genEncodeProperties(c):
- print "encode_properties(#'P_%s'{%s}) ->" % (c_ize(c.name), fieldMapList(c.fields))
- print " rabbit_binary_generator:encode_properties(%s, %s);" % \
- (fieldTypeList(c.fields), fieldTempList(c.fields))
+ print " case %d: {" % (c.index,)
+ if c.fields:
+ print " %s *p = (%s *) decoded;" % (c.structName(), c.structName())
+ for f in c.fields:
+ if spec.resolveDomain(f.domain) == 'bit':
+ pass
+ else:
+ print " if (flags & %s) {" % (cFlagName(c, f),)
+ genSingleEncode(" ", "p->%s" % (c_ize(f.name),), f.domain)
+ print " }"
+ print " return offset;"
+ print " }"
def genLookupException(c,v,cls):
# We do this because 0.8 uses "soft error" and 8.1 uses "soft-error".
mCls = c_ize(cls).upper()
- if mCls == 'SOFT_ERROR': genLookupException1(c,'false')
- elif mCls == 'HARD_ERROR': genLookupException1(c, 'true')
+ if mCls == 'SOFT_ERROR': genLookupException1(c,'AMQP_EXCEPTION_CATEGORY_CHANNEL')
+ elif mCls == 'HARD_ERROR': genLookupException1(c, 'AMQP_EXCEPTION_CATEGORY_CONNECTION')
elif mCls == '': pass
else: raise 'Unknown constant class', cls
- def genLookupException1(c,hardErrorBoolStr):
- n = cConstantName(c)
- print 'lookup_amqp_exception(%s) -> {%s, ?%s, <<"%s">>};' % \
- (n.lower(), hardErrorBoolStr, n, n)
+ def genLookupException1(c, cCategory):
+ print ' case %s: return %s;' % (cConstantName(c), cCategory)
methods = spec.allMethods()
print '#include <stdlib.h>'
print '#include <string.h>'
print '#include <stdio.h>'
+ print '#include <errno.h>'
+ print '#include <arpa/inet.h> /* ntohl, htonl, ntohs, htons */'
print
+ print '#include "amqp.h"'
print '#include "amqp_framing.h"'
+ print '#include "amqp_private.h"'
print """
-char const *amqp_method_name(uint32_t methodNumber) {
+char const *amqp_method_name(amqp_method_number_t methodNumber) {
switch (methodNumber) {"""
for m in methods: genLookupMethodName(m)
print """ default: return NULL;
}
-}
-"""
+}"""
print """
-amqp_boolean_t amqp_method_has_content(uint32_t methodNumber) {
+amqp_boolean_t amqp_method_has_content(amqp_method_number_t methodNumber) {
switch (methodNumber) {"""
for m in methods:
if m.hasContent:
print ' case %s: return 1;' % (m.defName())
print """ default: return 0;
}
-}
-"""
+}"""
+ print """
+int amqp_decode_method(amqp_method_number_t methodNumber,
+ amqp_pool_t *pool,
+ amqp_bytes_t encoded,
+ void **decoded)
+{
+ int offset = 0;
+ int table_result;
+ uint8_t bit_buffer;
+
+ switch (methodNumber) {"""
for m in methods: genDecodeMethodFields(m)
- print "decode_method_fields(Name, BinaryFields) ->"
- print " rabbit_misc:frame_error(Name, BinaryFields)."
+ print """ default: return -ENOENT;
+ }
+}"""
+ print """
+int amqp_decode_properties(uint16_t class_id,
+ amqp_pool_t *pool,
+ amqp_bytes_t encoded,
+ void **decoded)
+{
+ int offset = 0;
+ int table_result;
+
+ amqp_flags_t flags = 0;
+ int flagword_index = 0;
+ amqp_flags_t partial_flags;
+
+ do {
+ partial_flags = D_16(encoded, offset);
+ offset += 2;
+ flags |= (partial_flags << (flagword_index * 16));
+ } while (partial_flags & 1);
+
+ switch (class_id) {"""
for c in spec.allClasses(): genDecodeProperties(c)
- print "decode_properties(ClassId, _BinaryFields) -> exit({unknown_class_id, ClassId})."
+ print """ default: return -ENOENT;
+ }
+}"""
+
+ print """
+int amqp_encode_method(amqp_method_number_t methodNumber,
+ void *decoded,
+ amqp_bytes_t encoded)
+{
+ int offset = 0;
+ int table_result;
+ uint8_t bit_buffer;
+ switch (methodNumber) {"""
for m in methods: genEncodeMethodFields(m)
- print "encode_method_fields(Record) -> exit({unknown_method_name, element(1, Record)})."
+ print """ default: return -ENOENT;
+ }
+}"""
+
+ print """
+int amqp_encode_properties(uint16_t class_id,
+ void *decoded,
+ amqp_bytes_t encoded)
+{
+ int offset = 0;
+ int table_result;
+
+ /* Cheat, and get the flags out generically, relying on the
+ similarity of structure between classes */
+ amqp_flags_t flags = * (amqp_flags_t *) decoded; /* cheating! */
+
+ while (flags != 0) {
+ amqp_flags_t remainder = flags >> 16;
+ uint16_t partial_flags = flags & 0xFFFE;
+ if (remainder != 0) { partial_flags |= 1; }
+ E_16(encoded, offset, partial_flags);
+ offset += 2;
+ flags = remainder;
+ }
+ switch (class_id) {"""
for c in spec.allClasses(): genEncodeProperties(c)
- print "encode_properties(Record) -> exit({unknown_properties_record, Record})."
+ print """ default: return -ENOENT;
+ }
+}"""
+ print """
+int amqp_exception_category(uint16_t code) {
+ switch (code) {"""
for (c,v,cls) in spec.constants: genLookupException(c,v,cls)
- print "lookup_amqp_exception(Code) ->"
- print " rabbit_log:warning(\"Unknown AMQP error code '~p'~n\", [Code]),"
- print " {true, ?INTERNAL_ERROR, <<\"INTERNAL_ERROR\">>}."
-
+ print """ default: return 0;
+ }
+}"""
def genHrl(spec):
def cType(domain):
return cTypeMap[spec.resolveDomain(domain)]
- def fieldNameList(fields):
- return ', '.join([c_ize(f.name) for f in fields])
-
def fieldDeclList(fields):
return ''.join([" %s %s;\n" % (cType(f.domain), c_ize(f.name)) for f in fields])
@@ -282,20 +375,15 @@ def genHrl(spec):
for f in fields
if spec.resolveDomain(f.domain) != 'bit'])
- def fieldNameListDefaults(fields):
- def fillField(field):
- result = c_ize(f.name)
- if field.defaultvalue != None:
- conv_fn = defaultValueTypeConvMap[type(field.defaultvalue)]
- result += ' = ' + conv_fn(field.defaultvalue)
- return result
- return ', '.join([fillField(f) for f in fields])
-
methods = spec.allMethods()
- print "#ifndef LIBRABBITMQ_AMQP_FRAMING_H"
- print "#define LIBRABBITMQ_AMQP_FRAMING_H"
- print
+ print """#ifndef librabbitmq_amqp_framing_h
+#define librabbitmq_amqp_framing_h
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+"""
print "#define AMQP_PROTOCOL_VERSION_MAJOR %d" % (spec.major)
print "#define AMQP_PROTOCOL_VERSION_MINOR %d" % (spec.minor)
print "#define AMQP_PROTOCOL_PORT %d" % (spec.port)
@@ -307,7 +395,7 @@ def genHrl(spec):
print "/* Method field records. */"
for m in methods:
methodid = m.klass.index << 16 | m.index
- print "#define %s ((uint32_t) 0x%.08X) /* %d, %d; %d */" % \
+ print "#define %s ((amqp_method_number_t) 0x%.08X) /* %d, %d; %d */" % \
(m.defName(),
methodid,
m.klass.index,
@@ -324,13 +412,16 @@ def genHrl(spec):
shortnum = index / 16
partialindex = 15 - (index % 16)
bitindex = shortnum * 16 + partialindex
- print '#define %s_FLAG (1 << %d)' % (cConstantName(c.name + '_' + f.name), bitindex)
+ print '#define %s (1 << %d)' % (cFlagName(c, f), bitindex)
index = index + 1
- print "typedef struct {\n uint32_t _flags;\n%s} %s;\n" % \
- (fieldDeclList(c.fields), \
- 'amqp_%s_properties_t' % (c_ize(c.name),))
+ print "typedef struct {\n amqp_flags_t _flags;\n%s} %s;\n" % \
+ (fieldDeclList(c.fields), c.structName())
+
+ print """#ifdef __cplusplus
+}
+#endif
- print "#endif"
+#endif"""
def generateErl(specPath):
genErl(AmqpSpec(specPath))