diff options
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 103 |
1 files changed, 68 insertions, 35 deletions
@@ -35,7 +35,7 @@ from ruamel.yaml.constructor import ( from ruamel.yaml.loader import Loader as UnsafeLoader if False: # MYPY - from typing import List, Set, Dict, Union, Any # NOQA + from typing import List, Set, Dict, Union, Any, Callable # NOQA from ruamel.yaml.compat import StreamType, StreamTextType, VersionType # NOQA if PY3: @@ -61,7 +61,7 @@ class YAML(object): def __init__( self, _kw=enforce, typ=None, pure=False, output=None, plug_ins=None # input=None, ): - # type: (Any, Any, Any, Any) -> None + # type: (Any, Optional[Text], Any, Any, Any) -> None """ _kw: not used, forces keyword arguments in 2.7 (in 3 you can do (*, safe_load=..) typ: 'rt'/None -> RoundTripLoader/RoundTripDumper, (default) @@ -83,7 +83,7 @@ class YAML(object): # self._input = input self._output = output - self._context_manager = False + self._context_manager = None # type: Any self.plug_ins = [] # type: List[Any] for pu in ([] if plug_ins is None else plug_ins) + self.official_plug_ins(): @@ -394,7 +394,7 @@ class YAML(object): class XLoader(self.Parser, self.Constructor, rslvr): # type: ignore def __init__(selfx, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Union[None, VersionType], Union[None, bool]) -> None # NOQA + # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None # NOQA CParser.__init__(selfx, stream) selfx._parser = selfx._composer = selfx self.Constructor.__init__(selfx, loader=selfx) @@ -411,7 +411,17 @@ class YAML(object): if self._context_manager: if not self._output: raise TypeError('Missing output stream while dumping from context manager') - self._context_manager.dump(data, transform=transform) + if _kw is not enforce: + raise TypeError( + '{}.dump() takes one positional argument but at least ' + 'two were given ({!r})'.format(self.__class__.__name__, _kw) + ) + if transform is not None: + raise TypeError( + '{}.dump() in the context manager cannot have transform keyword ' + ''.format(self.__class__.__name__) + ) + self._context_manager.dump(data) else: # old style if stream is None: raise TypeError('Need a stream argument when not dumping from context manager') @@ -419,6 +429,23 @@ class YAML(object): def dump_all(self, documents, stream, _kw=enforce, transform=None): # type: (Any, Union[Path, StreamType], Any, Any) -> Any + if self._context_manager: + raise NotImplementedError + if _kw is not enforce: + raise TypeError( + '{}.dump(_all) takes two positional argument but at least ' + 'three were given ({!r})'.format(self.__class__.__name__, _kw) + ) + self._output = stream + self._context_manager = YAMLContextManager(self, transform=transform) + for data in documents: + self._context_manager.dump(data) + self._context_manager.teardown_output() + self._output = None + self._context_manager = None + + def Xdump_all(self, documents, stream, _kw=enforce, transform=None): + # type: (Any, Union[Path, StreamType], Any, Any) -> Any """ Serialize a sequence of Python objects into a YAML stream. """ @@ -515,7 +542,7 @@ class YAML(object): top_level_colon_align=None, prefix_colon=None, ): - # type: (StreamType, Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA + # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA CEmitter.__init__( selfx, stream, @@ -615,10 +642,12 @@ class YAML(object): # ### context manager def __enter__(self): + # type: () -> Any self._context_manager = YAMLContextManager(self) return self def __exit__(self, typ, value, traceback): + # type: (Any, Any, Any) -> None if typ: print('typ', typ) self._context_manager.teardown_output() @@ -657,12 +686,13 @@ class YAML(object): class YAMLContextManager(object): - def __init__(self, yaml): + def __init__(self, yaml, transform=None): + # type: (Any, Optional[Callable]) -> None self._yaml = yaml self._output_inited = False self._output_path = None self._output = self._yaml._output - self._transform = False + self._transform = transform # self._input_inited = False # self._input = input @@ -673,7 +703,7 @@ class YAMLContextManager(object): if not hasattr(self._output, 'write') and hasattr(self._output, 'open'): # pathlib.Path() instance, open with the same mode self._output_path = self._output - self._output = self._output_path.open('r') + self._output = self._output_path.open('w') # if not hasattr(self._stream, 'write') and hasattr(stream, 'open'): # if not hasattr(self._input, 'read') and hasattr(self._input, 'open'): @@ -681,14 +711,15 @@ class YAMLContextManager(object): # self._input_path = self._input # self._input = self._input_path.open('r') - # if self._transform is not None: - # self._fstream = self._stream - # if self._yaml.encoding is None: - # self._stream = StringIO() - # else: - # self._stream = BytesIO() + if self._transform is not None: + self._fstream = self._output + if self._yaml.encoding is None: + self._output = StringIO() + else: + self._output = BytesIO() def teardown_output(self): + # type: () -> None if self._output_inited: self._yaml.serializer.close() else: @@ -702,21 +733,22 @@ class YAMLContextManager(object): delattr(self._yaml, '_serializer') delattr(self._yaml, '_emitter') except AttributeError: - if not typ: - raise + raise if self._transform: - val = self._stream.getvalue() # type: ignore + val = self._output.getvalue() if self._yaml.encoding: val = val.decode(self._yaml.encoding) if self._fstream is None: self._transform(val) else: self._fstream.write(self._transform(val)) - self._fstream.close() + self._fstream.flush() + self._output = self._fstream # maybe not necessary if self._output_path is not None: self._output.close() - def init_output(self, data): + def init_output(self, first_data): + # type: (Any) -> None if self._yaml.top_level_colon_align is True: tlca = max([len(str(x)) for x in first_data]) # type: Any else: @@ -725,7 +757,8 @@ class YAMLContextManager(object): self._yaml.serializer.open() self._output_inited = True - def dump(self, data, transform=None): + def dump(self, data): + # type: (Any) -> None if not self._output_inited: self.init_output(data) try: @@ -853,7 +886,7 @@ def compose_all(stream, Loader=Loader): def load(stream, Loader=None, version=None, preserve_quotes=None): - # type: (StreamTextType, Any, Union[None, VersionType], Any) -> Any + # type: (StreamTextType, Any, Optional[VersionType], Any) -> Any """ Parse the first YAML document in a stream and produce the corresponding Python object. @@ -869,7 +902,7 @@ def load(stream, Loader=None, version=None, preserve_quotes=None): def load_all(stream, Loader=None, version=None, preserve_quotes=None): - # type: (Union[None, StreamTextType], Any, Union[None, VersionType], Union[None, bool]) -> Any # NOQA + # type: (Optional[StreamTextType], Any, Optional[VersionType], Optional[bool]) -> Any # NOQA """ Parse all YAML documents in a stream and produce corresponding Python objects. @@ -886,7 +919,7 @@ def load_all(stream, Loader=None, version=None, preserve_quotes=None): def safe_load(stream, version=None): - # type: (StreamTextType, Union[None, VersionType]) -> Any + # type: (StreamTextType, Optional[VersionType]) -> Any """ Parse the first YAML document in a stream and produce the corresponding Python object. @@ -896,7 +929,7 @@ def safe_load(stream, version=None): def safe_load_all(stream, version=None): - # type: (StreamTextType, Union[None, VersionType]) -> Any + # type: (StreamTextType, Optional[VersionType]) -> Any """ Parse all YAML documents in a stream and produce corresponding Python objects. @@ -906,7 +939,7 @@ def safe_load_all(stream, version=None): def round_trip_load(stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Union[None, VersionType], Union[None, bool]) -> Any + # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any """ Parse the first YAML document in a stream and produce the corresponding Python object. @@ -916,7 +949,7 @@ def round_trip_load(stream, version=None, preserve_quotes=None): def round_trip_load_all(stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Union[None, VersionType], Union[None, bool]) -> Any + # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any """ Parse all YAML documents in a stream and produce corresponding Python objects. @@ -935,7 +968,7 @@ def emit( allow_unicode=None, line_break=None, ): - # type: (Any, Union[None, StreamType], Any, Union[None, bool], Union[int, None], Union[None, int], Union[None, bool], Any) -> Any # NOQA + # type: (Any, Optional[StreamType], Any, Optional[bool], Union[int, None], Optional[int], Optional[bool], Any) -> Any # NOQA """ Emit YAML parsing events into a stream. If stream is None, return the produced string instead. @@ -983,7 +1016,7 @@ def serialize_all( version=None, tags=None, ): - # type: (Any, Union[None, StreamType], Any, Any, Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Union[None, VersionType], Any) -> Any # NOQA + # type: (Any, Optional[StreamType], Any, Any, Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any) -> Any # NOQA """ Serialize a sequence of representation trees into a YAML stream. If stream is None, return the produced string instead. @@ -1024,7 +1057,7 @@ def serialize_all( def serialize(node, stream=None, Dumper=Dumper, **kwds): - # type: (Any, Union[None, StreamType], Any, Any) -> Any + # type: (Any, Optional[StreamType], Any, Any) -> Any """ Serialize a representation tree into a YAML stream. If stream is None, return the produced string instead. @@ -1052,7 +1085,7 @@ def dump_all( top_level_colon_align=None, prefix_colon=None, ): - # type: (Any, Union[None, StreamType], Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> Union[None, str] # NOQA + # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> Optional[str] # NOQA """ Serialize a sequence of Python objects into a YAML stream. If stream is None, return the produced string instead. @@ -1122,7 +1155,7 @@ def dump( tags=None, block_seq_indent=None, ): - # type: (Any, Union[None, StreamType], Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Union[None, VersionType], Any, Any) -> Union[None, str] # NOQA + # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> Optional[str] # NOQA """ Serialize a Python object into a YAML stream. If stream is None, return the produced string instead. @@ -1151,7 +1184,7 @@ def dump( def safe_dump_all(documents, stream=None, **kwds): - # type: (Any, Union[None, StreamType], Any) -> Union[None, str] + # type: (Any, Optional[StreamType], Any) -> Optional[str] """ Serialize a sequence of Python objects into a YAML stream. Produce only basic YAML tags. @@ -1161,7 +1194,7 @@ def safe_dump_all(documents, stream=None, **kwds): def safe_dump(data, stream=None, **kwds): - # type: (Any, Union[None, StreamType], Any) -> Union[None, str] + # type: (Any, Optional[StreamType], Any) -> Optional[str] """ Serialize a Python object into a YAML stream. Produce only basic YAML tags. @@ -1190,7 +1223,7 @@ def round_trip_dump( top_level_colon_align=None, prefix_colon=None, ): - # type: (Any, Union[None, StreamType], Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Union[None, VersionType], Any, Any, Any, Any) -> Union[None, str] # NOQA + # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any, Any, Any) -> Optional[str] # NOQA allow_unicode = True if allow_unicode is None else allow_unicode return dump_all( [data], |