diff options
Diffstat (limited to 'kafka')
-rw-r--r-- | kafka/protocol/api.py | 32 | ||||
-rw-r--r-- | kafka/protocol/struct.py | 6 |
2 files changed, 37 insertions, 1 deletions
diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index efaf63e..64276fc 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -3,7 +3,7 @@ from __future__ import absolute_import import abc from kafka.protocol.struct import Struct -from kafka.protocol.types import Int16, Int32, String, Schema +from kafka.protocol.types import Int16, Int32, String, Schema, Array class RequestHeader(Struct): @@ -47,6 +47,9 @@ class Request(Struct): """Override this method if an api request does not always generate a response""" return True + def to_object(self): + return _to_object(self.SCHEMA, self) + class Response(Struct): __metaclass__ = abc.ABCMeta @@ -65,3 +68,30 @@ class Response(Struct): def SCHEMA(self): """An instance of Schema() representing the response structure""" pass + + def to_object(self): + return _to_object(self.SCHEMA, self) + + +def _to_object(schema, data): + obj = {} + for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)): + if isinstance(data, Struct): + val = data.get_item(name) + else: + val = data[idx] + + if isinstance(_type, Schema): + obj[name] = _to_object(_type, val) + elif isinstance(_type, Array): + if isinstance(_type.array_of, (Array, Schema)): + obj[name] = [ + _to_object(_type.array_of, x) + for x in val + ] + else: + obj[name] = val + else: + obj[name] = val + + return obj diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index 693e2a2..e9da6e6 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -30,6 +30,7 @@ class Struct(AbstractType): # causes instances to "leak" to garbage self.encode = WeakMethod(self._encode_self) + @classmethod def encode(cls, item): # pylint: disable=E0202 bits = [] @@ -48,6 +49,11 @@ class Struct(AbstractType): data = BytesIO(data) return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) + def get_item(self, name): + if name not in self.SCHEMA.names: + raise KeyError("%s is not in the schema" % name) + return self.__dict__[name] + def __repr__(self): key_vals = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): |