diff options
author | Roland Hedberg <roland.hedberg@adm.umu.se> | 2013-02-09 18:57:26 +0100 |
---|---|---|
committer | Roland Hedberg <roland.hedberg@adm.umu.se> | 2013-02-09 18:57:26 +0100 |
commit | f295e06ab790c5454c164255328ebba4a0c4f2da (patch) | |
tree | f7b29dd6acb2515cd0afbec4a33a82bc812c2b37 /src/saml2/ident.py | |
parent | 71246a38292f2f598163ce4b69a3c298f95d04c2 (diff) | |
download | pysaml2-f295e06ab790c5454c164255328ebba4a0c4f2da.tar.gz |
Rewrote to use NameID instances every where where I previously used just the text part of the instance.
Diffstat (limited to 'src/saml2/ident.py')
-rw-r--r-- | src/saml2/ident.py | 46 |
1 files changed, 27 insertions, 19 deletions
diff --git a/src/saml2/ident.py b/src/saml2/ident.py index c091f286..6fabe0b8 100644 --- a/src/saml2/ident.py +++ b/src/saml2/ident.py @@ -19,9 +19,11 @@ logger = logging.getLogger(__name__) ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id", "text"] + class Unknown(Exception): pass + def code(item): _res = [] i = 0 @@ -32,13 +34,15 @@ def code(item): i += 1 return ",".join(_res) -def decode(str): + +def decode(txt): _nid = NameID() - for part in str.split(","): + for part in txt.split(","): i, val = part.split("=") setattr(_nid, ATTR[int(i)], unquote(val)) return _nid + class IdentDB(object): """ A class that handles identifiers of entities Keeps a list of all nameIDs returned per SP @@ -51,19 +55,19 @@ class IdentDB(object): self.domain = domain self.name_qualifier = name_qualifier - def _create_id(self, format, name_qualifier="", sp_name_qualifier=""): + def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): _id = sha256(rndstr(32)) - _id.update(format) + _id.update(nformat) if name_qualifier: _id.update(name_qualifier) if sp_name_qualifier: _id.update(sp_name_qualifier) return _id.hexdigest() - def create_id(self, format, name_qualifier="", sp_name_qualifier=""): - _id = self._create_id(format, name_qualifier, sp_name_qualifier) + def create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): + _id = self._create_id(nformat, name_qualifier, sp_name_qualifier) while _id in self.db: - _id = self._create_id(format, name_qualifier, sp_name_qualifier) + _id = self._create_id(nformat, name_qualifier, sp_name_qualifier) return _id def store(self, ident, name_id): @@ -92,30 +96,30 @@ class IdentDB(object): del self.db[_cn] - def remove_local(self, id): - if isinstance(id, unicode): - id = id.encode("utf-8") + def remove_local(self, sid): + if isinstance(sid, unicode): + sid = sid.encode("utf-8") try: - for val in self.db[id].split(" "): + for val in self.db[sid].split(" "): try: del self.db[val] except KeyError: pass - del self.db[id] + del self.db[sid] except KeyError: pass - def get_nameid(self, userid, format, sp_name_qualifier, name_qualifier): - _id = self.create_id(format, name_qualifier, sp_name_qualifier) + def get_nameid(self, userid, nformat, sp_name_qualifier, name_qualifier): + _id = self.create_id(nformat, name_qualifier, sp_name_qualifier) - if format == NAMEID_FORMAT_EMAILADDRESS: + if nformat == NAMEID_FORMAT_EMAILADDRESS: if not self.domain: raise Exception("Can't issue email nameids, unknown domain") _id = "%s@%s" % (_id, self.domain) - nameid = NameID(format=format, sp_name_qualifier=sp_name_qualifier, + nameid = NameID(format=nformat, sp_name_qualifier=sp_name_qualifier, name_qualifier=name_qualifier, text=_id) self.store(userid, nameid) @@ -150,8 +154,9 @@ class IdentDB(object): if not name_qualifier: name_qualifier = self.name_qualifier - return {"format":nameid_format, "sp_name_qualifier": sp_name_qualifier, - "name_qualifier":name_qualifier} + return {"nformat": nameid_format, + "sp_name_qualifier": sp_name_qualifier, + "name_qualifier": name_qualifier} def construct_nameid(self, userid, local_policy=None, sp_name_qualifier=None, name_id_policy=None, @@ -175,7 +180,8 @@ class IdentDB(object): return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT, sp_name_qualifier, name_qualifier) - def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): + def persistent_nameid(self, userid, sp_name_qualifier="", + name_qualifier=""): nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier) if nameid: return nameid @@ -194,6 +200,8 @@ class IdentDB(object): try: return self.db[code(name_id)] except KeyError: + logger.debug("name: %s" % code(name_id)) + logger.debug("id keys: %s" % self.db.keys()) return None def match_local_id(self, userid, sp_name_qualifier, name_qualifier): |