#!/usr/bin/env python
import os
from pathutils import full_path
from saml2 import samlp
from saml2.assertion import Policy
from saml2.config import IdPConfig
from saml2.ident import IdentDB
from saml2.saml import NAMEID_FORMAT_PERSISTENT
from saml2.saml import NAMEID_FORMAT_TRANSIENT
def _eq(l1, l2):
return set(l1) == set(l2)
CONFIG = IdPConfig().load(
{
"entityid": "urn:mace:example.com:idp:2",
"name": "test",
"service": {
"idp": {
"endpoints": {
"single_sign_on_service": ["http://idp.example.org/"],
},
"policy": {
"default": {
"lifetime": {"minutes": 15},
"attribute_restrictions": None, # means all I have
"name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"nameid_format": NAMEID_FORMAT_PERSISTENT,
}
},
}
},
"virtual_organization": {
"http://vo.example.org/biomed": {
"nameid_format": "urn:oid:2.16.756.1.2.5.1.1.1-NameID",
"common_identifier": "uid",
},
"http://vo.example.org/design": {
"nameid_format": NAMEID_FORMAT_PERSISTENT,
"common_identifier": "uid",
},
},
}
)
NAME_ID_POLICY_1 = """
"""
NAME_ID_POLICY_2 = """
"""
class TestIdentifier:
def setup_class(self):
for extension in (".db", ".dir", ".dat", ".bak"):
try:
os.remove(full_path(f"subject.db{extension}"))
except OSError:
pass
self.id = IdentDB(full_path("subject.db"), "example.com", "example")
def test_persistent_1(self):
policy = Policy(
{
"default": {
"name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"nameid_format": NAMEID_FORMAT_PERSISTENT,
"attribute_restrictions": {
"surName": [".*berg"],
},
}
}
)
nameid = self.id.construct_nameid("foobar", policy, "urn:mace:example.com:sp:1")
assert _eq(nameid.keyswv(), ["format", "text", "sp_name_qualifier", "name_qualifier"])
assert nameid.sp_name_qualifier == "urn:mace:example.com:sp:1"
assert nameid.format == NAMEID_FORMAT_PERSISTENT
id_ = self.id.find_local_id(nameid)
assert id_ == "foobar"
def test_persistent_2(self):
userid = "foobar"
nameid1 = self.id.persistent_nameid(userid, sp_name_qualifier="sp1", name_qualifier="name0")
nameid2 = self.id.persistent_nameid(userid, sp_name_qualifier="sp1", name_qualifier="name0")
# persistent NameIDs should be _persistent_ :-)
assert nameid1 == nameid2
def test_transient_1(self):
policy = Policy(
{
"default": {
"name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"nameid_format": NAMEID_FORMAT_TRANSIENT,
"attribute_restrictions": {
"surName": [".*berg"],
},
}
}
)
nameid = self.id.construct_nameid("foobar", policy, "urn:mace:example.com:sp:1")
assert _eq(nameid.keyswv(), ["text", "format", "sp_name_qualifier", "name_qualifier"])
assert nameid.format == NAMEID_FORMAT_TRANSIENT
assert nameid.text != "foobar"
def test_vo_1(self):
policy = Policy(
{
"default": {
"name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"nameid_format": NAMEID_FORMAT_PERSISTENT,
"attribute_restrictions": {
"surName": [".*berg"],
},
}
}
)
name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1)
print(name_id_policy)
nameid = self.id.construct_nameid("foobar", policy, "http://vo.example.org/biomed", name_id_policy)
print(nameid)
assert _eq(nameid.keyswv(), ["text", "sp_name_qualifier", "format", "name_qualifier"])
assert nameid.sp_name_qualifier == "http://vo.example.org/biomed"
assert nameid.format == NAMEID_FORMAT_PERSISTENT
# we want to *NOT* keep the user identifier in the nameid node
assert nameid.text != "foobar"
def test_vo_2(self):
policy = Policy(
{
"default": {
"name_form": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"nameid_format": NAMEID_FORMAT_PERSISTENT,
"attribute_restrictions": {
"surName": [".*berg"],
},
}
}
)
name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_2)
nameid = self.id.construct_nameid("foobar", policy, "http://vo.example.org/design", name_id_policy)
assert _eq(nameid.keyswv(), ["text", "sp_name_qualifier", "format", "name_qualifier"])
assert nameid.sp_name_qualifier == "http://vo.example.org/design"
assert nameid.format == NAMEID_FORMAT_PERSISTENT
assert nameid.text != "foobar01"
def test_persistent_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.id.persistent_nameid("abcd0001", sp_id)
remote_id = nameid.text.strip()
print(remote_id)
local = self.id.find_local_id(nameid)
assert local == "abcd0001"
# Always get the same
nameid2 = self.id.persistent_nameid("abcd0001", sp_id)
assert nameid.text.strip() == nameid2.text.strip()
def test_transient_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.id.transient_nameid("abcd0001", sp_id)
remote_id = nameid.text.strip()
print(remote_id)
local = self.id.find_local_id(nameid)
assert local == "abcd0001"
# Getting a new, means really getting a new !
nameid2 = self.id.transient_nameid(sp_id, "abcd0001")
assert nameid.text.strip() != nameid2.text.strip()
def teardown_class(self):
if os.path.exists(full_path("subject.db")):
os.unlink(full_path("subject.db"))