summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorxi <xi@18f92427-320e-0410-9341-c67f048884a3>2006-04-18 14:35:28 +0000
committerxi <xi@18f92427-320e-0410-9341-c67f048884a3>2006-04-18 14:35:28 +0000
commit26665cbc66afb67dc3ee8c0131b8bd851e243928 (patch)
treecfafd978049ce08c6992517b1f94d33d630a304f
parent4e78012403a451090b72af539df0d976e8148ec4 (diff)
downloadpyyaml-26665cbc66afb67dc3ee8c0131b8bd851e243928.tar.gz
Add constructors for some simple python types.
git-svn-id: http://svn.pyyaml.org/pyyaml/trunk@139 18f92427-320e-0410-9341-c67f048884a3
-rw-r--r--lib/yaml/constructor.py131
-rw-r--r--lib/yaml/representer.py157
-rw-r--r--lib/yaml/serializer.py11
3 files changed, 256 insertions, 43 deletions
diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py
index 9fa9085..ff205c2 100644
--- a/lib/yaml/constructor.py
+++ b/lib/yaml/constructor.py
@@ -17,7 +17,7 @@ try:
except NameError:
from sets import Set as set
-import binascii, re
+import binascii, re, sys
class ConstructorError(MarkedYAMLError):
pass
@@ -61,7 +61,7 @@ class BaseConstructor(Composer):
tag_suffix = node.tag[len(tag_prefix):]
constructor = lambda node: \
self.yaml_multi_constructors[tag_prefix](self, tag_suffix, node)
- break
+ break
else:
if None in self.yaml_multi_constructors:
constructor = lambda node: \
@@ -75,6 +75,8 @@ class BaseConstructor(Composer):
constructor = self.construct_sequence
elif isinstance(node, MappingNode):
constructor = self.construct_mapping
+ else:
+ print node.tag
data = constructor(node)
self.constructed_objects[node] = data
return data
@@ -349,15 +351,12 @@ class SafeConstructor(BaseConstructor):
return self.construct_mapping(node)
def construct_yaml_object(self, node, cls):
- mapping = self.construct_mapping(node)
- state = {}
- for key in mapping:
- state[key.replace('-', '_')] = mapping[key]
+ state = self.construct_mapping(node)
data = cls.__new__(cls)
if hasattr(data, '__setstate__'):
- data.__setstate__(mapping)
+ data.__setstate__(state)
else:
- data.__dict__.update(mapping)
+ data.__dict__.update(state)
return data
def construct_undefined(self, node):
@@ -418,5 +417,119 @@ SafeConstructor.add_constructor(None,
SafeConstructor.construct_undefined)
class Constructor(SafeConstructor):
- pass
+
+ def construct_python_str(self, node):
+ return self.construct_scalar(node).encode('utf-8')
+
+ def construct_python_unicode(self, node):
+ return self.construct_scalar(node)
+
+ def construct_python_long(self, node):
+ return long(self.construct_yaml_int(node))
+
+ def construct_python_complex(self, node):
+ return complex(self.construct_scalar(node))
+
+ def construct_python_tuple(self, node):
+ return tuple(self.construct_yaml_seq(node))
+
+ def find_python_module(self, name, mark):
+ if not name:
+ raise ConstructorError("while constructing a Python module", mark,
+ "expected non-empty name appended to the tag", mark)
+ try:
+ __import__(name)
+ except ImportError, exc:
+ raise ConstructorError("while constructing a Python module", mark,
+ "cannot find module %r (%s)" % (name.encode('utf-8'), exc), mark)
+ return sys.modules[name]
+
+ def find_python_name(self, name, mark):
+ if not name:
+ raise ConstructorError("while constructing a Python object", mark,
+ "expected non-empty name appended to the tag", mark)
+ if u'.' in name:
+ module_name, object_name = name.rsplit('.', 1)
+ else:
+ module_name = '__builtin__'
+ object_name = name
+ try:
+ __import__(module_name)
+ except ImportError, exc:
+ raise ConstructorError("while constructing a Python object", mark,
+ "cannot find module %r (%s)" % (module_name.encode('utf-8'), exc), mark)
+ module = sys.modules[module_name]
+ if not hasattr(module, object_name):
+ raise ConstructorError("while constructing a Python object", mark,
+ "cannot find %r in the module %r" % (object_name.encode('utf-8'),
+ module.__name__), mark)
+ return getattr(module, object_name)
+
+ def construct_python_name(self, suffix, node):
+ value = self.construct_scalar(node)
+ if value:
+ raise ConstructorError("while constructing a Python name", node.start_mark,
+ "expected the empty value, but found %r" % value.encode('utf-8'),
+ node.start_mark)
+ return self.find_python_name(suffix, node.start_mark)
+
+ def construct_python_module(self, suffix, node):
+ value = self.construct_scalar(node)
+ if value:
+ raise ConstructorError("while constructing a Python module", node.start_mark,
+ "expected the empty value, but found %r" % value.encode('utf-8'),
+ node.start_mark)
+ return self.find_python_module(suffix, node.start_mark)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/none',
+ Constructor.construct_yaml_null)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/bool',
+ Constructor.construct_yaml_bool)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/str',
+ Constructor.construct_python_str)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/unicode',
+ Constructor.construct_python_unicode)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/int',
+ Constructor.construct_yaml_int)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/long',
+ Constructor.construct_python_long)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/float',
+ Constructor.construct_yaml_float)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/complex',
+ Constructor.construct_python_complex)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/list',
+ Constructor.construct_yaml_seq)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/tuple',
+ Constructor.construct_python_tuple)
+
+Constructor.add_constructor(
+ u'tag:yaml.org,2002:python/dict',
+ Constructor.construct_yaml_map)
+
+Constructor.add_multi_constructor(
+ u'tag:yaml.org,2002:python/name:',
+ Constructor.construct_python_name)
+
+Constructor.add_multi_constructor(
+ u'tag:yaml.org,2002:python/module:',
+ Constructor.construct_python_module)
diff --git a/lib/yaml/representer.py b/lib/yaml/representer.py
index 32c91d1..797d865 100644
--- a/lib/yaml/representer.py
+++ b/lib/yaml/representer.py
@@ -16,6 +16,8 @@ try:
except NameError:
from sets import Set as set
+import sys
+
class RepresenterError(YAMLError):
pass
@@ -31,6 +33,22 @@ class BaseRepresenter:
self.serialize(node)
self.represented_objects = {}
+ class C: pass
+ c = C()
+ def f(): pass
+ classobj_type = type(C)
+ instance_type = type(c)
+ function_type = type(f)
+ builtin_function_type = type(abs)
+ module_type = type(sys)
+ del C, c, f
+
+ def get_classobj_bases(self, cls):
+ bases = [cls]
+ for base in cls.__bases__:
+ bases.extend(self.get_classobj_bases(base))
+ return bases
+
def represent_object(self, data):
if self.ignore_aliases(data):
alias_key = None
@@ -43,7 +61,10 @@ class BaseRepresenter:
raise RepresenterError("recursive objects are not allowed: %r" % data)
return node
self.represented_objects[alias_key] = None
- for data_type in type(data).__mro__:
+ data_types = type(data).__mro__
+ if type(data) is self.instance_type:
+ data_types = self.get_classobj_bases(data.__class__)+data_types
+ for data_type in data_types:
if data_type in self.yaml_representers:
node = self.yaml_representers[data_type](self, data)
break
@@ -72,16 +93,17 @@ class BaseRepresenter:
return SequenceNode(tag, value, flow_style=flow_style)
def represent_mapping(self, tag, mapping, flow_style=None):
- value = {}
if hasattr(mapping, 'keys'):
+ value = {}
for item_key in mapping.keys():
item_value = mapping[item_key]
value[self.represent_object(item_key)] = \
self.represent_object(item_value)
else:
+ value = []
for item_key, item_value in mapping:
- value[self.represent_object(item_key)] = \
- self.represent_object(item_value)
+ value.append((self.represent_object(item_key),
+ self.represent_object(item_value)))
return MappingNode(tag, value, flow_style=flow_style)
def ignore_aliases(self, data):
@@ -100,22 +122,20 @@ class SafeRepresenter(BaseRepresenter):
u'null')
def represent_str(self, data):
- encoding = None
+ tag = None
+ style = None
try:
- unicode(data, 'ascii')
- encoding = 'ascii'
+ data = unicode(data, 'ascii')
+ tag = u'tag:yaml.org,2002:str'
except UnicodeDecodeError:
try:
- unicode(data, 'utf-8')
- encoding = 'utf-8'
+ data = unicode(data, 'utf-8')
+ tag = u'tag:yaml.org,2002:str'
except UnicodeDecodeError:
- pass
- if encoding:
- return self.represent_scalar(u'tag:yaml.org,2002:str',
- unicode(data, encoding))
- else:
- return self.represent_scalar(u'tag:yaml.org,2002:binary',
- unicode(data.encode('base64')), style='|')
+ data = data.encode('base64')
+ tag = u'tag:yaml.org,2002:binary'
+ style = '|'
+ return self.represent_scalar(tag, data, style=style)
def represent_unicode(self, data):
return self.represent_scalar(u'tag:yaml.org,2002:str', data)
@@ -144,15 +164,16 @@ class SafeRepresenter(BaseRepresenter):
elif data == self.nan_value or data != data:
value = u'.nan'
else:
- value = unicode(data)
+ value = unicode(repr(data))
return self.represent_scalar(u'tag:yaml.org,2002:float', value)
def represent_list(self, data):
- pairs = (len(data) > 0)
- for item in data:
- if not isinstance(item, tuple) or len(item) != 2:
- pairs = False
- break
+ pairs = (len(data) > 0 and isinstance(data, list))
+ if pairs:
+ for item in data:
+ if not isinstance(item, tuple) or len(item) != 2:
+ pairs = False
+ break
if not pairs:
return self.represent_sequence(u'tag:yaml.org,2002:seq', data)
value = []
@@ -189,14 +210,7 @@ class SafeRepresenter(BaseRepresenter):
state = data.__getstate__()
else:
state = data.__dict__.copy()
- mapping = state
- if hasattr(state, 'keys'):
- mapping = []
- keys = state.keys()
- keys.sort()
- for key in keys:
- mapping.append((key.replace('_', '-'), state[key]))
- return self.represent_mapping(tag, mapping, flow_style=flow_style)
+ return self.represent_mapping(tag, state, flow_style=flow_style)
def represent_undefined(self, data):
raise RepresenterError("cannot represent an object: %s" % data)
@@ -225,6 +239,9 @@ SafeRepresenter.add_representer(float,
SafeRepresenter.add_representer(list,
SafeRepresenter.represent_list)
+SafeRepresenter.add_representer(tuple,
+ SafeRepresenter.represent_list)
+
SafeRepresenter.add_representer(dict,
SafeRepresenter.represent_dict)
@@ -241,5 +258,83 @@ SafeRepresenter.add_representer(None,
SafeRepresenter.represent_undefined)
class Representer(SafeRepresenter):
- pass
+
+ def represent_str(self, data):
+ tag = None
+ style = None
+ try:
+ data = unicode(data, 'ascii')
+ tag = u'tag:yaml.org,2002:str'
+ except UnicodeDecodeError:
+ try:
+ data = unicode(data, 'utf-8')
+ tag = u'tag:yaml.org,2002:python/str'
+ except UnicodeDecodeError:
+ data = data.encode('base64')
+ tag = u'tag:yaml.org,2002:binary'
+ style = '|'
+ return self.represent_scalar(tag, data, style=style)
+
+ def represent_unicode(self, data):
+ tag = None
+ try:
+ data.encode('ascii')
+ tag = u'tag:yaml.org,2002:python/unicode'
+ except UnicodeEncodeError:
+ tag = u'tag:yaml.org,2002:str'
+ return self.represent_scalar(tag, data)
+
+ def represent_long(self, data):
+ tag = u'tag:yaml.org,2002:int'
+ if int(data) is not data:
+ tag = u'tag:yaml.org,2002:python/long'
+ return self.represent_scalar(tag, unicode(data))
+
+ def represent_complex(self, data):
+ if data.real != 0.0:
+ data = u'%r+%rj' % (data.real, data.imag)
+ else:
+ data = u'%rj' % data.imag
+ return self.represent_scalar(u'tag:yaml.org,2002:python/complex', data)
+
+ def represent_tuple(self, data):
+ return self.represent_sequence(u'tag:yaml.org,2002:python/tuple', data)
+
+ def represent_name(self, data):
+ name = u'%s.%s' % (data.__module__, data.__name__)
+ return self.represent_scalar(u'tag:yaml.org,2002:python/name:'+name, u'')
+
+ def represent_module(self, data):
+ return self.represent_scalar(
+ u'tag:yaml.org,2002:python/module:'+data.__name__, u'')
+
+Representer.add_representer(str,
+ Representer.represent_str)
+
+Representer.add_representer(unicode,
+ Representer.represent_unicode)
+
+Representer.add_representer(long,
+ Representer.represent_long)
+
+Representer.add_representer(complex,
+ Representer.represent_complex)
+
+Representer.add_representer(tuple,
+ Representer.represent_tuple)
+
+Representer.add_representer(type,
+ Representer.represent_name)
+
+Representer.add_representer(Representer.classobj_type,
+ Representer.represent_name)
+
+Representer.add_representer(Representer.function_type,
+ Representer.represent_name)
+
+Representer.add_representer(Representer.builtin_function_type,
+ Representer.represent_name)
+
+Representer.add_representer(Representer.module_type,
+ Representer.represent_module)
diff --git a/lib/yaml/serializer.py b/lib/yaml/serializer.py
index b57826d..937be9a 100644
--- a/lib/yaml/serializer.py
+++ b/lib/yaml/serializer.py
@@ -67,9 +67,14 @@ class Serializer:
for item in node.value:
self.anchor_node(item)
elif isinstance(node, MappingNode):
- for key in node.value:
- self.anchor_node(key)
- self.anchor_node(node.value[key])
+ if hasattr(node.value, 'keys'):
+ for key in node.value.keys():
+ self.anchor_node(key)
+ self.anchor_node(node.value[key])
+ else:
+ for key, value in node.value:
+ self.anchor_node(key)
+ self.anchor_node(value)
def generate_anchor(self, node):
self.last_anchor_id += 1