import copy import shelve import logging import six from hashlib import sha256 from six.moves.urllib.parse import quote from six.moves.urllib.parse import unquote from saml2 import SAMLError from saml2.s_utils import rndbytes from saml2.s_utils import PolicyError from saml2.saml import NameID from saml2.saml import NAMEID_FORMAT_PERSISTENT from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_EMAILADDRESS __author__ = 'rolandh' logger = logging.getLogger(__name__) ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id", "text"] class Unknown(SAMLError): pass def code(item): """ Turn a NameID class instance into a quoted string of comma separated attribute,value pairs. The attribute names are replaced with digits. Depends on knowledge on the specific order of the attributes for the class that is used. :param item: The class instance :return: A quoted string """ _res = [] i = 0 for attr in ATTR: val = getattr(item, attr) if val: _res.append("%d=%s" % (i, quote(val))) i += 1 return ",".join(_res) def code_binary(item): """ Return a binary 'code' suitable for hashing. """ code_str = code(item) if isinstance(code_str, six.string_types): return code_str.encode('utf-8') return code_str def decode(txt): """Turns a coded string by code() into a NameID class instance. :param txt: The coded string """ _nid = NameID() for part in txt.split(","): if part.find("=") != -1: i, val = part.split("=") try: setattr(_nid, ATTR[int(i)], unquote(val)) except: pass return _nid class IdentDB(object): """ A class that handles identifiers of entities Keeps a list of all nameIDs returned per SP """ def __init__(self, db, domain="", name_qualifier=""): if isinstance(db, six.string_types): self.db = shelve.open(db, protocol=2) else: self.db = db self.domain = domain self.name_qualifier = name_qualifier def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): _id = sha256(rndbytes(32)) if not isinstance(nformat, six.binary_type): nformat = nformat.encode('utf-8') _id.update(nformat) if name_qualifier: if not isinstance(name_qualifier, six.binary_type): name_qualifier = name_qualifier.encode('utf-8') _id.update(name_qualifier) if sp_name_qualifier: if not isinstance(sp_name_qualifier, six.binary_type): sp_name_qualifier = sp_name_qualifier.encode('utf-8') _id.update(sp_name_qualifier) return _id.hexdigest() def create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): _id = self._create_id(nformat, name_qualifier, sp_name_qualifier) while _id in self.db: _id = self._create_id(nformat, name_qualifier, sp_name_qualifier) return _id def store(self, ident, name_id): """ :param ident: user identifier :param name_id: NameID instance """ # One user may have more than one NameID defined try: val = self.db[ident].split(" ") except KeyError: val = [] _cn = code(name_id) val.append(_cn) self.db[ident] = " ".join(val) self.db[name_id.text] = ident def remove_remote(self, name_id): """ Remove a NameID to userID mapping :param name_id: NameID instance """ _cn = code(name_id) _id = self.db[name_id.text] try: vals = self.db[_id].split(" ") vals.remove(_cn) self.db[_id] = " ".join(vals) except KeyError: pass del self.db[name_id.text] def remove_local(self, sid): if isinstance(sid, unicode): sid = sid.encode("utf-8") try: for val in self.db[sid].split(" "): try: nid = decode(val) del self.db[nid.text] except KeyError: pass del self.db[sid] except KeyError: pass def get_nameid(self, userid, nformat, sp_name_qualifier, name_qualifier): if nformat == NAMEID_FORMAT_PERSISTENT: nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier) if nameid: logger.debug( "Found existing persistent NameId {nid} for user {uid}".format( nid=nameid, uid=userid ) ) return nameid _id = self.create_id(nformat, name_qualifier, sp_name_qualifier) if nformat == NAMEID_FORMAT_EMAILADDRESS: if not self.domain: raise SAMLError("Can't issue email nameids, unknown domain") _id = "%s@%s" % (_id, self.domain) nameid = NameID( format=nformat, sp_name_qualifier=sp_name_qualifier, name_qualifier=name_qualifier, text=_id, ) self.store(userid, nameid) return nameid def find_nameid(self, userid, **kwargs): """ Find a set of NameID's that matches the search criteria. :param userid: User id :param kwargs: The search filter a set of attribute/value pairs :return: a list of NameID instances """ res = [] try: _vals = self.db[userid] except KeyError: logger.debug("failed to find userid %s in IdentDB", userid) return res for val in _vals.split(" "): nid = decode(val) if kwargs: for key, _val in kwargs.items(): if getattr(nid, key, None) != _val: break else: res.append(nid) else: res.append(nid) return res def nim_args(self, local_policy=None, sp_name_qualifier="", name_id_policy=None, name_qualifier=""): """ :param local_policy: :param sp_name_qualifier: :param name_id_policy: :param name_qualifier: :return: """ logger.debug("local_policy: %s, name_id_policy: %s", local_policy, name_id_policy) if name_id_policy and name_id_policy.sp_name_qualifier: sp_name_qualifier = name_id_policy.sp_name_qualifier else: sp_name_qualifier = sp_name_qualifier if name_id_policy and name_id_policy.format: nameid_format = name_id_policy.format elif local_policy: nameid_format = local_policy.get_nameid_format(sp_name_qualifier) else: raise SAMLError("Unknown NameID format") if not name_qualifier: name_qualifier = self.name_qualifier return {"nformat": nameid_format, "sp_name_qualifier": sp_name_qualifier, "name_qualifier": name_qualifier} def construct_nameid(self, userid, local_policy=None, sp_name_qualifier=None, name_id_policy=None, name_qualifier=""): """ Returns a name_id for the userid. How the name_id is constructed depends on the context. :param local_policy: The policy the server is configured to follow :param userid: The local permanent identifier of the object :param sp_name_qualifier: The 'user'/-s of the name_id :param name_id_policy: The policy the server on the other side wants us to follow. :param name_qualifier: A domain qualifier :return: NameID instance precursor """ args = self.nim_args(local_policy, sp_name_qualifier, name_id_policy) if name_qualifier: args["name_qualifier"] = name_qualifier else: args["name_qualifier"] = self.name_qualifier return self.get_nameid(userid, **args) def transient_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT, sp_name_qualifier, name_qualifier) def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier) if nameid: return nameid else: return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT, sp_name_qualifier, name_qualifier) def find_local_id(self, name_id): """ Only find persistent IDs :param name_id: :return: """ try: return self.db[name_id.text] except KeyError: logger.debug("name: %s", name_id.text) #logger.debug("id sub keys: %s", self.subkeys()) return None def match_local_id(self, userid, sp_name_qualifier, name_qualifier): try: for val in self.db[userid].split(" "): nid = decode(val) if nid.format == NAMEID_FORMAT_TRANSIENT: continue snq = getattr(nid, "sp_name_qualifier", "") if snq and snq == sp_name_qualifier: nq = getattr(nid, "name_qualifier", None) if nq and nq == name_qualifier: return nid elif not nq and not name_qualifier: return nid elif not snq and not sp_name_qualifier: nq = getattr(nid, "name_qualifier", None) if nq and nq == name_qualifier: return nid elif not nq and not name_qualifier: return nid except KeyError: pass return None def handle_name_id_mapping_request(self, name_id, name_id_policy): """ :param name_id: The NameID that specifies the principal :param name_id_policy: The NameIDPolicy of the requester :return: If an old name_id exists that match the name-id policy that is return otherwise if a new one can be created it will be and returned. If no old matching exists and a new is not allowed to be created None is returned. """ _id = self.find_local_id(name_id) if not _id: raise Unknown("Unknown entity") # return an old one if present for val in self.db[_id].split(" "): _nid = decode(val) if _nid.format == name_id_policy.format: if _nid.sp_name_qualifier == name_id_policy.sp_name_qualifier: return _nid if name_id_policy.allow_create == "false": raise PolicyError("Not allowed to create new identifier") # else create and return a new one return self.construct_nameid(_id, name_id_policy=name_id_policy) def handle_manage_name_id_request(self, name_id, new_id=None, new_encrypted_id="", terminate=""): """ Requests from the SP is about the SPProvidedID attribute. So this is about adding,replacing and removing said attribute. :param name_id: NameID instance :param new_id: NewID instance :param new_encrypted_id: NewEncryptedID instance :param terminate: Terminate instance :return: The modified name_id """ _id = self.find_local_id(name_id) orig_name_id = copy.copy(name_id) if new_id: name_id.sp_provided_id = new_id.text elif new_encrypted_id: # TODO pass elif terminate: name_id.sp_provided_id = None else: #NOOP return name_id self.remove_remote(orig_name_id) self.store(_id, name_id) return name_id def close(self): if hasattr(self.db, 'close'): self.db.close() def sync(self): if hasattr(self.db, 'sync'): self.db.sync()