From 9ac44a0873d51d63150b0f1dc1d009b206577a29 Mon Sep 17 00:00:00 2001 From: Anthon van der Neut Date: Tue, 21 Mar 2017 17:18:18 +0100 Subject: update for mypy --strict, prepare de-inheritance (Loader/Dumper) --- .hgignore | 2 +- Makefile | 18 +- README.rst | 8 + __init__.py | 6 +- _test/lib/test_emitter.py | 3 + _test/test_add_xxx.py | 107 +++++++ _test/test_yamlobject.py | 38 +++ comments.py | 8 +- compat.py | 19 +- composer.py | 78 +++-- constructor.py | 100 +++--- cyaml.py | 55 +++- dumper.py | 52 +++- emitter.py | 170 +++++++--- error.py | 18 +- events.py | 11 + loader.py | 57 ++-- main.py | 241 ++++++++++---- nodes.py | 7 + parser.py | 295 ++++++++++-------- reader.py | 40 ++- representer.py | 122 ++++++-- resolver.py | 77 +++-- scalarstring.py | 17 +- scanner.py | 778 ++++++++++++++++++++++++---------------------- serializer.py | 80 +++-- timestamp.py | 9 +- tokens.py | 17 + tox.ini | 2 +- util.py | 4 +- 30 files changed, 1565 insertions(+), 874 deletions(-) create mode 100644 _test/test_add_xxx.py create mode 100644 _test/test_yamlobject.py diff --git a/.hgignore b/.hgignore index da8328e..0a6f71f 100644 --- a/.hgignore +++ b/.hgignore @@ -26,4 +26,4 @@ cmd TODO.rst _doc/_build .dcw_alt.yml -tags_for_scalarstrings +try_* diff --git a/Makefile b/Makefile index 27b7895..fb69498 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ gen_win_whl: python2 setup.py bdist_wheel --plat-name win_amd64 python3 setup.py bdist_wheel --plat-name win32 python3 setup.py bdist_wheel --plat-name win_amd64 - #@python make_win_whl.py dist/$(PKGNAME)-$(VERSION)-*-none-any.whl + # @python make_win_whl.py dist/$(PKGNAME)-$(VERSION)-*-none-any.whl clean: clean_common find . -name "*py.class" -exec rm {} + @@ -20,15 +20,27 @@ cython: ext/_yaml.c ext/_yaml.c: ext/_yaml.pyx cd ext; cython _yaml.pyx - + ls-l: ls -l dist/*$(VERSION)* +pytest: + py.test _test/*.py + +MYPYSRC:=$(shell ls -1 *.py | grep -Ev "^(setup.py|.*_flymake.py)$$" | sed 's|^|ruamel/yaml/|') +MYPYOPT:=--py2 --strict + mypy: cd ..; mypy --strict --no-warn-unused-ignores yaml/*.py +# sleep to give time to flymake*.py to disappear mypy2: - cd ../.. ; mypy --py2 --strict --no-strict-boolean --no-warn-unused-ignores ruamel/yaml/*.py + cd ../.. ; mypy $(MYPYOPT) $(MYPYSRC) + +mypy2single: + @echo 'mypy *.py' + @cd ../.. ; mypy $(MYPYOPT) $(MYPYSRC) | fgrep -v ordereddict/__init | grep . +# @echo 'mypy ' $(MYPYSRC) #tstvenv: testvenv testsetup testtest # diff --git a/README.rst b/README.rst index 8b990b0..50f32e2 100644 --- a/README.rst +++ b/README.rst @@ -18,6 +18,14 @@ ChangeLog .. should insert NEXT: at the beginning of line for next key +0.14.0 (2017-03-21): + - updates for mypy --strict + - preparation for moving away from inheritance in Loader and Dumper, calls from e.g. + the Representer to the Serializer.serialize() are now done via the attribute + .serializer.serialize(). Usage of .serialize() outside of Serializer will be + deprecated soon + - some extra tests on main.py functions + 0.13.14 (2017-02-12): - fix for issue 97: clipped block scalar followed by empty lines and comment would result in two CommentTokens of which the first was dropped. diff --git a/__init__.py b/__init__.py index 64b2f5e..b255c7a 100644 --- a/__init__.py +++ b/__init__.py @@ -10,8 +10,8 @@ from typing import Dict, Any # NOQA _package_data = dict( full_package_name='ruamel.yaml', - version_info=(0, 13, 15), - __version__='0.13.15', + version_info=(0, 14, 0), + __version__='0.14.0', author='Anthon van der Neut', author_email='a.van.der.neut@ruamel.eu', description='ruamel.yaml is a YAML parser/emitter that supports roundtrip preservation of comments, seq/map flow style, and map key order', # NOQA @@ -51,7 +51,7 @@ _package_data = dict( read_the_docs='yaml', many_linux='libyaml-devel', supported=[(2, 7), (3, 3)], # minimum -) # type: Dict[Any, Any] +) # type: Dict[Any, Any] version_info = _package_data['version_info'] diff --git a/_test/lib/test_emitter.py b/_test/lib/test_emitter.py index 1158854..7fff498 100644 --- a/_test/lib/test_emitter.py +++ b/_test/lib/test_emitter.py @@ -95,6 +95,9 @@ class EventsLoader(yaml.Loader): value = getattr(yaml, class_name)(**mapping) return value +# if Loader is not a composite, add this function +# EventsLoader.add_constructor = yaml.constructor.Constructor.add_constructor + EventsLoader.add_constructor(None, EventsLoader.construct_event) diff --git a/_test/test_add_xxx.py b/_test/test_add_xxx.py new file mode 100644 index 0000000..c283b44 --- /dev/null +++ b/_test/test_add_xxx.py @@ -0,0 +1,107 @@ +# coding: utf-8 + +import re +import pytest # NOQA +import ruamel.yaml + +from roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA + + +# from PyYAML docs +class Dice(tuple): + def __new__(cls, a, b): + return tuple.__new__(cls, [a, b]) + + def __repr__(self): + return "Dice(%s,%s)" % self + + +def dice_constructor(loader, node): + value = loader.construct_scalar(node) + a, b = map(int, value.split('d')) + return Dice(a, b) + + +def dice_representer(dumper, data): + return dumper.represent_scalar(u'!dice', u'{}d{}'.format(*data)) + + +def test_dice_constructor(): + ruamel.yaml.add_constructor(u'!dice', dice_constructor) + data = ruamel.yaml.load('initial hit points: !dice 8d4', Loader=ruamel.yaml.Loader) + assert str(data) == "{'initial hit points': Dice(8,4)}" + + +def test_dice_constructor_with_loader(): + ruamel.yaml.add_constructor(u'!dice', dice_constructor, Loader=ruamel.yaml.Loader) + data = ruamel.yaml.load('initial hit points: !dice 8d4', Loader=ruamel.yaml.Loader) + assert str(data) == "{'initial hit points': Dice(8,4)}" + + +def test_dice_representer(): + ruamel.yaml.add_representer(Dice, dice_representer) + assert ruamel.yaml.dump(dict(gold=Dice(10, 6)), default_flow_style=False) == \ + "gold: !dice '10d6'\n" + + +def test_dice_implicit_resolver(): + pattern = re.compile(r'^\d+d\d+$') + ruamel.yaml.add_implicit_resolver(u'!dice', pattern) + assert ruamel.yaml.dump(dict(treasure=Dice(10, 20)), default_flow_style=False) == \ + 'treasure: 10d20\n' + assert ruamel.yaml.load('damage: 5d10', Loader=ruamel.yaml.Loader) == \ + dict(damage=Dice(5, 10)) + + +class Obj1(dict): + def __init__(self, suffix): + self._suffix = suffix + self._node = None + + def add_node(self, n): + self._node = n + + def __repr__(self): + return 'Obj1(%s->%s)' % (self._suffix, self.items()) + + def dump(self): + return repr(self._node) + + +class YAMLObj1(object): + yaml_tag = u'!obj:' + + @classmethod + def from_yaml(cls, loader, suffix, node): + obj1 = Obj1(suffix) + if isinstance(node, ruamel.yaml.MappingNode): + obj1.add_node(loader.construct_mapping(node)) + else: + raise NotImplementedError + return obj1 + + @classmethod + def to_yaml(cls, dumper, data): + return dumper.represent_scalar(cls.yaml_tag + data._suffix, data.dump()) + + +def test_yaml_obj(): + ruamel.yaml.add_representer(Obj1, YAMLObj1.to_yaml) + ruamel.yaml.add_multi_constructor(YAMLObj1.yaml_tag, YAMLObj1.from_yaml) + x = ruamel.yaml.load('!obj:x.2\na: 1', Loader=ruamel.yaml.Loader) + print(x) + assert ruamel.yaml.dump(x) == '''!obj:x.2 "{'a': 1}"\n''' + + +def test_yaml_obj_with_loader_and_dumper(): + ruamel.yaml.add_representer(Obj1, YAMLObj1.to_yaml, Dumper=ruamel.yaml.Dumper) + ruamel.yaml.add_multi_constructor(YAMLObj1.yaml_tag, YAMLObj1.from_yaml, + Loader=ruamel.yaml.Loader) + x = ruamel.yaml.load('!obj:x.2\na: 1', Loader=ruamel.yaml.Loader) + # x = ruamel.yaml.load('!obj:x.2\na: 1') + print(x) + assert ruamel.yaml.dump(x) == '''!obj:x.2 "{'a': 1}"\n''' + + +# ToDo use nullege to search add_multi_representer and add_path_resolver +# and add some test code diff --git a/_test/test_yamlobject.py b/_test/test_yamlobject.py new file mode 100644 index 0000000..a5e06d8 --- /dev/null +++ b/_test/test_yamlobject.py @@ -0,0 +1,38 @@ +# coding: utf-8 + +import pytest # NOQA +import ruamel.yaml + +from roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA + + +class Monster(ruamel.yaml.YAMLObject): + yaml_tag = u'!Monster' + + def __init__(self, name, hp, ac, attacks): + self.name = name + self.hp = hp + self.ac = ac + self.attacks = attacks + + def __repr__(self): + return "%s(name=%r, hp=%r, ac=%r, attacks=%r)" % ( + self.__class__.__name__, self.name, self.hp, self.ac, self.attacks) + + +def test_monster(): + data = ruamel.yaml.load(dedent("""\ + --- !Monster + name: Cave spider + hp: [2,6] # 2d6 + ac: 16 + attacks: [BITE, HURT] + """), Loader=ruamel.yaml.Loader) + # normal dump, keys will be sorted + assert ruamel.yaml.dump(data) == dedent("""\ + !Monster + ac: 16 + attacks: [BITE, HURT] + hp: [2, 6] + name: Cave spider + """) diff --git a/comments.py b/comments.py index d77d72a..1aa6222 100644 --- a/comments.py +++ b/comments.py @@ -11,7 +11,7 @@ a separate base from typing import Any, Dict, Optional, List, Union # NOQA import copy -from collections import MutableSet, Sized, Set # type: ignore +from collections import MutableSet, Sized, Set from ruamel.yaml.compat import ordereddict, PY2 @@ -45,7 +45,7 @@ class Comment(object): def __str__(self): # type: () -> str - if self._end: + if bool(self._end): end = ',\n end=' + str(self._end) else: end = '' @@ -875,11 +875,11 @@ class CommentedSet(MutableSet, CommentedMap): __slots__ = Comment.attrib, 'odict', def __init__(self, values=None): - # type: (Optional[Any]) -> None + # type: (Any) -> None self.odict = ordereddict() MutableSet.__init__(self) if values is not None: - self |= values + self |= values # type: ignore def add(self, value): # type: (Any) -> None diff --git a/compat.py b/compat.py index d3529d4..54db14d 100644 --- a/compat.py +++ b/compat.py @@ -8,11 +8,11 @@ import sys import os import types -from typing import Any, Dict, Optional, List, Union, BinaryIO, IO, Text # NOQA +from typing import Any, Dict, Optional, List, Union, BinaryIO, IO, Text, Tuple # NOQA try: - from ruamel.ordereddict import ordereddict # type: ignore + from ruamel.ordereddict import ordereddict except: try: from collections import OrderedDict @@ -68,9 +68,9 @@ else: return unicode(s) if PY3: - string_types = str, - integer_types = int, - class_types = type, + string_types = str + integer_types = int + class_types = type text_type = str binary_type = bytes @@ -81,7 +81,7 @@ if PY3: BytesIO = io.BytesIO else: - string_types = basestring, + string_types = basestring integer_types = (int, long) class_types = (type, types.ClassType) text_type = unicode @@ -94,8 +94,11 @@ else: import cStringIO BytesIO = cStringIO.StringIO +# StreamType = Union[BinaryIO, IO[str], IO[unicode], StringIO] StreamType = Union[BinaryIO, IO[str], StringIO] + StreamTextType = Union[Text, StreamType] +VersionType = Union[List[int], str, Tuple[int, int]] if PY3: builtins_module = 'builtins' @@ -115,7 +118,7 @@ DBG_NODE = 4 _debug = None # type: Union[None, int] -if _debug: +if bool(_debug): class ObjectCounter(object): def __init__(self): # type: () -> None @@ -151,7 +154,7 @@ def dbg(val=None): def nprint(*args, **kw): # type: (Any, Any) -> None - if dbg: + if bool(dbg): print(*args, **kw) # char checkers following production rules diff --git a/composer.py b/composer.py index bbaa62f..9208bdc 100644 --- a/composer.py +++ b/composer.py @@ -26,73 +26,87 @@ class ComposerError(MarkedYAMLError): class Composer(object): - def __init__(self): - # type: () -> None + def __init__(self, loader=None): + # type: (Any) -> None + self.loader = loader + if self.loader is not None: + self.loader._composer = self self.anchors = {} # type: Dict[Any, Any] + @property + def parser(self): + # type: () -> Any + return self.loader._parser + + @property + def resolver(self): + # type: () -> Any + # assert self.loader._resolver is not None + return self.loader._resolver + def check_node(self): # type: () -> Any # Drop the STREAM-START event. - if self.check_event(StreamStartEvent): - self.get_event() + if self.parser.check_event(StreamStartEvent): + self.parser.get_event() # If there are more documents available? - return not self.check_event(StreamEndEvent) + return not self.parser.check_event(StreamEndEvent) def get_node(self): # type: () -> Any # Get the root node of the next document. - if not self.check_event(StreamEndEvent): + if not self.parser.check_event(StreamEndEvent): return self.compose_document() def get_single_node(self): # type: () -> Any # Drop the STREAM-START event. - self.get_event() + self.parser.get_event() # Compose a document if the stream is not empty. document = None - if not self.check_event(StreamEndEvent): + if not self.parser.check_event(StreamEndEvent): document = self.compose_document() # Ensure that the stream contains no more documents. - if not self.check_event(StreamEndEvent): - event = self.get_event() + if not self.parser.check_event(StreamEndEvent): + event = self.parser.get_event() raise ComposerError( "expected a single document in the stream", document.start_mark, "but found another document", event.start_mark) # Drop the STREAM-END event. - self.get_event() + self.parser.get_event() return document def compose_document(self): # type: (Any) -> Any # Drop the DOCUMENT-START event. - self.get_event() + self.parser.get_event() # Compose the root node. node = self.compose_node(None, None) # Drop the DOCUMENT-END event. - self.get_event() + self.parser.get_event() self.anchors = {} return node def compose_node(self, parent, index): # type: (Any, Any) -> Any - if self.check_event(AliasEvent): - event = self.get_event() + if self.parser.check_event(AliasEvent): + event = self.parser.get_event() alias = event.anchor if alias not in self.anchors: raise ComposerError( None, None, "found undefined alias %r" % utf8(alias), event.start_mark) return self.anchors[alias] - event = self.peek_event() + event = self.parser.peek_event() anchor = event.anchor if anchor is not None: # have an anchor if anchor in self.anchors: @@ -104,22 +118,22 @@ class Composer(object): "{}".format( (anchor), self.anchors[anchor].start_mark, event.start_mark) warnings.warn(ws, ReusedAnchorWarning) - self.descend_resolver(parent, index) - if self.check_event(ScalarEvent): + self.resolver.descend_resolver(parent, index) + if self.parser.check_event(ScalarEvent): node = self.compose_scalar_node(anchor) - elif self.check_event(SequenceStartEvent): + elif self.parser.check_event(SequenceStartEvent): node = self.compose_sequence_node(anchor) - elif self.check_event(MappingStartEvent): + elif self.parser.check_event(MappingStartEvent): node = self.compose_mapping_node(anchor) - self.ascend_resolver() + self.resolver.ascend_resolver() return node def compose_scalar_node(self, anchor): # type: (Any) -> Any - event = self.get_event() + event = self.parser.get_event() tag = event.tag if tag is None or tag == u'!': - tag = self.resolve(ScalarNode, event.value, event.implicit) + tag = self.resolver.resolve(ScalarNode, event.value, event.implicit) node = ScalarNode(tag, event.value, event.start_mark, event.end_mark, style=event.style, comment=event.comment) @@ -129,10 +143,10 @@ class Composer(object): def compose_sequence_node(self, anchor): # type: (Any) -> Any - start_event = self.get_event() + start_event = self.parser.get_event() tag = start_event.tag if tag is None or tag == u'!': - tag = self.resolve(SequenceNode, None, start_event.implicit) + tag = self.resolver.resolve(SequenceNode, None, start_event.implicit) node = SequenceNode(tag, [], start_event.start_mark, None, flow_style=start_event.flow_style, @@ -140,10 +154,10 @@ class Composer(object): if anchor is not None: self.anchors[anchor] = node index = 0 - while not self.check_event(SequenceEndEvent): + while not self.parser.check_event(SequenceEndEvent): node.value.append(self.compose_node(node, index)) index += 1 - end_event = self.get_event() + end_event = self.parser.get_event() if node.flow_style is True and end_event.comment is not None: if node.comment is not None: print('Warning: unexpected end_event commment in sequence ' @@ -155,18 +169,18 @@ class Composer(object): def compose_mapping_node(self, anchor): # type: (Any) -> Any - start_event = self.get_event() + start_event = self.parser.get_event() tag = start_event.tag if tag is None or tag == u'!': - tag = self.resolve(MappingNode, None, start_event.implicit) + tag = self.resolver.resolve(MappingNode, None, start_event.implicit) node = MappingNode(tag, [], start_event.start_mark, None, flow_style=start_event.flow_style, comment=start_event.comment, anchor=anchor) if anchor is not None: self.anchors[anchor] = node - while not self.check_event(MappingEndEvent): - # key_event = self.peek_event() + while not self.parser.check_event(MappingEndEvent): + # key_event = self.parser.peek_event() item_key = self.compose_node(node, None) # if item_key in node.value: # raise ComposerError("while composing a mapping", @@ -175,7 +189,7 @@ class Composer(object): item_value = self.compose_node(node, item_key) # node.value[item_key] = item_value node.value.append((item_key, item_value)) - end_event = self.get_event() + end_event = self.parser.get_event() if node.flow_style is True and end_event.comment is not None: node.comment = end_event.comment node.end_mark = end_event.end_mark diff --git a/constructor.py b/constructor.py index 2bcb3bd..6797d95 100644 --- a/constructor.py +++ b/constructor.py @@ -10,7 +10,7 @@ import re import sys import types -from typing import Any, Dict, List # NOQA +from typing import Any, Dict, List, Set, Generator # NOQA from ruamel.yaml.error import (MarkedYAMLError) @@ -39,29 +39,49 @@ class BaseConstructor(object): yaml_constructors = {} # type: Dict[Any, Any] yaml_multi_constructors = {} # type: Dict[Any, Any] - def __init__(self, preserve_quotes=None): - # type: (bool) -> None + def __init__(self, preserve_quotes=None, loader=None): + # type: (bool, Any) -> None + self.loader = loader + if self.loader is not None: + self.loader._constructor = self + self.loader = loader self.constructed_objects = {} # type: Dict[Any, Any] self.recursive_objects = {} # type: Dict[Any, Any] self.state_generators = [] # type: List[Any] self.deep_construct = False self._preserve_quotes = preserve_quotes + @property + def composer(self): + # type: () -> Any + try: + return self.loader._composer + except AttributeError: + print('slt', type(self)) + print('slc', self.loader._composer) + print(dir(self)) + raise + + @property + def resolver(self): + # type: () -> Any + return self.loader._resolver + def check_data(self): # type: () -> Any # If there are more documents available? - return self.check_node() + return self.composer.check_node() def get_data(self): # type: () -> Any # Construct and return the next document. - if self.check_node(): - return self.construct_document(self.get_node()) + if self.composer.check_node(): + return self.construct_document(self.composer.get_node()) def get_single_data(self): # type: () -> Any # Ensure that the stream contains a single document and construct it. - node = self.get_single_node() + node = self.composer.get_single_node() if node is not None: return self.construct_document(node) return None @@ -69,7 +89,7 @@ class BaseConstructor(object): def construct_document(self, node): # type: (Any) -> Any data = self.construct_object(node) - while self.state_generators: + while bool(self.state_generators): state_generators = self.state_generators self.state_generators = [] for generator in state_generators: @@ -112,20 +132,20 @@ class BaseConstructor(object): elif None in self.yaml_constructors: constructor = self.yaml_constructors[None] elif isinstance(node, ScalarNode): - constructor = self.__class__.construct_scalar + constructor = self.__class__.construct_scalar # type: ignore elif isinstance(node, SequenceNode): - constructor = self.__class__.construct_sequence + constructor = self.__class__.construct_sequence # type: ignore elif isinstance(node, MappingNode): - constructor = self.__class__.construct_mapping + constructor = self.__class__.construct_mapping # type: ignore if tag_suffix is None: data = constructor(self, node) else: data = constructor(self, tag_suffix, node) if isinstance(data, types.GeneratorType): generator = data - data = next(generator) + data = next(generator) # type: ignore if self.deep_construct: - for dummy in generator: + for dummy in generator: # type: ignore pass else: self.state_generators.append(generator) @@ -172,7 +192,7 @@ class BaseConstructor(object): # keys can be list -> deep key = self.construct_object(key_node, deep=True) # lists are not hashable, but tuples are - if not isinstance(key, collections.Hashable): + if not isinstance(key, collections.Hashable): # type: ignore if isinstance(key, list): key = tuple(key) if PY2: @@ -238,7 +258,7 @@ class SafeConstructor(BaseConstructor): by inserting keys from the merge dict/list of dicts if not yet available in this node """ - merge = [] + merge = [] # type: List[Any] index = 0 while index < len(node.value): key_node, value_node = node.value[index] @@ -272,7 +292,7 @@ class SafeConstructor(BaseConstructor): index += 1 else: index += 1 - if merge: + if bool(merge): node.value = merge + node.value def construct_mapping(self, node, deep=False): @@ -321,9 +341,9 @@ class SafeConstructor(BaseConstructor): return sign*int(value_s[2:], 16) elif value_s.startswith('0o'): return sign*int(value_s[2:], 8) - elif self.processing_version != (1, 2) and value_s[0] == '0': + elif self.resolver.processing_version != (1, 2) and value_s[0] == '0': return sign*int(value_s, 8) - elif self.processing_version != (1, 2) and ':' in value_s: + elif self.resolver.processing_version != (1, 2) and ':' in value_s: digits = [int(part) for part in value_s.split(':')] digits.reverse() base = 1 @@ -438,7 +458,7 @@ class SafeConstructor(BaseConstructor): delta = -delta data = datetime.datetime(year, month, day, hour, minute, second, fraction) - if delta: + if delta: # type: ignore data -= delta return data @@ -499,7 +519,7 @@ class SafeConstructor(BaseConstructor): def construct_yaml_set(self, node): # type: (Any) -> Any - data = set() + data = set() # type: Set[Any] yield data value = self.construct_mapping(node) data.update(value) @@ -516,13 +536,13 @@ class SafeConstructor(BaseConstructor): def construct_yaml_seq(self, node): # type: (Any) -> Any - data = [] + data = [] # type: List[Any] yield data data.extend(self.construct_sequence(node)) def construct_yaml_map(self, node): # type: (Any) -> Any - data = {} + data = {} # type: Dict[Any, Any] yield data value = self.construct_mapping(node) data.update(value) @@ -597,6 +617,10 @@ SafeConstructor.add_constructor( SafeConstructor.add_constructor( None, SafeConstructor.construct_undefined) +if PY2: + class classobj: + pass + class Constructor(SafeConstructor): @@ -702,10 +726,6 @@ class Constructor(SafeConstructor): node.start_mark) return self.find_python_module(suffix, node.start_mark) - if PY2: - class classobj: - pass - def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False): # type: (Any, Any, Any, Any, bool) -> Any @@ -720,9 +740,9 @@ class Constructor(SafeConstructor): else: return cls(*args, **kwds) else: - if newobj and isinstance(cls, type(self.classobj)) \ + if newobj and isinstance(cls, type(classobj)) \ and not args and not kwds: - instance = self.classobj() + instance = classobj() instance.__class__ = cls return instance elif newobj and isinstance(cls, type): @@ -772,7 +792,7 @@ class Constructor(SafeConstructor): args = self.construct_sequence(node, deep=True) kwds = {} # type: Dict[Any, Any] state = {} # type: Dict[Any, Any] - listitems = [] # List[Any] + listitems = [] # type: List[Any] dictitems = {} # type: Dict[Any, Any] else: value = self.construct_mapping(node, deep=True) @@ -782,11 +802,11 @@ class Constructor(SafeConstructor): listitems = value.get('listitems', []) dictitems = value.get('dictitems', {}) instance = self.make_python_instance(suffix, node, args, kwds, newobj) - if state: + if bool(state): self.set_python_instance_state(instance, state) - if listitems: + if bool(listitems): instance.extend(listitems) - if dictitems: + if bool(dictitems): for key in dictitems: instance[key] = dictitems[key] return instance @@ -880,7 +900,7 @@ class RoundTripConstructor(SafeConstructor): if node.style == '|' and isinstance(node.value, text_type): return PreservedScalarString(node.value) - elif self._preserve_quotes and isinstance(node.value, text_type): + elif bool(self._preserve_quotes) and isinstance(node.value, text_type): if node.style == "'": return SingleQuotedScalarString(node.value) if node.style == '"': @@ -914,7 +934,7 @@ class RoundTripConstructor(SafeConstructor): seqtyp._yaml_add_comment(node.comment[:2]) if len(node.comment) > 2: seqtyp.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: + if node.anchor: # type: ignore from ruamel.yaml.serializer import templated_id if not templated_id(node.anchor): seqtyp.yaml_set_anchor(node.anchor) @@ -993,7 +1013,7 @@ class RoundTripConstructor(SafeConstructor): # type: () -> None pass - def construct_mapping(self, node, maptyp, deep=False): + def construct_mapping(self, node, maptyp=None, deep=False): # type: ignore # type: (Any, Any, bool) -> Any if not isinstance(node, MappingNode): raise ConstructorError( @@ -1006,7 +1026,7 @@ class RoundTripConstructor(SafeConstructor): maptyp._yaml_add_comment(node.comment[:2]) if len(node.comment) > 2: maptyp.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: + if node.anchor: # type: ignore from ruamel.yaml.serializer import templated_id if not templated_id(node.anchor): maptyp.yaml_set_anchor(node.anchor) @@ -1015,7 +1035,7 @@ class RoundTripConstructor(SafeConstructor): # keys can be list -> deep key = self.construct_object(key_node, deep=True) # lists are not hashable, but tuples are - if not isinstance(key, collections.Hashable): + if not isinstance(key, collections.Hashable): # type: ignore if isinstance(key, list): key_a = CommentedKeySeq(key) if key_node.flow_style is True: @@ -1072,7 +1092,7 @@ class RoundTripConstructor(SafeConstructor): typ._yaml_add_comment(node.comment[:2]) if len(node.comment) > 2: typ.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: + if node.anchor: # type: ignore from ruamel.yaml.serializer import templated_id if not templated_id(node.anchor): typ.yaml_set_anchor(node.anchor) @@ -1080,7 +1100,7 @@ class RoundTripConstructor(SafeConstructor): # keys can be list -> deep key = self.construct_object(key_node, deep=True) # lists are not hashable, but tuples are - if not isinstance(key, collections.Hashable): + if not isinstance(key, collections.Hashable): # type: ignore if isinstance(key, list): key = tuple(key) if PY2: @@ -1229,7 +1249,7 @@ class RoundTripConstructor(SafeConstructor): delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) if values['tz_sign'] == '-': delta = -delta - if delta: + if delta: # type: ignore dt = datetime.datetime(year, month, day, hour, minute) dt -= delta data = TimeStamp(dt.year, dt.month, dt.day, dt.hour, dt.minute, diff --git a/cyaml.py b/cyaml.py index bff3289..2223a5b 100644 --- a/cyaml.py +++ b/cyaml.py @@ -2,6 +2,8 @@ from __future__ import absolute_import +from typing import Any, Union # NOQA + from _ruamel_yaml import CParser, CEmitter # type: ignore from ruamel.yaml.constructor import Constructor, BaseConstructor, SafeConstructor @@ -9,32 +11,52 @@ from ruamel.yaml.serializer import Serializer from ruamel.yaml.representer import Representer, SafeRepresenter, BaseRepresenter from ruamel.yaml.resolver import Resolver, BaseResolver +from ruamel.yaml.compat import StreamTextType, StreamType, VersionType # NOQA + __all__ = ['CBaseLoader', 'CSafeLoader', 'CLoader', 'CBaseDumper', 'CSafeDumper', 'CDumper'] -class CBaseLoader(CParser, BaseConstructor, BaseResolver): +# this includes some hacks to solve the usage of resolver by lower level +# parts of the parser + +class CBaseLoader(CParser, BaseConstructor, BaseResolver): # type: ignore def __init__(self, stream, version=None, preserve_quotes=None): + # type: (StreamTextType, VersionType, bool) -> None CParser.__init__(self, stream) - BaseConstructor.__init__(self) - BaseResolver.__init__(self) + self._parser = self._composer = self + BaseConstructor.__init__(self, loader=self) + BaseResolver.__init__(self, loadumper=self) + # self.descend_resolver = self._resolver.descend_resolver + # self.ascend_resolver = self._resolver.ascend_resolver + # self.resolve = self._resolver.resolve -class CSafeLoader(CParser, SafeConstructor, Resolver): +class CSafeLoader(CParser, SafeConstructor, Resolver): # type: ignore def __init__(self, stream, version=None, preserve_quotes=None): + # type: (StreamTextType, VersionType, bool) -> None CParser.__init__(self, stream) - SafeConstructor.__init__(self) - Resolver.__init__(self) + self._parser = self._composer = self + SafeConstructor.__init__(self, loader=self) + Resolver.__init__(self, loadumper=self) + # self.descend_resolver = self._resolver.descend_resolver + # self.ascend_resolver = self._resolver.ascend_resolver + # self.resolve = self._resolver.resolve -class CLoader(CParser, Constructor, Resolver): +class CLoader(CParser, Constructor, Resolver): # type: ignore def __init__(self, stream, version=None, preserve_quotes=None): + # type: (StreamTextType, VersionType, bool) -> None CParser.__init__(self, stream) - Constructor.__init__(self) - Resolver.__init__(self) + self._parser = self._composer = self + Constructor.__init__(self, loader=self) + Resolver.__init__(self, loadumper=self) + # self.descend_resolver = self._resolver.descend_resolver + # self.ascend_resolver = self._resolver.ascend_resolver + # self.resolve = self._resolver.resolve -class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): +class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): # type: ignore def __init__(self, stream, default_style=None, default_flow_style=None, canonical=None, indent=None, width=None, @@ -42,18 +64,19 @@ class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (StreamType, Any, Any, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA CEmitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, encoding=encoding, allow_unicode=allow_unicode, line_break=line_break, explicit_start=explicit_start, explicit_end=explicit_end, version=version, tags=tags) - Representer.__init__(self, default_style=default_style, - default_flow_style=default_flow_style) - Resolver.__init__(self) + BaseRepresenter.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, dumper=self) + BaseResolver.__init__(self, loadumper=self) -class CSafeDumper(CEmitter, SafeRepresenter, Resolver): +class CSafeDumper(CEmitter, SafeRepresenter, Resolver): # type: ignore def __init__(self, stream, default_style=None, default_flow_style=None, canonical=None, indent=None, width=None, @@ -61,6 +84,7 @@ class CSafeDumper(CEmitter, SafeRepresenter, Resolver): encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (StreamType, Any, Any, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA CEmitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, encoding=encoding, allow_unicode=allow_unicode, line_break=line_break, @@ -72,7 +96,7 @@ class CSafeDumper(CEmitter, SafeRepresenter, Resolver): Resolver.__init__(self) -class CDumper(CEmitter, Serializer, Representer, Resolver): +class CDumper(CEmitter, Serializer, Representer, Resolver): # type: ignore def __init__(self, stream, default_style=None, default_flow_style=None, canonical=None, indent=None, width=None, @@ -80,6 +104,7 @@ class CDumper(CEmitter, Serializer, Representer, Resolver): encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (StreamType, Any, Any, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA CEmitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, encoding=encoding, allow_unicode=allow_unicode, line_break=line_break, diff --git a/dumper.py b/dumper.py index 72b8106..139265e 100644 --- a/dumper.py +++ b/dumper.py @@ -2,6 +2,10 @@ from __future__ import absolute_import +from typing import Any, Dict, List, Union # NOQA + +from ruamel.yaml.compat import StreamType, VersionType # NOQA + from ruamel.yaml.emitter import Emitter from ruamel.yaml.serializer import Serializer from ruamel.yaml.representer import Representer, SafeRepresenter, BaseRepresenter, \ @@ -19,17 +23,21 @@ class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver): encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (Any, StreamType, Any, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA Emitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, allow_unicode=allow_unicode, line_break=line_break, - block_seq_indent=block_seq_indent) + block_seq_indent=block_seq_indent, + dumper=self) Serializer.__init__(self, encoding=encoding, explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) + version=version, tags=tags, + dumper=self) BaseRepresenter.__init__(self, default_style=default_style, - default_flow_style=default_flow_style) - BaseResolver.__init__(self) + default_flow_style=default_flow_style, + dumper=self) + BaseResolver.__init__(self, loadumper=self) class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): @@ -40,17 +48,21 @@ class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (Any, StreamType, Any, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA Emitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, allow_unicode=allow_unicode, line_break=line_break, - block_seq_indent=block_seq_indent) + block_seq_indent=block_seq_indent, + dumper=self) Serializer.__init__(self, encoding=encoding, explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) + version=version, tags=tags, + dumper=self) SafeRepresenter.__init__(self, default_style=default_style, - default_flow_style=default_flow_style) - Resolver.__init__(self) + default_flow_style=default_flow_style, + dumper=self) + Resolver.__init__(self, loadumper=self) class Dumper(Emitter, Serializer, Representer, Resolver): @@ -61,17 +73,21 @@ class Dumper(Emitter, Serializer, Representer, Resolver): encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (Any, StreamType, Any, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA Emitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, allow_unicode=allow_unicode, line_break=line_break, - block_seq_indent=block_seq_indent) + block_seq_indent=block_seq_indent, + dumper=self) Serializer.__init__(self, encoding=encoding, explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) + version=version, tags=tags, + dumper=self) Representer.__init__(self, default_style=default_style, - default_flow_style=default_flow_style) - Resolver.__init__(self) + default_flow_style=default_flow_style, + dumper=self) + Resolver.__init__(self, loadumper=self) class RoundTripDumper(Emitter, Serializer, RoundTripRepresenter, VersionedResolver): @@ -82,16 +98,20 @@ class RoundTripDumper(Emitter, Serializer, RoundTripRepresenter, VersionedResolv encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None, block_seq_indent=None, top_level_colon_align=None, prefix_colon=None): + # type: (Any, StreamType, Any, bool, Union[None, int], Union[None, int], bool, Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA Emitter.__init__(self, stream, canonical=canonical, indent=indent, width=width, allow_unicode=allow_unicode, line_break=line_break, block_seq_indent=block_seq_indent, top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon) + prefix_colon=prefix_colon, + dumper=self) Serializer.__init__(self, encoding=encoding, explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) + version=version, tags=tags, + dumper=self) RoundTripRepresenter.__init__(self, default_style=default_style, - default_flow_style=default_flow_style) - VersionedResolver.__init__(self) + default_flow_style=default_flow_style, + dumper=self) + VersionedResolver.__init__(self, loader=self) diff --git a/emitter.py b/emitter.py index 4feb6b2..7a2a5a2 100644 --- a/emitter.py +++ b/emitter.py @@ -10,10 +10,14 @@ from __future__ import print_function # sequence ::= SEQUENCE-START node* SEQUENCE-END # mapping ::= MAPPING-START (node node)* MAPPING-END +from typing import Any, Dict, List, Union, Text # NOQA + + from ruamel.yaml.error import YAMLError from ruamel.yaml.events import * # NOQA from ruamel.yaml.compat import utf8, text_type, PY2, nprint, dbg, DBG_EVENT, \ check_anchorname_char +from ruamel.yaml.compat import StreamType # NOQA __all__ = ['Emitter', 'EmitterError'] @@ -27,6 +31,7 @@ class ScalarAnalysis(object): allow_flow_plain, allow_block_plain, allow_single_quoted, allow_double_quoted, allow_block): + # type: (Any, Any, Any, bool, bool, bool, bool, bool) -> None self.scalar = scalar self.empty = empty self.multiline = multiline @@ -47,26 +52,29 @@ class Emitter(object): def __init__(self, stream, canonical=None, indent=None, width=None, allow_unicode=None, line_break=None, block_seq_indent=None, - top_level_colon_align=None, prefix_colon=None): - + top_level_colon_align=None, prefix_colon=None, dumper=None): + # type: (StreamType, Any, int, int, bool, Any, int, bool, Any, Any) -> None + self.dumper = dumper + if self.dumper is not None: + self.dumper._emitter = self # The stream should have the methods `write` and possibly `flush`. self.stream = stream # Encoding can be overriden by STREAM-START. - self.encoding = None + self.encoding = None # type: Union[None, Text] # Emitter is a state machine with a stack of states to handle nested # structures. - self.states = [] - self.state = self.expect_stream_start + self.states = [] # type: List[Any] + self.state = self.expect_stream_start # type: Any # Current event and the event queue. - self.events = [] - self.event = None + self.events = [] # type: List[Any] + self.event = None # type: Any # The current indentation level and the stack of previous indents. - self.indents = [] - self.indent = None + self.indents = [] # type: List[Union[None, int]] + self.indent = None # type: Union[None, int] # Flow level. self.flow_level = 0 @@ -86,7 +94,7 @@ class Emitter(object): self.column = 0 self.whitespace = True self.indention = True - self.no_newline = None # set if directly after `- ` + self.no_newline = None # type: Union[None, bool] # set if directly after `- ` # Whether the document requires an explicit document indicator self.open_ended = False @@ -98,37 +106,39 @@ class Emitter(object): # Formatting details. self.canonical = canonical self.allow_unicode = allow_unicode - self.block_seq_indent = block_seq_indent if block_seq_indent else 0 + self.block_seq_indent = block_seq_indent if block_seq_indent else 0 # type: ignore self.top_level_colon_align = top_level_colon_align self.best_indent = 2 - if indent and 1 < indent < 10: + if indent and 1 < indent < 10: # type: ignore self.best_indent = indent # if self.best_indent < self.block_seq_indent + 1: # self.best_indent = self.block_seq_indent + 1 self.best_width = 80 - if width and width > self.best_indent*2: + if width and width > self.best_indent*2: # type: ignore self.best_width = width self.best_line_break = u'\n' if line_break in [u'\r', u'\n', u'\r\n']: self.best_line_break = line_break # Tag prefixes. - self.tag_prefixes = None + self.tag_prefixes = None # type: Any # Prepared anchor and tag. - self.prepared_anchor = None - self.prepared_tag = None + self.prepared_anchor = None # type: Any + self.prepared_tag = None # type: Any # Scalar analysis and style. - self.analysis = None - self.style = None + self.analysis = None # type: Any + self.style = None # type: Any def dispose(self): + # type: () -> None # Reset the state attributes (to clear self-references) - self.states = [] + self.states = [] # type: List[Any] self.state = None def emit(self, event): + # type: (Any) -> None if dbg(DBG_EVENT): nprint(event) self.events.append(event) @@ -140,6 +150,7 @@ class Emitter(object): # In some cases, we wait for a few next events before emitting. def need_more_events(self): + # type: () -> bool if not self.events: return True event = self.events[0] @@ -153,6 +164,7 @@ class Emitter(object): return False def need_events(self, count): + # type: (int) -> bool level = 0 for event in self.events[1:]: if isinstance(event, (DocumentStartEvent, CollectionStartEvent)): @@ -166,6 +178,7 @@ class Emitter(object): return (len(self.events) < count+1) def increase_indent(self, flow=False, sequence=None, indentless=False): + # type: (bool, bool, bool) -> None self.indents.append(self.indent) if self.indent is None: if flow: @@ -182,6 +195,7 @@ class Emitter(object): # Stream handlers. def expect_stream_start(self): + # type: () -> None if isinstance(self.event, StreamStartEvent): if PY2: if self.event.encoding \ @@ -198,14 +212,17 @@ class Emitter(object): self.event) def expect_nothing(self): + # type: () -> None raise EmitterError("expected nothing, but got %s" % self.event) # Document handlers. def expect_first_document_start(self): + # type: () -> Any return self.expect_document_start(first=True) def expect_document_start(self, first=False): + # type: (bool) -> None if isinstance(self.event, DocumentStartEvent): if (self.event.version or self.event.tags) and self.open_ended: self.write_indicator(u'...', True) @@ -245,6 +262,7 @@ class Emitter(object): self.event) def expect_document_end(self): + # type: () -> None if isinstance(self.event, DocumentEndEvent): self.write_indent() if self.event.explicit: @@ -257,6 +275,7 @@ class Emitter(object): self.event) def expect_document_root(self): + # type: () -> None self.states.append(self.expect_document_end) self.expect_node(root=True) @@ -264,6 +283,7 @@ class Emitter(object): def expect_node(self, root=False, sequence=False, mapping=False, simple_key=False): + # type: (bool, bool, bool, bool) -> None self.root_context = root self.sequence_context = sequence # not used in PyYAML self.mapping_context = mapping @@ -300,12 +320,14 @@ class Emitter(object): raise EmitterError("expected NodeEvent, but got %s" % self.event) def expect_alias(self): + # type: () -> None if self.event.anchor is None: raise EmitterError("anchor is not specified for alias") self.process_anchor(u'*') self.state = self.states.pop() def expect_scalar(self): + # type: () -> None self.increase_indent(flow=True) self.process_scalar() self.indent = self.indents.pop() @@ -314,12 +336,14 @@ class Emitter(object): # Flow sequence handlers. def expect_flow_sequence(self): + # type: () -> None self.write_indicator(u'[', True, whitespace=True) self.flow_level += 1 self.increase_indent(flow=True, sequence=True) self.state = self.expect_first_flow_sequence_item def expect_first_flow_sequence_item(self): + # type: () -> None if isinstance(self.event, SequenceEndEvent): self.indent = self.indents.pop() self.flow_level -= 1 @@ -332,6 +356,7 @@ class Emitter(object): self.expect_node(sequence=True) def expect_flow_sequence_item(self): + # type: () -> None if isinstance(self.event, SequenceEndEvent): self.indent = self.indents.pop() self.flow_level -= 1 @@ -353,12 +378,14 @@ class Emitter(object): # Flow mapping handlers. def expect_flow_mapping(self): + # type: () -> None self.write_indicator(u'{', True, whitespace=True) self.flow_level += 1 self.increase_indent(flow=True, sequence=False) self.state = self.expect_first_flow_mapping_key def expect_first_flow_mapping_key(self): + # type: () -> None if isinstance(self.event, MappingEndEvent): self.indent = self.indents.pop() self.flow_level -= 1 @@ -379,6 +406,7 @@ class Emitter(object): self.expect_node(mapping=True) def expect_flow_mapping_key(self): + # type: () -> None if isinstance(self.event, MappingEndEvent): # if self.event.comment and self.event.comment[1]: # self.write_pre_comment(self.event) @@ -405,11 +433,13 @@ class Emitter(object): self.expect_node(mapping=True) def expect_flow_mapping_simple_value(self): + # type: () -> None self.write_indicator(self.prefixed_colon, False) self.states.append(self.expect_flow_mapping_key) self.expect_node(mapping=True) def expect_flow_mapping_value(self): + # type: () -> None if self.canonical or self.column > self.best_width: self.write_indent() self.write_indicator(self.prefixed_colon, True) @@ -419,14 +449,17 @@ class Emitter(object): # Block sequence handlers. def expect_block_sequence(self): + # type: () -> None indentless = (self.mapping_context and not self.indention) self.increase_indent(flow=False, sequence=True, indentless=indentless) self.state = self.expect_first_block_sequence_item def expect_first_block_sequence_item(self): + # type: () -> Any return self.expect_block_sequence_item(first=True) def expect_block_sequence_item(self, first=False): + # type: (bool) -> None if not first and isinstance(self.event, SequenceEndEvent): if self.event.comment and self.event.comment[1]: # final comments from a doc @@ -446,13 +479,16 @@ class Emitter(object): # Block mapping handlers. def expect_block_mapping(self): + # type: () -> None self.increase_indent(flow=False, sequence=False) self.state = self.expect_first_block_mapping_key def expect_first_block_mapping_key(self): + # type: () -> None return self.expect_block_mapping_key(first=True) def expect_block_mapping_key(self, first=False): + # type: (Any) -> None if not first and isinstance(self.event, MappingEndEvent): if self.event.comment and self.event.comment[1]: # final comments from a doc @@ -476,6 +512,7 @@ class Emitter(object): self.expect_node(mapping=True) def expect_block_mapping_simple_value(self): + # type: () -> None if getattr(self.event, 'style', None) != '?': # prefix = u'' if self.indent == 0 and self.top_level_colon_align is not None: @@ -488,6 +525,7 @@ class Emitter(object): self.expect_node(mapping=True) def expect_block_mapping_value(self): + # type: () -> None self.write_indent() self.write_indicator(self.prefixed_colon, True, indention=True) self.states.append(self.expect_block_mapping_key) @@ -496,14 +534,17 @@ class Emitter(object): # Checkers. def check_empty_sequence(self): - return (isinstance(self.event, SequenceStartEvent) and self.events and + # type: () -> bool + return (isinstance(self.event, SequenceStartEvent) and bool(self.events) and isinstance(self.events[0], SequenceEndEvent)) def check_empty_mapping(self): - return (isinstance(self.event, MappingStartEvent) and self.events and + # type: () -> bool + return (isinstance(self.event, MappingStartEvent) and bool(self.events) and isinstance(self.events[0], MappingEndEvent)) def check_empty_document(self): + # type: () -> bool if not isinstance(self.event, DocumentStartEvent) or not self.events: return False event = self.events[0] @@ -511,6 +552,7 @@ class Emitter(object): event.tag is None and event.implicit and event.value == u'') def check_simple_key(self): + # type: () -> bool length = 0 if isinstance(self.event, NodeEvent) and self.event.anchor is not None: if self.prepared_anchor is None: @@ -536,6 +578,7 @@ class Emitter(object): # Anchor, Tag, and Scalar processors. def process_anchor(self, indicator): + # type: (Any) -> None if self.event.anchor is None: self.prepared_anchor = None return @@ -546,6 +589,7 @@ class Emitter(object): self.prepared_anchor = None def process_tag(self): + # type: () -> None tag = self.event.tag if isinstance(self.event, ScalarEvent): if self.style is None: @@ -571,6 +615,7 @@ class Emitter(object): self.prepared_tag = None def choose_scalar_style(self): + # type: () -> Any if self.analysis is None: self.analysis = self.analyze_scalar(self.event.value) if self.event.style == '"' or self.canonical: @@ -596,6 +641,7 @@ class Emitter(object): return '"' def process_scalar(self): + # type: () -> None if self.analysis is None: self.analysis = self.analyze_scalar(self.event.value) if self.style is None: @@ -624,6 +670,7 @@ class Emitter(object): # Analyzers. def prepare_version(self, version): + # type: (Any) -> Any major, minor = version if major != 1: raise EmitterError("unsupported YAML version: %d.%d" % @@ -631,6 +678,7 @@ class Emitter(object): return u'%d.%d' % (major, minor) def prepare_tag_handle(self, handle): + # type: (Any) -> Any if not handle: raise EmitterError("tag handle must not be empty") if handle[0] != u'!' or handle[-1] != u'!': @@ -644,9 +692,10 @@ class Emitter(object): return handle def prepare_tag_prefix(self, prefix): + # type: (Any) -> Any if not prefix: raise EmitterError("tag prefix must not be empty") - chunks = [] + chunks = [] # type: List[Any] start = end = 0 if prefix[0] == u'!': end = 1 @@ -667,6 +716,7 @@ class Emitter(object): return u''.join(chunks) def prepare_tag(self, tag): + # type: (Any) -> Any if not tag: raise EmitterError("tag must not be empty") if tag == u'!': @@ -679,7 +729,7 @@ class Emitter(object): and (prefix == u'!' or len(prefix) < len(tag)): handle = self.tag_prefixes[prefix] suffix = tag[len(prefix):] - chunks = [] + chunks = [] # type: List[Any] start = end = 0 while end < len(suffix): ch = suffix[end] @@ -703,6 +753,7 @@ class Emitter(object): return u'!<%s>' % suffix_text def prepare_anchor(self, anchor): + # type: (Any) -> Any if not anchor: raise EmitterError("anchor must not be empty") for ch in anchor: @@ -712,7 +763,7 @@ class Emitter(object): return anchor def analyze_scalar(self, scalar): - + # type: (Any) -> Any # Empty scalar is a special case. if not scalar: return ScalarAnalysis( @@ -874,19 +925,23 @@ class Emitter(object): # Writers. def flush_stream(self): + # type: () -> None if hasattr(self.stream, 'flush'): self.stream.flush() def write_stream_start(self): + # type: () -> None # Write BOM if needed. - if self.encoding and self.encoding.startswith('utf-16'): + if self.encoding and self.encoding.startswith('utf-16'): # type: ignore self.stream.write(u'\uFEFF'.encode(self.encoding)) def write_stream_end(self): + # type: () -> None self.flush_stream() def write_indicator(self, indicator, need_whitespace, whitespace=False, indention=False): + # type: (Any, Any, bool, bool) -> None if self.whitespace or not need_whitespace: data = indicator else: @@ -895,15 +950,16 @@ class Emitter(object): self.indention = self.indention and indention self.column += len(data) self.open_ended = False - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) def write_indent(self): + # type: () -> None indent = self.indent or 0 if not self.indention or self.column > indent \ or (self.column == indent and not self.whitespace): - if self.no_newline: + if bool(self.no_newline): self.no_newline = False else: self.write_line_break() @@ -911,38 +967,42 @@ class Emitter(object): self.whitespace = True data = u' '*(indent-self.column) self.column = indent - if self.encoding: + if self.encoding: # type: ignore data = data.encode(self.encoding) - self.stream.write(data) + self.stream.write(data) # type: ignore def write_line_break(self, data=None): + # type: (Any) -> None if data is None: data = self.best_line_break self.whitespace = True self.indention = True self.line += 1 self.column = 0 - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) def write_version_directive(self, version_text): + # type: (Any) -> None data = u'%%YAML %s' % version_text - if self.encoding: + if self.encoding: # type: ignore data = data.encode(self.encoding) - self.stream.write(data) + self.stream.write(data) # type: ignore self.write_line_break() def write_tag_directive(self, handle_text, prefix_text): + # type: (Any, Any) -> None data = u'%%TAG %s %s' % (handle_text, prefix_text) - if self.encoding: + if self.encoding: # type: ignore data = data.encode(self.encoding) - self.stream.write(data) + self.stream.write(data) # type: ignore self.write_line_break() # Scalar streams. def write_single_quoted(self, text, split=True): + # type: (Any, Any) -> None self.write_indicator(u'\'', True) spaces = False breaks = False @@ -959,7 +1019,7 @@ class Emitter(object): else: data = text[start:end] self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) start = end @@ -979,14 +1039,14 @@ class Emitter(object): if start < end: data = text[start:end] self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) start = end if ch == u'\'': data = u'\'\'' self.column += 2 - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) start = end + 1 @@ -1015,6 +1075,7 @@ class Emitter(object): } def write_double_quoted(self, text, split=True): + # type: (Any, Any) -> None self.write_indicator(u'"', True) start = end = 0 while end <= len(text): @@ -1028,7 +1089,7 @@ class Emitter(object): if start < end: data = text[start:end] self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) start = end @@ -1042,7 +1103,7 @@ class Emitter(object): else: data = u'\\U%08X' % ord(ch) self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) start = end+1 @@ -1052,7 +1113,7 @@ class Emitter(object): if start < end: start = end self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) self.write_indent() @@ -1061,13 +1122,14 @@ class Emitter(object): if text[start] == u' ': data = u'\\' self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) end += 1 self.write_indicator(u'"', False) def determine_block_hints(self, text): + # type: (Any) -> Any hints = u'' if text: if text[0] in u' \n\x85\u2028\u2029': @@ -1079,6 +1141,7 @@ class Emitter(object): return hints def write_folded(self, text): + # type: (Any) -> None hints = self.determine_block_hints(text) self.write_indicator(u'>'+hints, True) if hints[-1:] == u'+': @@ -1113,7 +1176,7 @@ class Emitter(object): else: data = text[start:end] self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) start = end @@ -1121,7 +1184,7 @@ class Emitter(object): if ch is None or ch in u' \n\x85\u2028\u2029': data = text[start:end] self.column += len(data) - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) if ch is None: @@ -1133,6 +1196,7 @@ class Emitter(object): end += 1 def write_literal(self, text): + # type: (Any) -> None hints = self.determine_block_hints(text) self.write_indicator(u'|'+hints, True) if hints[-1:] == u'+': @@ -1157,7 +1221,7 @@ class Emitter(object): else: if ch is None or ch in u'\n\x85\u2028\u2029': data = text[start:end] - if self.encoding: + if bool(self.encoding): data = data.encode(self.encoding) self.stream.write(data) if ch is None: @@ -1168,6 +1232,7 @@ class Emitter(object): end += 1 def write_plain(self, text, split=True): + # type: (Any, Any) -> None if self.root_context: self.open_ended = True if not text: @@ -1175,9 +1240,9 @@ class Emitter(object): if not self.whitespace: data = u' ' self.column += len(data) - if self.encoding: + if self.encoding: # type: ignore data = data.encode(self.encoding) - self.stream.write(data) + self.stream.write(data) # type: ignore self.whitespace = False self.indention = False spaces = False @@ -1197,9 +1262,9 @@ class Emitter(object): else: data = text[start:end] self.column += len(data) - if self.encoding: + if self.encoding: # type: ignore data = data.encode(self.encoding) - self.stream.write(data) + self.stream.write(data) # type: ignore start = end elif breaks: if ch not in u'\n\x85\u2028\u2029': @@ -1218,9 +1283,9 @@ class Emitter(object): if ch is None or ch in u' \n\x85\u2028\u2029': data = text[start:end] self.column += len(data) - if self.encoding: + if self.encoding: # type: ignore data = data.encode(self.encoding) - self.stream.write(data) + self.stream.write(data) # type: ignore start = end if ch is not None: spaces = (ch == u' ') @@ -1228,6 +1293,7 @@ class Emitter(object): end += 1 def write_comment(self, comment): + # type: (Any) -> None value = comment.value # print('{:02d} {:02d} {!r}'.format(self.column, comment.start_mark.column, value)) if value[-1] == '\n': @@ -1248,7 +1314,7 @@ class Emitter(object): nr_spaces = 1 value = ' ' * nr_spaces + value try: - if self.encoding: + if bool(self.encoding): value = value.encode(self.encoding) except UnicodeDecodeError: pass @@ -1258,6 +1324,7 @@ class Emitter(object): self.write_line_break() def write_pre_comment(self, event): + # type: (Any) -> None comments = event.comment[1] if comments is None: return @@ -1276,6 +1343,7 @@ class Emitter(object): raise def write_post_comment(self, event): + # type: (Any) -> None if self.event.comment[0] is None: return comment = event.comment[0] diff --git a/error.py b/error.py index e140c5f..c35623f 100644 --- a/error.py +++ b/error.py @@ -4,25 +4,28 @@ from __future__ import absolute_import import warnings -from typing import Any, Dict, Optional, List # NOQA +from typing import Any, Dict, Optional, List, Text # NOQA from ruamel.yaml.compat import utf8 -__all__ = ['FileMark', 'StringMark', 'CommentMark', - 'YAMLError', 'MarkedYAMLError', 'ReusedAnchorWarning', - 'UnsafeLoaderWarning'] +__all__ = [ + 'FileMark', 'StringMark', 'CommentMark', 'YAMLError', 'MarkedYAMLError', + 'ReusedAnchorWarning', 'UnsafeLoaderWarning', +] class StreamMark(object): __slots__ = 'name', 'index', 'line', 'column', def __init__(self, name, index, line, column): + # type: (Any, int, int, int) -> None self.name = name self.index = index self.line = line self.column = column def __str__(self): + # type: () -> Any where = " in \"%s\", line %d, column %d" \ % (self.name, self.line+1, self.column+1) return where @@ -36,11 +39,13 @@ class StringMark(StreamMark): __slots__ = 'name', 'index', 'line', 'column', 'buffer', 'pointer', def __init__(self, name, index, line, column, buffer, pointer): + # type: (Any, int, int, int, Any, Any) -> None StreamMark.__init__(self, name, index, line, column) self.buffer = buffer self.pointer = pointer def get_snippet(self, indent=4, max_length=75): + # type: (int, int) -> Any if self.buffer is None: # always False return None head = '' @@ -68,6 +73,7 @@ class StringMark(StreamMark): + ' '*(indent+self.pointer-start+len(head)) + caret def __str__(self): + # type: () -> Any snippet = self.get_snippet() where = " in \"%s\", line %d, column %d" \ % (self.name, self.line+1, self.column+1) @@ -99,8 +105,8 @@ class MarkedYAMLError(YAMLError): self.note = note def __str__(self): - # type: () -> str - lines = [] + # type: () -> Any + lines = [] # type: List[str] if self.context is not None: lines.append(self.context) if self.context_mark is not None \ diff --git a/events.py b/events.py index a92be74..8c5c127 100644 --- a/events.py +++ b/events.py @@ -2,8 +2,11 @@ # Abstract classes. +from typing import Any, Dict, Optional, List # NOQA + def CommentCheck(): + # type: () -> None pass @@ -11,6 +14,7 @@ class Event(object): __slots__ = 'start_mark', 'end_mark', 'comment', def __init__(self, start_mark=None, end_mark=None, comment=CommentCheck): + # type: (Any, Any, Any) -> None self.start_mark = start_mark self.end_mark = end_mark # assert comment is not CommentCheck @@ -19,6 +23,7 @@ class Event(object): self.comment = comment def __repr__(self): + # type: () -> Any attributes = [key for key in ['anchor', 'tag', 'implicit', 'value', 'flow_style', 'style'] if hasattr(self, key)] @@ -33,6 +38,7 @@ class NodeEvent(Event): __slots__ = 'anchor', def __init__(self, anchor, start_mark=None, end_mark=None, comment=None): + # type: (Any, Any, Any, Any) -> None Event.__init__(self, start_mark, end_mark, comment) self.anchor = anchor @@ -42,6 +48,7 @@ class CollectionStartEvent(NodeEvent): def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None, flow_style=None, comment=None): + # type: (Any, Any, Any, Any, Any, Any, Any) -> None NodeEvent.__init__(self, anchor, start_mark, end_mark, comment) self.tag = tag self.implicit = implicit @@ -59,6 +66,7 @@ class StreamStartEvent(Event): def __init__(self, start_mark=None, end_mark=None, encoding=None, comment=None): + # type: (Any, Any, Any, Any) -> None Event.__init__(self, start_mark, end_mark, comment) self.encoding = encoding @@ -72,6 +80,7 @@ class DocumentStartEvent(Event): def __init__(self, start_mark=None, end_mark=None, explicit=None, version=None, tags=None, comment=None): + # type: (Any, Any, Any, Any, Any, Any) -> None Event.__init__(self, start_mark, end_mark, comment) self.explicit = explicit self.version = version @@ -83,6 +92,7 @@ class DocumentEndEvent(Event): def __init__(self, start_mark=None, end_mark=None, explicit=None, comment=None): + # type: (Any, Any, Any, Any) -> None Event.__init__(self, start_mark, end_mark, comment) self.explicit = explicit @@ -96,6 +106,7 @@ class ScalarEvent(NodeEvent): def __init__(self, anchor, tag, implicit, value, start_mark=None, end_mark=None, style=None, comment=None): + # type: (Any, Any, Any, Any, Any, Any, Any, Any) -> None NodeEvent.__init__(self, anchor, start_mark, end_mark, comment) self.tag = tag self.implicit = implicit diff --git a/loader.py b/loader.py index 661683d..a9b237f 100644 --- a/loader.py +++ b/loader.py @@ -2,6 +2,10 @@ from __future__ import absolute_import +from typing import Any, Dict, List # NOQA + +from ruamel.yaml.compat import StreamTextType, VersionType # NOQA + from ruamel.yaml.reader import Reader from ruamel.yaml.scanner import Scanner, RoundTripScanner from ruamel.yaml.parser import Parser, RoundTripParser @@ -15,40 +19,45 @@ __all__ = ['BaseLoader', 'SafeLoader', 'Loader', 'RoundTripLoader'] class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, VersionedResolver): def __init__(self, stream, version=None, preserve_quotes=None): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) - BaseConstructor.__init__(self) - VersionedResolver.__init__(self) + # type: (StreamTextType, VersionType, bool) -> None + Reader.__init__(self, stream, loader=self) + Scanner.__init__(self, loader=self) + Parser.__init__(self, loader=self) + Composer.__init__(self, loader=self) + BaseConstructor.__init__(self, loader=self) + VersionedResolver.__init__(self, version, loader=self) class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, VersionedResolver): def __init__(self, stream, version=None, preserve_quotes=None): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) - SafeConstructor.__init__(self) - VersionedResolver.__init__(self) + # type: (StreamTextType, VersionType, bool) -> None + Reader.__init__(self, stream, loader=self) + Scanner.__init__(self, loader=self) + Parser.__init__(self, loader=self) + Composer.__init__(self, loader=self) + SafeConstructor.__init__(self, loader=self) + VersionedResolver.__init__(self, version, loader=self) class Loader(Reader, Scanner, Parser, Composer, Constructor, VersionedResolver): def __init__(self, stream, version=None, preserve_quotes=None): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) - Constructor.__init__(self) - VersionedResolver.__init__(self) + # type: (StreamTextType, VersionType, bool) -> None + Reader.__init__(self, stream, loader=self) + Scanner.__init__(self, loader=self) + Parser.__init__(self, loader=self) + Composer.__init__(self, loader=self) + Constructor.__init__(self, loader=self) + VersionedResolver.__init__(self, version, loader=self) class RoundTripLoader(Reader, RoundTripScanner, RoundTripParser, Composer, RoundTripConstructor, VersionedResolver): def __init__(self, stream, version=None, preserve_quotes=None): - Reader.__init__(self, stream) - RoundTripScanner.__init__(self) - RoundTripParser.__init__(self) - Composer.__init__(self) - RoundTripConstructor.__init__(self, preserve_quotes=preserve_quotes) - VersionedResolver.__init__(self, version) + # type: (StreamTextType, VersionType, bool) -> None + # self.reader = Reader.__init__(self, stream) + Reader.__init__(self, stream, loader=self) + RoundTripScanner.__init__(self, loader=self) + RoundTripParser.__init__(self, loader=self) + Composer.__init__(self, loader=self) + RoundTripConstructor.__init__(self, preserve_quotes=preserve_quotes, loader=self) + VersionedResolver.__init__(self, version, loader=self) diff --git a/main.py b/main.py index f4768f5..25a8f36 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,12 @@ from __future__ import absolute_import, unicode_literals +import warnings -from typing import List, Set, Dict, Tuple, Optional, Union, BinaryIO, IO, Any # NOQA +from typing import List, Set, Dict, Union, Any # NOQA + +import ruamel.yaml from ruamel.yaml.error import * # NOQA from ruamel.yaml.tokens import * # NOQA @@ -14,11 +17,16 @@ from ruamel.yaml.nodes import * # NOQA from ruamel.yaml.loader import BaseLoader, SafeLoader, Loader, RoundTripLoader # NOQA from ruamel.yaml.dumper import BaseDumper, SafeDumper, Dumper, RoundTripDumper # NOQA from ruamel.yaml.compat import StringIO, BytesIO, with_metaclass, PY3 -from ruamel.yaml.compat import StreamType, StreamTextType # NOQA +from ruamel.yaml.compat import StreamType, StreamTextType, VersionType # NOQA +from ruamel.yaml.resolver import VersionedResolver, Resolver # NOQA +from ruamel.yaml.representer import (BaseRepresenter, SafeRepresenter, Representer, + RoundTripRepresenter) +from ruamel.yaml.constructor import (BaseConstructor, SafeConstructor, Constructor, + RoundTripConstructor) +from ruamel.yaml.loader import Loader as UnsafeLoader -# import io -VersionType = Union[List[int], str, Tuple[int, int]] +# import io def scan(stream, Loader=Loader): @@ -28,10 +36,10 @@ def scan(stream, Loader=Loader): """ loader = Loader(stream) try: - while loader.check_token(): - yield loader.get_token() + while loader.scanner.check_token(): + yield loader.scanner.get_token() finally: - loader.dispose() + loader._parser.dispose() def parse(stream, Loader=Loader): @@ -41,10 +49,10 @@ def parse(stream, Loader=Loader): """ loader = Loader(stream) try: - while loader.check_event(): - yield loader.get_event() + while loader._parser.check_event(): + yield loader._parser.get_event() finally: - loader.dispose() + loader._parser.dispose() def compose(stream, Loader=Loader): @@ -69,9 +77,9 @@ def compose_all(stream, Loader=Loader): loader = Loader(stream) try: while loader.check_node(): - yield loader.get_node() + yield loader._composer.get_node() finally: - loader.dispose() + loader._parser.dispose() def load(stream, Loader=None, version=None, preserve_quotes=None): @@ -81,15 +89,13 @@ def load(stream, Loader=None, version=None, preserve_quotes=None): and produce the corresponding Python object. """ if Loader is None: - from ruamel.yaml.loader import Loader as UnsafeLoader - import warnings warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2) Loader = UnsafeLoader loader = Loader(stream, version, preserve_quotes=preserve_quotes) try: - return loader.get_single_data() + return loader._constructor.get_single_data() finally: - loader.dispose() + loader._parser.dispose() def load_all(stream, Loader=None, version=None, preserve_quotes=None): @@ -99,16 +105,14 @@ def load_all(stream, Loader=None, version=None, preserve_quotes=None): and produce corresponding Python objects. """ if Loader is None: - from ruamel.yaml.loader import Loader as UnsafeLoader - import warnings warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2) Loader = UnsafeLoader loader = Loader(stream, version, preserve_quotes=preserve_quotes) try: - while loader.check_data(): - yield loader.get_data() + while loader._constructor.check_data(): + yield loader._constructor.get_data() finally: - loader.dispose() + loader._parser.dispose() def safe_load(stream, version=None): @@ -169,8 +173,8 @@ def emit(events, stream=None, Dumper=Dumper, for event in events: dumper.emit(event) finally: - dumper.dispose() - if getvalue: + dumper._emitter.dispose() + if getvalue is not None: return getvalue() enc = None if PY3 else 'utf-8' @@ -198,13 +202,13 @@ def serialize_all(nodes, stream=None, Dumper=Dumper, encoding=encoding, version=version, tags=tags, explicit_start=explicit_start, explicit_end=explicit_end) try: - dumper.open() + dumper._serializer.open() for node in nodes: dumper.serialize(node) - dumper.close() + dumper._serializer.close() finally: - dumper.dispose() - if getvalue: + dumper._emitter.dispose() + if getvalue is not None: return getvalue() @@ -248,13 +252,17 @@ def dump_all(documents, stream=None, Dumper=Dumper, top_level_colon_align=top_level_colon_align, prefix_colon=prefix_colon, ) try: - dumper.open() + dumper._serializer.open() for data in documents: - dumper.represent(data) - dumper.close() + try: + dumper._representer.represent(data) + except AttributeError: + # print(dir(dumper._representer)) + raise + dumper._serializer.close() finally: - dumper.dispose() - if getvalue: + dumper._emitter.dispose() + if getvalue is not None: return getvalue() return None @@ -327,72 +335,173 @@ def round_trip_dump(data, stream=None, Dumper=RoundTripDumper, top_level_colon_align=top_level_colon_align, prefix_colon=prefix_colon) -def add_implicit_resolver(tag, regexp, first=None, - Loader=Loader, Dumper=Dumper): - # type: (Any, Any, Any, Any, Any) -> None +# Loader/Dumper are no longer composites, to get to the associated +# Resolver()/Representer(), etc., you need to instantiate the class + +def add_implicit_resolver(tag, regexp, first=None, Loader=None, Dumper=None, + resolver=Resolver): + # type: (Any, Any, Any, Any, Any, Any) -> None """ Add an implicit scalar detector. If an implicit scalar value matches the given regexp, the corresponding tag is assigned to the scalar. first is a sequence of possible initial characters or None. """ - Loader.add_implicit_resolver(tag, regexp, first) - Dumper.add_implicit_resolver(tag, regexp, first) + if Loader is None and Dumper is None: + resolver.add_implicit_resolver(tag, regexp, first) + return + if Loader: + if hasattr(Loader, 'add_implicit_resolver'): + Loader.add_implicit_resolver(tag, regexp, first) + elif issubclass(Loader, (BaseLoader, SafeLoader, ruamel.yaml.loader.Loader, + RoundTripLoader)): + Resolver.add_implicit_resolver(tag, regexp, first) + else: + raise NotImplementedError + if Dumper: + if hasattr(Dumper, 'add_implicit_resolver'): + Dumper.add_implicit_resolver(tag, regexp, first) + elif issubclass(Dumper, (BaseDumper, SafeDumper, ruamel.yaml.dumper.Dumper, + RoundTripDumper)): + Resolver.add_implicit_resolver(tag, regexp, first) + else: + raise NotImplementedError -def add_path_resolver(tag, path, kind=None, Loader=Loader, Dumper=Dumper): - # type: (Any, Any, Any, Any, Any) -> None +# this code currently not tested +def add_path_resolver(tag, path, kind=None, Loader=None, Dumper=None, + resolver=Resolver): + # type: (Any, Any, Any, Any, Any, Any) -> None """ Add a path based resolver for the given tag. A path is a list of keys that forms a path to a node in the representation tree. Keys can be string values, integers, or None. """ - Loader.add_path_resolver(tag, path, kind) - Dumper.add_path_resolver(tag, path, kind) + if Loader is None and Dumper is None: + resolver.add_path_resolver(tag, path, kind) + return + if Loader: + if hasattr(Loader, 'add_path_resolver'): + Loader.add_path_resolver(tag, path, kind) + elif issubclass(Loader, (BaseLoader, SafeLoader, ruamel.yaml.loader.Loader, + RoundTripLoader)): + Resolver.add_path_resolver(tag, path, kind) + else: + raise NotImplementedError + if Dumper: + if hasattr(Dumper, 'add_path_resolver'): + Dumper.add_path_resolver(tag, path, kind) + elif issubclass(Dumper, (BaseDumper, SafeDumper, ruamel.yaml.dumper.Dumper, + RoundTripDumper)): + Resolver.add_path_resolver(tag, path, kind) + else: + raise NotImplementedError -def add_constructor(tag, constructor, Loader=Loader): - # type: (Any, Any, Any) -> None +def add_constructor(tag, object_constructor, Loader=None, constructor=Constructor): + # type: (Any, Any, Any, Any) -> None """ - Add a constructor for the given tag. - Constructor is a function that accepts a Loader instance + Add an object constructor for the given tag. + object_onstructor is a function that accepts a Loader instance and a node object and produces the corresponding Python object. """ - Loader.add_constructor(tag, constructor) + if Loader is None: + constructor.add_constructor(tag, object_constructor) + else: + if hasattr(Loader, 'add_constructor'): + Loader.add_constructor(tag, object_constructor) + return + if issubclass(Loader, BaseLoader): + BaseConstructor.add_constructor(tag, object_constructor) + elif issubclass(Loader, SafeLoader): + SafeConstructor.add_constructor(tag, object_constructor) + elif issubclass(Loader, Loader): + Constructor.add_constructor(tag, object_constructor) + elif issubclass(Loader, RoundTripLoader): + RoundTripConstructor.add_constructor(tag, object_constructor) + else: + raise NotImplementedError -def add_multi_constructor(tag_prefix, multi_constructor, Loader=Loader): - # type: (Any, Any, Any) -> None +def add_multi_constructor(tag_prefix, multi_constructor, Loader=None, + constructor=Constructor): + # type: (Any, Any, Any, Any) -> None """ Add a multi-constructor for the given tag prefix. Multi-constructor is called for a node if its tag starts with tag_prefix. Multi-constructor accepts a Loader instance, a tag suffix, and a node object and produces the corresponding Python object. """ - Loader.add_multi_constructor(tag_prefix, multi_constructor) + if Loader is None: + constructor.add_multi_constructor(tag_prefix, multi_constructor) + else: + if False and hasattr(Loader, 'add_multi_constructor'): + Loader.add_multi_constructor(tag_prefix, constructor) + return + if issubclass(Loader, BaseLoader): + BaseConstructor.add_multi_constructor(tag_prefix, multi_constructor) + elif issubclass(Loader, SafeLoader): + SafeConstructor.add_multi_constructor(tag_prefix, multi_constructor) + elif issubclass(Loader, ruamel.yaml.loader.Loader): + Constructor.add_multi_constructor(tag_prefix, multi_constructor) + elif issubclass(Loader, RoundTripLoader): + RoundTripConstructor.add_multi_constructor(tag_prefix, multi_constructor) + else: + raise NotImplementedError -def add_representer(data_type, representer, Dumper=Dumper): - # type: (Any, Any, Any) -> None +def add_representer(data_type, object_representer, Dumper=None, representer=Representer): + # type: (Any, Any, Any, Any) -> None """ Add a representer for the given type. - Representer is a function accepting a Dumper instance + object_representer is a function accepting a Dumper instance and an instance of the given data type and producing the corresponding representation node. """ - Dumper.add_representer(data_type, representer) + if Dumper is None: + representer.add_representer(data_type, object_representer) + else: + if hasattr(Dumper, 'add_representer'): + Dumper.add_representer(data_type, object_representer) + return + if issubclass(Dumper, BaseDumper): + BaseRepresenter.add_representer(data_type, object_representer) + elif issubclass(Dumper, SafeDumper): + SafeRepresenter.add_representer(data_type, object_representer) + elif issubclass(Dumper, Dumper): + Representer.add_representer(data_type, object_representer) + elif issubclass(Dumper, RoundTripDumper): + RoundTripRepresenter.add_representer(data_type, object_representer) + else: + raise NotImplementedError -def add_multi_representer(data_type, multi_representer, Dumper=Dumper): - # type: (Any, Any, Any) -> None +# this code currently not tested +def add_multi_representer(data_type, multi_representer, Dumper=None, representer=Representer): + # type: (Any, Any, Any, Any) -> None """ Add a representer for the given type. - Multi-representer is a function accepting a Dumper instance + multi_representer is a function accepting a Dumper instance and an instance of the given data type or subtype and producing the corresponding representation node. """ - Dumper.add_multi_representer(data_type, multi_representer) + if Dumper is None: + representer.add_multi_representer(data_type, multi_representer) + else: + if hasattr(Dumper, 'add_multi_representer'): + Dumper.add_multi_representer(data_type, multi_representer) + return + if issubclass(Dumper, BaseDumper): + BaseRepresenter.add_multi_representer(data_type, multi_representer) + elif issubclass(Dumper, SafeDumper): + SafeRepresenter.add_multi_representer(data_type, multi_representer) + elif issubclass(Dumper, Dumper): + Representer.add_multi_representer(data_type, multi_representer) + elif issubclass(Dumper, RoundTripDumper): + RoundTripRepresenter.add_multi_representer(data_type, multi_representer) + else: + raise NotImplementedError class YAMLObjectMetaclass(type): @@ -403,8 +512,8 @@ class YAMLObjectMetaclass(type): # type: (Any, Any, Any) -> None super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds) if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None: - cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml) # type: ignore - cls.yaml_dumper.add_representer(cls, cls.to_yaml) # type: ignore + cls.yaml_constructor.add_constructor(cls.yaml_tag, cls.from_yaml) # type: ignore + cls.yaml_representer.add_representer(cls, cls.to_yaml) # type: ignore class YAMLObject(with_metaclass(YAMLObjectMetaclass)): # type: ignore @@ -414,25 +523,25 @@ class YAMLObject(with_metaclass(YAMLObjectMetaclass)): # type: ignore """ __slots__ = () # no direct instantiation, so allow immutable subclasses - yaml_loader = Loader - yaml_dumper = Dumper + yaml_constructor = Constructor + yaml_representer = Representer yaml_tag = None # type: Any yaml_flow_style = None # type: Any @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, constructor, node): # type: (Any, Any) -> Any """ Convert a representation node to a Python object. """ - return loader.construct_yaml_object(node, cls) + return constructor.construct_yaml_object(node, cls) @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, representer, data): # type: (Any, Any) -> Any """ Convert a Python object to a representation node. """ - return dumper.represent_yaml_object(cls.yaml_tag, data, cls, - flow_style=cls.yaml_flow_style) + return representer.represent_yaml_object(cls.yaml_tag, data, cls, + flow_style=cls.yaml_flow_style) diff --git a/nodes.py b/nodes.py index b518513..b96524c 100644 --- a/nodes.py +++ b/nodes.py @@ -2,11 +2,14 @@ from __future__ import print_function +from typing import Dict, Any, Text # NOQA + class Node(object): __slots__ = 'tag', 'value', 'start_mark', 'end_mark', 'comment', 'anchor', def __init__(self, tag, value, start_mark, end_mark, comment=None): + # type: (Any, Any, Any, Any, Any) -> None self.tag = tag self.value = value self.start_mark = start_mark @@ -15,6 +18,7 @@ class Node(object): self.anchor = None def __repr__(self): + # type: () -> str value = self.value # if isinstance(value, list): # if len(value) == 0: @@ -33,6 +37,7 @@ class Node(object): self.tag, value) def dump(self, indent=0): + # type: (int) -> None if isinstance(self.value, basestring): print('{}{}(tag={!r}, value={!r})'.format( ' ' * indent, self.__class__.__name__, self.tag, self.value)) @@ -69,6 +74,7 @@ class ScalarNode(Node): def __init__(self, tag, value, start_mark=None, end_mark=None, style=None, comment=None): + # type: (Any, Any, Any, Any, Any, Any) -> None Node.__init__(self, tag, value, start_mark, end_mark, comment=comment) self.style = style @@ -78,6 +84,7 @@ class CollectionNode(Node): def __init__(self, tag, value, start_mark=None, end_mark=None, flow_style=None, comment=None, anchor=None): + # type: (Any, Any, Any, Any, Any, Any, Any) -> None Node.__init__(self, tag, value, start_mark, end_mark, comment=comment) self.flow_style = flow_style self.anchor = anchor diff --git a/parser.py b/parser.py index dc5d57f..653eb68 100644 --- a/parser.py +++ b/parser.py @@ -74,7 +74,10 @@ from __future__ import absolute_import # need to have full path, as pkg_resources tries to load parser.py in __init__.py # only to not do anything with the package afterwards # and for Jython too -from ruamel.yaml.error import MarkedYAMLError # type: ignore + +from typing import Any, Dict, Optional, List # NOQA + +from ruamel.yaml.error import MarkedYAMLError from ruamel.yaml.tokens import * # NOQA from ruamel.yaml.events import * # NOQA from ruamel.yaml.scanner import Scanner, RoundTripScanner, ScannerError # NOQA @@ -96,20 +99,31 @@ class Parser(object): u'!!': u'tag:yaml.org,2002:', } - def __init__(self): + def __init__(self, loader): + # type: (Any) -> None + self.loader = loader + if self.loader is not None: + self.loader._parser = self self.current_event = None self.yaml_version = None - self.tag_handles = {} - self.states = [] - self.marks = [] - self.state = self.parse_stream_start + self.tag_handles = {} # type: Dict[Any, Any] + self.states = [] # type: List[Any] + self.marks = [] # type: List[Any] + self.state = self.parse_stream_start # type: Any + + @property + def scanner(self): + # type: () -> Any + return self.loader._scanner def dispose(self): + # type: () -> None # Reset the state attributes (to clear self-references) self.states = [] self.state = None def check_event(self, *choices): + # type: (Any) -> bool # Check the type of the next event. if self.current_event is None: if self.state: @@ -123,6 +137,7 @@ class Parser(object): return False def peek_event(self): + # type: () -> Any # Get the next event. if self.current_event is None: if self.state: @@ -130,6 +145,7 @@ class Parser(object): return self.current_event def get_event(self): + # type: () -> Any # Get the next event and proceed further. if self.current_event is None: if self.state: @@ -144,10 +160,10 @@ class Parser(object): # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* def parse_stream_start(self): - + # type: () -> Any # Parse the stream start. - token = self.get_token() - token.move_comment(self.peek_token()) + token = self.scanner.get_token() + token.move_comment(self.scanner.peek_token()) event = StreamStartEvent(token.start_mark, token.end_mark, encoding=token.encoding) @@ -157,12 +173,12 @@ class Parser(object): return event def parse_implicit_document_start(self): - + # type: () -> Any # Parse an implicit document. - if not self.check_token(DirectiveToken, DocumentStartToken, - StreamEndToken): + if not self.scanner.check_token(DirectiveToken, DocumentStartToken, + StreamEndToken): self.tag_handles = self.DEFAULT_TAGS - token = self.peek_token() + token = self.scanner.peek_token() start_mark = end_mark = token.start_mark event = DocumentStartEvent(start_mark, end_mark, explicit=False) @@ -177,31 +193,30 @@ class Parser(object): return self.parse_document_start() def parse_document_start(self): - + # type: () -> Any # Parse any extra document end indicators. - while self.check_token(DocumentEndToken): - self.get_token() - + while self.scanner.check_token(DocumentEndToken): + self.scanner.get_token() # Parse an explicit document. - if not self.check_token(StreamEndToken): - token = self.peek_token() + if not self.scanner.check_token(StreamEndToken): + token = self.scanner.peek_token() start_mark = token.start_mark version, tags = self.process_directives() - if not self.check_token(DocumentStartToken): + if not self.scanner.check_token(DocumentStartToken): raise ParserError(None, None, "expected '', but found %r" - % self.peek_token().id, - self.peek_token().start_mark) - token = self.get_token() + % self.scanner.peek_token().id, + self.scanner.peek_token().start_mark) + token = self.scanner.get_token() end_mark = token.end_mark event = DocumentStartEvent( start_mark, end_mark, - explicit=True, version=version, tags=tags) + explicit=True, version=version, tags=tags) # type: Any self.states.append(self.parse_document_end) self.state = self.parse_document_content else: # Parse the end of the stream. - token = self.get_token() + token = self.scanner.get_token() event = StreamEndEvent(token.start_mark, token.end_mark, comment=token.comment) assert not self.states @@ -210,13 +225,13 @@ class Parser(object): return event def parse_document_end(self): - + # type: () -> Any # Parse the document end. - token = self.peek_token() + token = self.scanner.peek_token() start_mark = end_mark = token.start_mark explicit = False - if self.check_token(DocumentEndToken): - token = self.get_token() + if self.scanner.check_token(DocumentEndToken): + token = self.scanner.get_token() end_mark = token.end_mark explicit = True event = DocumentEndEvent(start_mark, end_mark, explicit=explicit) @@ -227,20 +242,22 @@ class Parser(object): return event def parse_document_content(self): - if self.check_token( + # type: () -> Any + if self.scanner.check_token( DirectiveToken, DocumentStartToken, DocumentEndToken, StreamEndToken): - event = self.process_empty_scalar(self.peek_token().start_mark) + event = self.process_empty_scalar(self.scanner.peek_token().start_mark) self.state = self.states.pop() return event else: return self.parse_block_node() def process_directives(self): + # type: () -> Any self.yaml_version = None self.tag_handles = {} - while self.check_token(DirectiveToken): - token = self.get_token() + while self.scanner.check_token(DirectiveToken): + token = self.scanner.get_token() if token.name == u'YAML': if self.yaml_version is not None: raise ParserError( @@ -261,8 +278,8 @@ class Parser(object): "duplicate tag handle %r" % utf8(handle), token.start_mark) self.tag_handles[handle] = prefix - if self.tag_handles: - value = self.yaml_version, self.tag_handles.copy() + if bool(self.tag_handles): + value = self.yaml_version, self.tag_handles.copy() # type: Any else: value = self.yaml_version, None for key in self.DEFAULT_TAGS: @@ -287,43 +304,48 @@ class Parser(object): # flow_collection ::= flow_sequence | flow_mapping def parse_block_node(self): + # type: () -> Any return self.parse_node(block=True) def parse_flow_node(self): + # type: () -> Any return self.parse_node() def parse_block_node_or_indentless_sequence(self): + # type: () -> Any return self.parse_node(block=True, indentless_sequence=True) def transform_tag(self, handle, suffix): + # type: (Any, Any) -> Any return self.tag_handles[handle] + suffix def parse_node(self, block=False, indentless_sequence=False): - if self.check_token(AliasToken): - token = self.get_token() - event = AliasEvent(token.value, token.start_mark, token.end_mark) + # type: (bool, bool) -> Any + if self.scanner.check_token(AliasToken): + token = self.scanner.get_token() + event = AliasEvent(token.value, token.start_mark, token.end_mark) # type: Any self.state = self.states.pop() else: anchor = None tag = None start_mark = end_mark = tag_mark = None - if self.check_token(AnchorToken): - token = self.get_token() + if self.scanner.check_token(AnchorToken): + token = self.scanner.get_token() start_mark = token.start_mark end_mark = token.end_mark anchor = token.value - if self.check_token(TagToken): - token = self.get_token() + if self.scanner.check_token(TagToken): + token = self.scanner.get_token() tag_mark = token.start_mark end_mark = token.end_mark tag = token.value - elif self.check_token(TagToken): - token = self.get_token() + elif self.scanner.check_token(TagToken): + token = self.scanner.get_token() start_mark = tag_mark = token.start_mark end_mark = token.end_mark tag = token.value - if self.check_token(AnchorToken): - token = self.get_token() + if self.scanner.check_token(AnchorToken): + token = self.scanner.get_token() end_mark = token.end_mark anchor = token.value if tag is not None: @@ -343,17 +365,17 @@ class Parser(object): # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' # and share your opinion.") if start_mark is None: - start_mark = end_mark = self.peek_token().start_mark + start_mark = end_mark = self.scanner.peek_token().start_mark event = None implicit = (tag is None or tag == u'!') - if indentless_sequence and self.check_token(BlockEntryToken): - end_mark = self.peek_token().end_mark + if indentless_sequence and self.scanner.check_token(BlockEntryToken): + end_mark = self.scanner.peek_token().end_mark event = SequenceStartEvent(anchor, tag, implicit, start_mark, end_mark) self.state = self.parse_indentless_sequence_entry else: - if self.check_token(ScalarToken): - token = self.get_token() + if self.scanner.check_token(ScalarToken): + token = self.scanner.get_token() end_mark = token.end_mark if (token.plain and tag is None) or tag == u'!': implicit = (True, False) @@ -367,23 +389,23 @@ class Parser(object): comment=token.comment ) self.state = self.states.pop() - elif self.check_token(FlowSequenceStartToken): - end_mark = self.peek_token().end_mark + elif self.scanner.check_token(FlowSequenceStartToken): + end_mark = self.scanner.peek_token().end_mark event = SequenceStartEvent( anchor, tag, implicit, start_mark, end_mark, flow_style=True) self.state = self.parse_flow_sequence_first_entry - elif self.check_token(FlowMappingStartToken): - end_mark = self.peek_token().end_mark + elif self.scanner.check_token(FlowMappingStartToken): + end_mark = self.scanner.peek_token().end_mark event = MappingStartEvent( anchor, tag, implicit, start_mark, end_mark, flow_style=True) self.state = self.parse_flow_mapping_first_key - elif block and self.check_token(BlockSequenceStartToken): - end_mark = self.peek_token().start_mark + elif block and self.scanner.check_token(BlockSequenceStartToken): + end_mark = self.scanner.peek_token().start_mark # should inserting the comment be dependent on the # indentation? - pt = self.peek_token() + pt = self.scanner.peek_token() comment = pt.comment # print('pt0', type(pt)) if comment is None or comment[1] is None: @@ -395,9 +417,9 @@ class Parser(object): comment=comment, ) self.state = self.parse_block_sequence_first_entry - elif block and self.check_token(BlockMappingStartToken): - end_mark = self.peek_token().start_mark - comment = self.peek_token().comment + elif block and self.scanner.check_token(BlockMappingStartToken): + end_mark = self.scanner.peek_token().start_mark + comment = self.scanner.peek_token().comment event = MappingStartEvent( anchor, tag, implicit, start_mark, end_mark, flow_style=False, comment=comment) @@ -413,7 +435,7 @@ class Parser(object): node = 'block' else: node = 'flow' - token = self.peek_token() + token = self.scanner.peek_token() raise ParserError( "while parsing a %s node" % node, start_mark, "expected the node content, but found %r" % token.id, @@ -424,29 +446,31 @@ class Parser(object): # BLOCK-END def parse_block_sequence_first_entry(self): - token = self.get_token() + # type: () -> Any + token = self.scanner.get_token() # move any comment from start token - # token.move_comment(self.peek_token()) + # token.move_comment(self.scanner.peek_token()) self.marks.append(token.start_mark) return self.parse_block_sequence_entry() def parse_block_sequence_entry(self): - if self.check_token(BlockEntryToken): - token = self.get_token() - token.move_comment(self.peek_token()) - if not self.check_token(BlockEntryToken, BlockEndToken): + # type: () -> Any + if self.scanner.check_token(BlockEntryToken): + token = self.scanner.get_token() + token.move_comment(self.scanner.peek_token()) + if not self.scanner.check_token(BlockEntryToken, BlockEndToken): self.states.append(self.parse_block_sequence_entry) return self.parse_block_node() else: self.state = self.parse_block_sequence_entry return self.process_empty_scalar(token.end_mark) - if not self.check_token(BlockEndToken): - token = self.peek_token() + if not self.scanner.check_token(BlockEndToken): + token = self.scanner.peek_token() raise ParserError( "while parsing a block collection", self.marks[-1], "expected , but found %r" % token.id, token.start_mark) - token = self.get_token() # BlockEndToken + token = self.scanner.get_token() # BlockEndToken event = SequenceEndEvent(token.start_mark, token.end_mark, comment=token.comment) self.state = self.states.pop() @@ -461,17 +485,18 @@ class Parser(object): # - nested def parse_indentless_sequence_entry(self): - if self.check_token(BlockEntryToken): - token = self.get_token() - token.move_comment(self.peek_token()) - if not self.check_token(BlockEntryToken, - KeyToken, ValueToken, BlockEndToken): + # type: () -> Any + if self.scanner.check_token(BlockEntryToken): + token = self.scanner.get_token() + token.move_comment(self.scanner.peek_token()) + if not self.scanner.check_token(BlockEntryToken, + KeyToken, ValueToken, BlockEndToken): self.states.append(self.parse_indentless_sequence_entry) return self.parse_block_node() else: self.state = self.parse_indentless_sequence_entry return self.process_empty_scalar(token.end_mark) - token = self.peek_token() + token = self.scanner.peek_token() event = SequenceEndEvent(token.start_mark, token.start_mark, comment=token.comment) self.state = self.states.pop() @@ -483,28 +508,30 @@ class Parser(object): # BLOCK-END def parse_block_mapping_first_key(self): - token = self.get_token() + # type: () -> Any + token = self.scanner.get_token() self.marks.append(token.start_mark) return self.parse_block_mapping_key() def parse_block_mapping_key(self): - if self.check_token(KeyToken): - token = self.get_token() - token.move_comment(self.peek_token()) - if not self.check_token(KeyToken, ValueToken, BlockEndToken): + # type: () -> Any + if self.scanner.check_token(KeyToken): + token = self.scanner.get_token() + token.move_comment(self.scanner.peek_token()) + if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken): self.states.append(self.parse_block_mapping_value) return self.parse_block_node_or_indentless_sequence() else: self.state = self.parse_block_mapping_value return self.process_empty_scalar(token.end_mark) - if not self.check_token(BlockEndToken): - token = self.peek_token() + if not self.scanner.check_token(BlockEndToken): + token = self.scanner.peek_token() raise ParserError( "while parsing a block mapping", self.marks[-1], "expected , but found %r" % token.id, token.start_mark) - token = self.get_token() - token.move_comment(self.peek_token()) + token = self.scanner.get_token() + token.move_comment(self.scanner.peek_token()) event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment) self.state = self.states.pop() @@ -512,23 +539,24 @@ class Parser(object): return event def parse_block_mapping_value(self): - if self.check_token(ValueToken): - token = self.get_token() + # type: () -> Any + if self.scanner.check_token(ValueToken): + token = self.scanner.get_token() # value token might have post comment move it to e.g. block - if self.check_token(ValueToken): - token.move_comment(self.peek_token()) + if self.scanner.check_token(ValueToken): + token.move_comment(self.scanner.peek_token()) else: - token.move_comment(self.peek_token(), empty=True) - if not self.check_token(KeyToken, ValueToken, BlockEndToken): + token.move_comment(self.scanner.peek_token(), empty=True) + if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken): self.states.append(self.parse_block_mapping_key) return self.parse_block_node_or_indentless_sequence() else: self.state = self.parse_block_mapping_key return self.process_empty_scalar(token.end_mark, - comment=self.peek_token().comment) + comment=self.scanner.peek_token().comment) else: self.state = self.parse_block_mapping_key - token = self.peek_token() + token = self.scanner.peek_token() return self.process_empty_scalar(token.start_mark) # flow_sequence ::= FLOW-SEQUENCE-START @@ -543,33 +571,35 @@ class Parser(object): # generate an inline mapping (set syntax). def parse_flow_sequence_first_entry(self): - token = self.get_token() + # type: () -> Any + token = self.scanner.get_token() self.marks.append(token.start_mark) return self.parse_flow_sequence_entry(first=True) def parse_flow_sequence_entry(self, first=False): - if not self.check_token(FlowSequenceEndToken): + # type: (bool) -> Any + if not self.scanner.check_token(FlowSequenceEndToken): if not first: - if self.check_token(FlowEntryToken): - self.get_token() + if self.scanner.check_token(FlowEntryToken): + self.scanner.get_token() else: - token = self.peek_token() + token = self.scanner.peek_token() raise ParserError( "while parsing a flow sequence", self.marks[-1], "expected ',' or ']', but got %r" % token.id, token.start_mark) - if self.check_token(KeyToken): - token = self.peek_token() + if self.scanner.check_token(KeyToken): + token = self.scanner.peek_token() event = MappingStartEvent(None, None, True, token.start_mark, token.end_mark, - flow_style=True) + flow_style=True) # type: Any self.state = self.parse_flow_sequence_entry_mapping_key return event - elif not self.check_token(FlowSequenceEndToken): + elif not self.scanner.check_token(FlowSequenceEndToken): self.states.append(self.parse_flow_sequence_entry) return self.parse_flow_node() - token = self.get_token() + token = self.scanner.get_token() event = SequenceEndEvent(token.start_mark, token.end_mark, comment=token.comment) self.state = self.states.pop() @@ -577,9 +607,10 @@ class Parser(object): return event def parse_flow_sequence_entry_mapping_key(self): - token = self.get_token() - if not self.check_token(ValueToken, - FlowEntryToken, FlowSequenceEndToken): + # type: () -> Any + token = self.scanner.get_token() + if not self.scanner.check_token(ValueToken, + FlowEntryToken, FlowSequenceEndToken): self.states.append(self.parse_flow_sequence_entry_mapping_value) return self.parse_flow_node() else: @@ -587,9 +618,10 @@ class Parser(object): return self.process_empty_scalar(token.end_mark) def parse_flow_sequence_entry_mapping_value(self): - if self.check_token(ValueToken): - token = self.get_token() - if not self.check_token(FlowEntryToken, FlowSequenceEndToken): + # type: () -> Any + if self.scanner.check_token(ValueToken): + token = self.scanner.get_token() + if not self.scanner.check_token(FlowEntryToken, FlowSequenceEndToken): self.states.append(self.parse_flow_sequence_entry_mapping_end) return self.parse_flow_node() else: @@ -597,12 +629,13 @@ class Parser(object): return self.process_empty_scalar(token.end_mark) else: self.state = self.parse_flow_sequence_entry_mapping_end - token = self.peek_token() + token = self.scanner.peek_token() return self.process_empty_scalar(token.start_mark) def parse_flow_sequence_entry_mapping_end(self): + # type: () -> Any self.state = self.parse_flow_sequence_entry - token = self.peek_token() + token = self.scanner.peek_token() return MappingEndEvent(token.start_mark, token.start_mark) # flow_mapping ::= FLOW-MAPPING-START @@ -612,34 +645,36 @@ class Parser(object): # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? def parse_flow_mapping_first_key(self): - token = self.get_token() + # type: () -> Any + token = self.scanner.get_token() self.marks.append(token.start_mark) return self.parse_flow_mapping_key(first=True) def parse_flow_mapping_key(self, first=False): - if not self.check_token(FlowMappingEndToken): + # type: (Any) -> Any + if not self.scanner.check_token(FlowMappingEndToken): if not first: - if self.check_token(FlowEntryToken): - self.get_token() + if self.scanner.check_token(FlowEntryToken): + self.scanner.get_token() else: - token = self.peek_token() + token = self.scanner.peek_token() raise ParserError( "while parsing a flow mapping", self.marks[-1], "expected ',' or '}', but got %r" % token.id, token.start_mark) - if self.check_token(KeyToken): - token = self.get_token() - if not self.check_token(ValueToken, - FlowEntryToken, FlowMappingEndToken): + if self.scanner.check_token(KeyToken): + token = self.scanner.get_token() + if not self.scanner.check_token(ValueToken, + FlowEntryToken, FlowMappingEndToken): self.states.append(self.parse_flow_mapping_value) return self.parse_flow_node() else: self.state = self.parse_flow_mapping_value return self.process_empty_scalar(token.end_mark) - elif not self.check_token(FlowMappingEndToken): + elif not self.scanner.check_token(FlowMappingEndToken): self.states.append(self.parse_flow_mapping_empty_value) return self.parse_flow_node() - token = self.get_token() + token = self.scanner.get_token() event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment) self.state = self.states.pop() @@ -647,9 +682,10 @@ class Parser(object): return event def parse_flow_mapping_value(self): - if self.check_token(ValueToken): - token = self.get_token() - if not self.check_token(FlowEntryToken, FlowMappingEndToken): + # type: () -> Any + if self.scanner.check_token(ValueToken): + token = self.scanner.get_token() + if not self.scanner.check_token(FlowEntryToken, FlowMappingEndToken): self.states.append(self.parse_flow_mapping_key) return self.parse_flow_node() else: @@ -657,20 +693,23 @@ class Parser(object): return self.process_empty_scalar(token.end_mark) else: self.state = self.parse_flow_mapping_key - token = self.peek_token() + token = self.scanner.peek_token() return self.process_empty_scalar(token.start_mark) def parse_flow_mapping_empty_value(self): + # type: () -> Any self.state = self.parse_flow_mapping_key - return self.process_empty_scalar(self.peek_token().start_mark) + return self.process_empty_scalar(self.scanner.peek_token().start_mark) def process_empty_scalar(self, mark, comment=None): + # type: (Any, Any) -> Any return ScalarEvent(None, None, (True, False), u'', mark, mark, comment=comment) class RoundTripParser(Parser): """roundtrip is a safe loader, that wants to see the unmangled tag""" def transform_tag(self, handle, suffix): + # type: (Any, Any) -> Any # return self.tag_handles[handle]+suffix if handle == '!!' and suffix in (u'null', u'bool', u'int', u'float', u'binary', u'timestamp', u'omap', u'pairs', u'set', u'str', diff --git a/reader.py b/reader.py index a7f0c37..a2db2f8 100644 --- a/reader.py +++ b/reader.py @@ -23,10 +23,11 @@ from __future__ import absolute_import import codecs import re -from typing import Any, Dict, Optional, List # NOQA +from typing import Any, Dict, Optional, List, Union, Text # NOQA from ruamel.yaml.error import YAMLError, FileMark, StringMark from ruamel.yaml.compat import text_type, binary_type, PY3 +from ruamel.yaml.compat import StreamTextType # NOQA __all__ = ['Reader', 'ReaderError'] @@ -42,7 +43,7 @@ class ReaderError(YAMLError): self.reason = reason def __str__(self): - # type () -> str + # type: () -> str if isinstance(self.character, binary_type): return "'%s' codec can't decode byte #x%02x: %s\n" \ " in \"%s\", position %d" \ @@ -69,16 +70,20 @@ class Reader(object): # Yeah, it's ugly and slow. - def __init__(self, stream): + def __init__(self, stream, loader=None): + # type: (StreamTextType, Any) -> None + self.loader = loader + if self.loader is not None: + self.loader._reader = self self.name = None - self.stream = None + self.stream = None # type: Any # as .read is called self.stream_pointer = 0 self.eof = True self.buffer = u'' self.pointer = 0 - self.raw_buffer = None + self.raw_buffer = None # type: Any self.raw_decode = None - self.encoding = None + self.encoding = None # type: Union[None, Text] self.index = 0 self.line = 0 self.column = 0 @@ -98,6 +103,7 @@ class Reader(object): self.determine_encoding() def peek(self, index=0): + # type: (int) -> Text try: return self.buffer[self.pointer+index] except IndexError: @@ -105,14 +111,16 @@ class Reader(object): return self.buffer[self.pointer+index] def prefix(self, length=1): + # type: (int) -> Any if self.pointer+length >= len(self.buffer): self.update(length) return self.buffer[self.pointer:self.pointer+length] def forward(self, length=1): + # type: (int) -> None if self.pointer+length+1 >= len(self.buffer): self.update(length+1) - while length: + while length != 0: ch = self.buffer[self.pointer] self.pointer += 1 self.index += 1 @@ -125,25 +133,27 @@ class Reader(object): length -= 1 def get_mark(self): - if self.stream is None: + # type: () -> Any + if self.stream is None and self.stream is None: return StringMark(self.name, self.index, self.line, self.column, self.buffer, self.pointer) else: return FileMark(self.name, self.index, self.line, self.column) def determine_encoding(self): + # type: () -> None while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2): self.update_raw() if isinstance(self.raw_buffer, binary_type): if self.raw_buffer.startswith(codecs.BOM_UTF16_LE): - self.raw_decode = codecs.utf_16_le_decode + self.raw_decode = codecs.utf_16_le_decode # type: ignore self.encoding = 'utf-16-le' elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE): - self.raw_decode = codecs.utf_16_be_decode + self.raw_decode = codecs.utf_16_be_decode # type: ignore self.encoding = 'utf-16-be' else: - self.raw_decode = codecs.utf_8_decode + self.raw_decode = codecs.utf_8_decode # type: ignore self.encoding = 'utf-8' self.update(1) @@ -167,14 +177,16 @@ class Reader(object): UNICODE_SIZE = 2 def check_printable(self, data): + # type: (Any) -> None match = self.NON_PRINTABLE.search(data) - if match: + if bool(match): character = match.group() position = self.index+(len(self.buffer)-self.pointer)+match.start() raise ReaderError(self.name, position, ord(character), 'unicode', "special characters are not allowed") def update(self, length): + # type: (int) -> None if self.raw_buffer is None: return self.buffer = self.buffer[self.pointer:] @@ -194,6 +206,9 @@ class Reader(object): if self.stream is not None: position = self.stream_pointer - \ len(self.raw_buffer) + exc.start + elif self.stream is not None: + position = self.stream_pointer - \ + len(self.raw_buffer) + exc.start else: position = exc.start raise ReaderError(self.name, position, character, @@ -210,6 +225,7 @@ class Reader(object): break def update_raw(self, size=None): + # type: (int) -> None if size is None: size = 4096 if PY3 else 1024 data = self.stream.read(size) diff --git a/representer.py b/representer.py index 00ac0c4..a64d825 100644 --- a/representer.py +++ b/representer.py @@ -3,7 +3,7 @@ from __future__ import absolute_import from __future__ import print_function -from typing import Dict, Any # NOQA +from typing import Dict, List, Any, Union # NOQA from ruamel.yaml.error import * # NOQA from ruamel.yaml.nodes import * # NOQA @@ -15,7 +15,7 @@ import datetime import sys import types if PY3: - import copyreg # type: ignore + import copyreg import base64 else: import copy_reg as copyreg # type: ignore @@ -28,34 +28,46 @@ __all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer', class RepresenterError(YAMLError): pass +if PY2: + def get_classobj_bases(cls): + # type: (Any) -> Any + bases = [cls] + for base in cls.__bases__: + bases.extend(get_classobj_bases(base)) + return bases + class BaseRepresenter(object): yaml_representers = {} # type: Dict[Any, Any] yaml_multi_representers = {} # type: Dict[Any, Any] - def __init__(self, default_style=None, default_flow_style=None): + def __init__(self, default_style=None, default_flow_style=None, dumper=None): + # type: (Any, Any, Any, Any) -> None + self.dumper = dumper + if self.dumper is not None: + self.dumper._representer = self self.default_style = default_style self.default_flow_style = default_flow_style - self.represented_objects = {} - self.object_keeper = [] - self.alias_key = None + self.represented_objects = {} # type: Dict[Any, Any] + self.object_keeper = [] # type: List[Any] + self.alias_key = None # type: Union[None, int] + + @property + def serializer(self): + # type: () -> Any + return self.dumper._serializer def represent(self, data): + # type: (Any) -> None node = self.represent_data(data) - self.serialize(node) - self.represented_objects = {} - self.object_keeper = [] + self.serializer.serialize(node) + self.represented_objects = {} # type: Dict[Any, Any] + self.object_keeper = [] # type: List[Any] self.alias_key = None - if PY2: - def get_classobj_bases(self, cls): - bases = [cls] - for base in cls.__bases__: - bases.extend(self.get_classobj_bases(base)) - return bases - def represent_data(self, data): + # type: (Any) -> Any if self.ignore_aliases(data): self.alias_key = None else: @@ -73,7 +85,7 @@ class BaseRepresenter(object): if PY2: # if type(data) is types.InstanceType: if isinstance(data, types.InstanceType): - data_types = self.get_classobj_bases(data.__class__) + \ + data_types = get_classobj_bases(data.__class__) + \ list(data_types) if data_types[0] in self.yaml_representers: node = self.yaml_representers[data_types[0]](self, data) @@ -94,6 +106,7 @@ class BaseRepresenter(object): return node def represent_key(self, data): + # type: (Any) -> Any """ David Fraser: Extract a method to represent keys in mappings, so that a subclass can choose not to quote them (for example) @@ -117,6 +130,7 @@ class BaseRepresenter(object): cls.yaml_multi_representers[data_type] = representer def represent_scalar(self, tag, value, style=None): + # type: (Any, Any, Any) -> Any if style is None: style = self.default_style node = ScalarNode(tag, value, style=style) @@ -125,7 +139,8 @@ class BaseRepresenter(object): return node def represent_sequence(self, tag, sequence, flow_style=None): - value = [] + # type: (Any, Any, Any) -> Any + value = [] # type: List[Any] node = SequenceNode(tag, value, flow_style=flow_style) if self.alias_key is not None: self.represented_objects[self.alias_key] = node @@ -143,7 +158,8 @@ class BaseRepresenter(object): return node def represent_omap(self, tag, omap, flow_style=None): - value = [] + # type: (Any, Any, Any) -> Any + value = [] # type: List[Any] node = SequenceNode(tag, value, flow_style=flow_style) if self.alias_key is not None: self.represented_objects[self.alias_key] = node @@ -163,7 +179,8 @@ class BaseRepresenter(object): return node def represent_mapping(self, tag, mapping, flow_style=None): - value = [] + # type: (Any, Any, Any) -> Any + value = [] # type: List[Any] node = MappingNode(tag, value, flow_style=flow_style) if self.alias_key is not None: self.represented_objects[self.alias_key] = node @@ -191,12 +208,14 @@ class BaseRepresenter(object): return node def ignore_aliases(self, data): + # type: (Any) -> bool return False class SafeRepresenter(BaseRepresenter): def ignore_aliases(self, data): + # type: (Any) -> bool # https://docs.python.org/3/reference/expressions.html#parenthesized-forms : # "i.e. two occurrences of the empty tuple may or may not yield the same object" # so "data is ()" should not be used @@ -204,16 +223,19 @@ class SafeRepresenter(BaseRepresenter): return True if isinstance(data, (binary_type, text_type, bool, int, float)): return True + return False def represent_none(self, data): - return self.represent_scalar(u'tag:yaml.org,2002:null', - u'null') + # type: (Any) -> Any + return self.represent_scalar(u'tag:yaml.org,2002:null', u'null') if PY3: def represent_str(self, data): + # type: (Any) -> Any return self.represent_scalar(u'tag:yaml.org,2002:str', data) def represent_binary(self, data): + # type: (Any) -> Any if hasattr(base64, 'encodebytes'): data = base64.encodebytes(data).decode('ascii') else: @@ -222,6 +244,7 @@ class SafeRepresenter(BaseRepresenter): style='|') else: def represent_str(self, data): + # type: (Any) -> Any tag = None style = None try: @@ -238,9 +261,11 @@ class SafeRepresenter(BaseRepresenter): return self.represent_scalar(tag, data, style=style) def represent_unicode(self, data): + # type: (Any) -> Any return self.represent_scalar(u'tag:yaml.org,2002:str', data) def represent_bool(self, data): + # type: (Any) -> Any if data: value = u'true' else: @@ -248,10 +273,12 @@ class SafeRepresenter(BaseRepresenter): return self.represent_scalar(u'tag:yaml.org,2002:bool', value) def represent_int(self, data): + # type: (Any) -> Any return self.represent_scalar(u'tag:yaml.org,2002:int', text_type(data)) if PY2: def represent_long(self, data): + # type: (Any) -> Any return self.represent_scalar(u'tag:yaml.org,2002:int', text_type(data)) @@ -260,6 +287,7 @@ class SafeRepresenter(BaseRepresenter): inf_value *= inf_value def represent_float(self, data): + # type: (Any) -> Any if data != data or (data == 0.0 and data == 1.0): value = u'.nan' elif data == self.inf_value: @@ -280,6 +308,7 @@ class SafeRepresenter(BaseRepresenter): return self.represent_scalar(u'tag:yaml.org,2002:float', value) def represent_list(self, data): + # type: (Any) -> Any # pairs = (len(data) > 0 and isinstance(data, list)) # if pairs: # for item in data: @@ -295,26 +324,32 @@ class SafeRepresenter(BaseRepresenter): # return SequenceNode(u'tag:yaml.org,2002:pairs', value) def represent_dict(self, data): + # type: (Any) -> Any return self.represent_mapping(u'tag:yaml.org,2002:map', data) def represent_ordereddict(self, data): + # type: (Any) -> Any return self.represent_omap(u'tag:yaml.org,2002:omap', data) def represent_set(self, data): - value = {} + # type: (Any) -> Any + value = {} # type: Dict[Any, None] for key in data: value[key] = None return self.represent_mapping(u'tag:yaml.org,2002:set', value) def represent_date(self, data): + # type: (Any) -> Any value = to_unicode(data.isoformat()) return self.represent_scalar(u'tag:yaml.org,2002:timestamp', value) def represent_datetime(self, data): + # type: (Any) -> Any value = to_unicode(data.isoformat(' ')) return self.represent_scalar(u'tag:yaml.org,2002:timestamp', value) def represent_yaml_object(self, tag, data, cls, flow_style=None): + # type: (Any, Any, Any, Any) -> Any if hasattr(data, '__getstate__'): state = data.__getstate__() else: @@ -322,6 +357,7 @@ class SafeRepresenter(BaseRepresenter): return self.represent_mapping(tag, state, flow_style=flow_style) def represent_undefined(self, data): + # type: (Any) -> None raise RepresenterError("cannot represent an object: %s" % data) SafeRepresenter.add_representer(type(None), @@ -383,6 +419,7 @@ SafeRepresenter.add_representer(None, class Representer(SafeRepresenter): if PY2: def represent_str(self, data): + # type: (Any) -> Any tag = None style = None try: @@ -399,6 +436,7 @@ class Representer(SafeRepresenter): return self.represent_scalar(tag, data, style=style) def represent_unicode(self, data): + # type: (Any) -> Any tag = None try: data.encode('ascii') @@ -408,12 +446,14 @@ class Representer(SafeRepresenter): return self.represent_scalar(tag, data) def represent_long(self, data): + # type: (Any) -> Any tag = u'tag:yaml.org,2002:int' if int(data) is not data: tag = u'tag:yaml.org,2002:python/long' return self.represent_scalar(tag, to_unicode(data)) def represent_complex(self, data): + # type: (Any) -> Any if data.imag == 0.0: data = u'%r' % data.real elif data.real == 0.0: @@ -425,19 +465,23 @@ class Representer(SafeRepresenter): return self.represent_scalar(u'tag:yaml.org,2002:python/complex', data) def represent_tuple(self, data): + # type: (Any) -> Any return self.represent_sequence(u'tag:yaml.org,2002:python/tuple', data) def represent_name(self, data): + # type: (Any) -> Any name = u'%s.%s' % (data.__module__, data.__name__) return self.represent_scalar(u'tag:yaml.org,2002:python/name:' + name, u'') def represent_module(self, data): + # type: (Any) -> Any return self.represent_scalar( u'tag:yaml.org,2002:python/module:'+data.__name__, u'') if PY2: def represent_instance(self, data): + # type: (Any) -> Any # For instances of classic classes, we use __getinitargs__ and # __getstate__ to serialize the data. @@ -473,13 +517,14 @@ class Representer(SafeRepresenter): u'tag:yaml.org,2002:python/object/new:' + class_name, args) value = {} - if args: + if bool(args): value['args'] = args value['state'] = state return self.represent_mapping( u'tag:yaml.org,2002:python/object/new:'+class_name, value) def represent_object(self, data): + # type: (Any) -> Any # We use __reduce__ API to save the data. data.__reduce__ returns # a tuple of length 2-5: # (function, args, state, listitems, dictitems) @@ -589,17 +634,21 @@ class RoundTripRepresenter(SafeRepresenter): # need to add type here and write out the .comment # in serializer and emitter - def __init__(self, default_style=None, default_flow_style=None): + def __init__(self, default_style=None, default_flow_style=None, dumper=None): + # type: (Any, Any, Any) -> None if default_flow_style is None: default_flow_style = False SafeRepresenter.__init__(self, default_style=default_style, - default_flow_style=default_flow_style) + default_flow_style=default_flow_style, + dumper=dumper) def represent_none(self, data): + # type: (Any) -> Any return self.represent_scalar(u'tag:yaml.org,2002:null', u'') def represent_preserved_scalarstring(self, data): + # type: (Any) -> Any tag = None style = '|' if PY2 and not isinstance(data, unicode): @@ -608,6 +657,7 @@ class RoundTripRepresenter(SafeRepresenter): return self.represent_scalar(tag, data, style=style) def represent_single_quoted_scalarstring(self, data): + # type: (Any) -> Any tag = None style = "'" if PY2 and not isinstance(data, unicode): @@ -616,6 +666,7 @@ class RoundTripRepresenter(SafeRepresenter): return self.represent_scalar(tag, data, style=style) def represent_double_quoted_scalarstring(self, data): + # type: (Any) -> Any tag = None style = '"' if PY2 and not isinstance(data, unicode): @@ -624,7 +675,8 @@ class RoundTripRepresenter(SafeRepresenter): return self.represent_scalar(tag, data, style=style) def represent_sequence(self, tag, sequence, flow_style=None): - value = [] + # type: (Any, Any, Any) -> Any + value = [] # type: List[Any] # if the flow_style is None, the flow style tacked on to the object # explicitly will be taken. If that is None as well the default flow # style rules @@ -664,6 +716,7 @@ class RoundTripRepresenter(SafeRepresenter): return node def represent_key(self, data): + # type: (Any) -> Any if isinstance(data, CommentedKeySeq): self.alias_key = None return self.represent_sequence(u'tag:yaml.org,2002:seq', data, @@ -671,7 +724,8 @@ class RoundTripRepresenter(SafeRepresenter): return SafeRepresenter.represent_key(self, data) def represent_mapping(self, tag, mapping, flow_style=None): - value = [] + # type: (Any, Any, Any) -> Any + value = [] # type: List[Any] try: flow_style = mapping.fa.flow_style(flow_style) except AttributeError: @@ -703,7 +757,7 @@ class RoundTripRepresenter(SafeRepresenter): except AttributeError: item_comments = {} merge_list = [m[1] for m in getattr(mapping, merge_attrib, [])] - if merge_list: + if bool(merge_list): items = mapping.non_merged_items() else: items = mapping.items() @@ -731,7 +785,7 @@ class RoundTripRepresenter(SafeRepresenter): node.flow_style = self.default_flow_style else: node.flow_style = best_style - if merge_list: + if bool(merge_list): # because of the call to represent_data here, the anchors # are marked as being used and thereby created if len(merge_list) == 1: @@ -744,7 +798,8 @@ class RoundTripRepresenter(SafeRepresenter): return node def represent_omap(self, tag, omap, flow_style=None): - value = [] + # type: (Any, Any, Any) -> Any + value = [] # type: List[Any] try: flow_style = omap.fa.flow_style(flow_style) except AttributeError: @@ -802,10 +857,11 @@ class RoundTripRepresenter(SafeRepresenter): return node def represent_set(self, setting): + # type: (Any) -> Any flow_style = False tag = u'tag:yaml.org,2002:set' # return self.represent_mapping(tag, value) - value = [] + value = [] # type: List[Any] flow_style = setting.fa.flow_style(flow_style) try: anchor = setting.yaml_anchor() @@ -851,6 +907,7 @@ class RoundTripRepresenter(SafeRepresenter): return node def represent_dict(self, data): + # type: (Any) -> Any """write out tag if saved on loading""" try: t = data.tag.value @@ -865,6 +922,7 @@ class RoundTripRepresenter(SafeRepresenter): return self.represent_mapping(tag, data) def represent_datetime(self, data): + # type: (Any) -> Any inter = 'T' if data._yaml['t'] else ' ' _yaml = data._yaml if _yaml['delta']: diff --git a/resolver.py b/resolver.py index 38546ac..eab1b41 100644 --- a/resolver.py +++ b/resolver.py @@ -4,11 +4,11 @@ from __future__ import absolute_import import re -from typing import Any, Dict # NOQA +from typing import Any, Dict, List, Union # NOQA from ruamel.yaml.error import * # NOQA from ruamel.yaml.nodes import * # NOQA -from ruamel.yaml.compat import string_types +from ruamel.yaml.compat import string_types, VersionType # NOQA __all__ = ['BaseResolver', 'Resolver', 'VersionedResolver'] @@ -100,11 +100,21 @@ class BaseResolver(object): yaml_implicit_resolvers = {} # type: Dict[Any, Any] yaml_path_resolvers = {} # type: Dict[Any, Any] - def __init__(self): - # type: () -> None - self._loader_version = None - self.resolver_exact_paths = [] - self.resolver_prefix_paths = [] + def __init__(self, loadumper=None): + # type: (Any, Any) -> None + self.loadumper = loadumper + if self.loadumper is not None: + self.loadumper._resolver = self.loadumper + self._loader_version = None # type: Any + self.resolver_exact_paths = [] # type: List[Any] + self.resolver_prefix_paths = [] # type: List[Any] + + @property + def parser(self): + # type: () -> Any + if self.loadumper is not None: + return self.loadumper._parser + return None @classmethod def add_implicit_resolver_base(cls, tag, regexp, first): @@ -138,6 +148,7 @@ class BaseResolver(object): @classmethod def add_path_resolver(cls, tag, path, kind=None): + # type: (Any, Any, Any) -> None # Note: `add_path_resolver` is experimental. The API could be changed. # `new_path` is a pattern that is matched against the path from the # root to the node that is being considered. `node_path` elements are @@ -152,11 +163,11 @@ class BaseResolver(object): # against a sequence value with the index equal to `index_check`. if 'yaml_path_resolvers' not in cls.__dict__: cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy() - new_path = [] + new_path = [] # type: List[Any] for element in path: if isinstance(element, (list, tuple)): if len(element) == 2: - node_check, index_check = element + node_check, index_check = element # type: ignore elif len(element) == 1: node_check = element[0] index_check = True @@ -191,6 +202,7 @@ class BaseResolver(object): cls.yaml_path_resolvers[tuple(new_path), kind] = tag def descend_resolver(self, current_node, current_index): + # type: (Any, Any) -> None if not self.yaml_path_resolvers: return exact_paths = {} @@ -215,6 +227,7 @@ class BaseResolver(object): self.resolver_prefix_paths.append(prefix_paths) def ascend_resolver(self): + # type: () -> None if not self.yaml_path_resolvers: return self.resolver_exact_paths.pop() @@ -222,29 +235,31 @@ class BaseResolver(object): def check_resolver_prefix(self, depth, path, kind, current_node, current_index): + # type: (int, Text, Any, Any, Any) -> bool node_check, index_check = path[depth-1] if isinstance(node_check, string_types): if current_node.tag != node_check: - return + return False elif node_check is not None: if not isinstance(current_node, node_check): - return + return False if index_check is True and current_index is not None: - return + return False if (index_check is False or index_check is None) \ and current_index is None: - return + return False if isinstance(index_check, string_types): if not (isinstance(current_index, ScalarNode) and index_check == current_index.value): - return + return False elif isinstance(index_check, int) and not isinstance(index_check, bool): if index_check != current_index: - return + return False return True def resolve(self, kind, value, implicit): + # type: (Any, Any, Any) -> Any if kind is ScalarNode and implicit[0]: if value == u'': resolvers = self.yaml_implicit_resolvers.get(u'', []) @@ -255,7 +270,7 @@ class BaseResolver(object): if regexp.match(value): return tag implicit = implicit[1] - if self.yaml_path_resolvers: + if bool(self.yaml_path_resolvers): exact_paths = self.resolver_exact_paths[-1] if kind in exact_paths: return exact_paths[kind] @@ -270,6 +285,7 @@ class BaseResolver(object): @property def processing_version(self): + # type: () -> Any return None @@ -341,16 +357,18 @@ class VersionedResolver(BaseResolver): """ contrary to the "normal" resolver, the smart resolver delays loading the pattern matching rules. That way it can decide to load 1.1 rules - or the (default) 1.2 that no longer support octal without 0o, sexagesimals + or the (default) 1.2 rules, that no longer support octal without 0o, sexagesimals and Yes/No/On/Off booleans. """ - def __init__(self, version=None): - BaseResolver.__init__(self) + def __init__(self, version=None, loader=None): + # type: (VersionType, Any) -> None + BaseResolver.__init__(self, loader) self._loader_version = self.get_loader_version(version) - self._version_implicit_resolver = {} + self._version_implicit_resolver = {} # type: Dict[Any, Any] def add_version_implicit_resolver(self, version, tag, regexp, first): + # type: (VersionType, Any, Any, Any) -> None if first is None: first = [None] impl_resolver = self._version_implicit_resolver.setdefault(version, {}) @@ -358,6 +376,7 @@ class VersionedResolver(BaseResolver): impl_resolver.setdefault(ch, []).append((tag, regexp)) def get_loader_version(self, version): + # type: (Union[VersionType, None]) -> Any if version is None or isinstance(version, tuple): return version if isinstance(version, list): @@ -366,7 +385,8 @@ class VersionedResolver(BaseResolver): return tuple(map(int, version.split(u'.'))) @property - def resolver(self): + def versioned_resolver(self): + # type: () -> Any """ select the resolver based on the version we are parsing """ @@ -378,17 +398,18 @@ class VersionedResolver(BaseResolver): return self._version_implicit_resolver[version] def resolve(self, kind, value, implicit): + # type: (Any, Any, Any) -> Any if kind is ScalarNode and implicit[0]: if value == u'': - resolvers = self.resolver.get(u'', []) + resolvers = self.versioned_resolver.get(u'', []) else: - resolvers = self.resolver.get(value[0], []) - resolvers += self.resolver.get(None, []) + resolvers = self.versioned_resolver.get(value[0], []) + resolvers += self.versioned_resolver.get(None, []) for tag, regexp in resolvers: if regexp.match(value): return tag implicit = implicit[1] - if self.yaml_path_resolvers: + if bool(self.yaml_path_resolvers): exact_paths = self.resolver_exact_paths[-1] if kind in exact_paths: return exact_paths[kind] @@ -403,11 +424,11 @@ class VersionedResolver(BaseResolver): @property def processing_version(self): + # type: () -> Any try: - version = self.yaml_version + version = self.parser.yaml_version except AttributeError: - # dumping - version = self.use_version + version = self.loadumper._serializer.use_version # dumping if version is None: version = self._loader_version if version is None: diff --git a/scalarstring.py b/scalarstring.py index c6e5734..da43bcf 100644 --- a/scalarstring.py +++ b/scalarstring.py @@ -3,6 +3,8 @@ from __future__ import absolute_import from __future__ import print_function +from typing import Text, Any, Dict, List # NOQA + from ruamel.yaml.compat import text_type __all__ = ["ScalarString", "PreservedScalarString", "SingleQuotedScalarString", @@ -13,13 +15,15 @@ class ScalarString(text_type): __slots__ = () def __new__(cls, *args, **kw): - return text_type.__new__(cls, *args, **kw) + # type: (Any, Any) -> Any + return text_type.__new__(cls, *args, **kw) # type: ignore class PreservedScalarString(ScalarString): __slots__ = () def __new__(cls, value): + # type: (Text) -> Any return ScalarString.__new__(cls, value) @@ -27,6 +31,7 @@ class SingleQuotedScalarString(ScalarString): __slots__ = () def __new__(cls, value): + # type: (Text) -> Any return ScalarString.__new__(cls, value) @@ -34,14 +39,17 @@ class DoubleQuotedScalarString(ScalarString): __slots__ = () def __new__(cls, value): + # type: (Text) -> Any return ScalarString.__new__(cls, value) def preserve_literal(s): + # type: (Text) -> Text return PreservedScalarString(s.replace('\r\n', '\n').replace('\r', '\n')) def walk_tree(base): + # type: (Any) -> None """ the routine here walks over a simple yaml tree (recursing in dict values and list items) and converts strings that @@ -51,15 +59,14 @@ def walk_tree(base): if isinstance(base, dict): for k in base: - v = base[k] + v = base[k] # type: Text if isinstance(v, string_types) and '\n' in v: base[k] = preserve_literal(v) else: walk_tree(v) elif isinstance(base, list): for idx, elem in enumerate(base): - if isinstance(elem, string_types) and '\n' in elem: - print(elem) - base[idx] = preserve_literal(elem) + if isinstance(elem, string_types) and '\n' in elem: # type: ignore + base[idx] = preserve_literal(elem) # type: ignore else: walk_tree(elem) diff --git a/scanner.py b/scanner.py index 51f6cb9..68b043c 100644 --- a/scanner.py +++ b/scanner.py @@ -29,6 +29,9 @@ from __future__ import print_function, absolute_import, division, unicode_litera # # Read comments in the Scanner code for more details. # + +from typing import Any, Dict, Optional, List, Union, Text # NOQA + from ruamel.yaml.error import MarkedYAMLError from ruamel.yaml.tokens import * # NOQA from ruamel.yaml.compat import utf8, unichr, PY3, check_anchorname_char @@ -44,6 +47,7 @@ class SimpleKey(object): # See below simple keys treatment. def __init__(self, token_number, required, index, line, column, mark): + # type: (Any, Any, int, int, int, Any) -> None self.token_number = token_number self.required = required self.index = index @@ -54,7 +58,8 @@ class SimpleKey(object): class Scanner(object): - def __init__(self): + def __init__(self, loader=None): + # type: (Any) -> None """Initialize the scanner.""" # It is assumed that Scanner and Reader will have a common descendant. # Reader do the dirty work of checking for BOM and converting the @@ -65,6 +70,10 @@ class Scanner(object): # self.prefix(l=1) # peek the next l characters # self.forward(l=1) # read the next l characters and move the pointer + self.loader = loader + if self.loader is not None: + self.loader._scanner = self + # Had we reached the end of the stream? self.done = False @@ -73,7 +82,7 @@ class Scanner(object): self.flow_level = 0 # List of processed tokens that are not yet emitted. - self.tokens = [] + self.tokens = [] # type: List[Any] # Add the STREAM-START token. self.fetch_stream_start() @@ -85,7 +94,7 @@ class Scanner(object): self.indent = -1 # Past indentation levels. - self.indents = [] + self.indents = [] # type: List[int] # Variables related to simple keys treatment. @@ -115,15 +124,21 @@ class Scanner(object): # (token_number, required, index, line, column, mark) # A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow), # '[', or '{' tokens. - self.possible_simple_keys = {} + self.possible_simple_keys = {} # type: Dict[Any, Any] + + @property + def reader(self): + # type: () -> Any + return self.loader._reader # Public methods. def check_token(self, *choices): + # type: (Any) -> bool # Check if the next token is one of the given types. while self.need_more_tokens(): self.fetch_more_tokens() - if self.tokens: + if bool(self.tokens): if not choices: return True for choice in choices: @@ -132,23 +147,26 @@ class Scanner(object): return False def peek_token(self): + # type: () -> Any # Return the next token, but do not delete if from the queue. while self.need_more_tokens(): self.fetch_more_tokens() - if self.tokens: + if bool(self.tokens): return self.tokens[0] def get_token(self): + # type: () -> Any # Return the next token. while self.need_more_tokens(): self.fetch_more_tokens() - if self.tokens: + if bool(self.tokens): self.tokens_taken += 1 return self.tokens.pop(0) # Private methods. def need_more_tokens(self): + # type: () -> bool if self.done: return False if not self.tokens: @@ -158,24 +176,27 @@ class Scanner(object): self.stale_possible_simple_keys() if self.next_possible_simple_key() == self.tokens_taken: return True + return False - def fetch_more_tokens(self): + def fetch_comment(self, comment): + # type: (Any) -> None + raise NotImplementedError + def fetch_more_tokens(self): + # type: () -> Any # Eat whitespaces and comments until we reach the next token. comment = self.scan_to_next_token() - if comment is not None: # never happens for base scanner return self.fetch_comment(comment) - # Remove obsolete possible simple keys. self.stale_possible_simple_keys() # Compare the current indentation and column. It may add some tokens # and decrease the current indentation level. - self.unwind_indent(self.column) + self.unwind_indent(self.reader.column) # Peek the next character. - ch = self.peek() + ch = self.reader.peek() # Is it the end of stream? if ch == u'\0': @@ -266,11 +287,12 @@ class Scanner(object): # No? It's an error. Let's produce a nice error message. raise ScannerError("while scanning for the next token", None, "found character %r that cannot start any token" - % utf8(ch), self.get_mark()) + % utf8(ch), self.reader.get_mark()) # Simple keys treatment. def next_possible_simple_key(self): + # type: () -> Any # Return the number of the nearest possible simple key. Actually we # don't need to loop through the whole dictionary. We may replace it # with the following code: @@ -286,6 +308,7 @@ class Scanner(object): return min_token_number def stale_possible_simple_keys(self): + # type: () -> None # Remove entries that are no longer possible simple keys. According to # the YAML specification, simple keys # - should be limited to a single line, @@ -294,21 +317,22 @@ class Scanner(object): # height (may cause problems if indentation is broken though). for level in list(self.possible_simple_keys): key = self.possible_simple_keys[level] - if key.line != self.line \ - or self.index-key.index > 1024: + if key.line != self.reader.line \ + or self.reader.index - key.index > 1024: if key.required: raise ScannerError( "while scanning a simple key", key.mark, - "could not find expected ':'", self.get_mark()) + "could not find expected ':'", self.reader.get_mark()) del self.possible_simple_keys[level] def save_possible_simple_key(self): + # type: () -> None # The next token may start a simple key. We check if it's possible # and save its position. This function is called for # ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'. # Check if a simple key is required at the current position. - required = not self.flow_level and self.indent == self.column + required = not self.flow_level and self.indent == self.reader.column # The next token might be a simple key. Let's save it's number and # position. @@ -317,10 +341,12 @@ class Scanner(object): token_number = self.tokens_taken+len(self.tokens) key = SimpleKey( token_number, required, - self.index, self.line, self.column, self.get_mark()) + self.reader.index, self.reader.line, self.reader.column, + self.reader.get_mark()) self.possible_simple_keys[self.flow_level] = key def remove_possible_simple_key(self): + # type: () -> None # Remove the saved possible key position at the current flow level. if self.flow_level in self.possible_simple_keys: key = self.possible_simple_keys[self.flow_level] @@ -328,14 +354,14 @@ class Scanner(object): if key.required: raise ScannerError( "while scanning a simple key", key.mark, - "could not find expected ':'", self.get_mark()) + "could not find expected ':'", self.reader.get_mark()) del self.possible_simple_keys[self.flow_level] # Indentation functions. def unwind_indent(self, column): - + # type: (Any) -> None # In flow context, tokens should respect indentation. # Actually the condition should be `self.indent >= column` according to # the spec. But this condition will prohibit intuitively correct @@ -346,20 +372,21 @@ class Scanner(object): # if self.flow_level and self.indent > column: # raise ScannerError(None, None, # "invalid intendation or unclosed '[' or '{'", - # self.get_mark()) + # self.reader.get_mark()) # In the flow context, indentation is ignored. We make the scanner less # restrictive then specification requires. - if self.flow_level: + if bool(self.flow_level): return # In block context, we may need to issue the BLOCK-END tokens. while self.indent > column: - mark = self.get_mark() + mark = self.reader.get_mark() self.indent = self.indents.pop() self.tokens.append(BlockEndToken(mark, mark)) def add_indent(self, column): + # type: (int) -> bool # Check if we need to increase indentation. if self.indent < column: self.indents.append(self.indent) @@ -370,37 +397,32 @@ class Scanner(object): # Fetchers. def fetch_stream_start(self): + # type: () -> None # We always add STREAM-START as the first token and STREAM-END as the # last token. - # Read the token. - mark = self.get_mark() - + mark = self.reader.get_mark() # Add STREAM-START. self.tokens.append(StreamStartToken(mark, mark, - encoding=self.encoding)) + encoding=self.reader.encoding)) def fetch_stream_end(self): - + # type: () -> None # Set the current intendation to -1. self.unwind_indent(-1) - # Reset simple keys. self.remove_possible_simple_key() self.allow_simple_key = False - self.possible_simple_keys = {} - + self.possible_simple_keys = {} # type: Dict[Any, Any] # Read the token. - mark = self.get_mark() - + mark = self.reader.get_mark() # Add STREAM-END. self.tokens.append(StreamEndToken(mark, mark)) - # The steam is finished. self.done = True def fetch_directive(self): - + # type: () -> None # Set the current intendation to -1. self.unwind_indent(-1) @@ -412,13 +434,15 @@ class Scanner(object): self.tokens.append(self.scan_directive()) def fetch_document_start(self): + # type: () -> None self.fetch_document_indicator(DocumentStartToken) def fetch_document_end(self): + # type: () -> None self.fetch_document_indicator(DocumentEndToken) def fetch_document_indicator(self, TokenClass): - + # type: (Any) -> None # Set the current intendation to -1. self.unwind_indent(-1) @@ -428,106 +452,97 @@ class Scanner(object): self.allow_simple_key = False # Add DOCUMENT-START or DOCUMENT-END. - start_mark = self.get_mark() - self.forward(3) - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward(3) + end_mark = self.reader.get_mark() self.tokens.append(TokenClass(start_mark, end_mark)) def fetch_flow_sequence_start(self): + # type: () -> None self.fetch_flow_collection_start(FlowSequenceStartToken) def fetch_flow_mapping_start(self): + # type: () -> None self.fetch_flow_collection_start(FlowMappingStartToken) def fetch_flow_collection_start(self, TokenClass): - + # type: (Any) -> None # '[' and '{' may start a simple key. self.save_possible_simple_key() - # Increase the flow level. self.flow_level += 1 - # Simple keys are allowed after '[' and '{'. self.allow_simple_key = True - # Add FLOW-SEQUENCE-START or FLOW-MAPPING-START. - start_mark = self.get_mark() - self.forward() - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() self.tokens.append(TokenClass(start_mark, end_mark)) def fetch_flow_sequence_end(self): + # type: () -> None self.fetch_flow_collection_end(FlowSequenceEndToken) def fetch_flow_mapping_end(self): + # type: () -> None self.fetch_flow_collection_end(FlowMappingEndToken) def fetch_flow_collection_end(self, TokenClass): - + # type: (Any) -> None # Reset possible simple key on the current level. self.remove_possible_simple_key() - # Decrease the flow level. self.flow_level -= 1 - # No simple keys after ']' or '}'. self.allow_simple_key = False - # Add FLOW-SEQUENCE-END or FLOW-MAPPING-END. - start_mark = self.get_mark() - self.forward() - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() self.tokens.append(TokenClass(start_mark, end_mark)) def fetch_flow_entry(self): - + # type: () -> None # Simple keys are allowed after ','. self.allow_simple_key = True - # Reset possible simple key on the current level. self.remove_possible_simple_key() - # Add FLOW-ENTRY. - start_mark = self.get_mark() - self.forward() - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() self.tokens.append(FlowEntryToken(start_mark, end_mark)) def fetch_block_entry(self): - + # type: () -> None # Block context needs additional checks. if not self.flow_level: - # Are we allowed to start a new entry? if not self.allow_simple_key: raise ScannerError(None, None, "sequence entries are not allowed here", - self.get_mark()) - + self.reader.get_mark()) # We may need to add BLOCK-SEQUENCE-START. - if self.add_indent(self.column): - mark = self.get_mark() + if self.add_indent(self.reader.column): + mark = self.reader.get_mark() self.tokens.append(BlockSequenceStartToken(mark, mark)) - # It's an error for the block entry to occur in the flow context, # but we let the parser detect this. else: pass - # Simple keys are allowed after '-'. self.allow_simple_key = True - # Reset possible simple key on the current level. self.remove_possible_simple_key() # Add BLOCK-ENTRY. - start_mark = self.get_mark() - self.forward() - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() self.tokens.append(BlockEntryToken(start_mark, end_mark)) def fetch_key(self): - + # type: () -> None # Block context needs additional checks. if not self.flow_level: @@ -535,11 +550,11 @@ class Scanner(object): if not self.allow_simple_key: raise ScannerError(None, None, "mapping keys are not allowed here", - self.get_mark()) + self.reader.get_mark()) # We may need to add BLOCK-MAPPING-START. - if self.add_indent(self.column): - mark = self.get_mark() + if self.add_indent(self.reader.column): + mark = self.reader.get_mark() self.tokens.append(BlockMappingStartToken(mark, mark)) # Simple keys are allowed after '?' in the block context. @@ -549,13 +564,13 @@ class Scanner(object): self.remove_possible_simple_key() # Add KEY. - start_mark = self.get_mark() - self.forward() - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() self.tokens.append(KeyToken(start_mark, end_mark)) def fetch_value(self): - + # type: () -> None # Do we determine a simple key? if self.flow_level in self.possible_simple_keys: # Add KEY. @@ -588,14 +603,14 @@ class Scanner(object): if not self.allow_simple_key: raise ScannerError(None, None, "mapping values are not allowed here", - self.get_mark()) + self.reader.get_mark()) # If this value starts a new block mapping, we need to add # BLOCK-MAPPING-START. It will be detected as an error later by # the parser. if not self.flow_level: - if self.add_indent(self.column): - mark = self.get_mark() + if self.add_indent(self.reader.column): + mark = self.reader.get_mark() self.tokens.append(BlockMappingStartToken(mark, mark)) # Simple keys are allowed after ':' in the block context. @@ -605,142 +620,134 @@ class Scanner(object): self.remove_possible_simple_key() # Add VALUE. - start_mark = self.get_mark() - self.forward() - end_mark = self.get_mark() + start_mark = self.reader.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() self.tokens.append(ValueToken(start_mark, end_mark)) def fetch_alias(self): - + # type: () -> None # ALIAS could be a simple key. self.save_possible_simple_key() - # No simple keys after ALIAS. self.allow_simple_key = False - # Scan and add ALIAS. self.tokens.append(self.scan_anchor(AliasToken)) def fetch_anchor(self): - + # type: () -> None # ANCHOR could start a simple key. self.save_possible_simple_key() - # No simple keys after ANCHOR. self.allow_simple_key = False - # Scan and add ANCHOR. self.tokens.append(self.scan_anchor(AnchorToken)) def fetch_tag(self): - + # type: () -> None # TAG could start a simple key. self.save_possible_simple_key() - # No simple keys after TAG. self.allow_simple_key = False - # Scan and add TAG. self.tokens.append(self.scan_tag()) def fetch_literal(self): + # type: () -> None self.fetch_block_scalar(style='|') def fetch_folded(self): + # type: () -> None self.fetch_block_scalar(style='>') def fetch_block_scalar(self, style): - + # type: (Any) -> None # A simple key may follow a block scalar. self.allow_simple_key = True - # Reset possible simple key on the current level. self.remove_possible_simple_key() - # Scan and add SCALAR. self.tokens.append(self.scan_block_scalar(style)) def fetch_single(self): + # type: () -> None self.fetch_flow_scalar(style='\'') def fetch_double(self): + # type: () -> None self.fetch_flow_scalar(style='"') def fetch_flow_scalar(self, style): - + # type: (Any) -> None # A flow scalar could be a simple key. self.save_possible_simple_key() - # No simple keys after flow scalars. self.allow_simple_key = False - # Scan and add SCALAR. self.tokens.append(self.scan_flow_scalar(style)) def fetch_plain(self): - + # type: () -> None # A plain scalar could be a simple key. self.save_possible_simple_key() - # No simple keys after plain scalars. But note that `scan_plain` will # change this flag if the scan is finished at the beginning of the # line. self.allow_simple_key = False - # Scan and add SCALAR. May change `allow_simple_key`. self.tokens.append(self.scan_plain()) # Checkers. def check_directive(self): - + # type: () -> Any # DIRECTIVE: ^ '%' ... # The '%' indicator is already checked. - if self.column == 0: + if self.reader.column == 0: return True + return None def check_document_start(self): - + # type: () -> Any # DOCUMENT-START: ^ '---' (' '|'\n') - if self.column == 0: - if self.prefix(3) == u'---' \ - and self.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': + if self.reader.column == 0: + if self.reader.prefix(3) == u'---' \ + and self.reader.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': return True + return None def check_document_end(self): - + # type: () -> Any # DOCUMENT-END: ^ '...' (' '|'\n') - if self.column == 0: - if self.prefix(3) == u'...' \ - and self.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': + if self.reader.column == 0: + if self.reader.prefix(3) == u'...' \ + and self.reader.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': return True + return None def check_block_entry(self): - + # type: () -> Any # BLOCK-ENTRY: '-' (' '|'\n') - return self.peek(1) in u'\0 \t\r\n\x85\u2028\u2029' + return self.reader.peek(1) in u'\0 \t\r\n\x85\u2028\u2029' def check_key(self): - + # type: () -> Any # KEY(flow context): '?' - if self.flow_level: + if bool(self.flow_level): return True - # KEY(block context): '?' (' '|'\n') - else: - return self.peek(1) in u'\0 \t\r\n\x85\u2028\u2029' + return self.reader.peek(1) in u'\0 \t\r\n\x85\u2028\u2029' def check_value(self): - + # type: () -> Any # VALUE(flow context): ':' - if self.flow_level: + if bool(self.flow_level): return True - # VALUE(block context): ':' (' '|'\n') - else: - return self.peek(1) in u'\0 \t\r\n\x85\u2028\u2029' + return self.reader.peek(1) in u'\0 \t\r\n\x85\u2028\u2029' def check_plain(self): + # type: () -> Any # A plain scalar may start with any non-space character except: # '-', '?', ':', ',', '[', ']', '{', '}', # '#', '&', '*', '!', '|', '>', '\'', '\"', @@ -753,14 +760,15 @@ class Scanner(object): # Note that we limit the last rule to the block context (except the # '-' character) because we want the flow context to be space # independent. - ch = self.peek() + ch = self.reader.peek() return ch not in u'\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>\'\"%@`' or \ - (self.peek(1) not in u'\0 \t\r\n\x85\u2028\u2029' and + (self.reader.peek(1) not in u'\0 \t\r\n\x85\u2028\u2029' and (ch == u'-' or (not self.flow_level and ch in u'?:'))) # Scanners. def scan_to_next_token(self): + # type: () -> Any # We ignore spaces, line breaks and comments. # If we find a line break in the block context, we set the flag # `allow_simple_key` on. @@ -780,145 +788,155 @@ class Scanner(object): # `unwind_indent` before issuing BLOCK-END. # Scanners for block, flow, and plain scalars need to be modified. - if self.index == 0 and self.peek() == u'\uFEFF': - self.forward() + if self.reader.index == 0 and self.reader.peek() == u'\uFEFF': + self.reader.forward() found = False while not found: - while self.peek() == u' ': - self.forward() - if self.peek() == u'#': - while self.peek() not in u'\0\r\n\x85\u2028\u2029': - self.forward() + while self.reader.peek() == u' ': + self.reader.forward() + if self.reader.peek() == u'#': + while self.reader.peek() not in u'\0\r\n\x85\u2028\u2029': + self.reader.forward() if self.scan_line_break(): if not self.flow_level: self.allow_simple_key = True else: found = True + return None def scan_directive(self): + # type: () -> Any # See the specification for details. - start_mark = self.get_mark() - self.forward() + start_mark = self.reader.get_mark() + self.reader.forward() name = self.scan_directive_name(start_mark) value = None if name == u'YAML': value = self.scan_yaml_directive_value(start_mark) - end_mark = self.get_mark() + end_mark = self.reader.get_mark() elif name == u'TAG': value = self.scan_tag_directive_value(start_mark) - end_mark = self.get_mark() + end_mark = self.reader.get_mark() else: - end_mark = self.get_mark() - while self.peek() not in u'\0\r\n\x85\u2028\u2029': - self.forward() + end_mark = self.reader.get_mark() + while self.reader.peek() not in u'\0\r\n\x85\u2028\u2029': + self.reader.forward() self.scan_directive_ignored_line(start_mark) return DirectiveToken(name, value, start_mark, end_mark) def scan_directive_name(self, start_mark): + # type: (Any) -> Any # See the specification for details. length = 0 - ch = self.peek(length) + ch = self.reader.peek(length) while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' or u'a' <= ch <= u'z' \ or ch in u'-_:.': length += 1 - ch = self.peek(length) + ch = self.reader.peek(length) if not length: raise ScannerError( "while scanning a directive", start_mark, "expected alphabetic or numeric character, but found %r" - % utf8(ch), self.get_mark()) - value = self.prefix(length) - self.forward(length) - ch = self.peek() + % utf8(ch), self.reader.get_mark()) + value = self.reader.prefix(length) + self.reader.forward(length) + ch = self.reader.peek() if ch not in u'\0 \r\n\x85\u2028\u2029': raise ScannerError( "while scanning a directive", start_mark, "expected alphabetic or numeric character, but found %r" - % utf8(ch), self.get_mark()) + % utf8(ch), self.reader.get_mark()) return value def scan_yaml_directive_value(self, start_mark): + # type: (Any) -> Any # See the specification for details. - while self.peek() == u' ': - self.forward() + while self.reader.peek() == u' ': + self.reader.forward() major = self.scan_yaml_directive_number(start_mark) - if self.peek() != '.': + if self.reader.peek() != '.': raise ScannerError( "while scanning a directive", start_mark, "expected a digit or '.', but found %r" - % utf8(self.peek()), - self.get_mark()) - self.forward() + % utf8(self.reader.peek()), + self.reader.get_mark()) + self.reader.forward() minor = self.scan_yaml_directive_number(start_mark) - if self.peek() not in u'\0 \r\n\x85\u2028\u2029': + if self.reader.peek() not in u'\0 \r\n\x85\u2028\u2029': raise ScannerError( "while scanning a directive", start_mark, "expected a digit or ' ', but found %r" - % utf8(self.peek()), - self.get_mark()) + % utf8(self.reader.peek()), + self.reader.get_mark()) return (major, minor) def scan_yaml_directive_number(self, start_mark): + # type: (Any) -> Any # See the specification for details. - ch = self.peek() + ch = self.reader.peek() if not (u'0' <= ch <= u'9'): raise ScannerError( "while scanning a directive", start_mark, "expected a digit, but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) length = 0 - while u'0' <= self.peek(length) <= u'9': + while u'0' <= self.reader.peek(length) <= u'9': length += 1 - value = int(self.prefix(length)) - self.forward(length) + value = int(self.reader.prefix(length)) + self.reader.forward(length) return value def scan_tag_directive_value(self, start_mark): + # type: (Any) -> Any # See the specification for details. - while self.peek() == u' ': - self.forward() + while self.reader.peek() == u' ': + self.reader.forward() handle = self.scan_tag_directive_handle(start_mark) - while self.peek() == u' ': - self.forward() + while self.reader.peek() == u' ': + self.reader.forward() prefix = self.scan_tag_directive_prefix(start_mark) return (handle, prefix) def scan_tag_directive_handle(self, start_mark): + # type: (Any) -> Any # See the specification for details. value = self.scan_tag_handle('directive', start_mark) - ch = self.peek() + ch = self.reader.peek() if ch != u' ': raise ScannerError("while scanning a directive", start_mark, "expected ' ', but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) return value def scan_tag_directive_prefix(self, start_mark): + # type: (Any) -> Any # See the specification for details. value = self.scan_tag_uri('directive', start_mark) - ch = self.peek() + ch = self.reader.peek() if ch not in u'\0 \r\n\x85\u2028\u2029': raise ScannerError("while scanning a directive", start_mark, "expected ' ', but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) return value def scan_directive_ignored_line(self, start_mark): + # type: (Any) -> None # See the specification for details. - while self.peek() == u' ': - self.forward() - if self.peek() == u'#': - while self.peek() not in u'\0\r\n\x85\u2028\u2029': - self.forward() - ch = self.peek() + while self.reader.peek() == u' ': + self.reader.forward() + if self.reader.peek() == u'#': + while self.reader.peek() not in u'\0\r\n\x85\u2028\u2029': + self.reader.forward() + ch = self.reader.peek() if ch not in u'\0\r\n\x85\u2028\u2029': raise ScannerError( "while scanning a directive", start_mark, "expected a comment or a line break, but found %r" - % utf8(ch), self.get_mark()) + % utf8(ch), self.reader.get_mark()) self.scan_line_break() def scan_anchor(self, TokenClass): + # type: (Any) -> Any # The specification does not restrict characters for anchors and # aliases. This may lead to problems, for instance, the document: # [ *alias, value ] @@ -927,56 +945,57 @@ class Scanner(object): # and # [ *alias , "value" ] # Therefore we restrict aliases to numbers and ASCII letters. - start_mark = self.get_mark() - indicator = self.peek() + start_mark = self.reader.get_mark() + indicator = self.reader.peek() if indicator == u'*': name = 'alias' else: name = 'anchor' - self.forward() + self.reader.forward() length = 0 - ch = self.peek(length) + ch = self.reader.peek(length) # while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' or u'a' <= ch <= u'z' \ # or ch in u'-_': while check_anchorname_char(ch): length += 1 - ch = self.peek(length) + ch = self.reader.peek(length) if not length: raise ScannerError( "while scanning an %s" % name, start_mark, "expected alphabetic or numeric character, but found %r" - % utf8(ch), self.get_mark()) - value = self.prefix(length) - self.forward(length) + % utf8(ch), self.reader.get_mark()) + value = self.reader.prefix(length) + self.reader.forward(length) # ch1 = ch - # ch = self.peek() # no need to peek, ch is already set + # ch = self.reader.peek() # no need to peek, ch is already set # assert ch1 == ch if ch not in u'\0 \t\r\n\x85\u2028\u2029?:,[]{}%@`': raise ScannerError( "while scanning an %s" % name, start_mark, "expected alphabetic or numeric character, but found %r" - % utf8(ch), self.get_mark()) - end_mark = self.get_mark() + % utf8(ch), self.reader.get_mark()) + end_mark = self.reader.get_mark() return TokenClass(value, start_mark, end_mark) def scan_tag(self): + # type: () -> Any # See the specification for details. - start_mark = self.get_mark() - ch = self.peek(1) + start_mark = self.reader.get_mark() + ch = self.reader.peek(1) if ch == u'<': handle = None - self.forward(2) + self.reader.forward(2) suffix = self.scan_tag_uri('tag', start_mark) - if self.peek() != u'>': + if self.reader.peek() != u'>': raise ScannerError( "while parsing a tag", start_mark, - "expected '>', but found %r" % utf8(self.peek()), - self.get_mark()) - self.forward() + "expected '>', but found %r" % utf8(self.reader.peek()), + self.reader.get_mark()) + self.reader.forward() elif ch in u'\0 \t\r\n\x85\u2028\u2029': handle = None suffix = u'!' - self.forward() + self.reader.forward() else: length = 1 use_handle = False @@ -985,36 +1004,36 @@ class Scanner(object): use_handle = True break length += 1 - ch = self.peek(length) + ch = self.reader.peek(length) handle = u'!' if use_handle: handle = self.scan_tag_handle('tag', start_mark) else: handle = u'!' - self.forward() + self.reader.forward() suffix = self.scan_tag_uri('tag', start_mark) - ch = self.peek() + ch = self.reader.peek() if ch not in u'\0 \r\n\x85\u2028\u2029': raise ScannerError("while scanning a tag", start_mark, "expected ' ', but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) value = (handle, suffix) - end_mark = self.get_mark() + end_mark = self.reader.get_mark() return TagToken(value, start_mark, end_mark) def scan_block_scalar(self, style): + # type: (Any) -> Any # See the specification for details. - if style == '>': folded = True else: folded = False - chunks = [] - start_mark = self.get_mark() + chunks = [] # type: List[Any] + start_mark = self.reader.get_mark() # Scan the header. - self.forward() + self.reader.forward() chomping, increment = self.scan_block_scalar_indicators(start_mark) self.scan_block_scalar_ignored_line(start_mark) @@ -1031,24 +1050,24 @@ class Scanner(object): line_break = u'' # Scan the inner part of the block scalar. - while self.column == indent and self.peek() != u'\0': + while self.reader.column == indent and self.reader.peek() != u'\0': chunks.extend(breaks) - leading_non_space = self.peek() not in u' \t' + leading_non_space = self.reader.peek() not in u' \t' length = 0 - while self.peek(length) not in u'\0\r\n\x85\u2028\u2029': + while self.reader.peek(length) not in u'\0\r\n\x85\u2028\u2029': length += 1 - chunks.append(self.prefix(length)) - self.forward(length) + chunks.append(self.reader.prefix(length)) + self.reader.forward(length) line_break = self.scan_line_break() breaks, end_mark = self.scan_block_scalar_breaks(indent) - if self.column == indent and self.peek() != u'\0': + if self.reader.column == indent and self.reader.peek() != u'\0': # Unfortunately, folding rules are ambiguous. # # This is the folding according to the specification: if folded and line_break == u'\n' \ - and leading_non_space and self.peek() not in u' \t': + and leading_non_space and self.reader.peek() not in u' \t': if not breaks: chunks.append(u' ') else: @@ -1059,7 +1078,7 @@ class Scanner(object): # # if folded and line_break == u'\n': # if not breaks: - # if self.peek() not in ' \t': + # if self.reader.peek() not in ' \t': # chunks.append(u' ') # else: # chunks.append(line_break) @@ -1070,7 +1089,7 @@ class Scanner(object): # Process trailing line breaks. The 'chomping' setting determines # whether they are included in the value. - trailing = [] + trailing = [] # type: List[Any] if chomping in [None, True]: chunks.append(line_break) if chomping is True: @@ -1090,32 +1109,33 @@ class Scanner(object): # Keep track of the trailing whitespace and following comments # as a comment token, if isn't all included in the actual value. - comment_end_mark = self.get_mark() + comment_end_mark = self.reader.get_mark() comment = CommentToken(''.join(trailing), end_mark, comment_end_mark) token.add_post_comment(comment) return token def scan_block_scalar_indicators(self, start_mark): + # type: (Any) -> Any # See the specification for details. chomping = None increment = None - ch = self.peek() + ch = self.reader.peek() if ch in u'+-': if ch == '+': chomping = True else: chomping = False - self.forward() - ch = self.peek() + self.reader.forward() + ch = self.reader.peek() if ch in u'0123456789': increment = int(ch) if increment == 0: raise ScannerError( "while scanning a block scalar", start_mark, "expected indentation indicator in the range 1-9, " - "but found 0", self.get_mark()) - self.forward() + "but found 0", self.reader.get_mark()) + self.reader.forward() elif ch in u'0123456789': increment = int(ch) if increment == 0: @@ -1123,67 +1143,71 @@ class Scanner(object): "while scanning a block scalar", start_mark, "expected indentation indicator in the range 1-9, " "but found 0", - self.get_mark()) - self.forward() - ch = self.peek() + self.reader.get_mark()) + self.reader.forward() + ch = self.reader.peek() if ch in u'+-': if ch == '+': chomping = True else: chomping = False - self.forward() - ch = self.peek() + self.reader.forward() + ch = self.reader.peek() if ch not in u'\0 \r\n\x85\u2028\u2029': raise ScannerError( "while scanning a block scalar", start_mark, "expected chomping or indentation indicators, but found %r" - % utf8(ch), self.get_mark()) + % utf8(ch), self.reader.get_mark()) return chomping, increment def scan_block_scalar_ignored_line(self, start_mark): + # type: (Any) -> Any # See the specification for details. - while self.peek() == u' ': - self.forward() - if self.peek() == u'#': - while self.peek() not in u'\0\r\n\x85\u2028\u2029': - self.forward() - ch = self.peek() + while self.reader.peek() == u' ': + self.reader.forward() + if self.reader.peek() == u'#': + while self.reader.peek() not in u'\0\r\n\x85\u2028\u2029': + self.reader.forward() + ch = self.reader.peek() if ch not in u'\0\r\n\x85\u2028\u2029': raise ScannerError( "while scanning a block scalar", start_mark, "expected a comment or a line break, but found %r" - % utf8(ch), self.get_mark()) + % utf8(ch), self.reader.get_mark()) self.scan_line_break() def scan_block_scalar_indentation(self): + # type: () -> Any # See the specification for details. chunks = [] max_indent = 0 - end_mark = self.get_mark() - while self.peek() in u' \r\n\x85\u2028\u2029': - if self.peek() != u' ': + end_mark = self.reader.get_mark() + while self.reader.peek() in u' \r\n\x85\u2028\u2029': + if self.reader.peek() != u' ': chunks.append(self.scan_line_break()) - end_mark = self.get_mark() + end_mark = self.reader.get_mark() else: - self.forward() - if self.column > max_indent: - max_indent = self.column + self.reader.forward() + if self.reader.column > max_indent: + max_indent = self.reader.column return chunks, max_indent, end_mark def scan_block_scalar_breaks(self, indent): + # type: (int) -> Any # See the specification for details. chunks = [] - end_mark = self.get_mark() - while self.column < indent and self.peek() == u' ': - self.forward() - while self.peek() in u'\r\n\x85\u2028\u2029': + end_mark = self.reader.get_mark() + while self.reader.column < indent and self.reader.peek() == u' ': + self.reader.forward() + while self.reader.peek() in u'\r\n\x85\u2028\u2029': chunks.append(self.scan_line_break()) - end_mark = self.get_mark() - while self.column < indent and self.peek() == u' ': - self.forward() + end_mark = self.reader.get_mark() + while self.reader.column < indent and self.reader.peek() == u' ': + self.reader.forward() return chunks, end_mark def scan_flow_scalar(self, style): + # type: (Any) -> Any # See the specification for details. # Note that we loose indentation rules for quoted scalars. Quoted # scalars don't need to adhere indentation because " and ' clearly @@ -1194,16 +1218,16 @@ class Scanner(object): double = True else: double = False - chunks = [] - start_mark = self.get_mark() - quote = self.peek() - self.forward() + chunks = [] # type: List[Any] + start_mark = self.reader.get_mark() + quote = self.reader.peek() + self.reader.forward() chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) - while self.peek() != quote: + while self.reader.peek() != quote: chunks.extend(self.scan_flow_scalar_spaces(double, start_mark)) chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) - self.forward() - end_mark = self.get_mark() + self.reader.forward() + end_mark = self.reader.get_mark() return ScalarToken(u''.join(chunks), False, start_mark, end_mark, style) @@ -1235,42 +1259,43 @@ class Scanner(object): } def scan_flow_scalar_non_spaces(self, double, start_mark): + # type: (Any, Any) -> Any # See the specification for details. - chunks = [] + chunks = [] # type: List[Any] while True: length = 0 - while self.peek(length) not in u'\'\"\\\0 \t\r\n\x85\u2028\u2029': + while self.reader.peek(length) not in u'\'\"\\\0 \t\r\n\x85\u2028\u2029': length += 1 - if length: - chunks.append(self.prefix(length)) - self.forward(length) - ch = self.peek() - if not double and ch == u'\'' and self.peek(1) == u'\'': + if length != 0: + chunks.append(self.reader.prefix(length)) + self.reader.forward(length) + ch = self.reader.peek() + if not double and ch == u'\'' and self.reader.peek(1) == u'\'': chunks.append(u'\'') - self.forward(2) + self.reader.forward(2) elif (double and ch == u'\'') or (not double and ch in u'\"\\'): chunks.append(ch) - self.forward() + self.reader.forward() elif double and ch == u'\\': - self.forward() - ch = self.peek() + self.reader.forward() + ch = self.reader.peek() if ch in self.ESCAPE_REPLACEMENTS: chunks.append(self.ESCAPE_REPLACEMENTS[ch]) - self.forward() + self.reader.forward() elif ch in self.ESCAPE_CODES: length = self.ESCAPE_CODES[ch] - self.forward() + self.reader.forward() for k in range(length): - if self.peek(k) not in u'0123456789ABCDEFabcdef': + if self.reader.peek(k) not in u'0123456789ABCDEFabcdef': raise ScannerError( "while scanning a double-quoted scalar", start_mark, "expected escape sequence of %d hexdecimal " "numbers, but found %r" % - (length, utf8(self.peek(k))), self.get_mark()) - code = int(self.prefix(length), 16) + (length, utf8(self.reader.peek(k))), self.reader.get_mark()) + code = int(self.reader.prefix(length), 16) chunks.append(unichr(code)) - self.forward(length) + self.reader.forward(length) elif ch in u'\r\n\x85\u2028\u2029': self.scan_line_break() chunks.extend(self.scan_flow_scalar_breaks( @@ -1279,23 +1304,24 @@ class Scanner(object): raise ScannerError( "while scanning a double-quoted scalar", start_mark, "found unknown escape character %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) else: return chunks def scan_flow_scalar_spaces(self, double, start_mark): + # type: (Any, Any) -> Any # See the specification for details. chunks = [] length = 0 - while self.peek(length) in u' \t': + while self.reader.peek(length) in u' \t': length += 1 - whitespaces = self.prefix(length) - self.forward(length) - ch = self.peek() + whitespaces = self.reader.prefix(length) + self.reader.forward(length) + ch = self.reader.peek() if ch == u'\0': raise ScannerError( "while scanning a quoted scalar", start_mark, - "found unexpected end of stream", self.get_mark()) + "found unexpected end of stream", self.reader.get_mark()) elif ch in u'\r\n\x85\u2028\u2029': line_break = self.scan_line_break() breaks = self.scan_flow_scalar_breaks(double, start_mark) @@ -1309,62 +1335,64 @@ class Scanner(object): return chunks def scan_flow_scalar_breaks(self, double, start_mark): + # type: (Any, Any) -> Any # See the specification for details. - chunks = [] + chunks = [] # type: List[Any] while True: # Instead of checking indentation, we check for document # separators. - prefix = self.prefix(3) + prefix = self.reader.prefix(3) if (prefix == u'---' or prefix == u'...') \ - and self.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': + and self.reader.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': raise ScannerError("while scanning a quoted scalar", start_mark, "found unexpected document separator", - self.get_mark()) - while self.peek() in u' \t': - self.forward() - if self.peek() in u'\r\n\x85\u2028\u2029': + self.reader.get_mark()) + while self.reader.peek() in u' \t': + self.reader.forward() + if self.reader.peek() in u'\r\n\x85\u2028\u2029': chunks.append(self.scan_line_break()) else: return chunks def scan_plain(self): + # type: () -> Any # See the specification for details. # We add an additional restriction for the flow context: # plain scalars in the flow context cannot contain ',', ': ' and '?'. # We also keep track of the `allow_simple_key` flag here. # Indentation rules are loosed for the flow context. - chunks = [] - start_mark = self.get_mark() + chunks = [] # type: List[Any] + start_mark = self.reader.get_mark() end_mark = start_mark indent = self.indent+1 # We allow zero indentation for scalars, but then we need to check for # document separators at the beginning of the line. # if indent == 0: # indent = 1 - spaces = [] + spaces = [] # type: List[Any] while True: length = 0 - if self.peek() == u'#': + if self.reader.peek() == u'#': break while True: - ch = self.peek(length) + ch = self.reader.peek(length) if (ch == u':' and - self.peek(length+1) not in u'\0 \t\r\n\x85\u2028\u2029'): + self.reader.peek(length+1) not in u'\0 \t\r\n\x85\u2028\u2029'): pass elif (ch in u'\0 \t\r\n\x85\u2028\u2029' or (not self.flow_level and ch == u':' and - self.peek(length+1) in u'\0 \t\r\n\x85\u2028\u2029') or + self.reader.peek(length+1) in u'\0 \t\r\n\x85\u2028\u2029') or (self.flow_level and ch in u',:?[]{}')): break length += 1 # It's not clear what we should do with ':' in the flow context. if (self.flow_level and ch == u':' and - self.peek(length+1) not in u'\0 \t\r\n\x85\u2028\u2029,[]{}'): - self.forward(length) + self.reader.peek(length+1) not in u'\0 \t\r\n\x85\u2028\u2029,[]{}'): + self.reader.forward(length) raise ScannerError( "while scanning a plain scalar", start_mark, - "found unexpected ':'", self.get_mark(), + "found unexpected ':'", self.reader.get_mark(), "Please check " "http://pyyaml.org/wiki/YAMLColonInFlowContext " "for details.") @@ -1372,12 +1400,12 @@ class Scanner(object): break self.allow_simple_key = False chunks.extend(spaces) - chunks.append(self.prefix(length)) - self.forward(length) - end_mark = self.get_mark() + chunks.append(self.reader.prefix(length)) + self.reader.forward(length) + end_mark = self.reader.get_mark() spaces = self.scan_plain_spaces(indent, start_mark) - if not spaces or self.peek() == u'#' \ - or (not self.flow_level and self.column < indent): + if not spaces or self.reader.peek() == u'#' \ + or (not self.flow_level and self.reader.column < indent): break token = ScalarToken(u''.join(chunks), True, start_mark, end_mark) @@ -1388,32 +1416,33 @@ class Scanner(object): return token def scan_plain_spaces(self, indent, start_mark): + # type: (Any, Any) -> Any # See the specification for details. # The specification is really confusing about tabs in plain scalars. # We just forbid them completely. Do not use tabs in YAML! chunks = [] length = 0 - while self.peek(length) in u' ': + while self.reader.peek(length) in u' ': length += 1 - whitespaces = self.prefix(length) - self.forward(length) - ch = self.peek() + whitespaces = self.reader.prefix(length) + self.reader.forward(length) + ch = self.reader.peek() if ch in u'\r\n\x85\u2028\u2029': line_break = self.scan_line_break() self.allow_simple_key = True - prefix = self.prefix(3) + prefix = self.reader.prefix(3) if (prefix == u'---' or prefix == u'...') \ - and self.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': + and self.reader.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': return breaks = [] - while self.peek() in u' \r\n\x85\u2028\u2029': - if self.peek() == ' ': - self.forward() + while self.reader.peek() in u' \r\n\x85\u2028\u2029': + if self.reader.peek() == ' ': + self.reader.forward() else: breaks.append(self.scan_line_break()) - prefix = self.prefix(3) + prefix = self.reader.prefix(3) if (prefix == u'---' or prefix == u'...') \ - and self.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': + and self.reader.peek(3) in u'\0 \t\r\n\x85\u2028\u2029': return if line_break != u'\n': chunks.append(line_break) @@ -1425,87 +1454,91 @@ class Scanner(object): return chunks def scan_tag_handle(self, name, start_mark): + # type: (Any, Any) -> Any # See the specification for details. # For some strange reasons, the specification does not allow '_' in # tag handles. I have allowed it anyway. - ch = self.peek() + ch = self.reader.peek() if ch != u'!': raise ScannerError("while scanning a %s" % name, start_mark, "expected '!', but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) length = 1 - ch = self.peek(length) + ch = self.reader.peek(length) if ch != u' ': while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' \ or u'a' <= ch <= u'z' \ or ch in u'-_': length += 1 - ch = self.peek(length) + ch = self.reader.peek(length) if ch != u'!': - self.forward(length) + self.reader.forward(length) raise ScannerError("while scanning a %s" % name, start_mark, "expected '!', but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) length += 1 - value = self.prefix(length) - self.forward(length) + value = self.reader.prefix(length) + self.reader.forward(length) return value def scan_tag_uri(self, name, start_mark): + # type: (Any, Any) -> Any # See the specification for details. # Note: we do not check if URI is well-formed. chunks = [] length = 0 - ch = self.peek(length) + ch = self.reader.peek(length) while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' or u'a' <= ch <= u'z' \ or ch in u'-;/?:@&=+$,_.!~*\'()[]%': if ch == u'%': - chunks.append(self.prefix(length)) - self.forward(length) + chunks.append(self.reader.prefix(length)) + self.reader.forward(length) length = 0 chunks.append(self.scan_uri_escapes(name, start_mark)) else: length += 1 - ch = self.peek(length) - if length: - chunks.append(self.prefix(length)) - self.forward(length) + ch = self.reader.peek(length) + if length != 0: + chunks.append(self.reader.prefix(length)) + self.reader.forward(length) length = 0 if not chunks: raise ScannerError("while parsing a %s" % name, start_mark, "expected URI, but found %r" % utf8(ch), - self.get_mark()) + self.reader.get_mark()) return u''.join(chunks) def scan_uri_escapes(self, name, start_mark): + # type: (Any, Any) -> Any # See the specification for details. - code_bytes = [] - mark = self.get_mark() - while self.peek() == u'%': - self.forward() + code_bytes = [] # type: List[Any] + mark = self.reader.get_mark() + while self.reader.peek() == u'%': + self.reader.forward() for k in range(2): - if self.peek(k) not in u'0123456789ABCDEFabcdef': + if self.reader.peek(k) not in u'0123456789ABCDEFabcdef': raise ScannerError( "while scanning a %s" % name, start_mark, "expected URI escape sequence of 2 hexdecimal numbers," " but found %r" - % utf8(self.peek(k)), self.get_mark()) + % utf8(self.reader.peek(k)), self.reader.get_mark()) if PY3: - code_bytes.append(int(self.prefix(2), 16)) + code_bytes.append(int(self.reader.prefix(2), 16)) else: - code_bytes.append(chr(int(self.prefix(2), 16))) - self.forward(2) + code_bytes.append(chr(int(self.reader.prefix(2), 16))) + self.reader.forward(2) try: if PY3: value = bytes(code_bytes).decode('utf-8') else: - value = unicode(''.join(code_bytes), 'utf-8') + value = unicode(''.join(code_bytes), 'utf-8') # type: ignore except UnicodeDecodeError as exc: raise ScannerError("while scanning a %s" % name, start_mark, str(exc), mark) return value def scan_line_break(self): + # type: () -> Any # Transforms: # '\r\n' : '\n' # '\r' : '\n' @@ -1514,26 +1547,27 @@ class Scanner(object): # '\u2028' : '\u2028' # '\u2029 : '\u2029' # default : '' - ch = self.peek() + ch = self.reader.peek() if ch in u'\r\n\x85': - if self.prefix(2) == u'\r\n': - self.forward(2) + if self.reader.prefix(2) == u'\r\n': + self.reader.forward(2) else: - self.forward() + self.reader.forward() return u'\n' elif ch in u'\u2028\u2029': - self.forward() + self.reader.forward() return ch return u'' class RoundTripScanner(Scanner): def check_token(self, *choices): + # type: (Any) -> bool # Check if the next token is one of the given types. while self.need_more_tokens(): self.fetch_more_tokens() self._gather_comments() - if self.tokens: + if bool(self.tokens): if not choices: return True for choice in choices: @@ -1542,16 +1576,19 @@ class RoundTripScanner(Scanner): return False def peek_token(self): + # type: () -> Any # Return the next token, but do not delete if from the queue. while self.need_more_tokens(): self.fetch_more_tokens() self._gather_comments() - if self.tokens: + if bool(self.tokens): return self.tokens[0] + return None def _gather_comments(self): + # type: () -> Any """combine multiple comment lines""" - comments = [] + comments = [] # type: List[Any] if not self.tokens: return comments if isinstance(self.tokens[0], CommentToken): @@ -1578,11 +1615,12 @@ class RoundTripScanner(Scanner): self.fetch_more_tokens() def get_token(self): + # type: () -> Any # Return the next token. while self.need_more_tokens(): self.fetch_more_tokens() self._gather_comments() - if self.tokens: + if bool(self.tokens): # only add post comment to single line tokens: # scalar, value token. FlowXEndToken, otherwise # hidden streamtokens could get them (leave them and they will be @@ -1600,8 +1638,10 @@ class RoundTripScanner(Scanner): self.tokens[0].add_post_comment(self.tokens.pop(1)) self.tokens_taken += 1 return self.tokens.pop(0) + return None def fetch_comment(self, comment): + # type: (Any) -> None value, start_mark, end_mark = comment while value and value[-1] == u' ': # empty line within indented key context @@ -1612,6 +1652,7 @@ class RoundTripScanner(Scanner): # scanner def scan_to_next_token(self): + # type: () -> Any # We ignore spaces, line breaks and comments. # If we find a line break in the block context, we set the flag # `allow_simple_key` on. @@ -1631,51 +1672,54 @@ class RoundTripScanner(Scanner): # `unwind_indent` before issuing BLOCK-END. # Scanners for block, flow, and plain scalars need to be modified. - if self.index == 0 and self.peek() == u'\uFEFF': - self.forward() + if self.reader.index == 0 and self.reader.peek() == u'\uFEFF': + self.reader.forward() found = False while not found: - while self.peek() == u' ': - self.forward() - ch = self.peek() + while self.reader.peek() == u' ': + self.reader.forward() + ch = self.reader.peek() if ch == u'#': - start_mark = self.get_mark() + start_mark = self.reader.get_mark() comment = ch - self.forward() + self.reader.forward() while ch not in u'\0\r\n\x85\u2028\u2029': - ch = self.peek() + ch = self.reader.peek() if ch == u'\0': # don't gobble the end-of-stream character break comment += ch - self.forward() + self.reader.forward() # gather any blank lines following the comment too ch = self.scan_line_break() while len(ch) > 0: comment += ch ch = self.scan_line_break() - end_mark = self.get_mark() + end_mark = self.reader.get_mark() if not self.flow_level: self.allow_simple_key = True return comment, start_mark, end_mark - if self.scan_line_break(): - start_mark = self.get_mark() + if bool(self.scan_line_break()): + start_mark = self.reader.get_mark() if not self.flow_level: self.allow_simple_key = True - ch = self.peek() + ch = self.reader.peek() if ch == '\n': # empty toplevel lines - start_mark = self.get_mark() + start_mark = self.reader.get_mark() comment = '' while ch: ch = self.scan_line_break(empty_line=True) comment += ch - if self.peek() == '#': # empty line followed by indented real comment + if self.reader.peek() == '#': + # empty line followed by indented real comment comment = comment.rsplit('\n', 1)[0] + '\n' - end_mark = self.get_mark() + end_mark = self.reader.get_mark() return comment, start_mark, end_mark else: found = True + return None def scan_line_break(self, empty_line=False): + # type: (bool) -> Text # Transforms: # '\r\n' : '\n' # '\r' : '\n' @@ -1684,18 +1728,18 @@ class RoundTripScanner(Scanner): # '\u2028' : '\u2028' # '\u2029 : '\u2029' # default : '' - ch = self.peek() + ch = self.reader.peek() # type: Text if ch in u'\r\n\x85': - if self.prefix(2) == u'\r\n': - self.forward(2) + if self.reader.prefix(2) == u'\r\n': + self.reader.forward(2) else: - self.forward() + self.reader.forward() return u'\n' elif ch in u'\u2028\u2029': - self.forward() + self.reader.forward() return ch elif empty_line and ch in '\t ': - self.forward() + self.reader.forward() return ch return u'' diff --git a/serializer.py b/serializer.py index d769b9c..7cac44d 100644 --- a/serializer.py +++ b/serializer.py @@ -4,8 +4,11 @@ from __future__ import absolute_import import re +from typing import Any, Dict, Union, Text # NOQA + from ruamel.yaml.error import YAMLError from ruamel.yaml.compat import nprint, DBG_NODE, dbg, string_types +from ruamel.yaml.compat import VersionType # NOQA from ruamel.yaml.events import ( StreamStartEvent, StreamEndEvent, MappingStartEvent, MappingEndEvent, @@ -30,24 +33,39 @@ class Serializer(object): ANCHOR_RE = re.compile(u'id(?!000$)\\d{3,}') def __init__(self, encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None): + version=None, tags=None, dumper=None): + # type: (Any, bool, bool, VersionType, Any, Any) -> None + self.dumper = dumper + if self.dumper is not None: + self.dumper._serializer = self self.use_encoding = encoding self.use_explicit_start = explicit_start self.use_explicit_end = explicit_end if isinstance(version, string_types): self.use_version = tuple(map(int, version.split('.'))) else: - self.use_version = version + self.use_version = version # type: ignore self.use_tags = tags - self.serialized_nodes = {} - self.anchors = {} + self.serialized_nodes = {} # type: Dict[Any, Any] + self.anchors = {} # type: Dict[Any, Any] self.last_anchor_id = 0 - self.closed = None + self.closed = None # type: Union[None, bool] self._templated_id = None + @property + def emitter(self): + # type: () -> Any + return self.dumper._emitter + + @property + def resolver(self): + # type: () -> Any + return self.dumper._resolver + def open(self): + # type: () -> None if self.closed is None: - self.emit(StreamStartEvent(encoding=self.use_encoding)) + self.emitter.emit(StreamStartEvent(encoding=self.use_encoding)) self.closed = False elif self.closed: raise SerializerError("serializer is closed") @@ -55,16 +73,18 @@ class Serializer(object): raise SerializerError("serializer is already opened") def close(self): + # type: () -> None if self.closed is None: raise SerializerError("serializer is not opened") elif not self.closed: - self.emit(StreamEndEvent()) + self.emitter.emit(StreamEndEvent()) self.closed = True # def __del__(self): # self.close() def serialize(self, node): + # type: (Any) -> None if dbg(DBG_NODE): nprint('Serializing nodes') node.dump() @@ -72,17 +92,18 @@ class Serializer(object): raise SerializerError("serializer is not opened") elif self.closed: raise SerializerError("serializer is closed") - self.emit(DocumentStartEvent(explicit=self.use_explicit_start, - version=self.use_version, - tags=self.use_tags)) + self.emitter.emit(DocumentStartEvent(explicit=self.use_explicit_start, + version=self.use_version, + tags=self.use_tags)) self.anchor_node(node) self.serialize_node(node, None, None) - self.emit(DocumentEndEvent(explicit=self.use_explicit_end)) + self.emitter.emit(DocumentEndEvent(explicit=self.use_explicit_end)) self.serialized_nodes = {} self.anchors = {} self.last_anchor_id = 0 def anchor_node(self, node): + # type: (Any) -> None if node in self.anchors: if self.anchors[node] is None: self.anchors[node] = self.generate_anchor(node) @@ -103,6 +124,7 @@ class Serializer(object): self.anchor_node(value) def generate_anchor(self, node): + # type: (Any) -> Any try: anchor = node.anchor.value except: @@ -113,22 +135,23 @@ class Serializer(object): return anchor def serialize_node(self, node, parent, index): + # type: (Any, Any, Any) -> None alias = self.anchors[node] if node in self.serialized_nodes: - self.emit(AliasEvent(alias)) + self.emitter.emit(AliasEvent(alias)) else: self.serialized_nodes[node] = True - self.descend_resolver(parent, index) + self.resolver.descend_resolver(parent, index) if isinstance(node, ScalarNode): # here check if the node.tag equals the one that would result from parsing # if not equal quoting is necessary for strings - detected_tag = self.resolve(ScalarNode, node.value, (True, False)) - default_tag = self.resolve(ScalarNode, node.value, (False, True)) + detected_tag = self.resolver.resolve(ScalarNode, node.value, (True, False)) + default_tag = self.resolver.resolve(ScalarNode, node.value, (False, True)) implicit = (node.tag == detected_tag), (node.tag == default_tag) - self.emit(ScalarEvent(alias, node.tag, implicit, node.value, - style=node.style, comment=node.comment)) + self.emitter.emit(ScalarEvent(alias, node.tag, implicit, node.value, + style=node.style, comment=node.comment)) elif isinstance(node, SequenceNode): - implicit = (node.tag == self.resolve(SequenceNode, node.value, True)) + implicit = (node.tag == self.resolver.resolve(SequenceNode, node.value, True)) comment = node.comment # print('comment >>>>>>>>>>>>>.', comment, node.flow_style) end_comment = None @@ -141,16 +164,16 @@ class Serializer(object): end_comment = comment[2] else: end_comment = None - self.emit(SequenceStartEvent(alias, node.tag, implicit, - flow_style=node.flow_style, - comment=node.comment)) + self.emitter.emit(SequenceStartEvent(alias, node.tag, implicit, + flow_style=node.flow_style, + comment=node.comment)) index = 0 for item in node.value: self.serialize_node(item, node, index) index += 1 - self.emit(SequenceEndEvent(comment=[seq_comment, end_comment])) + self.emitter.emit(SequenceEndEvent(comment=[seq_comment, end_comment])) elif isinstance(node, MappingNode): - implicit = (node.tag == self.resolve(MappingNode, node.value, True)) + implicit = (node.tag == self.resolver.resolve(MappingNode, node.value, True)) comment = node.comment end_comment = None map_comment = None @@ -160,15 +183,16 @@ class Serializer(object): # comment[0] = None if comment and len(comment) > 2: end_comment = comment[2] - self.emit(MappingStartEvent(alias, node.tag, implicit, - flow_style=node.flow_style, - comment=node.comment)) + self.emitter.emit(MappingStartEvent(alias, node.tag, implicit, + flow_style=node.flow_style, + comment=node.comment)) for key, value in node.value: self.serialize_node(key, node, None) self.serialize_node(value, node, key) - self.emit(MappingEndEvent(comment=[map_comment, end_comment])) - self.ascend_resolver() + self.emitter.emit(MappingEndEvent(comment=[map_comment, end_comment])) + self.resolver.ascend_resolver() def templated_id(s): + # type: (Text) -> Any return Serializer.ANCHOR_RE.match(s) diff --git a/timestamp.py b/timestamp.py index 8b8715d..02ea085 100644 --- a/timestamp.py +++ b/timestamp.py @@ -5,15 +5,20 @@ from __future__ import print_function, absolute_import, division, unicode_litera import datetime import copy +from typing import Any, Dict, Optional, List # NOQA + class TimeStamp(datetime.datetime): def __init__(self, *args, **kw): - self._yaml = dict(t=False, tz=None, delta=0) + # type: (Any, Any) -> None + self._yaml = dict(t=False, tz=None, delta=0) # type: Dict[Any, Any] def __new__(cls, *args, **kw): # datetime is immutable - return datetime.datetime.__new__(cls, *args, **kw) + # type: (Any, Any) -> Any + return datetime.datetime.__new__(cls, *args, **kw) # type: ignore def __deepcopy__(self, memo): + # type: (Any) -> Any ts = TimeStamp(self.year, self.month, self.day, self.hour, self.minute, self.second) ts._yaml = copy.deepcopy(self._yaml) diff --git a/tokens.py b/tokens.py index 9001216..ed572e7 100644 --- a/tokens.py +++ b/tokens.py @@ -1,15 +1,19 @@ # # header # coding: utf-8 +from typing import Any, Dict, Optional, List # NOQA + class Token(object): __slots__ = 'start_mark', 'end_mark', '_comment', def __init__(self, start_mark, end_mark): + # type: (Any, Any) -> None self.start_mark = start_mark self.end_mark = end_mark def __repr__(self): + # type: () -> Any attributes = [key for key in self.__slots__ if not key.endswith('_mark') and hasattr('self', key)] attributes.sort() @@ -18,24 +22,29 @@ class Token(object): return '%s(%s)' % (self.__class__.__name__, arguments) def add_post_comment(self, comment): + # type: (Any) -> None if not hasattr(self, '_comment'): self._comment = [None, None] self._comment[0] = comment def add_pre_comments(self, comments): + # type: (Any) -> None if not hasattr(self, '_comment'): self._comment = [None, None] assert self._comment[1] is None self._comment[1] = comments def get_comment(self): + # type: () -> Any return getattr(self, '_comment', None) @property def comment(self): + # type: () -> Any return getattr(self, '_comment', None) def move_comment(self, target, empty=False): + # type: (Any, bool) -> Any """move a comment from this token to target (normally next token) used to combine e.g. comments before a BlockEntryToken to the ScalarToken that follows it @@ -66,6 +75,7 @@ class Token(object): return self def split_comment(self): + # type: () -> Any """ split the post part of a comment, and return it as comment to be added. Delete second part if [None, None] abc: # this goes to sequence @@ -89,6 +99,7 @@ class DirectiveToken(Token): id = '' def __init__(self, name, value, start_mark, end_mark): + # type: (Any, Any, Any, Any) -> None Token.__init__(self, start_mark, end_mark) self.name = name self.value = value @@ -109,6 +120,7 @@ class StreamStartToken(Token): id = '' def __init__(self, start_mark=None, end_mark=None, encoding=None): + # type: (Any, Any, Any) -> None Token.__init__(self, start_mark, end_mark) self.encoding = encoding @@ -178,6 +190,7 @@ class AliasToken(Token): id = '' def __init__(self, value, start_mark, end_mark): + # type: (Any, Any, Any) -> None Token.__init__(self, start_mark, end_mark) self.value = value @@ -187,6 +200,7 @@ class AnchorToken(Token): id = '' def __init__(self, value, start_mark, end_mark): + # type: (Any, Any, Any) -> None Token.__init__(self, start_mark, end_mark) self.value = value @@ -196,6 +210,7 @@ class TagToken(Token): id = '' def __init__(self, value, start_mark, end_mark): + # type: (Any, Any, Any) -> None Token.__init__(self, start_mark, end_mark) self.value = value @@ -205,6 +220,7 @@ class ScalarToken(Token): id = '' def __init__(self, value, plain, start_mark, end_mark, style=None): + # type: (Any, Any, Any, Any, Any) -> None Token.__init__(self, start_mark, end_mark) self.value = value self.plain = plain @@ -221,5 +237,6 @@ class CommentToken(Token): self.value = value def reset(self): + # type: () -> None if hasattr(self, 'pre_done'): delattr(self, 'pre_done') diff --git a/tox.ini b/tox.ini index 5744346..2524272 100644 --- a/tox.ini +++ b/tox.ini @@ -20,4 +20,4 @@ commands = [flake8] show-source = True max-line-length = 95 -exclude = _test/lib,.hg,.git,.tox,dist,.cache,__pycache__,ruamel.zip2tar.egg-info +exclude = _test/lib,.hg,.git,.tox,dist,.cache,__pycache__,ruamel.zip2tar.egg-info,try_* diff --git a/util.py b/util.py index bb061ce..4ce14ec 100644 --- a/util.py +++ b/util.py @@ -6,7 +6,7 @@ some helper functions that might be generally useful from __future__ import absolute_import, print_function -from typing import Any, Dict, Optional, List # NOQA +from typing import Any, Dict, Optional, List, Text # NOQA from .compat import text_type, binary_type from .compat import StreamTextType, StringIO # NOQA @@ -85,7 +85,7 @@ def configobj_walker(cfg): walks over a ConfigObj (INI file with comments) generating corresponding YAML output (including comments """ - from configobj import ConfigObj + from configobj import ConfigObj # type: ignore assert isinstance(cfg, ConfigObj) for c in cfg.initial_comment: if c.strip(): -- cgit v1.2.1