summaryrefslogtreecommitdiff
path: root/librabbitmq/codegen.py
diff options
context:
space:
mode:
authorDavid Wragg <david@rabbitmq.com>2010-10-21 17:49:04 +0100
committerDavid Wragg <david@rabbitmq.com>2010-10-21 17:49:04 +0100
commit2f838304acb599b104a27c5d308bba91ed00b31e (patch)
treeaa95e0ba96da79d94f9bfd951ae81cf62ef8ef93 /librabbitmq/codegen.py
parente50a94ba2fd471c61509ed3120f42d4506d99ffb (diff)
downloadrabbitmq-c-2f838304acb599b104a27c5d308bba91ed00b31e.tar.gz
Convert generated code to use the new codec helper functions
Diffstat (limited to 'librabbitmq/codegen.py')
-rw-r--r--librabbitmq/codegen.py339
1 files changed, 188 insertions, 151 deletions
diff --git a/librabbitmq/codegen.py b/librabbitmq/codegen.py
index 65b67f3..a496b9f 100644
--- a/librabbitmq/codegen.py
+++ b/librabbitmq/codegen.py
@@ -52,18 +52,162 @@ from amqp_codegen import *
import string
import re
-cTypeMap = {
- 'octet': 'uint8_t',
- 'shortstr': 'amqp_bytes_t',
- 'longstr': 'amqp_bytes_t',
- 'short': 'uint16_t',
- 'long': 'uint32_t',
- 'longlong': 'uint64_t',
- 'bit': 'amqp_boolean_t',
- 'table': 'amqp_table_t',
- 'timestamp': 'uint64_t',
+
+class Emitter(object):
+ """An object the trivially emits generated code lines.
+
+ This largely exists to be wrapped by more sophisticated emitter
+ classes.
+ """
+
+ def __init__(self, prefix):
+ self.prefix = prefix
+
+ def emit(self, line):
+ """Emit a line of generated code."""
+ print self.prefix + line
+
+
+class BitDecoder(object):
+ """An emitter object that keeps track of the state involved in
+ decoding the AMQP bit type."""
+
+ def __init__(self, emitter):
+ self.emitter = emitter
+ self.bit = 0
+
+ def emit(self, line):
+ self.bit = 0
+ self.emitter.emit(line)
+
+ def decode_bit(self, lvalue):
+ """Generate code to decode a value of the AMQP bit type into
+ the given lvalue."""
+ if self.bit == 0:
+ self.emitter.emit("if (!amqp_decode_8(encoded, &offset, &bit_buffer)) return -ERROR_BAD_AMQP_DATA;")
+
+ self.emitter.emit("%s = (bit_buffer & (1 << %d)) ? 1 : 0;"
+ % (lvalue, self.bit))
+ self.bit += 1
+ if self.bit == 8:
+ self.bit = 0
+
+
+class BitEncoder(object):
+ """An emitter object that keeps track of the state involved in
+ encoding the AMQP bit type."""
+
+ def __init__(self, emitter):
+ self.emitter = emitter
+ self.bit = 0
+
+ def flush(self):
+ """Flush the state associated with AMQP bit types."""
+ if self.bit:
+ self.emitter.emit("if (!amqp_encode_8(encoded, &offset, bit_buffer)) return -ERROR_BAD_AMQP_DATA;")
+ self.bit = 0
+
+ def emit(self, line):
+ self.flush()
+ self.emitter.emit(line)
+
+ def encode_bit(self, value):
+ """Generate code to ebcode a value of the AMQP bit type from
+ the given value."""
+ if self.bit == 0:
+ self.emitter.emit("bit_buffer = 0;")
+
+ self.emitter.emit("if (%s) bit_buffer |= (1 << %d);"
+ % (value, self.bit))
+ self.bit += 1
+ if self.bit == 8:
+ self.flush()
+
+
+class SimpleType(object):
+ """A AMQP type that corresponds to a simple scalar C value of a
+ certain width."""
+
+ def __init__(self, bits):
+ self.bits = bits
+ self.ctype = "uint%d_t" % (bits,)
+
+ def decode(self, emitter, lvalue):
+ emitter.emit("if (!amqp_decode_%d(encoded, &offset, &%s)) return -ERROR_BAD_AMQP_DATA;" % (self.bits, lvalue))
+
+ def encode(self, emitter, value):
+ emitter.emit("if (!amqp_encode_%d(encoded, &offset, %s)) return -ERROR_BAD_AMQP_DATA;" % (self.bits, value))
+
+
+class StrType(object):
+ """The AMQP shortstr or longstr types."""
+
+ def __init__(self, lenbits):
+ self.lenbits = lenbits
+ self.ctype = "amqp_bytes_t"
+
+ def decode(self, emitter, lvalue):
+ emitter.emit("{")
+ emitter.emit(" uint%d_t len;" % (self.lenbits,))
+ emitter.emit(" if (!amqp_decode_%d(encoded, &offset, &len)" % (self.lenbits,))
+ emitter.emit(" || !amqp_decode_bytes(encoded, &offset, &%s, len))" % (lvalue,))
+ emitter.emit(" return -ERROR_BAD_AMQP_DATA;")
+ emitter.emit("}")
+
+ def encode(self, emitter, value):
+ emitter.emit("if (!amqp_encode_%d(encoded, &offset, %s.len)" % (self.lenbits, value))
+ emitter.emit(" || !amqp_encode_bytes(encoded, &offset, %s))" % (value,))
+ emitter.emit(" return -ERROR_BAD_AMQP_DATA;")
+
+
+class BitType(object):
+ """The AMQP bit type."""
+
+ def __init__(self):
+ self.ctype = "amqp_boolean_t"
+
+ def decode(self, emitter, lvalue):
+ emitter.decode_bit(lvalue)
+
+ def encode(self, emitter, value):
+ emitter.encode_bit(value)
+
+
+class TableType(object):
+ """The AMQP table type."""
+
+ def __init__(self):
+ self.ctype = "amqp_table_t"
+
+ def decode(self, emitter, lvalue):
+ emitter.emit("{")
+ emitter.emit(" int res = amqp_decode_table(encoded, pool, &(%s), &offset);" % (lvalue,))
+ emitter.emit(" if (res < 0) return res;")
+ emitter.emit("}")
+
+ def encode(self, emitter, value):
+ emitter.emit("{")
+ emitter.emit(" int res = amqp_encode_table(encoded, &(%s), &offset);" % (value,))
+ emitter.emit(" if (res < 0) return res;")
+ emitter.emit("}")
+
+
+types = {
+ 'octet': SimpleType(8),
+ 'short': SimpleType(16),
+ 'long': SimpleType(32),
+ 'longlong': SimpleType(64),
+ 'shortstr': StrType(8),
+ 'longstr': StrType(32),
+ 'bit': BitType(),
+ 'table': TableType(),
+ 'timestamp': SimpleType(64),
}
+def typeFor(spec, f):
+ """Get a representation of the AMQP type of a field."""
+ return types[spec.resolveDomain(f.domain)]
+
def c_ize(s):
s = s.replace('-', '_')
s = s.replace(' ', '_')
@@ -81,9 +225,6 @@ def cFlagName(c, f):
return cConstantName(c.name + '_' + f.name) + '_FLAG'
def genErl(spec):
- def cType(domain):
- return cTypeMap[spec.resolveDomain(domain)]
-
def fieldTempList(fields):
return '[' + ', '.join(['F' + str(f.index) for f in fields]) + ']'
@@ -93,78 +234,6 @@ def genErl(spec):
def genLookupMethodName(m):
print ' case %s: return "%s";' % (m.defName(), m.defName())
- def genSingleDecode(prefix, cLvalue, unresolved_domain):
- type = spec.resolveDomain(unresolved_domain)
- if type == 'shortstr':
- print prefix + "%s.len = D_8(encoded, offset);" % (cLvalue,)
- print prefix + "offset++;"
- print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue)
- print prefix + "offset += %s.len;" % (cLvalue,)
- elif type == 'longstr':
- print prefix + "%s.len = D_32(encoded, offset);" % (cLvalue,)
- print prefix + "offset += 4;"
- print prefix + "%s.bytes = D_BYTES(encoded, offset, %s.len);" % (cLvalue, cLvalue)
- print prefix + "offset += %s.len;" % (cLvalue,)
- elif type == 'octet':
- print prefix + "%s = D_8(encoded, offset);" % (cLvalue,)
- print prefix + "offset++;"
- elif type == 'short':
- print prefix + "%s = D_16(encoded, offset);" % (cLvalue,)
- print prefix + "offset += 2;"
- elif type == 'long':
- print prefix + "%s = D_32(encoded, offset);" % (cLvalue,)
- print prefix + "offset += 4;"
- elif type == 'longlong':
- print prefix + "%s = D_64(encoded, offset);" % (cLvalue,)
- print prefix + "offset += 8;"
- elif type == 'timestamp':
- print prefix + "%s = D_64(encoded, offset);" % (cLvalue,)
- print prefix + "offset += 8;"
- elif type == 'bit':
- raise "Can't decode bit in genSingleDecode"
- elif type == 'table':
- print prefix + "table_result = amqp_decode_table(encoded, pool, &(%s), &offset);" % \
- (cLvalue,)
- print prefix + "AMQP_CHECK_RESULT(table_result);"
- else:
- raise "Illegal domain in genSingleDecode", type
-
- def genSingleEncode(prefix, cValue, unresolved_domain):
- type = spec.resolveDomain(unresolved_domain)
- if type == 'shortstr':
- print prefix + "E_8(encoded, offset, %s.len);" % (cValue,)
- print prefix + "offset++;"
- print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue)
- print prefix + "offset += %s.len;" % (cValue,)
- elif type == 'longstr':
- print prefix + "E_32(encoded, offset, %s.len);" % (cValue,)
- print prefix + "offset += 4;"
- print prefix + "E_BYTES(encoded, offset, %s.len, %s.bytes);" % (cValue, cValue)
- print prefix + "offset += %s.len;" % (cValue,)
- elif type == 'octet':
- print prefix + "E_8(encoded, offset, %s);" % (cValue,)
- print prefix + "offset++;"
- elif type == 'short':
- print prefix + "E_16(encoded, offset, %s);" % (cValue,)
- print prefix + "offset += 2;"
- elif type == 'long':
- print prefix + "E_32(encoded, offset, %s);" % (cValue,)
- print prefix + "offset += 4;"
- elif type == 'longlong':
- print prefix + "E_64(encoded, offset, %s);" % (cValue,)
- print prefix + "offset += 8;"
- elif type == 'timestamp':
- print prefix + "E_64(encoded, offset, %s);" % (cValue,)
- print prefix + "offset += 8;"
- elif type == 'bit':
- raise "Can't encode bit in genSingleDecode"
- elif type == 'table':
- print prefix + "table_result = amqp_encode_table(encoded, &(%s), &offset);" % \
- (cValue,)
- print prefix + "if (table_result < 0) return table_result;"
- else:
- raise "Illegal domain in genSingleEncode", type
-
def genDecodeMethodFields(m):
print " case %s: {" % (m.defName(),)
if m.arguments:
@@ -173,22 +242,11 @@ def genErl(spec):
print " if (m == NULL) { return -ERROR_NO_MEMORY; }"
else:
print " %s *m = NULL; /* no fields */" % (m.structName(),)
- bitindex = None
+
+ emitter = BitDecoder(Emitter(" "))
for f in m.arguments:
- if spec.resolveDomain(f.domain) == 'bit':
- if bitindex is None:
- bitindex = 0
- if bitindex >= 8:
- bitindex = 0
- if bitindex == 0:
- print " bit_buffer = D_8(encoded, offset);"
- print " offset++;"
- print " m->%s = (bit_buffer & (1 << %d)) ? 1 : 0;" % \
- (c_ize(f.name), bitindex)
- bitindex = bitindex + 1
- else:
- bitindex = None
- genSingleDecode(" ", "m->%s" % (c_ize(f.name),), f.domain)
+ typeFor(spec, f).decode(emitter, "m->"+c_ize(f.name))
+
print " *decoded = m;"
print " return 0;"
print " }"
@@ -199,13 +257,13 @@ def genErl(spec):
(c.structName(), c.structName(), c.structName())
print " if (p == NULL) { return -ERROR_NO_MEMORY; }"
print " p->_flags = flags;"
+
+ emitter = Emitter(" ")
for f in c.fields:
- if spec.resolveDomain(f.domain) == 'bit':
- pass
- else:
- print " if (flags & %s) {" % (cFlagName(c, f),)
- genSingleDecode(" ", "p->%s" % (c_ize(f.name),), f.domain)
- print " }"
+ emitter.emit("if (flags & %s) {" % (cFlagName(c, f),))
+ typeFor(spec, f).decode(emitter, "p->"+c_ize(f.name))
+ emitter.emit("}")
+
print " *decoded = p;"
print " return 0;"
print " }"
@@ -214,28 +272,12 @@ def genErl(spec):
print " case %s: {" % (m.defName(),)
if m.arguments:
print " %s *m = (%s *) decoded;" % (m.structName(), m.structName())
- bitindex = None
- def finishBits():
- if bitindex is not None:
- print " E_8(encoded, offset, bit_buffer);"
- print " offset++;"
+
+ emitter = BitEncoder(Emitter(" "))
for f in m.arguments:
- if spec.resolveDomain(f.domain) == 'bit':
- if bitindex is None:
- bitindex = 0
- print " bit_buffer = 0;"
- if bitindex >= 8:
- finishBits()
- print " bit_buffer = 0;"
- bitindex = 0
- print " if (m->%s) { bit_buffer |= (1 << %d); }" % \
- (c_ize(f.name), bitindex)
- bitindex = bitindex + 1
- else:
- finishBits()
- bitindex = None
- genSingleEncode(" ", "m->%s" % (c_ize(f.name),), f.domain)
- finishBits()
+ typeFor(spec, f).encode(emitter, "m->"+c_ize(f.name))
+ emitter.flush()
+
print " return offset;"
print " }"
@@ -243,13 +285,13 @@ def genErl(spec):
print " case %d: {" % (c.index,)
if c.fields:
print " %s *p = (%s *) decoded;" % (c.structName(), c.structName())
+
+ emitter = Emitter(" ")
for f in c.fields:
- if spec.resolveDomain(f.domain) == 'bit':
- pass
- else:
- print " if (flags & %s) {" % (cFlagName(c, f),)
- genSingleEncode(" ", "p->%s" % (c_ize(f.name),), f.domain)
- print " }"
+ emitter.emit(" if (flags & %s) {" % (cFlagName(c, f),))
+ typeFor(spec, f).encode(emitter, "p->"+c_ize(f.name))
+ emitter.emit("}")
+
print " return offset;"
print " }"
@@ -310,8 +352,7 @@ int amqp_decode_method(amqp_method_number_t methodNumber,
amqp_bytes_t encoded,
void **decoded)
{
- int offset = 0;
- int table_result;
+ size_t offset = 0;
uint8_t bit_buffer;
switch (methodNumber) {"""
@@ -326,16 +367,15 @@ int amqp_decode_properties(uint16_t class_id,
amqp_bytes_t encoded,
void **decoded)
{
- int offset = 0;
- int table_result;
+ size_t offset = 0;
amqp_flags_t flags = 0;
int flagword_index = 0;
- amqp_flags_t partial_flags;
+ uint16_t partial_flags;
do {
- partial_flags = D_16(encoded, offset);
- offset += 2;
+ if (!amqp_decode_16(encoded, &offset, &partial_flags))
+ return -ERROR_BAD_AMQP_DATA;
flags |= (partial_flags << (flagword_index * 16));
flagword_index++;
} while (partial_flags & 1);
@@ -351,8 +391,7 @@ int amqp_encode_method(amqp_method_number_t methodNumber,
void *decoded,
amqp_bytes_t encoded)
{
- int offset = 0;
- int table_result;
+ size_t offset = 0;
uint8_t bit_buffer;
switch (methodNumber) {"""
@@ -366,8 +405,7 @@ int amqp_encode_properties(uint16_t class_id,
void *decoded,
amqp_bytes_t encoded)
{
- int offset = 0;
- int table_result;
+ size_t offset = 0;
/* Cheat, and get the flags out generically, relying on the
similarity of structure between classes */
@@ -381,8 +419,8 @@ int amqp_encode_properties(uint16_t class_id,
amqp_flags_t remainder = remaining_flags >> 16;
uint16_t partial_flags = remaining_flags & 0xFFFE;
if (remainder != 0) { partial_flags |= 1; }
- E_16(encoded, offset, partial_flags);
- offset += 2;
+ if (!amqp_encode_16(encoded, &offset, partial_flags))
+ return -ERROR_BAD_AMQP_DATA;
remaining_flags = remainder;
} while (remaining_flags != 0);
}
@@ -394,17 +432,16 @@ int amqp_encode_properties(uint16_t class_id,
}"""
def genHrl(spec):
- def cType(domain):
- return cTypeMap[spec.resolveDomain(domain)]
-
def fieldDeclList(fields):
if fields:
- return ''.join([" %s %s;\n" % (cType(f.domain), c_ize(f.name)) for f in fields])
+ return ''.join([" %s %s;\n" % (typeFor(spec, f).ctype,
+ c_ize(f.name))
+ for f in fields])
else:
return " char dummy; /* Dummy field to avoid empty struct */\n"
def propDeclList(fields):
- return ''.join([" %s %s;\n" % (cType(f.domain), c_ize(f.name))
+ return ''.join([" %s %s;\n" % (typeFor(spec, f).ctype, c_ize(f.name))
for f in fields
if spec.resolveDomain(f.domain) != 'bit'])