diff options
Diffstat (limited to 'src/saml2')
-rw-r--r-- | src/saml2/__init__.py | 15 | ||||
-rw-r--r-- | src/saml2/client_base.py | 5 | ||||
-rw-r--r-- | src/saml2/entity.py | 20 | ||||
-rw-r--r-- | src/saml2/response.py | 111 | ||||
-rw-r--r-- | src/saml2/sigver.py | 30 |
5 files changed, 118 insertions, 63 deletions
diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index 87bfafa6..4493c9c8 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -591,13 +591,14 @@ class SamlBase(ExtensionContainer): def get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion(self, assertion_tag, advice_tag): for tmp_encrypted_assertion in self.assertion.advice.encrypted_assertion: - prefix_map = self.get_prefix_map([tmp_encrypted_assertion._to_element_tree(). - find(assertion_tag)]) - - tree = self._to_element_tree() - - self.set_prefixes(tree.find(assertion_tag).find(advice_tag).find(tmp_encrypted_assertion._to_element_tree() - .tag).find(assertion_tag), prefix_map) + if tmp_encrypted_assertion.encrypted_data is None: + prefix_map = self.get_prefix_map([tmp_encrypted_assertion._to_element_tree().find(assertion_tag)]) + tree = self._to_element_tree() + encs = tree.find(assertion_tag).find(advice_tag).findall(tmp_encrypted_assertion._to_element_tree().tag) + for enc in encs: + assertion = enc.find(assertion_tag) + if assertion is not None: + self.set_prefixes(assertion, prefix_map) return ElementTree.tostring(tree, encoding="UTF-8") diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 3f7a227b..da35ae9c 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -542,7 +542,7 @@ class Base(Entity): # ======== response handling =========== def parse_authn_request_response(self, xmlstr, binding, outstanding=None, - outstanding_certs=None, decrypt=True, pefim=False): + outstanding_certs=None): """ Deal with an AuthnResponse :param xmlstr: The reply as a xml string @@ -573,12 +573,11 @@ class Base(Entity): "attribute_converters": self.config.attribute_converters, "allow_unknown_attributes": self.config.allow_unknown_attributes, - "decrypt": decrypt } try: resp = self._parse_response(xmlstr, AuthnResponse, "assertion_consumer_service", - binding, pefim=pefim, **kwargs) + binding, **kwargs) except StatusError as err: logger.error("SAML status error: %s" % err) raise diff --git a/src/saml2/entity.py b/src/saml2/entity.py index e7a75a5a..e7ef879c 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -978,7 +978,7 @@ class Entity(HTTPBase): # ------------------------------------------------------------------------ def _parse_response(self, xmlstr, response_cls, service, binding, - outstanding_certs=None, pefim=False, **kwargs): + outstanding_certs=None, **kwargs): """ Deal with a Response :param xmlstr: The response as a xml string @@ -1040,23 +1040,23 @@ class Entity(HTTPBase): logger.debug("XMLSTR: %s" % xmlstr) if response: + keys = None if outstanding_certs: try: cert = outstanding_certs[response.in_response_to] except KeyError: - key_file = "" + keys = None else: - _, key_file = make_temp("%s" % cert["key"], - decode=False) - else: - key_file = "" + if not isinstance(cert, list): + cert = [cert] + keys = [] + for _cert in cert: + keys.append(_cert["key"]) only_identity_in_encrypted_assertion = False if "only_identity_in_encrypted_assertion" in kwargs: only_identity_in_encrypted_assertion = kwargs["only_identity_in_encrypted_assertion"] - decrypt = True - if "decrypt" in kwargs: - decrypt = kwargs["decrypt"] - response = response.verify(key_file, decrypt=decrypt, pefim=pefim) + + response = response.verify(keys) if not response: return None diff --git a/src/saml2/response.py b/src/saml2/response.py index d46a0b0f..6f733280 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -395,7 +395,7 @@ class StatusResponse(object): def loads(self, xmldata, decode=True, origxml=None): return self._loads(xmldata, decode, origxml) - def verify(self, key_file="", decrypt=True, pefim=False): + def verify(self, keys=None): try: return self._verify() except AssertionError: @@ -636,18 +636,19 @@ class AuthnResponse(StatusResponse): """ ava = {} - if self.assertion.advice: - if self.assertion.advice.assertion: - for tmp_assertion in self.assertion.advice.assertion: - if tmp_assertion.attribute_statement: - assert len(tmp_assertion.attribute_statement) == 1 - ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0])) - if self.assertion.attribute_statement: - assert len(self.assertion.attribute_statement) == 1 - _attr_statem = self.assertion.attribute_statement[0] - ava.update(self.read_attribute_statement(_attr_statem)) - if not ava: - logger.error("Missing Attribute Statement") + for _assertion in self.assertions: + if _assertion.advice: + if _assertion.advice.assertion: + for tmp_assertion in _assertion.advice.assertion: + if tmp_assertion.attribute_statement: + assert len(tmp_assertion.attribute_statement) == 1 + ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0])) + if _assertion.attribute_statement: + assert len(_assertion.attribute_statement) == 1 + _attr_statem = _assertion.attribute_statement[0] + ava.update(self.read_attribute_statement(_attr_statem)) + if not ava: + logger.error("Missing Attribute Statement") return ava def _bearer_confirmed(self, data): @@ -796,14 +797,14 @@ class AuthnResponse(StatusResponse): logger.exception("get subject") raise - def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None): + def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False): res = [] for encrypted_assertion in encrypted_assertions: if encrypted_assertion.extension_elements: assertions = extension_elements_to_elements( encrypted_assertion.extension_elements, [saml, samlp]) for assertion in assertions: - if assertion.signature: + if assertion.signature and not verified: if not self.sec.check_signature( assertion, origdoc=decr_txt, node_name=class_name(assertion), issuer=issuer): @@ -812,7 +813,35 @@ class AuthnResponse(StatusResponse): res.append(assertion) return res - def parse_assertion(self, key_file="", decrypt=True, pefim=False): + def find_encrypt_data_assertion(self, enc_assertions): + for _assertion in enc_assertions: + if _assertion.encrypted_data is not None: + return True + + def find_encrypt_data_assertion_list(self, _assertions): + for _assertion in _assertions: + if _assertion.advice: + if _assertion.advice.encrypted_assertion: + res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion) + if res: + return True + + def find_encrypt_data(self, resp): + _has_encrypt_data = False + if resp.encrypted_assertion: + res = self.find_encrypt_data_assertion(resp.encrypted_assertion) + if res: + return True + if resp.assertion: + for tmp_assertion in resp.assertion: + if tmp_assertion.advice: + if tmp_assertion.advice.encrypted_assertion: + res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion) + if res: + return True + return False + + def parse_assertion(self, keys=None): if self.context == "AuthnQuery": # can contain one or more assertions pass @@ -823,13 +852,13 @@ class AuthnResponse(StatusResponse): except AssertionError: raise Exception("No assertion part") - has_encrypted_assertions = self.response.encrypted_assertion - if not has_encrypted_assertions and self.response.assertion: - for tmp_assertion in self.response.assertion: - if tmp_assertion.advice: - if tmp_assertion.advice.encrypted_assertion: - has_encrypted_assertions = True - break + has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion + #if not has_encrypted_assertions and self.response.assertion: + # for tmp_assertion in self.response.assertion: + # if tmp_assertion.advice: + # if tmp_assertion.advice.encrypted_assertion: + # has_encrypted_assertions = True + # break if self.response.assertion: logger.debug("***Unencrypted assertion***") @@ -837,16 +866,25 @@ class AuthnResponse(StatusResponse): if not self._assertion(assertion, False): return False - if has_encrypted_assertions and decrypt: + if has_encrypted_assertions: _enc_assertions = [] logger.debug("***Encrypted assertion/-s***") - decr_text = self.sec.decrypt(self.xmlstr, key_file) - resp = samlp.response_from_string(decr_text) + decr_text = "%s" % self.response + resp = self.response + while self.find_encrypt_data(resp): + decr_text = self.sec.decrypt_keys(decr_text, keys) + resp = samlp.response_from_string(decr_text) _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text) - decr_text = self.sec.decrypt(decr_text, key_file) - resp = samlp.response_from_string(decr_text) + while self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions): + decr_text = self.sec.decrypt_keys(decr_text, keys) + resp = samlp.response_from_string(decr_text) + _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) + #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) + all_assertions = _enc_assertions if resp.assertion: - for tmp_ass in resp.assertion: + all_assertions = all_assertions + resp.assertion + if len(all_assertions) > 0: + for tmp_ass in all_assertions: if tmp_ass.advice and tmp_ass.advice.encrypted_assertion: advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion, decr_text, @@ -855,21 +893,20 @@ class AuthnResponse(StatusResponse): tmp_ass.advice.assertion.extend(advice_res) else: tmp_ass.advice.assertion = advice_res - if not pefim: - _enc_assertions.extend(advice_res) tmp_ass.advice.encrypted_assertion = [] self.response.assertion = resp.assertion for assertion in _enc_assertions: if not self._assertion(assertion, True): return False + else: + self.assertions.append(assertion) + self.xmlstr = decr_text self.response.encrypted_assertion = [] if self.response.assertion: for assertion in self.response.assertion: - if assertion.advice and assertion.advice.assertion: - for advice_assertion in assertion.advice.assertion: - self.assertions.append(assertion) + self.assertions.append(assertion) if self.assertions and len(self.assertions) > 0: self.assertion = self.assertions[0] @@ -880,7 +917,7 @@ class AuthnResponse(StatusResponse): return True - def verify(self, key_file="", decrypt=True, pefim=False): + def verify(self, keys=None): """ Verify that the assertion is syntactically correct and the signature is correct if present. :param key_file: If not the default key file should be used this is it. @@ -898,7 +935,7 @@ class AuthnResponse(StatusResponse): if not isinstance(self.response, samlp.Response): return self - if self.parse_assertion(key_file, decrypt=decrypt, pefim=pefim): + if self.parse_assertion(keys): return self else: logger.error("Could not parse the assertion") @@ -1126,7 +1163,7 @@ class AssertionIDResponse(object): return self._postamble() - def verify(self, key_file="", decrypt=True, pefim=False): + def verify(self, keys=None): try: valid_instance(self.response) except NotValid as exc: diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 096134e2..ca928686 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -769,7 +769,7 @@ class CryptoBackendXmlSec1(CryptoBackend): return output def encrypt_assertion(self, statement, enc_key, template, - key_type="des-192", node_xpath=None): + key_type="des-192", node_xpath=None, node_id=None): """ Will encrypt an assertion @@ -792,6 +792,8 @@ class CryptoBackendXmlSec1(CryptoBackend): com_list = [self.xmlsec, "encrypt", "--pubkey-cert-pem", enc_key, "--session-key", key_type, "--xml-data", fil, "--node-xpath", node_xpath] + if node_id: + com_list.extend(["--node-id", node_id]) (_stdout, _stderr, output) = self._run_xmlsec( com_list, [tmpl], exception=EncryptError, validate_output=False) @@ -1300,22 +1302,38 @@ class SecurityContext(object): """ raise NotImplemented() - def decrypt(self, enctext, key_file=None): + def decrypt_keys(self, enctext, keys=None): """ Decrypting an encrypted text by the use of a private key. :param enctext: The encrypted text as a string :return: The decrypted text """ + if not isinstance(keys, list): + keys = [keys] _enctext = self.crypto.decrypt(enctext, self.key_file) if _enctext is not None and len(_enctext) > 0: return _enctext - if key_file is not None and len(key_file.strip()) > 0: - _enctext = self.crypto.decrypt(enctext, key_file) - if _enctext is not None and len(_enctext) > 0: - return _enctext + for _key in keys: + if _key is not None and len(_key.strip()) > 0: + _, key_file = make_temp("%s" % _key, decode=False) + _enctext = self.crypto.decrypt(enctext, key_file) + if _enctext is not None and len(_enctext) > 0: + return _enctext + return enctext + + def decrypt(self, enctext, key_file=None): + """ Decrypting an encrypted text by the use of a private key. + + :param enctext: The encrypted text as a string + :return: The decrypted text + """ _enctext = self.crypto.decrypt(enctext, self.key_file) if _enctext is not None and len(_enctext) > 0: return _enctext + if key_file is not None and len(key_file.strip()) > 0: + _enctext = self.crypto.decrypt(enctext, key_file) + if _enctext is not None and len(_enctext) > 0: + return _enctext return enctext def verify_signature(self, signedtext, cert_file=None, cert_type="pem", |