summaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
authorAnthon van der Neut <anthon@mnt.org>2018-08-05 21:34:36 +0200
committerAnthon van der Neut <anthon@mnt.org>2018-08-05 21:34:36 +0200
commit893db272efb6d7041e09aa09f04da2010ec92072 (patch)
tree0ed4b138a0f253845cf60c6382ae65e219721842 /main.py
parente58ed78b291889578477741fb5ad5f05bf914d6b (diff)
downloadruamel.yaml-893db272efb6d7041e09aa09f04da2010ec92072.tar.gz
initial contextmanager dumping
Diffstat (limited to 'main.py')
-rw-r--r--main.py151
1 files changed, 147 insertions, 4 deletions
diff --git a/main.py b/main.py
index 510650e..a2a11f4 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,6 @@
# coding: utf-8
-from __future__ import absolute_import, unicode_literals
+from __future__ import absolute_import, unicode_literals, print_function
import sys
import os
@@ -37,6 +37,7 @@ from ruamel.yaml.loader import Loader as UnsafeLoader
if False: # MYPY
from typing import List, Set, Dict, Union, Any # NOQA
from ruamel.yaml.compat import StreamType, StreamTextType, VersionType # NOQA
+
if PY3:
from pathlib import Path
else:
@@ -57,7 +58,9 @@ enforce = object()
class YAML(object):
- def __init__(self, _kw=enforce, typ=None, pure=False, plug_ins=None):
+ def __init__(
+ self, _kw=enforce, typ=None, pure=False, output=None, plug_ins=None # input=None,
+ ):
# type: (Any, Any, Any, Any) -> None
"""
_kw: not used, forces keyword arguments in 2.7 (in 3 you can do (*, safe_load=..)
@@ -66,6 +69,7 @@ class YAML(object):
'unsafe' -> normal/unsafe Loader/Dumper
'base' -> baseloader
pure: if True only use Python modules
+ input/output: needed to work as context manager
plug_ins: a list of plug-in files
"""
if _kw is not enforce:
@@ -76,6 +80,11 @@ class YAML(object):
self.typ = 'rt' if typ is None else typ
self.pure = pure
+
+ # self._input = input
+ self._output = output
+ self._context_manager = False
+
self.plug_ins = [] # type: List[Any]
for pu in ([] if plug_ins is None else plug_ins) + self.official_plug_ins():
file_name = pu.replace(os.sep, '.')
@@ -282,6 +291,17 @@ class YAML(object):
# separate output resolver?
+ # def load(self, stream=None):
+ # if self._context_manager:
+ # if not self._input:
+ # raise TypeError("Missing input stream while dumping from context manager")
+ # for data in self._context_manager.load():
+ # yield data
+ # return
+ # if stream is None:
+ # raise TypeError("Need a stream argument when not loading from context manager")
+ # return self.load_one(stream)
+
def load(self, stream):
# type: (Union[Path, StreamTextType]) -> Any
"""
@@ -386,9 +406,16 @@ class YAML(object):
return loader, loader
return self.constructor, self.parser
- def dump(self, data, stream, _kw=enforce, transform=None):
+ def dump(self, data, stream=None, _kw=enforce, transform=None):
# type: (Any, Union[Path, StreamType], Any, Any) -> Any
- return self.dump_all([data], stream, _kw, transform=transform)
+ if self._context_manager:
+ if not self._output:
+ raise TypeError('Missing output stream while dumping from context manager')
+ self._context_manager.dump(data, transform=transform)
+ else: # old style
+ if stream is None:
+ raise TypeError('Need a stream argument when not dumping from context manager')
+ return self.dump_all([data], stream, _kw, transform=transform)
def dump_all(self, documents, stream, _kw=enforce, transform=None):
# type: (Any, Union[Path, StreamType], Any, Any) -> Any
@@ -585,6 +612,19 @@ class YAML(object):
self.constructor.add_constructor(tag, f_y)
return cls
+ # ### context manager
+
+ def __enter__(self):
+ self._context_manager = YAMLContextManager(self)
+ return self
+
+ def __exit__(self, typ, value, traceback):
+ if typ:
+ print('typ', typ)
+ self._context_manager.teardown_output()
+ # self._context_manager.teardown_input()
+ self._context_manager = None
+
# ### backwards compatibility
def _indent(self, mapping=None, sequence=None, offset=None):
# type: (Any, Any, Any) -> None
@@ -616,6 +656,109 @@ class YAML(object):
self.sequence_dash_offset = val
+class YAMLContextManager(object):
+ def __init__(self, yaml):
+ self._yaml = yaml
+ self._output_inited = False
+ self._output_path = None
+ self._output = self._yaml._output
+ self._transform = False
+
+ # self._input_inited = False
+ # self._input = input
+ # self._input_path = None
+ # self._transform = yaml.transform
+ # self._fstream = None
+
+ if not hasattr(self._output, 'write') and hasattr(self._output, 'open'):
+ # pathlib.Path() instance, open with the same mode
+ self._output_path = self._output
+ self._output = self._output_path.open('r')
+
+ # if not hasattr(self._stream, 'write') and hasattr(stream, 'open'):
+ # if not hasattr(self._input, 'read') and hasattr(self._input, 'open'):
+ # # pathlib.Path() instance, open with the same mode
+ # self._input_path = self._input
+ # self._input = self._input_path.open('r')
+
+ # if self._transform is not None:
+ # self._fstream = self._stream
+ # if self._yaml.encoding is None:
+ # self._stream = StringIO()
+ # else:
+ # self._stream = BytesIO()
+
+ def teardown_output(self):
+ if self._output_inited:
+ self._yaml.serializer.close()
+ else:
+ return
+ try:
+ self._yaml.emitter.dispose()
+ except AttributeError:
+ raise
+ # self.dumper.dispose() # cyaml
+ try:
+ delattr(self._yaml, '_serializer')
+ delattr(self._yaml, '_emitter')
+ except AttributeError:
+ if not typ:
+ raise
+ if self._transform:
+ val = self._stream.getvalue() # type: ignore
+ if self._yaml.encoding:
+ val = val.decode(self._yaml.encoding)
+ if self._fstream is None:
+ self._transform(val)
+ else:
+ self._fstream.write(self._transform(val))
+ self._fstream.close()
+ if self._output_path is not None:
+ self._output.close()
+
+ def init_output(self, data):
+ if self._yaml.top_level_colon_align is True:
+ tlca = max([len(str(x)) for x in first_data]) # type: Any
+ else:
+ tlca = self._yaml.top_level_colon_align
+ self._yaml.get_serializer_representer_emitter(self._output, tlca)
+ self._yaml.serializer.open()
+ self._output_inited = True
+
+ def dump(self, data, transform=None):
+ if not self._output_inited:
+ self.init_output(data)
+ try:
+ self._yaml.representer.represent(data)
+ except AttributeError:
+ # print(dir(dumper._representer))
+ raise
+
+ # def teardown_input(self):
+ # pass
+ #
+ # def init_input(self):
+ # # set the constructor and parser on YAML() instance
+ # self._yaml.get_constructor_parser(stream)
+ #
+ # def load(self):
+ # if not self._input_inited:
+ # self.init_input()
+ # try:
+ # while self._yaml.constructor.check_data():
+ # yield self._yaml.constructor.get_data()
+ # finally:
+ # parser.dispose()
+ # try:
+ # self._reader.reset_reader() # type: ignore
+ # except AttributeError:
+ # pass
+ # try:
+ # self._scanner.reset_scanner() # type: ignore
+ # except AttributeError:
+ # pass
+
+
def yaml_object(yml):
# type: (Any) -> Any
""" decorator for classes that needs to dump/load objects