summaryrefslogtreecommitdiff
path: root/src/saml2
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2')
-rw-r--r--src/saml2/__init__.py15
-rw-r--r--src/saml2/client_base.py5
-rw-r--r--src/saml2/entity.py20
-rw-r--r--src/saml2/response.py111
-rw-r--r--src/saml2/sigver.py30
5 files changed, 118 insertions, 63 deletions
diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py
index 87bfafa6..4493c9c8 100644
--- a/src/saml2/__init__.py
+++ b/src/saml2/__init__.py
@@ -591,13 +591,14 @@ class SamlBase(ExtensionContainer):
def get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion(self, assertion_tag, advice_tag):
for tmp_encrypted_assertion in self.assertion.advice.encrypted_assertion:
- prefix_map = self.get_prefix_map([tmp_encrypted_assertion._to_element_tree().
- find(assertion_tag)])
-
- tree = self._to_element_tree()
-
- self.set_prefixes(tree.find(assertion_tag).find(advice_tag).find(tmp_encrypted_assertion._to_element_tree()
- .tag).find(assertion_tag), prefix_map)
+ if tmp_encrypted_assertion.encrypted_data is None:
+ prefix_map = self.get_prefix_map([tmp_encrypted_assertion._to_element_tree().find(assertion_tag)])
+ tree = self._to_element_tree()
+ encs = tree.find(assertion_tag).find(advice_tag).findall(tmp_encrypted_assertion._to_element_tree().tag)
+ for enc in encs:
+ assertion = enc.find(assertion_tag)
+ if assertion is not None:
+ self.set_prefixes(assertion, prefix_map)
return ElementTree.tostring(tree, encoding="UTF-8")
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py
index 3f7a227b..da35ae9c 100644
--- a/src/saml2/client_base.py
+++ b/src/saml2/client_base.py
@@ -542,7 +542,7 @@ class Base(Entity):
# ======== response handling ===========
def parse_authn_request_response(self, xmlstr, binding, outstanding=None,
- outstanding_certs=None, decrypt=True, pefim=False):
+ outstanding_certs=None):
""" Deal with an AuthnResponse
:param xmlstr: The reply as a xml string
@@ -573,12 +573,11 @@ class Base(Entity):
"attribute_converters": self.config.attribute_converters,
"allow_unknown_attributes":
self.config.allow_unknown_attributes,
- "decrypt": decrypt
}
try:
resp = self._parse_response(xmlstr, AuthnResponse,
"assertion_consumer_service",
- binding, pefim=pefim, **kwargs)
+ binding, **kwargs)
except StatusError as err:
logger.error("SAML status error: %s" % err)
raise
diff --git a/src/saml2/entity.py b/src/saml2/entity.py
index e7a75a5a..e7ef879c 100644
--- a/src/saml2/entity.py
+++ b/src/saml2/entity.py
@@ -978,7 +978,7 @@ class Entity(HTTPBase):
# ------------------------------------------------------------------------
def _parse_response(self, xmlstr, response_cls, service, binding,
- outstanding_certs=None, pefim=False, **kwargs):
+ outstanding_certs=None, **kwargs):
""" Deal with a Response
:param xmlstr: The response as a xml string
@@ -1040,23 +1040,23 @@ class Entity(HTTPBase):
logger.debug("XMLSTR: %s" % xmlstr)
if response:
+ keys = None
if outstanding_certs:
try:
cert = outstanding_certs[response.in_response_to]
except KeyError:
- key_file = ""
+ keys = None
else:
- _, key_file = make_temp("%s" % cert["key"],
- decode=False)
- else:
- key_file = ""
+ if not isinstance(cert, list):
+ cert = [cert]
+ keys = []
+ for _cert in cert:
+ keys.append(_cert["key"])
only_identity_in_encrypted_assertion = False
if "only_identity_in_encrypted_assertion" in kwargs:
only_identity_in_encrypted_assertion = kwargs["only_identity_in_encrypted_assertion"]
- decrypt = True
- if "decrypt" in kwargs:
- decrypt = kwargs["decrypt"]
- response = response.verify(key_file, decrypt=decrypt, pefim=pefim)
+
+ response = response.verify(keys)
if not response:
return None
diff --git a/src/saml2/response.py b/src/saml2/response.py
index d46a0b0f..6f733280 100644
--- a/src/saml2/response.py
+++ b/src/saml2/response.py
@@ -395,7 +395,7 @@ class StatusResponse(object):
def loads(self, xmldata, decode=True, origxml=None):
return self._loads(xmldata, decode, origxml)
- def verify(self, key_file="", decrypt=True, pefim=False):
+ def verify(self, keys=None):
try:
return self._verify()
except AssertionError:
@@ -636,18 +636,19 @@ class AuthnResponse(StatusResponse):
"""
ava = {}
- if self.assertion.advice:
- if self.assertion.advice.assertion:
- for tmp_assertion in self.assertion.advice.assertion:
- if tmp_assertion.attribute_statement:
- assert len(tmp_assertion.attribute_statement) == 1
- ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
- if self.assertion.attribute_statement:
- assert len(self.assertion.attribute_statement) == 1
- _attr_statem = self.assertion.attribute_statement[0]
- ava.update(self.read_attribute_statement(_attr_statem))
- if not ava:
- logger.error("Missing Attribute Statement")
+ for _assertion in self.assertions:
+ if _assertion.advice:
+ if _assertion.advice.assertion:
+ for tmp_assertion in _assertion.advice.assertion:
+ if tmp_assertion.attribute_statement:
+ assert len(tmp_assertion.attribute_statement) == 1
+ ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
+ if _assertion.attribute_statement:
+ assert len(_assertion.attribute_statement) == 1
+ _attr_statem = _assertion.attribute_statement[0]
+ ava.update(self.read_attribute_statement(_attr_statem))
+ if not ava:
+ logger.error("Missing Attribute Statement")
return ava
def _bearer_confirmed(self, data):
@@ -796,14 +797,14 @@ class AuthnResponse(StatusResponse):
logger.exception("get subject")
raise
- def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None):
+ def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False):
res = []
for encrypted_assertion in encrypted_assertions:
if encrypted_assertion.extension_elements:
assertions = extension_elements_to_elements(
encrypted_assertion.extension_elements, [saml, samlp])
for assertion in assertions:
- if assertion.signature:
+ if assertion.signature and not verified:
if not self.sec.check_signature(
assertion, origdoc=decr_txt,
node_name=class_name(assertion), issuer=issuer):
@@ -812,7 +813,35 @@ class AuthnResponse(StatusResponse):
res.append(assertion)
return res
- def parse_assertion(self, key_file="", decrypt=True, pefim=False):
+ def find_encrypt_data_assertion(self, enc_assertions):
+ for _assertion in enc_assertions:
+ if _assertion.encrypted_data is not None:
+ return True
+
+ def find_encrypt_data_assertion_list(self, _assertions):
+ for _assertion in _assertions:
+ if _assertion.advice:
+ if _assertion.advice.encrypted_assertion:
+ res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion)
+ if res:
+ return True
+
+ def find_encrypt_data(self, resp):
+ _has_encrypt_data = False
+ if resp.encrypted_assertion:
+ res = self.find_encrypt_data_assertion(resp.encrypted_assertion)
+ if res:
+ return True
+ if resp.assertion:
+ for tmp_assertion in resp.assertion:
+ if tmp_assertion.advice:
+ if tmp_assertion.advice.encrypted_assertion:
+ res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion)
+ if res:
+ return True
+ return False
+
+ def parse_assertion(self, keys=None):
if self.context == "AuthnQuery":
# can contain one or more assertions
pass
@@ -823,13 +852,13 @@ class AuthnResponse(StatusResponse):
except AssertionError:
raise Exception("No assertion part")
- has_encrypted_assertions = self.response.encrypted_assertion
- if not has_encrypted_assertions and self.response.assertion:
- for tmp_assertion in self.response.assertion:
- if tmp_assertion.advice:
- if tmp_assertion.advice.encrypted_assertion:
- has_encrypted_assertions = True
- break
+ has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion
+ #if not has_encrypted_assertions and self.response.assertion:
+ # for tmp_assertion in self.response.assertion:
+ # if tmp_assertion.advice:
+ # if tmp_assertion.advice.encrypted_assertion:
+ # has_encrypted_assertions = True
+ # break
if self.response.assertion:
logger.debug("***Unencrypted assertion***")
@@ -837,16 +866,25 @@ class AuthnResponse(StatusResponse):
if not self._assertion(assertion, False):
return False
- if has_encrypted_assertions and decrypt:
+ if has_encrypted_assertions:
_enc_assertions = []
logger.debug("***Encrypted assertion/-s***")
- decr_text = self.sec.decrypt(self.xmlstr, key_file)
- resp = samlp.response_from_string(decr_text)
+ decr_text = "%s" % self.response
+ resp = self.response
+ while self.find_encrypt_data(resp):
+ decr_text = self.sec.decrypt_keys(decr_text, keys)
+ resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
- decr_text = self.sec.decrypt(decr_text, key_file)
- resp = samlp.response_from_string(decr_text)
+ while self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions):
+ decr_text = self.sec.decrypt_keys(decr_text, keys)
+ resp = samlp.response_from_string(decr_text)
+ _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
+ #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
+ all_assertions = _enc_assertions
if resp.assertion:
- for tmp_ass in resp.assertion:
+ all_assertions = all_assertions + resp.assertion
+ if len(all_assertions) > 0:
+ for tmp_ass in all_assertions:
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion,
decr_text,
@@ -855,21 +893,20 @@ class AuthnResponse(StatusResponse):
tmp_ass.advice.assertion.extend(advice_res)
else:
tmp_ass.advice.assertion = advice_res
- if not pefim:
- _enc_assertions.extend(advice_res)
tmp_ass.advice.encrypted_assertion = []
self.response.assertion = resp.assertion
for assertion in _enc_assertions:
if not self._assertion(assertion, True):
return False
+ else:
+ self.assertions.append(assertion)
+
self.xmlstr = decr_text
self.response.encrypted_assertion = []
if self.response.assertion:
for assertion in self.response.assertion:
- if assertion.advice and assertion.advice.assertion:
- for advice_assertion in assertion.advice.assertion:
- self.assertions.append(assertion)
+ self.assertions.append(assertion)
if self.assertions and len(self.assertions) > 0:
self.assertion = self.assertions[0]
@@ -880,7 +917,7 @@ class AuthnResponse(StatusResponse):
return True
- def verify(self, key_file="", decrypt=True, pefim=False):
+ def verify(self, keys=None):
""" Verify that the assertion is syntactically correct and
the signature is correct if present.
:param key_file: If not the default key file should be used this is it.
@@ -898,7 +935,7 @@ class AuthnResponse(StatusResponse):
if not isinstance(self.response, samlp.Response):
return self
- if self.parse_assertion(key_file, decrypt=decrypt, pefim=pefim):
+ if self.parse_assertion(keys):
return self
else:
logger.error("Could not parse the assertion")
@@ -1126,7 +1163,7 @@ class AssertionIDResponse(object):
return self._postamble()
- def verify(self, key_file="", decrypt=True, pefim=False):
+ def verify(self, keys=None):
try:
valid_instance(self.response)
except NotValid as exc:
diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py
index 096134e2..ca928686 100644
--- a/src/saml2/sigver.py
+++ b/src/saml2/sigver.py
@@ -769,7 +769,7 @@ class CryptoBackendXmlSec1(CryptoBackend):
return output
def encrypt_assertion(self, statement, enc_key, template,
- key_type="des-192", node_xpath=None):
+ key_type="des-192", node_xpath=None, node_id=None):
"""
Will encrypt an assertion
@@ -792,6 +792,8 @@ class CryptoBackendXmlSec1(CryptoBackend):
com_list = [self.xmlsec, "encrypt", "--pubkey-cert-pem", enc_key,
"--session-key", key_type, "--xml-data", fil,
"--node-xpath", node_xpath]
+ if node_id:
+ com_list.extend(["--node-id", node_id])
(_stdout, _stderr, output) = self._run_xmlsec(
com_list, [tmpl], exception=EncryptError, validate_output=False)
@@ -1300,22 +1302,38 @@ class SecurityContext(object):
"""
raise NotImplemented()
- def decrypt(self, enctext, key_file=None):
+ def decrypt_keys(self, enctext, keys=None):
""" Decrypting an encrypted text by the use of a private key.
:param enctext: The encrypted text as a string
:return: The decrypted text
"""
+ if not isinstance(keys, list):
+ keys = [keys]
_enctext = self.crypto.decrypt(enctext, self.key_file)
if _enctext is not None and len(_enctext) > 0:
return _enctext
- if key_file is not None and len(key_file.strip()) > 0:
- _enctext = self.crypto.decrypt(enctext, key_file)
- if _enctext is not None and len(_enctext) > 0:
- return _enctext
+ for _key in keys:
+ if _key is not None and len(_key.strip()) > 0:
+ _, key_file = make_temp("%s" % _key, decode=False)
+ _enctext = self.crypto.decrypt(enctext, key_file)
+ if _enctext is not None and len(_enctext) > 0:
+ return _enctext
+ return enctext
+
+ def decrypt(self, enctext, key_file=None):
+ """ Decrypting an encrypted text by the use of a private key.
+
+ :param enctext: The encrypted text as a string
+ :return: The decrypted text
+ """
_enctext = self.crypto.decrypt(enctext, self.key_file)
if _enctext is not None and len(_enctext) > 0:
return _enctext
+ if key_file is not None and len(key_file.strip()) > 0:
+ _enctext = self.crypto.decrypt(enctext, key_file)
+ if _enctext is not None and len(_enctext) > 0:
+ return _enctext
return enctext
def verify_signature(self, signedtext, cert_file=None, cert_type="pem",