summaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
authorAnthon van der Neut <anthon@mnt.org>2018-08-05 23:20:17 +0200
committerAnthon van der Neut <anthon@mnt.org>2018-08-05 23:20:17 +0200
commitbfd63d6184e3a43cb63e0831467819bc44513d50 (patch)
tree9426b3c93cb1fd7849ea678424b6bbd900ff2d68 /main.py
parent893db272efb6d7041e09aa09f04da2010ec92072 (diff)
downloadruamel.yaml-bfd63d6184e3a43cb63e0831467819bc44513d50.tar.gz
added context manager, mypy cleanup (w. Optional) added tests0.15.50
Diffstat (limited to 'main.py')
-rw-r--r--main.py103
1 files changed, 68 insertions, 35 deletions
diff --git a/main.py b/main.py
index a2a11f4..ce607bf 100644
--- a/main.py
+++ b/main.py
@@ -35,7 +35,7 @@ from ruamel.yaml.constructor import (
from ruamel.yaml.loader import Loader as UnsafeLoader
if False: # MYPY
- from typing import List, Set, Dict, Union, Any # NOQA
+ from typing import List, Set, Dict, Union, Any, Callable # NOQA
from ruamel.yaml.compat import StreamType, StreamTextType, VersionType # NOQA
if PY3:
@@ -61,7 +61,7 @@ class YAML(object):
def __init__(
self, _kw=enforce, typ=None, pure=False, output=None, plug_ins=None # input=None,
):
- # type: (Any, Any, Any, Any) -> None
+ # type: (Any, Optional[Text], Any, Any, Any) -> None
"""
_kw: not used, forces keyword arguments in 2.7 (in 3 you can do (*, safe_load=..)
typ: 'rt'/None -> RoundTripLoader/RoundTripDumper, (default)
@@ -83,7 +83,7 @@ class YAML(object):
# self._input = input
self._output = output
- self._context_manager = False
+ self._context_manager = None # type: Any
self.plug_ins = [] # type: List[Any]
for pu in ([] if plug_ins is None else plug_ins) + self.official_plug_ins():
@@ -394,7 +394,7 @@ class YAML(object):
class XLoader(self.Parser, self.Constructor, rslvr): # type: ignore
def __init__(selfx, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Union[None, VersionType], Union[None, bool]) -> None # NOQA
+ # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None # NOQA
CParser.__init__(selfx, stream)
selfx._parser = selfx._composer = selfx
self.Constructor.__init__(selfx, loader=selfx)
@@ -411,7 +411,17 @@ class YAML(object):
if self._context_manager:
if not self._output:
raise TypeError('Missing output stream while dumping from context manager')
- self._context_manager.dump(data, transform=transform)
+ if _kw is not enforce:
+ raise TypeError(
+ '{}.dump() takes one positional argument but at least '
+ 'two were given ({!r})'.format(self.__class__.__name__, _kw)
+ )
+ if transform is not None:
+ raise TypeError(
+ '{}.dump() in the context manager cannot have transform keyword '
+ ''.format(self.__class__.__name__)
+ )
+ self._context_manager.dump(data)
else: # old style
if stream is None:
raise TypeError('Need a stream argument when not dumping from context manager')
@@ -419,6 +429,23 @@ class YAML(object):
def dump_all(self, documents, stream, _kw=enforce, transform=None):
# type: (Any, Union[Path, StreamType], Any, Any) -> Any
+ if self._context_manager:
+ raise NotImplementedError
+ if _kw is not enforce:
+ raise TypeError(
+ '{}.dump(_all) takes two positional argument but at least '
+ 'three were given ({!r})'.format(self.__class__.__name__, _kw)
+ )
+ self._output = stream
+ self._context_manager = YAMLContextManager(self, transform=transform)
+ for data in documents:
+ self._context_manager.dump(data)
+ self._context_manager.teardown_output()
+ self._output = None
+ self._context_manager = None
+
+ def Xdump_all(self, documents, stream, _kw=enforce, transform=None):
+ # type: (Any, Union[Path, StreamType], Any, Any) -> Any
"""
Serialize a sequence of Python objects into a YAML stream.
"""
@@ -515,7 +542,7 @@ class YAML(object):
top_level_colon_align=None,
prefix_colon=None,
):
- # type: (StreamType, Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> None # NOQA
+ # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
CEmitter.__init__(
selfx,
stream,
@@ -615,10 +642,12 @@ class YAML(object):
# ### context manager
def __enter__(self):
+ # type: () -> Any
self._context_manager = YAMLContextManager(self)
return self
def __exit__(self, typ, value, traceback):
+ # type: (Any, Any, Any) -> None
if typ:
print('typ', typ)
self._context_manager.teardown_output()
@@ -657,12 +686,13 @@ class YAML(object):
class YAMLContextManager(object):
- def __init__(self, yaml):
+ def __init__(self, yaml, transform=None):
+ # type: (Any, Optional[Callable]) -> None
self._yaml = yaml
self._output_inited = False
self._output_path = None
self._output = self._yaml._output
- self._transform = False
+ self._transform = transform
# self._input_inited = False
# self._input = input
@@ -673,7 +703,7 @@ class YAMLContextManager(object):
if not hasattr(self._output, 'write') and hasattr(self._output, 'open'):
# pathlib.Path() instance, open with the same mode
self._output_path = self._output
- self._output = self._output_path.open('r')
+ self._output = self._output_path.open('w')
# if not hasattr(self._stream, 'write') and hasattr(stream, 'open'):
# if not hasattr(self._input, 'read') and hasattr(self._input, 'open'):
@@ -681,14 +711,15 @@ class YAMLContextManager(object):
# self._input_path = self._input
# self._input = self._input_path.open('r')
- # if self._transform is not None:
- # self._fstream = self._stream
- # if self._yaml.encoding is None:
- # self._stream = StringIO()
- # else:
- # self._stream = BytesIO()
+ if self._transform is not None:
+ self._fstream = self._output
+ if self._yaml.encoding is None:
+ self._output = StringIO()
+ else:
+ self._output = BytesIO()
def teardown_output(self):
+ # type: () -> None
if self._output_inited:
self._yaml.serializer.close()
else:
@@ -702,21 +733,22 @@ class YAMLContextManager(object):
delattr(self._yaml, '_serializer')
delattr(self._yaml, '_emitter')
except AttributeError:
- if not typ:
- raise
+ raise
if self._transform:
- val = self._stream.getvalue() # type: ignore
+ val = self._output.getvalue()
if self._yaml.encoding:
val = val.decode(self._yaml.encoding)
if self._fstream is None:
self._transform(val)
else:
self._fstream.write(self._transform(val))
- self._fstream.close()
+ self._fstream.flush()
+ self._output = self._fstream # maybe not necessary
if self._output_path is not None:
self._output.close()
- def init_output(self, data):
+ def init_output(self, first_data):
+ # type: (Any) -> None
if self._yaml.top_level_colon_align is True:
tlca = max([len(str(x)) for x in first_data]) # type: Any
else:
@@ -725,7 +757,8 @@ class YAMLContextManager(object):
self._yaml.serializer.open()
self._output_inited = True
- def dump(self, data, transform=None):
+ def dump(self, data):
+ # type: (Any) -> None
if not self._output_inited:
self.init_output(data)
try:
@@ -853,7 +886,7 @@ def compose_all(stream, Loader=Loader):
def load(stream, Loader=None, version=None, preserve_quotes=None):
- # type: (StreamTextType, Any, Union[None, VersionType], Any) -> Any
+ # type: (StreamTextType, Any, Optional[VersionType], Any) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
@@ -869,7 +902,7 @@ def load(stream, Loader=None, version=None, preserve_quotes=None):
def load_all(stream, Loader=None, version=None, preserve_quotes=None):
- # type: (Union[None, StreamTextType], Any, Union[None, VersionType], Union[None, bool]) -> Any # NOQA
+ # type: (Optional[StreamTextType], Any, Optional[VersionType], Optional[bool]) -> Any # NOQA
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
@@ -886,7 +919,7 @@ def load_all(stream, Loader=None, version=None, preserve_quotes=None):
def safe_load(stream, version=None):
- # type: (StreamTextType, Union[None, VersionType]) -> Any
+ # type: (StreamTextType, Optional[VersionType]) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
@@ -896,7 +929,7 @@ def safe_load(stream, version=None):
def safe_load_all(stream, version=None):
- # type: (StreamTextType, Union[None, VersionType]) -> Any
+ # type: (StreamTextType, Optional[VersionType]) -> Any
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
@@ -906,7 +939,7 @@ def safe_load_all(stream, version=None):
def round_trip_load(stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Union[None, VersionType], Union[None, bool]) -> Any
+ # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
@@ -916,7 +949,7 @@ def round_trip_load(stream, version=None, preserve_quotes=None):
def round_trip_load_all(stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Union[None, VersionType], Union[None, bool]) -> Any
+ # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
@@ -935,7 +968,7 @@ def emit(
allow_unicode=None,
line_break=None,
):
- # type: (Any, Union[None, StreamType], Any, Union[None, bool], Union[int, None], Union[None, int], Union[None, bool], Any) -> Any # NOQA
+ # type: (Any, Optional[StreamType], Any, Optional[bool], Union[int, None], Optional[int], Optional[bool], Any) -> Any # NOQA
"""
Emit YAML parsing events into a stream.
If stream is None, return the produced string instead.
@@ -983,7 +1016,7 @@ def serialize_all(
version=None,
tags=None,
):
- # type: (Any, Union[None, StreamType], Any, Any, Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Union[None, VersionType], Any) -> Any # NOQA
+ # type: (Any, Optional[StreamType], Any, Any, Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any) -> Any # NOQA
"""
Serialize a sequence of representation trees into a YAML stream.
If stream is None, return the produced string instead.
@@ -1024,7 +1057,7 @@ def serialize_all(
def serialize(node, stream=None, Dumper=Dumper, **kwds):
- # type: (Any, Union[None, StreamType], Any, Any) -> Any
+ # type: (Any, Optional[StreamType], Any, Any) -> Any
"""
Serialize a representation tree into a YAML stream.
If stream is None, return the produced string instead.
@@ -1052,7 +1085,7 @@ def dump_all(
top_level_colon_align=None,
prefix_colon=None,
):
- # type: (Any, Union[None, StreamType], Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Any, Any, Any, Any, Any) -> Union[None, str] # NOQA
+ # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> Optional[str] # NOQA
"""
Serialize a sequence of Python objects into a YAML stream.
If stream is None, return the produced string instead.
@@ -1122,7 +1155,7 @@ def dump(
tags=None,
block_seq_indent=None,
):
- # type: (Any, Union[None, StreamType], Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Union[None, VersionType], Any, Any) -> Union[None, str] # NOQA
+ # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> Optional[str] # NOQA
"""
Serialize a Python object into a YAML stream.
If stream is None, return the produced string instead.
@@ -1151,7 +1184,7 @@ def dump(
def safe_dump_all(documents, stream=None, **kwds):
- # type: (Any, Union[None, StreamType], Any) -> Union[None, str]
+ # type: (Any, Optional[StreamType], Any) -> Optional[str]
"""
Serialize a sequence of Python objects into a YAML stream.
Produce only basic YAML tags.
@@ -1161,7 +1194,7 @@ def safe_dump_all(documents, stream=None, **kwds):
def safe_dump(data, stream=None, **kwds):
- # type: (Any, Union[None, StreamType], Any) -> Union[None, str]
+ # type: (Any, Optional[StreamType], Any) -> Optional[str]
"""
Serialize a Python object into a YAML stream.
Produce only basic YAML tags.
@@ -1190,7 +1223,7 @@ def round_trip_dump(
top_level_colon_align=None,
prefix_colon=None,
):
- # type: (Any, Union[None, StreamType], Any, Any, Any, Union[None, bool], Union[None, int], Union[None, int], Union[None, bool], Any, Any, Union[None, bool], Union[None, bool], Union[None, VersionType], Any, Any, Any, Any) -> Union[None, str] # NOQA
+ # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any, Any, Any) -> Optional[str] # NOQA
allow_unicode = True if allow_unicode is None else allow_unicode
return dump_all(
[data],