summaryrefslogtreecommitdiff
path: root/tests/test_emitter.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_emitter.py')
-rw-r--r--tests/test_emitter.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/tests/test_emitter.py b/tests/test_emitter.py
index fed6953..698ea50 100644
--- a/tests/test_emitter.py
+++ b/tests/test_emitter.py
@@ -2,6 +2,7 @@
import test_appliance, sys, StringIO
from yaml import *
+import yaml
class TestEmitterOnCanonical(test_appliance.TestAppliance):
@@ -15,6 +16,58 @@ class TestEmitterOnCanonical(test_appliance.TestAppliance):
#print file(canonical_filename, 'rb').read()
for event in events:
emitter.emit(event)
+ data = writer.getvalue()
+ new_events = list(parse(data))
+ self.failUnlessEqual(len(events), len(new_events))
+ for event, new_event in zip(events, new_events):
+ self.failUnlessEqual(event.__class__, new_event.__class__)
TestEmitterOnCanonical.add_tests('testEmitterOnCanonical', '.canonical')
+class EventsConstructor(Constructor):
+
+ def construct_event(self, node):
+ if isinstance(node, ScalarNode):
+ mapping = {}
+ else:
+ mapping = self.construct_mapping(node)
+ class_name = str(node.tag[1:])+'Event'
+ if class_name in ['AliasEvent', 'ScalarEvent', 'SequenceEvent', 'MappingEvent']:
+ mapping.setdefault('anchor', None)
+ if class_name in ['ScalarEvent', 'SequenceEvent', 'MappingEvent']:
+ mapping.setdefault('tag', None)
+ if class_name == 'ScalarEvent':
+ mapping.setdefault('value', '')
+ value = getattr(yaml, class_name)(**mapping)
+ return value
+
+EventsConstructor.add_constructor(None, EventsConstructor.construct_event)
+
+class TestEmitter(test_appliance.TestAppliance):
+
+ def _testEmitter(self, test_name, events_filename):
+ events = load_document(file(events_filename, 'rb'), Constructor=EventsConstructor)
+ self._dump(events_filename, events)
+ writer = StringIO.StringIO()
+ emitter = Emitter(writer)
+ for event in events:
+ emitter.emit(event)
+ data = writer.getvalue()
+ new_events = list(parse(data))
+ self.failUnlessEqual(len(events), len(new_events))
+ for event, new_event in zip(events, new_events):
+ self.failUnlessEqual(event.__class__, new_event.__class__)
+
+ def _dump(self, events_filename, events):
+ writer = sys.stdout
+ emitter = Emitter(writer)
+ print "="*30
+ print "EVENTS:"
+ print file(events_filename, 'rb').read()
+ print '-'*30
+ print "OUTPUT:"
+ for event in events:
+ emitter.emit(event)
+
+TestEmitter.add_tests('testEmitter', '.events')
+