summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--_test/roundtrip.py49
-rw-r--r--_test/test_issues.py11
-rw-r--r--emitter.py4
-rw-r--r--main.py5
-rw-r--r--parser.py9
5 files changed, 78 insertions, 0 deletions
diff --git a/_test/roundtrip.py b/_test/roundtrip.py
index ee430a3..090b9e0 100644
--- a/_test/roundtrip.py
+++ b/_test/roundtrip.py
@@ -151,6 +151,55 @@ def round_trip(
return data
+def na_round_trip(
+ inp,
+ outp=None,
+ extra=None,
+ intermediate=None,
+ indent=None,
+ top_level_colon_align=None,
+ prefix_colon=None,
+ preserve_quotes=None,
+ explicit_start=None,
+ explicit_end=None,
+ version=None,
+ dump_data=None,
+):
+ """
+ inp: input string to parse
+ outp: expected output (equals input if not specified)
+ """
+ if outp is None:
+ outp = inp
+ if version is not None:
+ version = version
+ doutp = dedent(outp)
+ if extra is not None:
+ doutp += extra
+ yaml = YAML()
+ yaml.preserve_quotes = preserve_quotes
+ yaml.scalar_after_indicator = False # newline after every directives end
+ data = yaml.load(inp)
+ if dump_data:
+ print('data', data)
+ if intermediate is not None:
+ if isinstance(intermediate, dict):
+ for k, v in intermediate.items():
+ if data[k] != v:
+ print('{0!r} <> {1!r}'.format(data[k], v))
+ raise ValueError
+ yaml.indent = indent
+ yaml.top_level_colon_align = top_level_colon_align
+ yaml.prefix_colon = prefix_colon
+ yaml.explicit_start = explicit_start
+ yaml.explicit_end = explicit_end
+ res = yaml.dump(data, compare=doutp)
+ #if res != doutp:
+ # diff(doutp, res, 'input string')
+ #print('\nroundtrip data:\n', res, sep="")
+ #assert res == doutp
+
+
def YAML(**kw):
import ruamel.yaml # NOQA
diff --git a/_test/test_issues.py b/_test/test_issues.py
index 3692f61..9b301a9 100644
--- a/_test/test_issues.py
+++ b/_test/test_issues.py
@@ -8,6 +8,7 @@ import pytest # NOQA
from roundtrip import (
round_trip,
+ na_round_trip,
round_trip_load,
round_trip_dump,
dedent,
@@ -850,6 +851,16 @@ class TestIssues:
match='while scanning a directive'):
yaml.load(inp)
+ def test_issue_304(self):
+ inp = """
+ %YAML 1.2
+ %TAG ! tag:example.com,2019:
+ ---
+ !foo null
+ ...
+ """
+ d = na_round_trip(inp) # NOQA
+
# @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError)
# def test_issue_ xxx(self):
# inp = """
diff --git a/emitter.py b/emitter.py
index 14277fc..7cdebcd 100644
--- a/emitter.py
+++ b/emitter.py
@@ -204,6 +204,8 @@ class Emitter(object):
self.analysis = None # type: Any
self.style = None # type: Any
+ self.scalar_after_indicator = True # write a scalar on the same line as `---`
+
@property
def stream(self):
# type: () -> Any
@@ -408,6 +410,8 @@ class Emitter(object):
and self.sequence_context
):
self.sequence_context = False
+ if root and isinstance(self.event, ScalarEvent) and not self.scalar_after_indicator:
+ self.write_indent()
self.process_tag()
if isinstance(self.event, ScalarEvent):
# nprint('@', self.indention, self.no_newline, self.column)
diff --git a/main.py b/main.py
index 184c7ad..0785c95 100644
--- a/main.py
+++ b/main.py
@@ -161,6 +161,7 @@ class YAML(object):
self.tags = None
self.default_style = None
self.top_level_block_style_scalar_no_indent_error_1_1 = False
+ self.scalar_after_indicator = None # directives end indicator with single scalar document
# [a, b: 1, c: {d: 2}] vs. [a, {b: 1}, {c: {d: 2}}]
self.brace_single_entry_mapping_in_flow_sequence = False
@@ -518,12 +519,16 @@ class YAML(object):
self.Serializer = ruamel.yaml.serializer.Serializer
self.emitter.stream = stream
self.emitter.top_level_colon_align = tlca
+ if self.scalar_after_indicator is not None:
+ self.emitter.scalar_after_indicator = self.scalar_after_indicator
return self.serializer, self.representer, self.emitter
if self.Serializer is not None:
# cannot set serializer with CEmitter
self.Emitter = ruamel.yaml.emitter.Emitter
self.emitter.stream = stream
self.emitter.top_level_colon_align = tlca
+ if self.scalar_after_indicator is not None:
+ self.emitter.scalar_after_indicator = self.scalar_after_indicator
return self.serializer, self.representer, self.emitter
# C routines
diff --git a/parser.py b/parser.py
index 9793b42..566d13a 100644
--- a/parser.py
+++ b/parser.py
@@ -219,6 +219,9 @@ class Parser(object):
)
token = self.scanner.get_token()
end_mark = token.end_mark
+ # if self.loader is not None and \
+ # end_mark.line != self.scanner.peek_token().start_mark.line:
+ # self.loader.scalar_after_indicator = False
event = DocumentStartEvent(
start_mark, end_mark, explicit=True, version=version, tags=tags
) # type: Any
@@ -295,6 +298,12 @@ class Parser(object):
value = self.yaml_version, self.tag_handles.copy() # type: Any
else:
value = self.yaml_version, None
+ if self.loader is not None and hasattr(self.loader, 'tags'):
+ self.loader.version = self.yaml_version
+ if self.loader.tags is None:
+ self.loader.tags = {}
+ for k in self.tag_handles:
+ self.loader.tags[k] = self.tag_handles[k]
for key in self.DEFAULT_TAGS:
if key not in self.tag_handles:
self.tag_handles[key] = self.DEFAULT_TAGS[key]