summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClément Hallet <clement.hallet@inuse.eu>2021-03-12 19:26:09 +0100
committerIvan Kanakarakis <ivan.kanak@gmail.com>2021-11-02 13:29:46 +0200
commit0a4c358f45bc851a140c9b1818974add74982a08 (patch)
tree24574e5ca2a78c938e8c0c48f70baacad9fafd89
parent59172fcfdce37ce354fd8f30a166c7f8cd6fd4dd (diff)
downloadpysaml2-0a4c358f45bc851a140c9b1818974add74982a08.tar.gz
Ouput the according KeyName in encrypted answer
Signed-off-by: Ivan Kanakarakis <ivan.kanak@gmail.com>
-rw-r--r--src/saml2/entity.py9
-rw-r--r--src/saml2/mdstore.py24
-rw-r--r--src/saml2/sigver.py12
-rw-r--r--tests/test_30_mdstore.py42
-rw-r--r--tests/test_42_enc.py13
-rw-r--r--tests/test_70_redirect_signing.py2
6 files changed, 73 insertions, 29 deletions
diff --git a/src/saml2/entity.py b/src/saml2/entity.py
index f6ca396c..6658a8cf 100644
--- a/src/saml2/entity.py
+++ b/src/saml2/entity.py
@@ -30,7 +30,6 @@ from saml2.saml import NameID
from saml2.saml import EncryptedAssertion
from saml2.saml import Issuer
from saml2.saml import NAMEID_FORMAT_ENTITY
-from saml2.response import AuthnResponse
from saml2.response import LogoutResponse
from saml2.response import UnsolicitedResponse
from saml2.time_util import instant
@@ -683,11 +682,11 @@ class Entity(HTTPBase):
_certs = []
if encrypt_cert:
- _certs.append(encrypt_cert)
+ _certs.append((None, encrypt_cert))
elif sp_entity_id is not None:
_certs = self.metadata.certs(sp_entity_id, "any", "encryption")
exception = None
- for _cert in _certs:
+ for _cert_name, _cert in _certs:
wrapped_cert, unwrapped_cert = get_pem_wrapped_unwrapped(_cert)
try:
tmp = make_temp(
@@ -698,7 +697,9 @@ class Entity(HTTPBase):
response = self.sec.encrypt_assertion(
response,
tmp.name,
- pre_encryption_part(encrypt_cert=unwrapped_cert),
+ pre_encryption_part(
+ key_name=_cert_name, encrypt_cert=unwrapped_cert
+ ),
node_xpath=node_xpath,
)
return response
diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py
index d001999d..40f7232e 100644
--- a/src/saml2/mdstore.py
+++ b/src/saml2/mdstore.py
@@ -490,20 +490,16 @@ class MetaData(object):
def extract_certs(srvs):
res = []
for srv in srvs:
- if "key_descriptor" in srv:
- for key in srv["key_descriptor"]:
- if "use" in key and key["use"] == use:
- for dat in key["key_info"]["x509_data"]:
- cert = repack_cert(
- dat["x509_certificate"]["text"])
- if cert not in res:
- res.append(cert)
- elif not "use" in key:
- for dat in key["key_info"]["x509_data"]:
- cert = repack_cert(
- dat["x509_certificate"]["text"])
- if cert not in res:
- res.append(cert)
+ for key in srv.get("key_descriptor", []):
+ key_use = key.get("use")
+ key_info = key.get("key_info") or {}
+ key_name = (key_info.get("key_name") or [{"text": None}])[0]
+ key_name_txt = key_name.get("text")
+ if "use" not in key or key_use == use:
+ for dat in key_info["x509_data"]:
+ cert = repack_cert(dat["x509_certificate"]["text"])
+ if cert not in res:
+ res.append((key_name_txt, cert))
return res
diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py
index 0bfad44a..8c7a3f4c 100644
--- a/src/saml2/sigver.py
+++ b/src/saml2/sigver.py
@@ -1451,7 +1451,7 @@ class SecurityContext(object):
_certs = []
certs = []
- for cert in _certs:
+ for cert_name, cert in _certs:
if isinstance(cert, six.string_types):
content = pem_format(cert)
tmp = make_temp(content,
@@ -1943,7 +1943,7 @@ def pre_encryption_part(
*,
msg_enc=TRIPLE_DES_CBC,
key_enc=RSA_OAEP_MGF1P,
- key_name='my-rsa-key',
+ key_name=None,
encrypted_key_id=None,
encrypted_data_id=None,
encrypt_cert=None,
@@ -1958,9 +1958,11 @@ def pre_encryption_part(
if encrypt_cert
else None
)
- key_info = ds.KeyInfo(
- key_name=ds.KeyName(text=key_name),
- x509_data=x509_data,
+ key_name = ds.KeyName(text=key_name) if key_name else None
+ key_info = (
+ ds.KeyInfo(key_name=key_name, x509_data=x509_data)
+ if key_name or x509_data
+ else None
)
encrypted_key = EncryptedKey(
diff --git a/tests/test_30_mdstore.py b/tests/test_30_mdstore.py
index bf54594e..8ce058b7 100644
--- a/tests/test_30_mdstore.py
+++ b/tests/test_30_mdstore.py
@@ -99,7 +99,6 @@ TEST_METADATA_STRING = """
</EntitiesDescriptor>
""".format(cert_data=TEST_CERT)
-
ATTRCONV = ac_factory(full_path("attributemaps"))
METADATACONF = {
@@ -522,10 +521,45 @@ def test_load_string():
def test_get_certs_from_metadata():
mds = MetadataStore(ATTRCONV, None)
mds.imp(METADATACONF["11"])
- certs1 = mds.certs("http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "any")
- certs2 = mds.certs("http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "idpsso")
- assert certs1[0] == certs2[0] == TEST_CERT
+ cert_any_name, cert_any = mds.certs(
+ "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "any"
+ )[0]
+ cert_idpsso_name, cert_idpsso = mds.certs(
+ "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "idpsso"
+ )[0]
+
+ assert cert_any_name is None
+ assert cert_idpsso_name is None
+
+
+def test_get_unnamed_certs_from_metadata():
+ mds = MetadataStore(ATTRCONV, None)
+ mds.imp(METADATACONF["11"])
+
+ cert_any_name, cert_any = mds.certs(
+ "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "any"
+ )[0]
+ cert_idpsso_name, cert_idpsso = mds.certs(
+ "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "idpsso"
+ )[0]
+
+ assert cert_any_name is None
+ assert cert_idpsso_name is None
+
+
+def test_get_named_certs_from_metadata():
+ mds = MetadataStore(ATTRCONV, None)
+ mds.imp(METADATACONF["3"])
+
+ cert_sign_name, cert_sign = mds.certs(
+ "https://coip-test.sunet.se/shibboleth", "spsso", "signing"
+ )[0]
+ cert_enc_name, cert_enc = mds.certs(
+ "https://coip-test.sunet.se/shibboleth", "spsso", "encryption"
+ )[0]
+
+ assert cert_sign_name == cert_enc_name == "coip-test.sunet.se"
def test_get_certs_from_metadata_without_keydescriptor():
diff --git a/tests/test_42_enc.py b/tests/test_42_enc.py
index c268aa76..54825902 100644
--- a/tests/test_42_enc.py
+++ b/tests/test_42_enc.py
@@ -12,7 +12,7 @@ from pathutils import full_path
__author__ = 'roland'
-TMPL_NO_HEADER = """<ns0:EncryptedData xmlns:ns0="http://www.w3.org/2001/04/xmlenc#" xmlns:ns1="http://www.w3.org/2000/09/xmldsig#" Id="{ed_id}" Type="http://www.w3.org/2001/04/xmlenc#Element"><ns0:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#tripledes-cbc" /><ns1:KeyInfo><ns0:EncryptedKey Id="{ek_id}"><ns0:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p" /><ns1:KeyInfo><ns1:KeyName>my-rsa-key</ns1:KeyName></ns1:KeyInfo><ns0:CipherData><ns0:CipherValue /></ns0:CipherData></ns0:EncryptedKey></ns1:KeyInfo><ns0:CipherData><ns0:CipherValue /></ns0:CipherData></ns0:EncryptedData>"""
+TMPL_NO_HEADER = """<ns0:EncryptedData xmlns:ns0="http://www.w3.org/2001/04/xmlenc#" xmlns:ns1="http://www.w3.org/2000/09/xmldsig#" Id="{ed_id}" Type="http://www.w3.org/2001/04/xmlenc#Element"><ns0:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#tripledes-cbc" /><ns1:KeyInfo><ns0:EncryptedKey Id="{ek_id}"><ns0:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p" />{key_info}<ns0:CipherData><ns0:CipherValue /></ns0:CipherData></ns0:EncryptedKey></ns1:KeyInfo><ns0:CipherData><ns0:CipherValue /></ns0:CipherData></ns0:EncryptedData>"""
TMPL = f"<?xml version='1.0' encoding='UTF-8'?>\n{TMPL_NO_HEADER}"
IDENTITY = {"eduPersonAffiliation": ["staff", "member"],
@@ -47,6 +47,7 @@ def test_pre_enc_with_pregenerated_key():
expected = TMPL_NO_HEADER.format(
ed_id=tmpl.id,
ek_id=tmpl.key_info.encrypted_key.id,
+ key_info=''
)
assert str(tmpl) == expected
@@ -56,6 +57,16 @@ def test_pre_enc_with_generated_key():
expected = TMPL_NO_HEADER.format(
ed_id=tmpl.id,
ek_id=tmpl.key_info.encrypted_key.id,
+ key_info=''
+ )
+ assert str(tmpl) == expected
+
+def test_pre_enc_with_named_key():
+ tmpl = pre_encryption_part(key_name="my-rsa-key")
+ expected = TMPL_NO_HEADER.format(
+ ed_id=tmpl.id,
+ ek_id=tmpl.key_info.encrypted_key.id,
+ key_info='<ns1:KeyInfo><ns1:KeyName>my-rsa-key</ns1:KeyName></ns1:KeyInfo>'
)
assert str(tmpl) == expected
diff --git a/tests/test_70_redirect_signing.py b/tests/test_70_redirect_signing.py
index 5286d4c6..26d6c486 100644
--- a/tests/test_70_redirect_signing.py
+++ b/tests/test_70_redirect_signing.py
@@ -49,7 +49,7 @@ def test():
for cert in _certs:
if verify_redirect_signature(
list_values2simpletons(_dict), sp.sec.sec_backend,
- cert):
+ cert[1]):
verified_ok = True
assert verified_ok