#!/usr/bin/env python import os from pathutils import full_path from saml2 import samlp from saml2.assertion import Policy from saml2.config import IdPConfig from saml2.ident import IdentDB from saml2.saml import NAMEID_FORMAT_PERSISTENT from saml2.saml import NAMEID_FORMAT_TRANSIENT def _eq(l1, l2): return set(l1) == set(l2) CONFIG = IdPConfig().load( { "entityid": "urn:mace:example.com:idp:2", "name": "test", "service": { "idp": { "endpoints": { "single_sign_on_service": ["http://idp.example.org/"], }, "policy": { "default": { "lifetime": {"minutes": 15}, "attribute_restrictions": None, # means all I have "name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "nameid_format": NAMEID_FORMAT_PERSISTENT, } }, } }, "virtual_organization": { "http://vo.example.org/biomed": { "nameid_format": "urn:oid:2.16.756.1.2.5.1.1.1-NameID", "common_identifier": "uid", }, "http://vo.example.org/design": { "nameid_format": NAMEID_FORMAT_PERSISTENT, "common_identifier": "uid", }, }, } ) NAME_ID_POLICY_1 = """ """ NAME_ID_POLICY_2 = """ """ class TestIdentifier: def setup_class(self): for extension in (".db", ".dir", ".dat", ".bak"): try: os.remove(full_path(f"subject.db{extension}")) except OSError: pass self.id = IdentDB(full_path("subject.db"), "example.com", "example") def test_persistent_1(self): policy = Policy( { "default": { "name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "nameid_format": NAMEID_FORMAT_PERSISTENT, "attribute_restrictions": { "surName": [".*berg"], }, } } ) nameid = self.id.construct_nameid("foobar", policy, "urn:mace:example.com:sp:1") assert _eq(nameid.keyswv(), ["format", "text", "sp_name_qualifier", "name_qualifier"]) assert nameid.sp_name_qualifier == "urn:mace:example.com:sp:1" assert nameid.format == NAMEID_FORMAT_PERSISTENT id_ = self.id.find_local_id(nameid) assert id_ == "foobar" def test_persistent_2(self): userid = "foobar" nameid1 = self.id.persistent_nameid(userid, sp_name_qualifier="sp1", name_qualifier="name0") nameid2 = self.id.persistent_nameid(userid, sp_name_qualifier="sp1", name_qualifier="name0") # persistent NameIDs should be _persistent_ :-) assert nameid1 == nameid2 def test_transient_1(self): policy = Policy( { "default": { "name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "nameid_format": NAMEID_FORMAT_TRANSIENT, "attribute_restrictions": { "surName": [".*berg"], }, } } ) nameid = self.id.construct_nameid("foobar", policy, "urn:mace:example.com:sp:1") assert _eq(nameid.keyswv(), ["text", "format", "sp_name_qualifier", "name_qualifier"]) assert nameid.format == NAMEID_FORMAT_TRANSIENT assert nameid.text != "foobar" def test_vo_1(self): policy = Policy( { "default": { "name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "nameid_format": NAMEID_FORMAT_PERSISTENT, "attribute_restrictions": { "surName": [".*berg"], }, } } ) name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1) print(name_id_policy) nameid = self.id.construct_nameid("foobar", policy, "http://vo.example.org/biomed", name_id_policy) print(nameid) assert _eq(nameid.keyswv(), ["text", "sp_name_qualifier", "format", "name_qualifier"]) assert nameid.sp_name_qualifier == "http://vo.example.org/biomed" assert nameid.format == NAMEID_FORMAT_PERSISTENT # we want to *NOT* keep the user identifier in the nameid node assert nameid.text != "foobar" def test_vo_2(self): policy = Policy( { "default": { "name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "nameid_format": NAMEID_FORMAT_PERSISTENT, "attribute_restrictions": { "surName": [".*berg"], }, } } ) name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_2) nameid = self.id.construct_nameid("foobar", policy, "http://vo.example.org/design", name_id_policy) assert _eq(nameid.keyswv(), ["text", "sp_name_qualifier", "format", "name_qualifier"]) assert nameid.sp_name_qualifier == "http://vo.example.org/design" assert nameid.format == NAMEID_FORMAT_PERSISTENT assert nameid.text != "foobar01" def test_persistent_nameid(self): sp_id = "urn:mace:umu.se:sp" nameid = self.id.persistent_nameid("abcd0001", sp_id) remote_id = nameid.text.strip() print(remote_id) local = self.id.find_local_id(nameid) assert local == "abcd0001" # Always get the same nameid2 = self.id.persistent_nameid("abcd0001", sp_id) assert nameid.text.strip() == nameid2.text.strip() def test_transient_nameid(self): sp_id = "urn:mace:umu.se:sp" nameid = self.id.transient_nameid("abcd0001", sp_id) remote_id = nameid.text.strip() print(remote_id) local = self.id.find_local_id(nameid) assert local == "abcd0001" # Getting a new, means really getting a new ! nameid2 = self.id.transient_nameid(sp_id, "abcd0001") assert nameid.text.strip() != nameid2.text.strip() def teardown_class(self): if os.path.exists(full_path("subject.db")): os.unlink(full_path("subject.db"))