From 40041b642e11a8838cc99f4ec358d134192f4166 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 21 Apr 2013 16:40:02 +0200 Subject: Adding authn context support --- tools/parse_xsd2.py | 183 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 145 insertions(+), 38 deletions(-) (limited to 'tools') diff --git a/tools/parse_xsd2.py b/tools/parse_xsd2.py index 6688521b..844ac99a 100755 --- a/tools/parse_xsd2.py +++ b/tools/parse_xsd2.py @@ -6,16 +6,11 @@ import getopt import imp import sys import types +import errno -__version__ = 0.4 +__version__ = 0.5 -try: - from xml.etree import cElementTree as ElementTree -except ImportError: - try: - import cElementTree as ElementTree - except ImportError: - from elementtree import ElementTree +from xml.etree import cElementTree as ElementTree INDENT = 4*" " DEBUG = False @@ -47,6 +42,7 @@ def class_pyify(ref): PROTECTED_KEYWORDS = ["import", "def", "if", "else", "return", "for", "while", "not", "try", "except", "in"] + def def_init(imports, attributes): indent = INDENT+INDENT indent3 = INDENT+INDENT+INDENT @@ -77,6 +73,7 @@ def def_init(imports, attributes): line.append("%s):" % indent) return line + def base_init(imports): line = [] indent4 = INDENT+INDENT+INDENT+INDENT @@ -104,6 +101,7 @@ def base_init(imports): line.append("%s)" % indent4) return line + def initialize(attributes): indent = INDENT+INDENT line = [] @@ -121,6 +119,7 @@ def initialize(attributes): line.append("%sself.%s=%s" % (indent, _name, _vname)) return line + def _mod_typ(prop): try: (mod, typ) = prop.type @@ -139,6 +138,7 @@ def _mod_typ(prop): return mod, typ + def _mod_cname(prop, cdict): if hasattr(prop, "scoped"): cname = prop.class_name @@ -155,6 +155,7 @@ def _mod_cname(prop, cdict): return mod, cname + def leading_uppercase(string): try: return string[0].upper()+string[1:] @@ -163,6 +164,7 @@ def leading_uppercase(string): except TypeError: return "" + def leading_lowercase(string): try: return string[0].lower()+string[1:] @@ -171,6 +173,7 @@ def leading_lowercase(string): except TypeError: return "" + def rm_duplicates(properties): keys = [] clist = [] @@ -189,12 +192,14 @@ def rm_duplicates(properties): # res.append(item) # return res + def klass_namn(obj): if obj.class_name: return obj.class_name else: return obj.name + class PyObj(object): def __init__(self, name=None, pyname=None, root=None): self.name = name @@ -622,7 +627,7 @@ class PyType(PyObj): self.namespace = namespace def text(self, target_namespace, cdict, _child=True, ignore=None, - _session=None): + _session=None): if not self.properties and not self.type \ and not self.superior: self.done = True @@ -715,6 +720,8 @@ class PyAttribute(PyObj): self.namespace = namespace self.base = None self.type = typ + self.fixed = False + self.default = None def text(self, _target_namespace, cdict, _child=True): if isinstance(self.type, PyObj): @@ -752,6 +759,7 @@ class PyGroup(object): self.root = root self.properties = [] self.done = False + self.ref = [] def text(self, _target_namespace, _dict, _child, _ignore): return [], [] @@ -837,6 +845,9 @@ def _namespace_and_tag(obj, param, top): except ValueError: namespace = "" tag = param + # except AttributeError: + # namespace = "" + # tag = obj.name return namespace, tag @@ -852,6 +863,7 @@ class Simple(object): self.use = None self.ref = None self.scoped = False + self.itemType = None for attribute, value in elem.attrib.iteritems(): self.__setattr__(attribute, value) @@ -995,7 +1007,11 @@ class MaxExclusive(Simple): class List(Simple): pass - + +class Include(Simple): + pass + + # ----------------------------------------------------------------------------- def sequence(elem): @@ -1218,7 +1234,7 @@ class Element(Complex): objekt.type = typ objekt.value_type = {"base": typ} - except AttributeError, exc: + except AttributeError: # neither type nor reference, definitely local if hasattr(self, "parts"): if len(self.parts) == 1: @@ -1252,6 +1268,7 @@ class Element(Complex): return objekt + class SimpleType(Complex): def repr(self, top=None, _sup=None, _argv=None, _child=True, parent=""): if self.py_class: @@ -1284,6 +1301,7 @@ class SimpleType(Complex): self.py_class = obj return obj + class Sequence(Complex): def collect(self, top, sup, argv=None, parent=""): argv_copy = sd_copy(argv) @@ -1295,15 +1313,23 @@ class Sequence(Complex): print "#Sequence: %s" % argv return Complex.collect(self, top, sup, argv_copy, parent) + class SimpleContent(Complex): pass + class ComplexContent(Complex): pass + class Key(Complex): pass + +class Redefine(Complex): + pass + + class Extension(Complex): def collect(self, top, sup, argv=None, parent=""): if self._own or self._inherited: @@ -1330,8 +1356,8 @@ class Extension(Complex): #print "#EXT..-", ia self._inherited = iattr except (AttributeError, ValueError): - pass - + base = None + self._extend(top, sup, argv, parent, base) return self._own, self._inherited @@ -1525,6 +1551,7 @@ def pyify_0(name): res += "_" return res + def pyify(name): # AssertionIDRef res = [] @@ -1532,7 +1559,7 @@ def pyify(name): upc = [] pre = "" for char in name: - if char >= "A" and char <= "Z": + if "A" <= char <= "Z": upc.append(char) elif char == "-": upc.append("_") @@ -1558,6 +1585,7 @@ def pyify(name): return "".join(res) + def get_type_def( typ, defs): for cdef in defs: try: @@ -1599,6 +1627,7 @@ def sort_elements(els): return res, els + def output(elem, target_namespace, eldict, ignore=None): done = 0 @@ -1631,6 +1660,7 @@ def output(elem, target_namespace, eldict, ignore=None): return done + def intro(): print """#!/usr/bin/env python @@ -1644,6 +1674,7 @@ from saml2 import SamlBase #NAMESPACE = 'http://www.w3.org/2000/09/xmldsig#' + def block_items(objekt, block, eldict): if objekt not in block: if isinstance(objekt.type, PyType): @@ -1666,6 +1697,8 @@ def find_parent(elm, eldict): return find_parent(sup, eldict) elif elm.ref: sup = eldict[elm.ref] + if sup.name == elm.name: + return elm return find_parent(sup, eldict) else: if elm.superior: @@ -1676,6 +1709,7 @@ def find_parent(elm, eldict): return elm + class Schema(Complex): def __init__(self, elem, impo, add, modul, defs): @@ -1942,6 +1976,8 @@ _MAP = { "selector": Selector, "field": Field, "key": Key, + "include": Include, + "redefine": Redefine } ELEMENTFUNCTION = {} @@ -1961,7 +1997,6 @@ def evaluate(typ, elem): NS_MAP = "xmlns_map" def parse_nsmap(fil): - events = "start", "start-ns", "end-ns" root = None @@ -2015,10 +2050,99 @@ def get_mod(name, path=None): raise sys.modules[name] = mod_a return mod_a - + + +def recursive_add_xmlns_map(_sch, base): + for _part in _sch.parts: + _part.xmlns_map.update(base.xmlns_map) + if isinstance(_part, Complex): + recursive_add_xmlns_map(_part, base) + +def find_and_replace(base, mods): + base.xmlns_map = mods.xmlns_map + recursive_add_xmlns_map(base, mods) + rm = [] + for part in mods.parts: + try: + _name = part.name + except AttributeError: + continue + for _part in base.parts: + try: + if _name == _part.name: + rm.append(_part) + except AttributeError: + continue + for part in rm: + base.parts.remove(part) + base.parts.extend(mods.parts) + return base + +def read_schema(doc, add, defs, impo, modul, ignore, sdir): + for path in sdir: + fil = "%s%s" % (path, doc) + try: + fp = open(fil) + fp.close() + break + except IOError as e: + if e.errno == errno.EACCES: + continue + else: + raise Exception("Could not find schema file") + + tree = parse_nsmap(fil) + + known = NAMESPACE_BASE[:] + known.append(XML_NAMESPACE) + for key, namespace in tree._root.attrib["xmlns_map"].items(): + if namespace in known: + continue + else: + try: + modul[key] = modul[namespace] + impo[namespace][1] = key + except KeyError: + if namespace == tree._root.attrib["targetNamespace"]: + continue + elif namespace in ignore: + continue + else: + raise Exception("Undefined namespace: %s" % namespace) + + _schema = Schema(tree._root, impo, add, modul, defs) + _included_parts = [] + _remove_parts = [] + _replace = [] + for part in _schema.parts: + if isinstance(part, Include): + _sch = read_schema(part.schemaLocation, add, defs, impo, modul, + ignore, sdir) + # Add namespace information + recursive_add_xmlns_map(_sch, _schema) + _included_parts.extend(_sch.parts) + _remove_parts.append(part) + elif isinstance(part, Redefine): + # This is the schema that is going to be redefined + _redef = read_schema(part.schemaLocation, add, defs, impo, modul, + ignore, sdir) + # so find and replace + # Use the schema to be redefined as starting point + _replacement = find_and_replace(_redef, part) + _replace.append((part, _replacement.parts)) + + for part in _remove_parts: + _schema.parts.remove(part) + _schema.parts.extend(_included_parts) + if _replace: + for vad, med in _replace: + _schema.parts.remove(vad) + _schema.parts.extend(med) + return _schema + def main(argv): try: - opts, args = getopt.getopt(argv, "a:d:hi:I:", + opts, args = getopt.getopt(argv, "a:d:hi:I:s:", ["add=", "help", "import=", "defs="]) except getopt.GetoptError, err: # print help information and exit: @@ -2031,6 +2155,7 @@ def main(argv): impo = {} modul = {} ignore = [] + sdir = ["./"] for opt, arg in opts: if opt in ("-a", "--add"): @@ -2040,6 +2165,8 @@ def main(argv): elif opt in ("-h", "--help"): usage() sys.exit() + elif opt in ("-s", "--schemadir"): + sdir.append(arg) elif opt in ("-i", "--import"): mod = get_mod(arg, ['.']) modul[mod.NAMESPACE] = mod @@ -2053,28 +2180,8 @@ def main(argv): print "No XSD-file specified" usage() sys.exit(2) - - tree = parse_nsmap(args[0]) - - known = NAMESPACE_BASE[:] - known.append(XML_NAMESPACE) - for key, namespace in tree._root.attrib["xmlns_map"].items(): - if namespace in known: - continue - else: - try: - modul[key] = modul[namespace] - impo[namespace][1] = key - except KeyError: - if namespace == tree._root.attrib["targetNamespace"]: - continue - elif namespace in ignore: - continue - else: - raise Exception("Undefined namespace: %s" % namespace) - - schema = Schema(tree._root, impo, add, modul, defs) + schema = read_schema(args[0], add, defs, impo, modul, ignore, sdir) #print schema.__dict__ schema.out() -- cgit v1.2.1