diff options
Diffstat (limited to 'src/saml2')
-rw-r--r-- | src/saml2/client_base.py | 4 | ||||
-rw-r--r-- | src/saml2/ident.py | 23 | ||||
-rw-r--r-- | src/saml2/mongo_store.py | 23 |
3 files changed, 44 insertions, 6 deletions
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 39a7d0ed..15e3b0ec 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -339,6 +339,10 @@ class Base(Entity): # If no nameid_format has been set in the configuration # or passed in then transient is the default. if nameid_format is None: + # SAML 2.0 errata says AllowCreate MUST NOT be used for + # transient ids - to make a conservative change this is + # only applied for the default cause + allow_create = None nameid_format = NAMEID_FORMAT_TRANSIENT # If a list has been configured or passed in choose the diff --git a/src/saml2/ident.py b/src/saml2/ident.py index db8365bc..d6a6620a 100644 --- a/src/saml2/ident.py +++ b/src/saml2/ident.py @@ -155,6 +155,16 @@ class IdentDB(object): pass def get_nameid(self, userid, nformat, sp_name_qualifier, name_qualifier): + if nformat == NAMEID_FORMAT_PERSISTENT: + nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier) + if nameid: + logger.debug( + "Found existing persistent NameId {nid} for user {uid}".format( + nid=nameid, uid=userid + ) + ) + return nameid + _id = self.create_id(nformat, name_qualifier, sp_name_qualifier) if nformat == NAMEID_FORMAT_EMAILADDRESS: @@ -163,11 +173,12 @@ class IdentDB(object): _id = "%s@%s" % (_id, self.domain) - # if nformat == NAMEID_FORMAT_PERSISTENT: - # _id = userid - - nameid = NameID(format=nformat, sp_name_qualifier=sp_name_qualifier, - name_qualifier=name_qualifier, text=_id) + nameid = NameID( + format=nformat, + sp_name_qualifier=sp_name_qualifier, + name_qualifier=name_qualifier, + text=_id, + ) self.store(userid, nameid) return nameid @@ -236,7 +247,7 @@ class IdentDB(object): def construct_nameid(self, userid, local_policy=None, sp_name_qualifier=None, name_id_policy=None, name_qualifier=""): - """ Returns a name_id for the object. How the name_id is + """ Returns a name_id for the userid. How the name_id is constructed depends on the context. :param local_policy: The policy the server is configured to follow diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py index 6a8f9f45..903dd481 100644 --- a/src/saml2/mongo_store.py +++ b/src/saml2/mongo_store.py @@ -1,3 +1,4 @@ +import datetime from hashlib import sha1 import logging @@ -5,6 +6,8 @@ from pymongo import MongoClient from pymongo.mongo_replica_set_client import MongoReplicaSetClient import pymongo.uri_parser import pymongo.errors +from saml2.saml import NAMEID_FORMAT_PERSISTENT + from saml2.eptid import Eptid from saml2.mdstore import InMemoryMetaData from saml2.mdstore import metadata_modules @@ -163,6 +166,23 @@ class IdentMDB(IdentDB): return item[self.mdb.primary_key] return None + def match_local_id(self, userid, sp_name_qualifier, name_qualifier): + """ + Match a local persistent identifier. + + Look for an existing persistent NameID matching userid, + sp_name_qualifier and name_qualifier. + """ + filter = { + "name_id.sp_name_qualifier": sp_name_qualifier, + "name_id.name_qualifier": name_qualifier, + "name_id.format": NAMEID_FORMAT_PERSISTENT, + } + res = self.mdb.get(value=userid, **filter) + if not res: + return None + return from_dict(res[0]["name_id"], ONTS, True) + def remove_remote(self, name_id): cnid = to_dict(name_id, MMODS, True) self.mdb.remove(name_id=cnid) @@ -192,6 +212,9 @@ class MDB(object): else: doc = {} doc.update(kwargs) + # Add timestamp to all documents to allow external garbage collecting + if "created_at" not in doc: + doc["created_at"] = datetime.datetime.utcnow() _ = self.db.insert(doc) def get(self, value=None, **kwargs): |