summaryrefslogtreecommitdiff
path: root/src/saml2/response.py
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2015-11-18 10:38:39 +0100
committerRoland Hedberg <roland.hedberg@adm.umu.se>2015-11-18 10:38:39 +0100
commit2ce425c84ceb2f37e8cf9f2eb383abdbc1a2f087 (patch)
tree7239bd458d27939bb705d6fe68d61ae5b95069c0 /src/saml2/response.py
parent0218b0b064e4d088e7771200b96ecd3454f04154 (diff)
downloadpysaml2-2ce425c84ceb2f37e8cf9f2eb383abdbc1a2f087.tar.gz
Fixed a problem in parsing metadata extensions.
Diffstat (limited to 'src/saml2/response.py')
-rw-r--r--src/saml2/response.py106
1 files changed, 65 insertions, 41 deletions
diff --git a/src/saml2/response.py b/src/saml2/response.py
index 810266cf..a3487155 100644
--- a/src/saml2/response.py
+++ b/src/saml2/response.py
@@ -58,6 +58,7 @@ from saml2.validate import NotValid
logger = logging.getLogger(__name__)
+
# ---------------------------------------------------------------------------
@@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError):
class StatusUnsupportedBinding(StatusError):
pass
+
class StatusResponder(StatusError):
pass
+
STATUSCODE2EXCEPTION = {
STATUS_VERSION_MISMATCH: StatusVersionMismatch,
STATUS_AUTHN_FAILED: StatusAuthnFailed,
@@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = {
STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding,
STATUS_RESPONDER: StatusResponder,
}
+
+
# ---------------------------------------------------------------------------
@@ -206,7 +211,8 @@ def for_me(conditions, myself):
if audience.text.strip() == myself:
return True
else:
- #print("Not for me: %s != %s" % (audience.text.strip(), myself))
+ # print("Not for me: %s != %s" % (audience.text.strip(),
+ # myself))
pass
return False
@@ -336,7 +342,7 @@ class StatusResponse(object):
logger.exception("EXCEPTION: %s", excp)
raise
- #print("<", self.response)
+ # print("<", self.response)
return self._postamble()
@@ -377,7 +383,7 @@ class StatusResponse(object):
if self.request_id and self.in_response_to and \
self.in_response_to != self.request_id:
logger.error("Not the id I expected: %s != %s",
- self.in_response_to, self.request_id)
+ self.in_response_to, self.request_id)
return None
try:
@@ -391,9 +397,9 @@ class StatusResponse(object):
if self.asynchop:
if self.response.destination and \
- self.response.destination not in self.return_addrs:
+ self.response.destination not in self.return_addrs:
logger.error("%s not in %s", self.response.destination,
- self.return_addrs)
+ self.return_addrs)
return None
assert self.issue_instant_ok()
@@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse):
request_id=0, asynchop=True):
StatusResponse.__init__(self, sec_context, return_addrs, timeslack,
request_id, asynchop)
- self.signature_check = self.sec\
+ self.signature_check = self.sec \
.correctly_signed_name_id_mapping_response
@@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse):
if self.asynchop:
if self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to]
- #del self.outstanding_queries[self.in_response_to]
+ # del self.outstanding_queries[self.in_response_to]
try:
if not self.check_subject_confirmation_in_response_to(
self.in_response_to):
@@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse):
def read_attribute_statement(self, attr_statem):
logger.debug("Attribute Statement: %s", attr_statem)
- for aconv in self.attribute_converters:
- logger.debug("Converts name format: %s", aconv.name_format)
+ # for aconv in self.attribute_converters:
+ # logger.debug("Converts name format: %s", aconv.name_format)
self.decrypt_attributes(attr_statem)
return to_local(self.attribute_converters, attr_statem,
- self.allow_unknown_attributes)
+ self.allow_unknown_attributes)
def get_identity(self):
""" The assertion can contain zero or one attributeStatements
@@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in _assertion.advice.assertion:
if tmp_assertion.attribute_statement:
assert len(tmp_assertion.attribute_statement) == 1
- ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
+ ava.update(self.read_attribute_statement(
+ tmp_assertion.attribute_statement[0]))
if _assertion.attribute_statement:
assert len(_assertion.attribute_statement) == 1
_attr_statem = _assertion.attribute_statement[0]
@@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse):
if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[
data.in_response_to]
- #del self.outstanding_queries[data.in_response_to]
+ # del self.outstanding_queries[data.in_response_to]
elif self.allow_unsolicited:
pass
else:
@@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse):
# recognize
logger.debug("in response to: '%s'", data.in_response_to)
logger.info("outstanding queries: %s",
- self.outstanding_queries.keys())
+ self.outstanding_queries.keys())
raise Exception(
"Combination of session id and requestURI I don't "
"recall")
@@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse):
logger.debug("signed")
if not verified and self.do_not_verify is False:
try:
- self.sec.check_signature(assertion, class_name(assertion),self.xmlstr)
+ self.sec.check_signature(assertion, class_name(assertion),
+ self.xmlstr)
except Exception as exc:
logger.error("correctly_signed_response: %s", exc)
raise
@@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse):
logger.debug("assertion keys: %s", assertion.keyswv())
logger.debug("outstanding_queries: %s", self.outstanding_queries)
- #if self.context == "AuthnReq" or self.context == "AttrQuery":
+ # if self.context == "AuthnReq" or self.context == "AttrQuery":
if self.context == "AuthnReq":
self.authn_statement_ok()
- # elif self.context == "AttrQuery":
+ # elif self.context == "AttrQuery":
# self.authn_statement_ok(True)
if not self.condition_ok():
@@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse):
logger.debug("--- Getting Identity ---")
- #if self.context == "AuthnReq" or self.context == "AttrQuery":
+ # if self.context == "AuthnReq" or self.context == "AttrQuery":
# self.ava = self.get_identity()
# logger.debug("--- AVA: %s", self.ava)
@@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse):
logger.exception("get subject")
raise
- def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False):
- """ Moves the decrypted assertion from the encrypted assertion to a list.
+ def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None,
+ verified=False):
+ """ Moves the decrypted assertion from the encrypted assertion to a
+ list.
:param encrypted_assertions: A list of encrypted assertions.
- :param decr_txt: The string representation containing the decrypted data. Used when verifying signatures.
+ :param decr_txt: The string representation containing the decrypted
+ data. Used when verifying signatures.
:param issuer: The issuer of the response.
- :param verified: If True do not verify signatures, otherwise verify the signature if it exists.
+ :param verified: If True do not verify signatures, otherwise verify
+ the signature if it exists.
:return: A list of decrypted assertions.
"""
res = []
@@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse):
if not self.sec.check_signature(
assertion, origdoc=decr_txt,
node_name=class_name(assertion), issuer=issuer):
- logger.error("Failed to verify signature on '%s'", assertion)
+ logger.error("Failed to verify signature on '%s'",
+ assertion)
raise SignatureError()
res.append(assertion)
return res
@@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse):
:return: True encrypted data exists otherwise false.
"""
for _assertion in enc_assertions:
- if _assertion.encrypted_data is not None:
- return True
+ if _assertion.encrypted_data is not None:
+ return True
def find_encrypt_data_assertion_list(self, _assertions):
- """ Verifies if a list of assertions contains encrypted data in the advice element.
+ """ Verifies if a list of assertions contains encrypted data in the
+ advice element.
:param _assertions: A list of assertions.
:return: True encrypted data exists otherwise false.
@@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse):
for _assertion in _assertions:
if _assertion.advice:
if _assertion.advice.encrypted_assertion:
- res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion)
+ res = self.find_encrypt_data_assertion(
+ _assertion.advice.encrypted_assertion)
if res:
return True
def find_encrypt_data(self, resp):
- """ Verifies if a saml response contains encrypted assertions with encrypted data.
+ """ Verifies if a saml response contains encrypted assertions with
+ encrypted data.
:param resp: A saml response.
:return: True encrypted data exists otherwise false.
@@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in resp.assertion:
if tmp_assertion.advice:
if tmp_assertion.advice.encrypted_assertion:
- res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion)
+ res = self.find_encrypt_data_assertion(
+ tmp_assertion.advice.encrypted_assertion)
if res:
return True
return False
@@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse):
def parse_assertion(self, keys=None):
""" Parse the assertions for a saml response.
- :param keys: A string representing a RSA key or a list of strings containing RSA keys.
+ :param keys: A string representing a RSA key or a list of strings
+ containing RSA keys.
:return: True if the assertions are parsed otherwise False.
"""
if self.context == "AuthnQuery":
@@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse):
else: # This is a saml2int limitation
try:
assert len(self.response.assertion) == 1 or \
- len(self.response.encrypted_assertion) == 1
+ len(self.response.encrypted_assertion) == 1
except AssertionError:
raise Exception("No assertion part")
- has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion
- #if not has_encrypted_assertions and self.response.assertion:
+ has_encrypted_assertions = self.find_encrypt_data(self.response) #
+ # self.response.encrypted_assertion
+ # if not has_encrypted_assertions and self.response.assertion:
# for tmp_assertion in self.response.assertion:
# if tmp_assertion.advice:
# if tmp_assertion.advice.encrypted_assertion:
@@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse):
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text)
- _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
+ _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion,
+ decr_text)
decr_text_old = None
- while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \
+ while (self.find_encrypt_data(
+ resp) or self.find_encrypt_data_assertion_list(
+ _enc_assertions)) and \
decr_text_old != decr_text:
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text)
- _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
- #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
+ _enc_assertions = self.decrypt_assertions(
+ resp.encrypted_assertion, decr_text, verified=True)
+ # _enc_assertions = self.decrypt_assertions(
+ # resp.encrypted_assertion, decr_text, verified=True)
all_assertions = _enc_assertions
if resp.assertion:
all_assertions = all_assertions + resp.assertion
@@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse):
for tmp_ass in all_assertions:
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
- advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion,
- decr_text,
- tmp_ass.issuer)
+ advice_res = self.decrypt_assertions(
+ tmp_ass.advice.encrypted_assertion,
+ decr_text,
+ tmp_ass.issuer)
if tmp_ass.advice.assertion:
tmp_ass.advice.assertion.extend(advice_res)
else:
@@ -1211,7 +1236,7 @@ class AssertionIDResponse(object):
logger.exception("EXCEPTION: %s", excp)
raise
- #print("<", self.response)
+ # print("<", self.response)
return self._postamble()
@@ -1233,4 +1258,3 @@ class AssertionIDResponse(object):
logger.debug("response: %s", self.response)
return self
-