summaryrefslogtreecommitdiff
path: root/_test/roundtrip.py
blob: ac41e7b0b8d747ebcf4fb90b8031f154a17ee6e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

from __future__ import print_function

"""
helper routines for testing round trip of commented YAML data
"""
import textwrap

import ruamel.yaml
from ruamel.yaml.compat import StringIO, BytesIO  # NOQA


def dedent(data):
    try:
        position_of_first_newline = data.index('\n')
        for idx in range(position_of_first_newline):
            if not data[idx].isspace():
                raise ValueError
    except ValueError:
        pass
    else:
        data = data[position_of_first_newline+1:]
    return textwrap.dedent(data)


def round_trip_load(inp, preserve_quotes=None, version=None):
    dinp = dedent(inp)
    return ruamel.yaml.load(
        dinp,
        Loader=ruamel.yaml.RoundTripLoader,
        preserve_quotes=preserve_quotes,
        version=version,
    )


def round_trip_load_all(inp, preserve_quotes=None, version=None):
    dinp = dedent(inp)
    return ruamel.yaml.load_all(
        dinp,
        Loader=ruamel.yaml.RoundTripLoader,
        preserve_quotes=preserve_quotes,
        version=version,
    )


def round_trip_dump(data, indent=None, block_seq_indent=None, top_level_colon_align=None,
                    prefix_colon=None, explicit_start=None, explicit_end=None, version=None):
    return ruamel.yaml.round_trip_dump(data,
                                       indent=indent, block_seq_indent=block_seq_indent,
                                       top_level_colon_align=top_level_colon_align,
                                       prefix_colon=prefix_colon,
                                       explicit_start=explicit_start,
                                       explicit_end=explicit_end,
                                       version=version)


def round_trip(inp, outp=None, extra=None, intermediate=None, indent=None,
               block_seq_indent=None, top_level_colon_align=None, prefix_colon=None,
               preserve_quotes=None,
               explicit_start=None, explicit_end=None,
               version=None):
    """
    inp:    input string to parse
    outp:   expected output (equals input if not specified)
    """
    if outp is None:
        outp = inp
    doutp = dedent(outp)
    if extra is not None:
        doutp += extra
    data = round_trip_load(inp, preserve_quotes=preserve_quotes)
    if intermediate is not None:
        if isinstance(intermediate, dict):
            for k, v in intermediate.items():
                if data[k] != v:
                    print('{0!r} <> {1!r}'.format(data[k], v))
                    raise ValueError
    res = round_trip_dump(data, indent=indent, block_seq_indent=block_seq_indent,
                          top_level_colon_align=top_level_colon_align,
                          prefix_colon=prefix_colon,
                          explicit_start=explicit_start,
                          explicit_end=explicit_end,
                          version=version)
    print('roundtrip data:\n', res, sep='')
    assert res == doutp
    res = round_trip_dump(data, indent=indent, block_seq_indent=block_seq_indent,
                          top_level_colon_align=top_level_colon_align,
                          prefix_colon=prefix_colon,
                          explicit_start=explicit_start,
                          explicit_end=explicit_end,
                          version=version)
    print('roundtrip second round data:\n', res, sep='')
    assert res == doutp


class YAML(ruamel.yaml.YAML):
    """auto dedent string parameters on load"""
    def load(self, stream):
        if isinstance(stream, str):
            if stream and stream[0] == '\n':
                stream = stream[1:]
            stream = textwrap.dedent(stream)
        return ruamel.yaml.YAML.load(self, stream)

    def dump(self, data, **kw):
        assert ('stream' in kw) ^ ('compare' in kw)
        if 'stream' in kw:
            return ruamel.yaml.YAML.dump(data, **kw)
        lkw = kw.copy()
        expected = textwrap.dedent(lkw.pop('compare'))
        if expected and expected[0] == '\n':
            expected = expected[1:]
        lkw['stream'] = st = StringIO() if self.encoding is None else BytesIO()
        ruamel.yaml.YAML.dump(self, data, **lkw)
        res = st.getvalue()
        print(res)
        assert res == expected