summaryrefslogtreecommitdiff
path: root/tests/test_emitter.py
blob: 698ea50b07bedae74dd122b9a448af744a8aa3b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

import test_appliance, sys, StringIO

from yaml import *
import yaml

class TestEmitterOnCanonical(test_appliance.TestAppliance):

    def _testEmitterOnCanonical(self, test_name, canonical_filename):
        events = list(iter(Parser(Scanner(Reader(file(canonical_filename, 'rb'))))))
        #writer = sys.stdout
        writer = StringIO.StringIO()
        emitter = Emitter(writer)
        #print "-"*30
        #print "ORIGINAL DATA:"
        #print file(canonical_filename, 'rb').read()
        for event in events:
            emitter.emit(event)
        data = writer.getvalue()
        new_events = list(parse(data))
        self.failUnlessEqual(len(events), len(new_events))
        for event, new_event in zip(events, new_events):
            self.failUnlessEqual(event.__class__, new_event.__class__)

TestEmitterOnCanonical.add_tests('testEmitterOnCanonical', '.canonical')

class EventsConstructor(Constructor):

    def construct_event(self, node):
        if isinstance(node, ScalarNode):
            mapping = {}
        else:
            mapping = self.construct_mapping(node)
        class_name = str(node.tag[1:])+'Event'
        if class_name in ['AliasEvent', 'ScalarEvent', 'SequenceEvent', 'MappingEvent']:
            mapping.setdefault('anchor', None)
        if class_name in ['ScalarEvent', 'SequenceEvent', 'MappingEvent']:
            mapping.setdefault('tag', None)
        if class_name == 'ScalarEvent':
            mapping.setdefault('value', '')
        value = getattr(yaml, class_name)(**mapping)
        return value

EventsConstructor.add_constructor(None, EventsConstructor.construct_event)

class TestEmitter(test_appliance.TestAppliance):

    def _testEmitter(self, test_name, events_filename):
        events = load_document(file(events_filename, 'rb'), Constructor=EventsConstructor)
        self._dump(events_filename, events)
        writer = StringIO.StringIO()
        emitter = Emitter(writer)
        for event in events:
            emitter.emit(event)
        data = writer.getvalue()
        new_events = list(parse(data))
        self.failUnlessEqual(len(events), len(new_events))
        for event, new_event in zip(events, new_events):
            self.failUnlessEqual(event.__class__, new_event.__class__)

    def _dump(self, events_filename, events):
        writer = sys.stdout
        emitter = Emitter(writer)
        print "="*30
        print "EVENTS:"
        print file(events_filename, 'rb').read()
        print '-'*30
        print "OUTPUT:"
        for event in events:
            emitter.emit(event)
        
TestEmitter.add_tests('testEmitter', '.events')