summaryrefslogtreecommitdiff
path: root/src/saml2/ident.py
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2013-02-09 18:57:26 +0100
committerRoland Hedberg <roland.hedberg@adm.umu.se>2013-02-09 18:57:26 +0100
commitf295e06ab790c5454c164255328ebba4a0c4f2da (patch)
treef7b29dd6acb2515cd0afbec4a33a82bc812c2b37 /src/saml2/ident.py
parent71246a38292f2f598163ce4b69a3c298f95d04c2 (diff)
downloadpysaml2-f295e06ab790c5454c164255328ebba4a0c4f2da.tar.gz
Rewrote to use NameID instances every where where I previously used just the text part of the instance.
Diffstat (limited to 'src/saml2/ident.py')
-rw-r--r--src/saml2/ident.py46
1 files changed, 27 insertions, 19 deletions
diff --git a/src/saml2/ident.py b/src/saml2/ident.py
index c091f286..6fabe0b8 100644
--- a/src/saml2/ident.py
+++ b/src/saml2/ident.py
@@ -19,9 +19,11 @@ logger = logging.getLogger(__name__)
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id",
"text"]
+
class Unknown(Exception):
pass
+
def code(item):
_res = []
i = 0
@@ -32,13 +34,15 @@ def code(item):
i += 1
return ",".join(_res)
-def decode(str):
+
+def decode(txt):
_nid = NameID()
- for part in str.split(","):
+ for part in txt.split(","):
i, val = part.split("=")
setattr(_nid, ATTR[int(i)], unquote(val))
return _nid
+
class IdentDB(object):
""" A class that handles identifiers of entities
Keeps a list of all nameIDs returned per SP
@@ -51,19 +55,19 @@ class IdentDB(object):
self.domain = domain
self.name_qualifier = name_qualifier
- def _create_id(self, format, name_qualifier="", sp_name_qualifier=""):
+ def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""):
_id = sha256(rndstr(32))
- _id.update(format)
+ _id.update(nformat)
if name_qualifier:
_id.update(name_qualifier)
if sp_name_qualifier:
_id.update(sp_name_qualifier)
return _id.hexdigest()
- def create_id(self, format, name_qualifier="", sp_name_qualifier=""):
- _id = self._create_id(format, name_qualifier, sp_name_qualifier)
+ def create_id(self, nformat, name_qualifier="", sp_name_qualifier=""):
+ _id = self._create_id(nformat, name_qualifier, sp_name_qualifier)
while _id in self.db:
- _id = self._create_id(format, name_qualifier, sp_name_qualifier)
+ _id = self._create_id(nformat, name_qualifier, sp_name_qualifier)
return _id
def store(self, ident, name_id):
@@ -92,30 +96,30 @@ class IdentDB(object):
del self.db[_cn]
- def remove_local(self, id):
- if isinstance(id, unicode):
- id = id.encode("utf-8")
+ def remove_local(self, sid):
+ if isinstance(sid, unicode):
+ sid = sid.encode("utf-8")
try:
- for val in self.db[id].split(" "):
+ for val in self.db[sid].split(" "):
try:
del self.db[val]
except KeyError:
pass
- del self.db[id]
+ del self.db[sid]
except KeyError:
pass
- def get_nameid(self, userid, format, sp_name_qualifier, name_qualifier):
- _id = self.create_id(format, name_qualifier, sp_name_qualifier)
+ def get_nameid(self, userid, nformat, sp_name_qualifier, name_qualifier):
+ _id = self.create_id(nformat, name_qualifier, sp_name_qualifier)
- if format == NAMEID_FORMAT_EMAILADDRESS:
+ if nformat == NAMEID_FORMAT_EMAILADDRESS:
if not self.domain:
raise Exception("Can't issue email nameids, unknown domain")
_id = "%s@%s" % (_id, self.domain)
- nameid = NameID(format=format, sp_name_qualifier=sp_name_qualifier,
+ nameid = NameID(format=nformat, sp_name_qualifier=sp_name_qualifier,
name_qualifier=name_qualifier, text=_id)
self.store(userid, nameid)
@@ -150,8 +154,9 @@ class IdentDB(object):
if not name_qualifier:
name_qualifier = self.name_qualifier
- return {"format":nameid_format, "sp_name_qualifier": sp_name_qualifier,
- "name_qualifier":name_qualifier}
+ return {"nformat": nameid_format,
+ "sp_name_qualifier": sp_name_qualifier,
+ "name_qualifier": name_qualifier}
def construct_nameid(self, userid, local_policy=None,
sp_name_qualifier=None, name_id_policy=None,
@@ -175,7 +180,8 @@ class IdentDB(object):
return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT,
sp_name_qualifier, name_qualifier)
- def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
+ def persistent_nameid(self, userid, sp_name_qualifier="",
+ name_qualifier=""):
nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier)
if nameid:
return nameid
@@ -194,6 +200,8 @@ class IdentDB(object):
try:
return self.db[code(name_id)]
except KeyError:
+ logger.debug("name: %s" % code(name_id))
+ logger.debug("id keys: %s" % self.db.keys())
return None
def match_local_id(self, userid, sp_name_qualifier, name_qualifier):