diff options
author | David Wragg <david@rabbitmq.com> | 2010-10-21 17:49:04 +0100 |
---|---|---|
committer | David Wragg <david@rabbitmq.com> | 2010-10-21 17:49:04 +0100 |
commit | 2f838304acb599b104a27c5d308bba91ed00b31e (patch) | |
tree | aa95e0ba96da79d94f9bfd951ae81cf62ef8ef93 /librabbitmq/codegen.py | |
parent | e50a94ba2fd471c61509ed3120f42d4506d99ffb (diff) | |
download | rabbitmq-c-github-ask-2f838304acb599b104a27c5d308bba91ed00b31e.tar.gz |
Convert generated code to use the new codec helper functions
Diffstat (limited to 'librabbitmq/codegen.py')
-rw-r--r-- | librabbitmq/codegen.py | 339 |
1 files changed, 188 insertions, 151 deletions
diff --git a/librabbitmq/codegen.py b/librabbitmq/codegen.py index 65b67f3..a496b9f 100644 --- a/librabbitmq/codegen.py +++ b/librabbitmq/codegen.py @@ -52,18 +52,162 @@ from amqp_codegen import * import string import re -cTypeMap = { - 'octet': 'uint8_t', - 'shortstr': 'amqp_bytes_t', - 'longstr': 'amqp_bytes_t', - 'short': 'uint16_t', - 'long': 'uint32_t', - 'longlong': 'uint64_t', - 'bit': 'amqp_boolean_t', - 'table': 'amqp_table_t', - 'timestamp': 'uint64_t', + +class Emitter(object): + """An object the trivially emits generated code lines. + + This largely exists to be wrapped by more sophisticated emitter + classes. + """ + + def __init__(self, prefix): + self.prefix = prefix + + def emit(self, line): + """Emit a line of generated code.""" + print self.prefix + line + + +class BitDecoder(object): + """An emitter object that keeps track of the state involved in + decoding the AMQP bit type.""" + + def __init__(self, emitter): + self.emitter = emitter + self.bit = 0 + + def emit(self, line): + self.bit = 0 + self.emitter.emit(line) + + def decode_bit(self, lvalue): + """Generate code to decode a value of the AMQP bit type into + the given lvalue.""" + if self.bit == 0: + self.emitter.emit("if (!amqp_decode_8(encoded, &offset, &bit_buffer)) return -ERROR_BAD_AMQP_DATA;") + + self.emitter.emit("%s = (bit_buffer & (1 << %d)) ? 1 : 0;" + % (lvalue, self.bit)) + self.bit += 1 + if self.bit == 8: + self.bit = 0 + + +class BitEncoder(object): + """An emitter object that keeps track of the state involved in + encoding the AMQP bit type.""" + + def __init__(self, emitter): + self.emitter = emitter + self.bit = 0 + + def flush(self): + """Flush the state associated with AMQP bit types.""" + if self.bit: + self.emitter.emit("if (!amqp_encode_8(encoded, &offset, bit_buffer)) return -ERROR_BAD_AMQP_DATA;") + self.bit = 0 + + def emit(self, line): + self.flush() + self.emitter.emit(line) + + def encode_bit(self, value): + """Generate code to ebcode a value of the AMQP bit type from + the given value.""" + if self.bit == 0: + self.emitter.emit("bit_buffer = 0;") + + self.emitter.emit("if (%s) bit_buffer |= (1 << %d);" + % (value, self.bit)) + self.bit += 1 + if self.bit == 8: + self.flush() + + +class SimpleType(object): + """A AMQP type that corresponds to a simple scalar C value of a + certain width.""" + + def __init__(self, bits): + self.bits = bits + self.ctype = "uint%d_t" % (bits,) + + def decode(self, emitter, lvalue): + emitter.emit("if (!amqp_decode_%d(encoded, &offset, &%s)) return -ERROR_BAD_AMQP_DATA;" % (self.bits, lvalue)) + + def encode(self, emitter, value): + emitter.emit("if (!amqp_encode_%d(encoded, &offset, %s)) return -ERROR_BAD_AMQP_DATA;" % (self.bits, value)) + + +class StrType(object): + """The AMQP shortstr or longstr types.""" + + def __init__(self, lenbits): + self.lenbits = lenbits + self.ctype = "amqp_bytes_t" + + def decode(self, emitter, lvalue): + emitter.emit("{") + emitter.emit(" uint%d_t len;" % (self.lenbits,)) + emitter.emit(" if (!amqp_decode_%d(encoded, &offset, &len)" % (self.lenbits,)) + emitter.emit(" || !amqp_decode_bytes(encoded, &offset, &%s, len))" % (lvalue,)) + emitter.emit(" return -ERROR_BAD_AMQP_DATA;") + emitter.emit("}") + + def encode(self, emitter, value): + emitter.emit("if (!amqp_encode_%d(encoded, &offset, %s.len)" % (self.lenbits, value)) + emitter.emit(" || !amqp_encode_bytes(encoded, &offset, %s))" % (value,)) + emitter.emit(" return -ERROR_BAD_AMQP_DATA;") + + +class BitType(object): + """The AMQP bit type.""" + + def __init__(self): + self.ctype = "amqp_boolean_t" + + def decode(self, emitter, lvalue): + emitter.decode_bit(lvalue) + + def encode(self, emitter, value): + emitter.encode_bit(value) + + +class TableType(object): + """The AMQP table type.""" + + def __init__(self): + self.ctype = "amqp_table_t" + + def decode(self, emitter, lvalue): + emitter.emit("{") + emitter.emit(" int res = amqp_decode_table(encoded, pool, &(%s), &offset);" % (lvalue,)) + emitter.emit(" if (res < 0) return res;") + emitter.emit("}") + + def encode(self, emitter, value): + emitter.emit("{") + emitter.emit(" int res = amqp_encode_table(encoded, &(%s), &offset);" % (value,)) + emitter.emit(" if (res < 0) return res;") + emitter.emit("}") + + +types = { + 'octet': SimpleType(8), + 'short': SimpleType(16), + 'long': SimpleType(32), + 'longlong': SimpleType(64), + 'shortstr': StrType(8), + 'longstr': StrType(32), + 'bit': BitType(), + 'table': TableType(), + 'timestamp': SimpleType(64), } +def typeFor(spec, f): + """Get a representation of the AMQP type of a field.""" + return types[spec.resolveDomain(f.domain)] + def c_ize(s): s = s.replace('-', '_') s = s.replace(' ', '_') @@ -81,9 +225,6 @@ def cFlagName(c, f): return cConstantName(c.name + '_' + f.name) + '_FLAG' def genErl(spec): - def cType(domain): - return cTypeMap[spec.resolveDomain(domain)] - def fieldTempList(fields): return '[' + ', '.join(['F' + str(f.index) for f in fields]) + ']' @@ -93,78 +234,6 @@ def genErl(spec): def genLookupMethodName(m): print ' case %s: return "%s";' % (m.defName(), m.defName()) - def genSingleDecode(prefix, cLvalue, unresolved_domain): - type = spec.resolveDomain(unresolved_domain) - if type == 'shortstr': - print prefix + "%s.len = D_8(encoded, offset);" % (cLvalue,) - print prefix + "offset++;" - print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue) - print prefix + "offset += %s.len;" % (cLvalue,) - elif type == 'longstr': - print prefix + "%s.len = D_32(encoded, offset);" % (cLvalue,) - print prefix + "offset += 4;" - print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue) - print prefix + "offset += %s.len;" % (cLvalue,) - elif type == 'octet': - print prefix + "%s = D_8(encoded, offset);" % (cLvalue,) - print prefix + "offset++;" - elif type == 'short': - print prefix + "%s = D_16(encoded, offset);" % (cLvalue,) - print prefix + "offset += 2;" - elif type == 'long': - print prefix + "%s = D_32(encoded, offset);" % (cLvalue,) - print prefix + "offset += 4;" - elif type == 'longlong': - print prefix + "%s = D_64(encoded, offset);" % (cLvalue,) - print prefix + "offset += 8;" - elif type == 'timestamp': - print prefix + "%s = D_64(encoded, offset);" % (cLvalue,) - print prefix + "offset += 8;" - elif type == 'bit': - raise "Can't decode bit in genSingleDecode" - elif type == 'table': - print prefix + "table_result = amqp_decode_table(encoded, pool, &(%s), &offset);" % \ - (cLvalue,) - print prefix + "AMQP_CHECK_RESULT(table_result);" - else: - raise "Illegal domain in genSingleDecode", type - - def genSingleEncode(prefix, cValue, unresolved_domain): - type = spec.resolveDomain(unresolved_domain) - if type == 'shortstr': - print prefix + "E_8(encoded, offset, %s.len);" % (cValue,) - print prefix + "offset++;" - print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue) - print prefix + "offset += %s.len;" % (cValue,) - elif type == 'longstr': - print prefix + "E_32(encoded, offset, %s.len);" % (cValue,) - print prefix + "offset += 4;" - print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue) - print prefix + "offset += %s.len;" % (cValue,) - elif type == 'octet': - print prefix + "E_8(encoded, offset, %s);" % (cValue,) - print prefix + "offset++;" - elif type == 'short': - print prefix + "E_16(encoded, offset, %s);" % (cValue,) - print prefix + "offset += 2;" - elif type == 'long': - print prefix + "E_32(encoded, offset, %s);" % (cValue,) - print prefix + "offset += 4;" - elif type == 'longlong': - print prefix + "E_64(encoded, offset, %s);" % (cValue,) - print prefix + "offset += 8;" - elif type == 'timestamp': - print prefix + "E_64(encoded, offset, %s);" % (cValue,) - print prefix + "offset += 8;" - elif type == 'bit': - raise "Can't encode bit in genSingleDecode" - elif type == 'table': - print prefix + "table_result = amqp_encode_table(encoded, &(%s), &offset);" % \ - (cValue,) - print prefix + "if (table_result < 0) return table_result;" - else: - raise "Illegal domain in genSingleEncode", type - def genDecodeMethodFields(m): print " case %s: {" % (m.defName(),) if m.arguments: @@ -173,22 +242,11 @@ def genErl(spec): print " if (m == NULL) { return -ERROR_NO_MEMORY; }" else: print " %s *m = NULL; /* no fields */" % (m.structName(),) - bitindex = None + + emitter = BitDecoder(Emitter(" ")) for f in m.arguments: - if spec.resolveDomain(f.domain) == 'bit': - if bitindex is None: - bitindex = 0 - if bitindex >= 8: - bitindex = 0 - if bitindex == 0: - print " bit_buffer = D_8(encoded, offset);" - print " offset++;" - print " m->%s = (bit_buffer & (1 << %d)) ? 1 : 0;" % \ - (c_ize(f.name), bitindex) - bitindex = bitindex + 1 - else: - bitindex = None - genSingleDecode(" ", "m->%s" % (c_ize(f.name),), f.domain) + typeFor(spec, f).decode(emitter, "m->"+c_ize(f.name)) + print " *decoded = m;" print " return 0;" print " }" @@ -199,13 +257,13 @@ def genErl(spec): (c.structName(), c.structName(), c.structName()) print " if (p == NULL) { return -ERROR_NO_MEMORY; }" print " p->_flags = flags;" + + emitter = Emitter(" ") for f in c.fields: - if spec.resolveDomain(f.domain) == 'bit': - pass - else: - print " if (flags & %s) {" % (cFlagName(c, f),) - genSingleDecode(" ", "p->%s" % (c_ize(f.name),), f.domain) - print " }" + emitter.emit("if (flags & %s) {" % (cFlagName(c, f),)) + typeFor(spec, f).decode(emitter, "p->"+c_ize(f.name)) + emitter.emit("}") + print " *decoded = p;" print " return 0;" print " }" @@ -214,28 +272,12 @@ def genErl(spec): print " case %s: {" % (m.defName(),) if m.arguments: print " %s *m = (%s *) decoded;" % (m.structName(), m.structName()) - bitindex = None - def finishBits(): - if bitindex is not None: - print " E_8(encoded, offset, bit_buffer);" - print " offset++;" + + emitter = BitEncoder(Emitter(" ")) for f in m.arguments: - if spec.resolveDomain(f.domain) == 'bit': - if bitindex is None: - bitindex = 0 - print " bit_buffer = 0;" - if bitindex >= 8: - finishBits() - print " bit_buffer = 0;" - bitindex = 0 - print " if (m->%s) { bit_buffer |= (1 << %d); }" % \ - (c_ize(f.name), bitindex) - bitindex = bitindex + 1 - else: - finishBits() - bitindex = None - genSingleEncode(" ", "m->%s" % (c_ize(f.name),), f.domain) - finishBits() + typeFor(spec, f).encode(emitter, "m->"+c_ize(f.name)) + emitter.flush() + print " return offset;" print " }" @@ -243,13 +285,13 @@ def genErl(spec): print " case %d: {" % (c.index,) if c.fields: print " %s *p = (%s *) decoded;" % (c.structName(), c.structName()) + + emitter = Emitter(" ") for f in c.fields: - if spec.resolveDomain(f.domain) == 'bit': - pass - else: - print " if (flags & %s) {" % (cFlagName(c, f),) - genSingleEncode(" ", "p->%s" % (c_ize(f.name),), f.domain) - print " }" + emitter.emit(" if (flags & %s) {" % (cFlagName(c, f),)) + typeFor(spec, f).encode(emitter, "p->"+c_ize(f.name)) + emitter.emit("}") + print " return offset;" print " }" @@ -310,8 +352,7 @@ int amqp_decode_method(amqp_method_number_t methodNumber, amqp_bytes_t encoded, void **decoded) { - int offset = 0; - int table_result; + size_t offset = 0; uint8_t bit_buffer; switch (methodNumber) {""" @@ -326,16 +367,15 @@ int amqp_decode_properties(uint16_t class_id, amqp_bytes_t encoded, void **decoded) { - int offset = 0; - int table_result; + size_t offset = 0; amqp_flags_t flags = 0; int flagword_index = 0; - amqp_flags_t partial_flags; + uint16_t partial_flags; do { - partial_flags = D_16(encoded, offset); - offset += 2; + if (!amqp_decode_16(encoded, &offset, &partial_flags)) + return -ERROR_BAD_AMQP_DATA; flags |= (partial_flags << (flagword_index * 16)); flagword_index++; } while (partial_flags & 1); @@ -351,8 +391,7 @@ int amqp_encode_method(amqp_method_number_t methodNumber, void *decoded, amqp_bytes_t encoded) { - int offset = 0; - int table_result; + size_t offset = 0; uint8_t bit_buffer; switch (methodNumber) {""" @@ -366,8 +405,7 @@ int amqp_encode_properties(uint16_t class_id, void *decoded, amqp_bytes_t encoded) { - int offset = 0; - int table_result; + size_t offset = 0; /* Cheat, and get the flags out generically, relying on the similarity of structure between classes */ @@ -381,8 +419,8 @@ int amqp_encode_properties(uint16_t class_id, amqp_flags_t remainder = remaining_flags >> 16; uint16_t partial_flags = remaining_flags & 0xFFFE; if (remainder != 0) { partial_flags |= 1; } - E_16(encoded, offset, partial_flags); - offset += 2; + if (!amqp_encode_16(encoded, &offset, partial_flags)) + return -ERROR_BAD_AMQP_DATA; remaining_flags = remainder; } while (remaining_flags != 0); } @@ -394,17 +432,16 @@ int amqp_encode_properties(uint16_t class_id, }""" def genHrl(spec): - def cType(domain): - return cTypeMap[spec.resolveDomain(domain)] - def fieldDeclList(fields): if fields: - return ''.join([" %s %s;\n" % (cType(f.domain), c_ize(f.name)) for f in fields]) + return ''.join([" %s %s;\n" % (typeFor(spec, f).ctype, + c_ize(f.name)) + for f in fields]) else: return " char dummy; /* Dummy field to avoid empty struct */\n" def propDeclList(fields): - return ''.join([" %s %s;\n" % (cType(f.domain), c_ize(f.name)) + return ''.join([" %s %s;\n" % (typeFor(spec, f).ctype, c_ize(f.name)) for f in fields if spec.resolveDomain(f.domain) != 'bit']) |