From 0a4c358f45bc851a140c9b1818974add74982a08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Hallet?= Date: Fri, 12 Mar 2021 19:26:09 +0100 Subject: Ouput the according KeyName in encrypted answer Signed-off-by: Ivan Kanakarakis --- src/saml2/entity.py | 9 +++++---- src/saml2/mdstore.py | 24 ++++++++++------------ src/saml2/sigver.py | 12 ++++++----- tests/test_30_mdstore.py | 42 +++++++++++++++++++++++++++++++++++---- tests/test_42_enc.py | 13 +++++++++++- tests/test_70_redirect_signing.py | 2 +- 6 files changed, 73 insertions(+), 29 deletions(-) diff --git a/src/saml2/entity.py b/src/saml2/entity.py index f6ca396c..6658a8cf 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -30,7 +30,6 @@ from saml2.saml import NameID from saml2.saml import EncryptedAssertion from saml2.saml import Issuer from saml2.saml import NAMEID_FORMAT_ENTITY -from saml2.response import AuthnResponse from saml2.response import LogoutResponse from saml2.response import UnsolicitedResponse from saml2.time_util import instant @@ -683,11 +682,11 @@ class Entity(HTTPBase): _certs = [] if encrypt_cert: - _certs.append(encrypt_cert) + _certs.append((None, encrypt_cert)) elif sp_entity_id is not None: _certs = self.metadata.certs(sp_entity_id, "any", "encryption") exception = None - for _cert in _certs: + for _cert_name, _cert in _certs: wrapped_cert, unwrapped_cert = get_pem_wrapped_unwrapped(_cert) try: tmp = make_temp( @@ -698,7 +697,9 @@ class Entity(HTTPBase): response = self.sec.encrypt_assertion( response, tmp.name, - pre_encryption_part(encrypt_cert=unwrapped_cert), + pre_encryption_part( + key_name=_cert_name, encrypt_cert=unwrapped_cert + ), node_xpath=node_xpath, ) return response diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index d001999d..40f7232e 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -490,20 +490,16 @@ class MetaData(object): def extract_certs(srvs): res = [] for srv in srvs: - if "key_descriptor" in srv: - for key in srv["key_descriptor"]: - if "use" in key and key["use"] == use: - for dat in key["key_info"]["x509_data"]: - cert = repack_cert( - dat["x509_certificate"]["text"]) - if cert not in res: - res.append(cert) - elif not "use" in key: - for dat in key["key_info"]["x509_data"]: - cert = repack_cert( - dat["x509_certificate"]["text"]) - if cert not in res: - res.append(cert) + for key in srv.get("key_descriptor", []): + key_use = key.get("use") + key_info = key.get("key_info") or {} + key_name = (key_info.get("key_name") or [{"text": None}])[0] + key_name_txt = key_name.get("text") + if "use" not in key or key_use == use: + for dat in key_info["x509_data"]: + cert = repack_cert(dat["x509_certificate"]["text"]) + if cert not in res: + res.append((key_name_txt, cert)) return res diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 0bfad44a..8c7a3f4c 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -1451,7 +1451,7 @@ class SecurityContext(object): _certs = [] certs = [] - for cert in _certs: + for cert_name, cert in _certs: if isinstance(cert, six.string_types): content = pem_format(cert) tmp = make_temp(content, @@ -1943,7 +1943,7 @@ def pre_encryption_part( *, msg_enc=TRIPLE_DES_CBC, key_enc=RSA_OAEP_MGF1P, - key_name='my-rsa-key', + key_name=None, encrypted_key_id=None, encrypted_data_id=None, encrypt_cert=None, @@ -1958,9 +1958,11 @@ def pre_encryption_part( if encrypt_cert else None ) - key_info = ds.KeyInfo( - key_name=ds.KeyName(text=key_name), - x509_data=x509_data, + key_name = ds.KeyName(text=key_name) if key_name else None + key_info = ( + ds.KeyInfo(key_name=key_name, x509_data=x509_data) + if key_name or x509_data + else None ) encrypted_key = EncryptedKey( diff --git a/tests/test_30_mdstore.py b/tests/test_30_mdstore.py index bf54594e..8ce058b7 100644 --- a/tests/test_30_mdstore.py +++ b/tests/test_30_mdstore.py @@ -99,7 +99,6 @@ TEST_METADATA_STRING = """ """.format(cert_data=TEST_CERT) - ATTRCONV = ac_factory(full_path("attributemaps")) METADATACONF = { @@ -522,10 +521,45 @@ def test_load_string(): def test_get_certs_from_metadata(): mds = MetadataStore(ATTRCONV, None) mds.imp(METADATACONF["11"]) - certs1 = mds.certs("http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "any") - certs2 = mds.certs("http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "idpsso") - assert certs1[0] == certs2[0] == TEST_CERT + cert_any_name, cert_any = mds.certs( + "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "any" + )[0] + cert_idpsso_name, cert_idpsso = mds.certs( + "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "idpsso" + )[0] + + assert cert_any_name is None + assert cert_idpsso_name is None + + +def test_get_unnamed_certs_from_metadata(): + mds = MetadataStore(ATTRCONV, None) + mds.imp(METADATACONF["11"]) + + cert_any_name, cert_any = mds.certs( + "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "any" + )[0] + cert_idpsso_name, cert_idpsso = mds.certs( + "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php", "idpsso" + )[0] + + assert cert_any_name is None + assert cert_idpsso_name is None + + +def test_get_named_certs_from_metadata(): + mds = MetadataStore(ATTRCONV, None) + mds.imp(METADATACONF["3"]) + + cert_sign_name, cert_sign = mds.certs( + "https://coip-test.sunet.se/shibboleth", "spsso", "signing" + )[0] + cert_enc_name, cert_enc = mds.certs( + "https://coip-test.sunet.se/shibboleth", "spsso", "encryption" + )[0] + + assert cert_sign_name == cert_enc_name == "coip-test.sunet.se" def test_get_certs_from_metadata_without_keydescriptor(): diff --git a/tests/test_42_enc.py b/tests/test_42_enc.py index c268aa76..54825902 100644 --- a/tests/test_42_enc.py +++ b/tests/test_42_enc.py @@ -12,7 +12,7 @@ from pathutils import full_path __author__ = 'roland' -TMPL_NO_HEADER = """my-rsa-key""" +TMPL_NO_HEADER = """{key_info}""" TMPL = f"\n{TMPL_NO_HEADER}" IDENTITY = {"eduPersonAffiliation": ["staff", "member"], @@ -47,6 +47,7 @@ def test_pre_enc_with_pregenerated_key(): expected = TMPL_NO_HEADER.format( ed_id=tmpl.id, ek_id=tmpl.key_info.encrypted_key.id, + key_info='' ) assert str(tmpl) == expected @@ -56,6 +57,16 @@ def test_pre_enc_with_generated_key(): expected = TMPL_NO_HEADER.format( ed_id=tmpl.id, ek_id=tmpl.key_info.encrypted_key.id, + key_info='' + ) + assert str(tmpl) == expected + +def test_pre_enc_with_named_key(): + tmpl = pre_encryption_part(key_name="my-rsa-key") + expected = TMPL_NO_HEADER.format( + ed_id=tmpl.id, + ek_id=tmpl.key_info.encrypted_key.id, + key_info='my-rsa-key' ) assert str(tmpl) == expected diff --git a/tests/test_70_redirect_signing.py b/tests/test_70_redirect_signing.py index 5286d4c6..26d6c486 100644 --- a/tests/test_70_redirect_signing.py +++ b/tests/test_70_redirect_signing.py @@ -49,7 +49,7 @@ def test(): for cert in _certs: if verify_redirect_signature( list_values2simpletons(_dict), sp.sec.sec_backend, - cert): + cert[1]): verified_ok = True assert verified_ok -- cgit v1.2.1