summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2013-04-21 16:40:02 +0200
committerRoland Hedberg <roland.hedberg@adm.umu.se>2013-04-21 16:40:02 +0200
commit40041b642e11a8838cc99f4ec358d134192f4166 (patch)
tree361000a181548088811d91b0042a820f99beaf0a /tools
parentc0d9144172a41cfa508dadc239d3fb6b3dbe14a8 (diff)
downloadpysaml2-40041b642e11a8838cc99f4ec358d134192f4166.tar.gz
Adding authn context support
Diffstat (limited to 'tools')
-rwxr-xr-xtools/parse_xsd2.py183
1 files changed, 145 insertions, 38 deletions
diff --git a/tools/parse_xsd2.py b/tools/parse_xsd2.py
index 6688521b..844ac99a 100755
--- a/tools/parse_xsd2.py
+++ b/tools/parse_xsd2.py
@@ -6,16 +6,11 @@ import getopt
import imp
import sys
import types
+import errno
-__version__ = 0.4
+__version__ = 0.5
-try:
- from xml.etree import cElementTree as ElementTree
-except ImportError:
- try:
- import cElementTree as ElementTree
- except ImportError:
- from elementtree import ElementTree
+from xml.etree import cElementTree as ElementTree
INDENT = 4*" "
DEBUG = False
@@ -47,6 +42,7 @@ def class_pyify(ref):
PROTECTED_KEYWORDS = ["import", "def", "if", "else", "return", "for",
"while", "not", "try", "except", "in"]
+
def def_init(imports, attributes):
indent = INDENT+INDENT
indent3 = INDENT+INDENT+INDENT
@@ -77,6 +73,7 @@ def def_init(imports, attributes):
line.append("%s):" % indent)
return line
+
def base_init(imports):
line = []
indent4 = INDENT+INDENT+INDENT+INDENT
@@ -104,6 +101,7 @@ def base_init(imports):
line.append("%s)" % indent4)
return line
+
def initialize(attributes):
indent = INDENT+INDENT
line = []
@@ -121,6 +119,7 @@ def initialize(attributes):
line.append("%sself.%s=%s" % (indent, _name, _vname))
return line
+
def _mod_typ(prop):
try:
(mod, typ) = prop.type
@@ -139,6 +138,7 @@ def _mod_typ(prop):
return mod, typ
+
def _mod_cname(prop, cdict):
if hasattr(prop, "scoped"):
cname = prop.class_name
@@ -155,6 +155,7 @@ def _mod_cname(prop, cdict):
return mod, cname
+
def leading_uppercase(string):
try:
return string[0].upper()+string[1:]
@@ -163,6 +164,7 @@ def leading_uppercase(string):
except TypeError:
return ""
+
def leading_lowercase(string):
try:
return string[0].lower()+string[1:]
@@ -171,6 +173,7 @@ def leading_lowercase(string):
except TypeError:
return ""
+
def rm_duplicates(properties):
keys = []
clist = []
@@ -189,12 +192,14 @@ def rm_duplicates(properties):
# res.append(item)
# return res
+
def klass_namn(obj):
if obj.class_name:
return obj.class_name
else:
return obj.name
+
class PyObj(object):
def __init__(self, name=None, pyname=None, root=None):
self.name = name
@@ -622,7 +627,7 @@ class PyType(PyObj):
self.namespace = namespace
def text(self, target_namespace, cdict, _child=True, ignore=None,
- _session=None):
+ _session=None):
if not self.properties and not self.type \
and not self.superior:
self.done = True
@@ -715,6 +720,8 @@ class PyAttribute(PyObj):
self.namespace = namespace
self.base = None
self.type = typ
+ self.fixed = False
+ self.default = None
def text(self, _target_namespace, cdict, _child=True):
if isinstance(self.type, PyObj):
@@ -752,6 +759,7 @@ class PyGroup(object):
self.root = root
self.properties = []
self.done = False
+ self.ref = []
def text(self, _target_namespace, _dict, _child, _ignore):
return [], []
@@ -837,6 +845,9 @@ def _namespace_and_tag(obj, param, top):
except ValueError:
namespace = ""
tag = param
+ # except AttributeError:
+ # namespace = ""
+ # tag = obj.name
return namespace, tag
@@ -852,6 +863,7 @@ class Simple(object):
self.use = None
self.ref = None
self.scoped = False
+ self.itemType = None
for attribute, value in elem.attrib.iteritems():
self.__setattr__(attribute, value)
@@ -995,7 +1007,11 @@ class MaxExclusive(Simple):
class List(Simple):
pass
-
+
+class Include(Simple):
+ pass
+
+
# -----------------------------------------------------------------------------
def sequence(elem):
@@ -1218,7 +1234,7 @@ class Element(Complex):
objekt.type = typ
objekt.value_type = {"base": typ}
- except AttributeError, exc:
+ except AttributeError:
# neither type nor reference, definitely local
if hasattr(self, "parts"):
if len(self.parts) == 1:
@@ -1252,6 +1268,7 @@ class Element(Complex):
return objekt
+
class SimpleType(Complex):
def repr(self, top=None, _sup=None, _argv=None, _child=True, parent=""):
if self.py_class:
@@ -1284,6 +1301,7 @@ class SimpleType(Complex):
self.py_class = obj
return obj
+
class Sequence(Complex):
def collect(self, top, sup, argv=None, parent=""):
argv_copy = sd_copy(argv)
@@ -1295,15 +1313,23 @@ class Sequence(Complex):
print "#Sequence: %s" % argv
return Complex.collect(self, top, sup, argv_copy, parent)
+
class SimpleContent(Complex):
pass
+
class ComplexContent(Complex):
pass
+
class Key(Complex):
pass
+
+class Redefine(Complex):
+ pass
+
+
class Extension(Complex):
def collect(self, top, sup, argv=None, parent=""):
if self._own or self._inherited:
@@ -1330,8 +1356,8 @@ class Extension(Complex):
#print "#EXT..-", ia
self._inherited = iattr
except (AttributeError, ValueError):
- pass
-
+ base = None
+
self._extend(top, sup, argv, parent, base)
return self._own, self._inherited
@@ -1525,6 +1551,7 @@ def pyify_0(name):
res += "_"
return res
+
def pyify(name):
# AssertionIDRef
res = []
@@ -1532,7 +1559,7 @@ def pyify(name):
upc = []
pre = ""
for char in name:
- if char >= "A" and char <= "Z":
+ if "A" <= char <= "Z":
upc.append(char)
elif char == "-":
upc.append("_")
@@ -1558,6 +1585,7 @@ def pyify(name):
return "".join(res)
+
def get_type_def( typ, defs):
for cdef in defs:
try:
@@ -1599,6 +1627,7 @@ def sort_elements(els):
return res, els
+
def output(elem, target_namespace, eldict, ignore=None):
done = 0
@@ -1631,6 +1660,7 @@ def output(elem, target_namespace, eldict, ignore=None):
return done
+
def intro():
print """#!/usr/bin/env python
@@ -1644,6 +1674,7 @@ from saml2 import SamlBase
#NAMESPACE = 'http://www.w3.org/2000/09/xmldsig#'
+
def block_items(objekt, block, eldict):
if objekt not in block:
if isinstance(objekt.type, PyType):
@@ -1666,6 +1697,8 @@ def find_parent(elm, eldict):
return find_parent(sup, eldict)
elif elm.ref:
sup = eldict[elm.ref]
+ if sup.name == elm.name:
+ return elm
return find_parent(sup, eldict)
else:
if elm.superior:
@@ -1676,6 +1709,7 @@ def find_parent(elm, eldict):
return elm
+
class Schema(Complex):
def __init__(self, elem, impo, add, modul, defs):
@@ -1942,6 +1976,8 @@ _MAP = {
"selector": Selector,
"field": Field,
"key": Key,
+ "include": Include,
+ "redefine": Redefine
}
ELEMENTFUNCTION = {}
@@ -1961,7 +1997,6 @@ def evaluate(typ, elem):
NS_MAP = "xmlns_map"
def parse_nsmap(fil):
-
events = "start", "start-ns", "end-ns"
root = None
@@ -2015,10 +2050,99 @@ def get_mod(name, path=None):
raise
sys.modules[name] = mod_a
return mod_a
-
+
+
+def recursive_add_xmlns_map(_sch, base):
+ for _part in _sch.parts:
+ _part.xmlns_map.update(base.xmlns_map)
+ if isinstance(_part, Complex):
+ recursive_add_xmlns_map(_part, base)
+
+def find_and_replace(base, mods):
+ base.xmlns_map = mods.xmlns_map
+ recursive_add_xmlns_map(base, mods)
+ rm = []
+ for part in mods.parts:
+ try:
+ _name = part.name
+ except AttributeError:
+ continue
+ for _part in base.parts:
+ try:
+ if _name == _part.name:
+ rm.append(_part)
+ except AttributeError:
+ continue
+ for part in rm:
+ base.parts.remove(part)
+ base.parts.extend(mods.parts)
+ return base
+
+def read_schema(doc, add, defs, impo, modul, ignore, sdir):
+ for path in sdir:
+ fil = "%s%s" % (path, doc)
+ try:
+ fp = open(fil)
+ fp.close()
+ break
+ except IOError as e:
+ if e.errno == errno.EACCES:
+ continue
+ else:
+ raise Exception("Could not find schema file")
+
+ tree = parse_nsmap(fil)
+
+ known = NAMESPACE_BASE[:]
+ known.append(XML_NAMESPACE)
+ for key, namespace in tree._root.attrib["xmlns_map"].items():
+ if namespace in known:
+ continue
+ else:
+ try:
+ modul[key] = modul[namespace]
+ impo[namespace][1] = key
+ except KeyError:
+ if namespace == tree._root.attrib["targetNamespace"]:
+ continue
+ elif namespace in ignore:
+ continue
+ else:
+ raise Exception("Undefined namespace: %s" % namespace)
+
+ _schema = Schema(tree._root, impo, add, modul, defs)
+ _included_parts = []
+ _remove_parts = []
+ _replace = []
+ for part in _schema.parts:
+ if isinstance(part, Include):
+ _sch = read_schema(part.schemaLocation, add, defs, impo, modul,
+ ignore, sdir)
+ # Add namespace information
+ recursive_add_xmlns_map(_sch, _schema)
+ _included_parts.extend(_sch.parts)
+ _remove_parts.append(part)
+ elif isinstance(part, Redefine):
+ # This is the schema that is going to be redefined
+ _redef = read_schema(part.schemaLocation, add, defs, impo, modul,
+ ignore, sdir)
+ # so find and replace
+ # Use the schema to be redefined as starting point
+ _replacement = find_and_replace(_redef, part)
+ _replace.append((part, _replacement.parts))
+
+ for part in _remove_parts:
+ _schema.parts.remove(part)
+ _schema.parts.extend(_included_parts)
+ if _replace:
+ for vad, med in _replace:
+ _schema.parts.remove(vad)
+ _schema.parts.extend(med)
+ return _schema
+
def main(argv):
try:
- opts, args = getopt.getopt(argv, "a:d:hi:I:",
+ opts, args = getopt.getopt(argv, "a:d:hi:I:s:",
["add=", "help", "import=", "defs="])
except getopt.GetoptError, err:
# print help information and exit:
@@ -2031,6 +2155,7 @@ def main(argv):
impo = {}
modul = {}
ignore = []
+ sdir = ["./"]
for opt, arg in opts:
if opt in ("-a", "--add"):
@@ -2040,6 +2165,8 @@ def main(argv):
elif opt in ("-h", "--help"):
usage()
sys.exit()
+ elif opt in ("-s", "--schemadir"):
+ sdir.append(arg)
elif opt in ("-i", "--import"):
mod = get_mod(arg, ['.'])
modul[mod.NAMESPACE] = mod
@@ -2053,28 +2180,8 @@ def main(argv):
print "No XSD-file specified"
usage()
sys.exit(2)
-
- tree = parse_nsmap(args[0])
-
- known = NAMESPACE_BASE[:]
- known.append(XML_NAMESPACE)
- for key, namespace in tree._root.attrib["xmlns_map"].items():
- if namespace in known:
- continue
- else:
- try:
- modul[key] = modul[namespace]
- impo[namespace][1] = key
- except KeyError:
- if namespace == tree._root.attrib["targetNamespace"]:
- continue
- elif namespace in ignore:
- continue
- else:
- raise Exception("Undefined namespace: %s" % namespace)
-
- schema = Schema(tree._root, impo, add, modul, defs)
+ schema = read_schema(args[0], add, defs, impo, modul, ignore, sdir)
#print schema.__dict__
schema.out()