diff options
Diffstat (limited to 'src/saml2/response.py')
-rw-r--r-- | src/saml2/response.py | 106 |
1 files changed, 65 insertions, 41 deletions
diff --git a/src/saml2/response.py b/src/saml2/response.py index 810266cf..a3487155 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -58,6 +58,7 @@ from saml2.validate import NotValid logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- @@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError): class StatusUnsupportedBinding(StatusError): pass + class StatusResponder(StatusError): pass + STATUSCODE2EXCEPTION = { STATUS_VERSION_MISMATCH: StatusVersionMismatch, STATUS_AUTHN_FAILED: StatusAuthnFailed, @@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = { STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding, STATUS_RESPONDER: StatusResponder, } + + # --------------------------------------------------------------------------- @@ -206,7 +211,8 @@ def for_me(conditions, myself): if audience.text.strip() == myself: return True else: - #print("Not for me: %s != %s" % (audience.text.strip(), myself)) + # print("Not for me: %s != %s" % (audience.text.strip(), + # myself)) pass return False @@ -336,7 +342,7 @@ class StatusResponse(object): logger.exception("EXCEPTION: %s", excp) raise - #print("<", self.response) + # print("<", self.response) return self._postamble() @@ -377,7 +383,7 @@ class StatusResponse(object): if self.request_id and self.in_response_to and \ self.in_response_to != self.request_id: logger.error("Not the id I expected: %s != %s", - self.in_response_to, self.request_id) + self.in_response_to, self.request_id) return None try: @@ -391,9 +397,9 @@ class StatusResponse(object): if self.asynchop: if self.response.destination and \ - self.response.destination not in self.return_addrs: + self.response.destination not in self.return_addrs: logger.error("%s not in %s", self.response.destination, - self.return_addrs) + self.return_addrs) return None assert self.issue_instant_ok() @@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse): request_id=0, asynchop=True): StatusResponse.__init__(self, sec_context, return_addrs, timeslack, request_id, asynchop) - self.signature_check = self.sec\ + self.signature_check = self.sec \ .correctly_signed_name_id_mapping_response @@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse): if self.asynchop: if self.in_response_to in self.outstanding_queries: self.came_from = self.outstanding_queries[self.in_response_to] - #del self.outstanding_queries[self.in_response_to] + # del self.outstanding_queries[self.in_response_to] try: if not self.check_subject_confirmation_in_response_to( self.in_response_to): @@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse): def read_attribute_statement(self, attr_statem): logger.debug("Attribute Statement: %s", attr_statem) - for aconv in self.attribute_converters: - logger.debug("Converts name format: %s", aconv.name_format) + # for aconv in self.attribute_converters: + # logger.debug("Converts name format: %s", aconv.name_format) self.decrypt_attributes(attr_statem) return to_local(self.attribute_converters, attr_statem, - self.allow_unknown_attributes) + self.allow_unknown_attributes) def get_identity(self): """ The assertion can contain zero or one attributeStatements @@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse): for tmp_assertion in _assertion.advice.assertion: if tmp_assertion.attribute_statement: assert len(tmp_assertion.attribute_statement) == 1 - ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0])) + ava.update(self.read_attribute_statement( + tmp_assertion.attribute_statement[0])) if _assertion.attribute_statement: assert len(_assertion.attribute_statement) == 1 _attr_statem = _assertion.attribute_statement[0] @@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse): if data.in_response_to in self.outstanding_queries: self.came_from = self.outstanding_queries[ data.in_response_to] - #del self.outstanding_queries[data.in_response_to] + # del self.outstanding_queries[data.in_response_to] elif self.allow_unsolicited: pass else: @@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse): # recognize logger.debug("in response to: '%s'", data.in_response_to) logger.info("outstanding queries: %s", - self.outstanding_queries.keys()) + self.outstanding_queries.keys()) raise Exception( "Combination of session id and requestURI I don't " "recall") @@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse): logger.debug("signed") if not verified and self.do_not_verify is False: try: - self.sec.check_signature(assertion, class_name(assertion),self.xmlstr) + self.sec.check_signature(assertion, class_name(assertion), + self.xmlstr) except Exception as exc: logger.error("correctly_signed_response: %s", exc) raise @@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse): logger.debug("assertion keys: %s", assertion.keyswv()) logger.debug("outstanding_queries: %s", self.outstanding_queries) - #if self.context == "AuthnReq" or self.context == "AttrQuery": + # if self.context == "AuthnReq" or self.context == "AttrQuery": if self.context == "AuthnReq": self.authn_statement_ok() - # elif self.context == "AttrQuery": + # elif self.context == "AttrQuery": # self.authn_statement_ok(True) if not self.condition_ok(): @@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse): logger.debug("--- Getting Identity ---") - #if self.context == "AuthnReq" or self.context == "AttrQuery": + # if self.context == "AuthnReq" or self.context == "AttrQuery": # self.ava = self.get_identity() # logger.debug("--- AVA: %s", self.ava) @@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse): logger.exception("get subject") raise - def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False): - """ Moves the decrypted assertion from the encrypted assertion to a list. + def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, + verified=False): + """ Moves the decrypted assertion from the encrypted assertion to a + list. :param encrypted_assertions: A list of encrypted assertions. - :param decr_txt: The string representation containing the decrypted data. Used when verifying signatures. + :param decr_txt: The string representation containing the decrypted + data. Used when verifying signatures. :param issuer: The issuer of the response. - :param verified: If True do not verify signatures, otherwise verify the signature if it exists. + :param verified: If True do not verify signatures, otherwise verify + the signature if it exists. :return: A list of decrypted assertions. """ res = [] @@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse): if not self.sec.check_signature( assertion, origdoc=decr_txt, node_name=class_name(assertion), issuer=issuer): - logger.error("Failed to verify signature on '%s'", assertion) + logger.error("Failed to verify signature on '%s'", + assertion) raise SignatureError() res.append(assertion) return res @@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse): :return: True encrypted data exists otherwise false. """ for _assertion in enc_assertions: - if _assertion.encrypted_data is not None: - return True + if _assertion.encrypted_data is not None: + return True def find_encrypt_data_assertion_list(self, _assertions): - """ Verifies if a list of assertions contains encrypted data in the advice element. + """ Verifies if a list of assertions contains encrypted data in the + advice element. :param _assertions: A list of assertions. :return: True encrypted data exists otherwise false. @@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse): for _assertion in _assertions: if _assertion.advice: if _assertion.advice.encrypted_assertion: - res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion) + res = self.find_encrypt_data_assertion( + _assertion.advice.encrypted_assertion) if res: return True def find_encrypt_data(self, resp): - """ Verifies if a saml response contains encrypted assertions with encrypted data. + """ Verifies if a saml response contains encrypted assertions with + encrypted data. :param resp: A saml response. :return: True encrypted data exists otherwise false. @@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse): for tmp_assertion in resp.assertion: if tmp_assertion.advice: if tmp_assertion.advice.encrypted_assertion: - res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion) + res = self.find_encrypt_data_assertion( + tmp_assertion.advice.encrypted_assertion) if res: return True return False @@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse): def parse_assertion(self, keys=None): """ Parse the assertions for a saml response. - :param keys: A string representing a RSA key or a list of strings containing RSA keys. + :param keys: A string representing a RSA key or a list of strings + containing RSA keys. :return: True if the assertions are parsed otherwise False. """ if self.context == "AuthnQuery": @@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse): else: # This is a saml2int limitation try: assert len(self.response.assertion) == 1 or \ - len(self.response.encrypted_assertion) == 1 + len(self.response.encrypted_assertion) == 1 except AssertionError: raise Exception("No assertion part") - has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion - #if not has_encrypted_assertions and self.response.assertion: + has_encrypted_assertions = self.find_encrypt_data(self.response) # + # self.response.encrypted_assertion + # if not has_encrypted_assertions and self.response.assertion: # for tmp_assertion in self.response.assertion: # if tmp_assertion.advice: # if tmp_assertion.advice.encrypted_assertion: @@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse): decr_text_old = decr_text decr_text = self.sec.decrypt_keys(decr_text, keys) resp = samlp.response_from_string(decr_text) - _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text) + _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, + decr_text) decr_text_old = None - while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \ + while (self.find_encrypt_data( + resp) or self.find_encrypt_data_assertion_list( + _enc_assertions)) and \ decr_text_old != decr_text: decr_text_old = decr_text decr_text = self.sec.decrypt_keys(decr_text, keys) resp = samlp.response_from_string(decr_text) - _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) - #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) + _enc_assertions = self.decrypt_assertions( + resp.encrypted_assertion, decr_text, verified=True) + # _enc_assertions = self.decrypt_assertions( + # resp.encrypted_assertion, decr_text, verified=True) all_assertions = _enc_assertions if resp.assertion: all_assertions = all_assertions + resp.assertion @@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse): for tmp_ass in all_assertions: if tmp_ass.advice and tmp_ass.advice.encrypted_assertion: - advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion, - decr_text, - tmp_ass.issuer) + advice_res = self.decrypt_assertions( + tmp_ass.advice.encrypted_assertion, + decr_text, + tmp_ass.issuer) if tmp_ass.advice.assertion: tmp_ass.advice.assertion.extend(advice_res) else: @@ -1211,7 +1236,7 @@ class AssertionIDResponse(object): logger.exception("EXCEPTION: %s", excp) raise - #print("<", self.response) + # print("<", self.response) return self._postamble() @@ -1233,4 +1258,3 @@ class AssertionIDResponse(object): logger.debug("response: %s", self.response) return self - |