diff options
Diffstat (limited to 'constructor.py')
-rw-r--r-- | constructor.py | 100 |
1 files changed, 60 insertions, 40 deletions
diff --git a/constructor.py b/constructor.py index 2bcb3bd..6797d95 100644 --- a/constructor.py +++ b/constructor.py @@ -10,7 +10,7 @@ import re import sys import types -from typing import Any, Dict, List # NOQA +from typing import Any, Dict, List, Set, Generator # NOQA from ruamel.yaml.error import (MarkedYAMLError) @@ -39,29 +39,49 @@ class BaseConstructor(object): yaml_constructors = {} # type: Dict[Any, Any] yaml_multi_constructors = {} # type: Dict[Any, Any] - def __init__(self, preserve_quotes=None): - # type: (bool) -> None + def __init__(self, preserve_quotes=None, loader=None): + # type: (bool, Any) -> None + self.loader = loader + if self.loader is not None: + self.loader._constructor = self + self.loader = loader self.constructed_objects = {} # type: Dict[Any, Any] self.recursive_objects = {} # type: Dict[Any, Any] self.state_generators = [] # type: List[Any] self.deep_construct = False self._preserve_quotes = preserve_quotes + @property + def composer(self): + # type: () -> Any + try: + return self.loader._composer + except AttributeError: + print('slt', type(self)) + print('slc', self.loader._composer) + print(dir(self)) + raise + + @property + def resolver(self): + # type: () -> Any + return self.loader._resolver + def check_data(self): # type: () -> Any # If there are more documents available? - return self.check_node() + return self.composer.check_node() def get_data(self): # type: () -> Any # Construct and return the next document. - if self.check_node(): - return self.construct_document(self.get_node()) + if self.composer.check_node(): + return self.construct_document(self.composer.get_node()) def get_single_data(self): # type: () -> Any # Ensure that the stream contains a single document and construct it. - node = self.get_single_node() + node = self.composer.get_single_node() if node is not None: return self.construct_document(node) return None @@ -69,7 +89,7 @@ class BaseConstructor(object): def construct_document(self, node): # type: (Any) -> Any data = self.construct_object(node) - while self.state_generators: + while bool(self.state_generators): state_generators = self.state_generators self.state_generators = [] for generator in state_generators: @@ -112,20 +132,20 @@ class BaseConstructor(object): elif None in self.yaml_constructors: constructor = self.yaml_constructors[None] elif isinstance(node, ScalarNode): - constructor = self.__class__.construct_scalar + constructor = self.__class__.construct_scalar # type: ignore elif isinstance(node, SequenceNode): - constructor = self.__class__.construct_sequence + constructor = self.__class__.construct_sequence # type: ignore elif isinstance(node, MappingNode): - constructor = self.__class__.construct_mapping + constructor = self.__class__.construct_mapping # type: ignore if tag_suffix is None: data = constructor(self, node) else: data = constructor(self, tag_suffix, node) if isinstance(data, types.GeneratorType): generator = data - data = next(generator) + data = next(generator) # type: ignore if self.deep_construct: - for dummy in generator: + for dummy in generator: # type: ignore pass else: self.state_generators.append(generator) @@ -172,7 +192,7 @@ class BaseConstructor(object): # keys can be list -> deep key = self.construct_object(key_node, deep=True) # lists are not hashable, but tuples are - if not isinstance(key, collections.Hashable): + if not isinstance(key, collections.Hashable): # type: ignore if isinstance(key, list): key = tuple(key) if PY2: @@ -238,7 +258,7 @@ class SafeConstructor(BaseConstructor): by inserting keys from the merge dict/list of dicts if not yet available in this node """ - merge = [] + merge = [] # type: List[Any] index = 0 while index < len(node.value): key_node, value_node = node.value[index] @@ -272,7 +292,7 @@ class SafeConstructor(BaseConstructor): index += 1 else: index += 1 - if merge: + if bool(merge): node.value = merge + node.value def construct_mapping(self, node, deep=False): @@ -321,9 +341,9 @@ class SafeConstructor(BaseConstructor): return sign*int(value_s[2:], 16) elif value_s.startswith('0o'): return sign*int(value_s[2:], 8) - elif self.processing_version != (1, 2) and value_s[0] == '0': + elif self.resolver.processing_version != (1, 2) and value_s[0] == '0': return sign*int(value_s, 8) - elif self.processing_version != (1, 2) and ':' in value_s: + elif self.resolver.processing_version != (1, 2) and ':' in value_s: digits = [int(part) for part in value_s.split(':')] digits.reverse() base = 1 @@ -438,7 +458,7 @@ class SafeConstructor(BaseConstructor): delta = -delta data = datetime.datetime(year, month, day, hour, minute, second, fraction) - if delta: + if delta: # type: ignore data -= delta return data @@ -499,7 +519,7 @@ class SafeConstructor(BaseConstructor): def construct_yaml_set(self, node): # type: (Any) -> Any - data = set() + data = set() # type: Set[Any] yield data value = self.construct_mapping(node) data.update(value) @@ -516,13 +536,13 @@ class SafeConstructor(BaseConstructor): def construct_yaml_seq(self, node): # type: (Any) -> Any - data = [] + data = [] # type: List[Any] yield data data.extend(self.construct_sequence(node)) def construct_yaml_map(self, node): # type: (Any) -> Any - data = {} + data = {} # type: Dict[Any, Any] yield data value = self.construct_mapping(node) data.update(value) @@ -597,6 +617,10 @@ SafeConstructor.add_constructor( SafeConstructor.add_constructor( None, SafeConstructor.construct_undefined) +if PY2: + class classobj: + pass + class Constructor(SafeConstructor): @@ -702,10 +726,6 @@ class Constructor(SafeConstructor): node.start_mark) return self.find_python_module(suffix, node.start_mark) - if PY2: - class classobj: - pass - def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False): # type: (Any, Any, Any, Any, bool) -> Any @@ -720,9 +740,9 @@ class Constructor(SafeConstructor): else: return cls(*args, **kwds) else: - if newobj and isinstance(cls, type(self.classobj)) \ + if newobj and isinstance(cls, type(classobj)) \ and not args and not kwds: - instance = self.classobj() + instance = classobj() instance.__class__ = cls return instance elif newobj and isinstance(cls, type): @@ -772,7 +792,7 @@ class Constructor(SafeConstructor): args = self.construct_sequence(node, deep=True) kwds = {} # type: Dict[Any, Any] state = {} # type: Dict[Any, Any] - listitems = [] # List[Any] + listitems = [] # type: List[Any] dictitems = {} # type: Dict[Any, Any] else: value = self.construct_mapping(node, deep=True) @@ -782,11 +802,11 @@ class Constructor(SafeConstructor): listitems = value.get('listitems', []) dictitems = value.get('dictitems', {}) instance = self.make_python_instance(suffix, node, args, kwds, newobj) - if state: + if bool(state): self.set_python_instance_state(instance, state) - if listitems: + if bool(listitems): instance.extend(listitems) - if dictitems: + if bool(dictitems): for key in dictitems: instance[key] = dictitems[key] return instance @@ -880,7 +900,7 @@ class RoundTripConstructor(SafeConstructor): if node.style == '|' and isinstance(node.value, text_type): return PreservedScalarString(node.value) - elif self._preserve_quotes and isinstance(node.value, text_type): + elif bool(self._preserve_quotes) and isinstance(node.value, text_type): if node.style == "'": return SingleQuotedScalarString(node.value) if node.style == '"': @@ -914,7 +934,7 @@ class RoundTripConstructor(SafeConstructor): seqtyp._yaml_add_comment(node.comment[:2]) if len(node.comment) > 2: seqtyp.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: + if node.anchor: # type: ignore from ruamel.yaml.serializer import templated_id if not templated_id(node.anchor): seqtyp.yaml_set_anchor(node.anchor) @@ -993,7 +1013,7 @@ class RoundTripConstructor(SafeConstructor): # type: () -> None pass - def construct_mapping(self, node, maptyp, deep=False): + def construct_mapping(self, node, maptyp=None, deep=False): # type: ignore # type: (Any, Any, bool) -> Any if not isinstance(node, MappingNode): raise ConstructorError( @@ -1006,7 +1026,7 @@ class RoundTripConstructor(SafeConstructor): maptyp._yaml_add_comment(node.comment[:2]) if len(node.comment) > 2: maptyp.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: + if node.anchor: # type: ignore from ruamel.yaml.serializer import templated_id if not templated_id(node.anchor): maptyp.yaml_set_anchor(node.anchor) @@ -1015,7 +1035,7 @@ class RoundTripConstructor(SafeConstructor): # keys can be list -> deep key = self.construct_object(key_node, deep=True) # lists are not hashable, but tuples are - if not isinstance(key, collections.Hashable): + if not isinstance(key, collections.Hashable): # type: ignore if isinstance(key, list): key_a = CommentedKeySeq(key) if key_node.flow_style is True: @@ -1072,7 +1092,7 @@ class RoundTripConstructor(SafeConstructor): typ._yaml_add_comment(node.comment[:2]) if len(node.comment) > 2: typ.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: + if node.anchor: # type: ignore from ruamel.yaml.serializer import templated_id if not templated_id(node.anchor): typ.yaml_set_anchor(node.anchor) @@ -1080,7 +1100,7 @@ class RoundTripConstructor(SafeConstructor): # keys can be list -> deep key = self.construct_object(key_node, deep=True) # lists are not hashable, but tuples are - if not isinstance(key, collections.Hashable): + if not isinstance(key, collections.Hashable): # type: ignore if isinstance(key, list): key = tuple(key) if PY2: @@ -1229,7 +1249,7 @@ class RoundTripConstructor(SafeConstructor): delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) if values['tz_sign'] == '-': delta = -delta - if delta: + if delta: # type: ignore dt = datetime.datetime(year, month, day, hour, minute) dt -= delta data = TimeStamp(dt.year, dt.month, dt.day, dt.hour, dt.minute, |