summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Kanakarakis <ivan.kanak@gmail.com>2019-05-14 12:34:48 +0200
committerGitHub <noreply@github.com>2019-05-14 12:34:48 +0200
commit4e660fe05bb97093a6304dd2e44340a397d75a94 (patch)
tree3907a62a8c48b17db543cf0afcd5fd69483e2992
parent20de0557af129b3ef9e9f5f82d6a67e822eb2e5f (diff)
parent70084f518537e8b82278de1b4aa1988d79498a09 (diff)
downloadpysaml2-4e660fe05bb97093a6304dd2e44340a397d75a94.tar.gz
Merge pull request #616 from SUNET/eduid-nameid_fixes
Check for an existing local-persistent NameID when retrieving it
-rw-r--r--src/saml2/client_base.py4
-rw-r--r--src/saml2/ident.py23
-rw-r--r--src/saml2/mongo_store.py19
-rw-r--r--tests/test_50_server.py4
-rw-r--r--tests/test_51_client.py4
5 files changed, 44 insertions, 10 deletions
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py
index 39a7d0ed..15e3b0ec 100644
--- a/src/saml2/client_base.py
+++ b/src/saml2/client_base.py
@@ -339,6 +339,10 @@ class Base(Entity):
# If no nameid_format has been set in the configuration
# or passed in then transient is the default.
if nameid_format is None:
+ # SAML 2.0 errata says AllowCreate MUST NOT be used for
+ # transient ids - to make a conservative change this is
+ # only applied for the default cause
+ allow_create = None
nameid_format = NAMEID_FORMAT_TRANSIENT
# If a list has been configured or passed in choose the
diff --git a/src/saml2/ident.py b/src/saml2/ident.py
index db8365bc..d6a6620a 100644
--- a/src/saml2/ident.py
+++ b/src/saml2/ident.py
@@ -155,6 +155,16 @@ class IdentDB(object):
pass
def get_nameid(self, userid, nformat, sp_name_qualifier, name_qualifier):
+ if nformat == NAMEID_FORMAT_PERSISTENT:
+ nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier)
+ if nameid:
+ logger.debug(
+ "Found existing persistent NameId {nid} for user {uid}".format(
+ nid=nameid, uid=userid
+ )
+ )
+ return nameid
+
_id = self.create_id(nformat, name_qualifier, sp_name_qualifier)
if nformat == NAMEID_FORMAT_EMAILADDRESS:
@@ -163,11 +173,12 @@ class IdentDB(object):
_id = "%s@%s" % (_id, self.domain)
- # if nformat == NAMEID_FORMAT_PERSISTENT:
- # _id = userid
-
- nameid = NameID(format=nformat, sp_name_qualifier=sp_name_qualifier,
- name_qualifier=name_qualifier, text=_id)
+ nameid = NameID(
+ format=nformat,
+ sp_name_qualifier=sp_name_qualifier,
+ name_qualifier=name_qualifier,
+ text=_id,
+ )
self.store(userid, nameid)
return nameid
@@ -236,7 +247,7 @@ class IdentDB(object):
def construct_nameid(self, userid, local_policy=None,
sp_name_qualifier=None, name_id_policy=None,
name_qualifier=""):
- """ Returns a name_id for the object. How the name_id is
+ """ Returns a name_id for the userid. How the name_id is
constructed depends on the context.
:param local_policy: The policy the server is configured to follow
diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py
index 6a8f9f45..4120e9e0 100644
--- a/src/saml2/mongo_store.py
+++ b/src/saml2/mongo_store.py
@@ -5,6 +5,8 @@ from pymongo import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
import pymongo.uri_parser
import pymongo.errors
+from saml2.saml import NAMEID_FORMAT_PERSISTENT
+
from saml2.eptid import Eptid
from saml2.mdstore import InMemoryMetaData
from saml2.mdstore import metadata_modules
@@ -163,6 +165,23 @@ class IdentMDB(IdentDB):
return item[self.mdb.primary_key]
return None
+ def match_local_id(self, userid, sp_name_qualifier, name_qualifier):
+ """
+ Match a local persistent identifier.
+
+ Look for an existing persistent NameID matching userid,
+ sp_name_qualifier and name_qualifier.
+ """
+ filter = {
+ "name_id.sp_name_qualifier": sp_name_qualifier,
+ "name_id.name_qualifier": name_qualifier,
+ "name_id.format": NAMEID_FORMAT_PERSISTENT,
+ }
+ res = self.mdb.get(value=userid, **filter)
+ if not res:
+ return None
+ return from_dict(res[0]["name_id"], ONTS, True)
+
def remove_remote(self, name_id):
cnid = to_dict(name_id, MMODS, True)
self.mdb.remove(name_id=cnid)
diff --git a/tests/test_50_server.py b/tests/test_50_server.py
index dc6cbf42..ecef319e 100644
--- a/tests/test_50_server.py
+++ b/tests/test_50_server.py
@@ -267,7 +267,7 @@ class TestServer1():
assert resp_args["destination"] == "http://lingon.catalogix.se:8087/"
assert resp_args["in_response_to"] == "id1"
name_id_policy = resp_args["name_id_policy"]
- assert _eq(name_id_policy.keyswv(), ["format", "allow_create"])
+ assert _eq(name_id_policy.keyswv(), ["format"])
assert name_id_policy.format == saml.NAMEID_FORMAT_TRANSIENT
assert resp_args[
"sp_entity_id"] == "urn:mace:example.com:saml:roland:sp"
@@ -1341,7 +1341,7 @@ class TestServer1NonAsciiAva():
assert resp_args["destination"] == "http://lingon.catalogix.se:8087/"
assert resp_args["in_response_to"] == "id1"
name_id_policy = resp_args["name_id_policy"]
- assert _eq(name_id_policy.keyswv(), ["format", "allow_create"])
+ assert _eq(name_id_policy.keyswv(), ["format"])
assert name_id_policy.format == saml.NAMEID_FORMAT_TRANSIENT
assert resp_args[
"sp_entity_id"] == "urn:mace:example.com:saml:roland:sp"
diff --git a/tests/test_51_client.py b/tests/test_51_client.py
index 2e2b7f1c..75dd8f75 100644
--- a/tests/test_51_client.py
+++ b/tests/test_51_client.py
@@ -269,7 +269,7 @@ class TestClient:
assert ar.provider_name == "urn:mace:example.com:saml:roland:sp"
assert ar.issuer.text == "urn:mace:example.com:saml:roland:sp"
nid_policy = ar.name_id_policy
- assert nid_policy.allow_create == "false"
+ assert nid_policy.allow_create is None
assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT
node_requested_attributes = None
@@ -1757,7 +1757,7 @@ class TestClientNonAsciiAva:
assert ar.provider_name == "urn:mace:example.com:saml:roland:sp"
assert ar.issuer.text == "urn:mace:example.com:saml:roland:sp"
nid_policy = ar.name_id_policy
- assert nid_policy.allow_create == "false"
+ assert nid_policy.allow_create is None
assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT
node_requested_attributes = None