summaryrefslogtreecommitdiff
path: root/src/saml2
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2015-11-18 10:38:39 +0100
committerRoland Hedberg <roland.hedberg@adm.umu.se>2015-11-18 10:38:39 +0100
commit2ce425c84ceb2f37e8cf9f2eb383abdbc1a2f087 (patch)
tree7239bd458d27939bb705d6fe68d61ae5b95069c0 /src/saml2
parent0218b0b064e4d088e7771200b96ecd3454f04154 (diff)
downloadpysaml2-2ce425c84ceb2f37e8cf9f2eb383abdbc1a2f087.tar.gz
Fixed a problem in parsing metadata extensions.
Diffstat (limited to 'src/saml2')
-rw-r--r--src/saml2/__init__.py2
-rw-r--r--src/saml2/mdstore.py73
-rw-r--r--src/saml2/response.py106
3 files changed, 113 insertions, 68 deletions
diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py
index 4afd69ca..16555e32 100644
--- a/src/saml2/__init__.py
+++ b/src/saml2/__init__.py
@@ -979,7 +979,7 @@ def extension_elements_to_elements(extension_elements, schemas):
if isinstance(schemas, list):
pass
elif isinstance(schemas, dict):
- schemas = schemas.values()
+ schemas = list(schemas.values())
else:
return res
diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py
index 06940d96..5d701428 100644
--- a/src/saml2/mdstore.py
+++ b/src/saml2/mdstore.py
@@ -1,19 +1,20 @@
from __future__ import print_function
+
+import hashlib
import logging
import os
import sys
import json
-import six
+import requests
+import six
from hashlib import sha1
from os.path import isfile, join
from saml2.httpbase import HTTPBase
from saml2.extension.idpdisc import BINDING_DISCO
from saml2.extension.idpdisc import DiscoveryResponse
from saml2.md import EntitiesDescriptor
-
from saml2.mdie import to_dict
-
from saml2 import md
from saml2 import samlp
from saml2 import SAMLError
@@ -67,6 +68,20 @@ ENTITY_CATEGORY_SUPPORT = "http://macedir.org/entity-category-support"
# ---------------------------------------------------
+def load_extensions():
+ from saml2 import extension
+ import pkgutil
+
+ package = extension
+ prefix = package.__name__ + "."
+ ext_map = {}
+ for importer, modname, ispkg in pkgutil.iter_modules(package.__path__,
+ prefix):
+ module = __import__(modname, fromlist="dummy")
+ ext_map[module.NAMESPACE] = module
+
+ return ext_map
+
def destinations(srvs):
return [s["location"] for s in srvs]
@@ -564,8 +579,8 @@ class InMemoryMetaData(MetaData):
return True
node_name = self.node_name \
- or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
- md.EntitiesDescriptor.c_tag)
+ or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
+ md.EntitiesDescriptor.c_tag)
if self.security.verify_signature(
txt, node_name=node_name, cert_file=self.cert):
@@ -705,27 +720,31 @@ class MetaDataMDX(InMemoryMetaData):
""" Uses the md protocol to fetch entity information
"""
- def __init__(self, entity_transform, onts, attrc, url, security, cert,
- http, **kwargs):
+ @staticmethod
+ def sha1_entity_transform(entity_id):
+ return "{{sha1}}{}".format(
+ hashlib.sha1(entity_id.encode("utf-8")).hexdigest())
+
+ def __init__(self, url, entity_transform=None):
"""
- :params entity_transform: function transforming (e.g. base64 or sha1
+ :params url: mdx service url
+ :params entity_transform: function transforming (e.g. base64,
+ sha1 hash or URL quote
hash) the entity id. It is applied to the entity id before it is
- concatenated with the request URL sent to the MDX server.
- :params onts:
- :params attrc:
- :params url:
- :params security: SecurityContext()
- :params cert:
- :params http:
+ concatenated with the request URL sent to the MDX server. Defaults to
+ sha1 transformation.
"""
- super(MetaDataMDX, self).__init__(onts, attrc, **kwargs)
+ super(MetaDataMDX, self).__init__(None, None)
self.url = url
- self.security = security
- self.cert = cert
- self.http = http
- self.entity_transform = entity_transform
+
+ if entity_transform:
+ self.entity_transform = entity_transform
+ else:
+
+ self.entity_transform = MetaDataMDX.sha1_entity_transform
def load(self):
+ # Do nothing
pass
def __getitem__(self, item):
@@ -733,13 +752,9 @@ class MetaDataMDX(InMemoryMetaData):
return self.entity[item]
except KeyError:
mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item))
- response = self.http.send(
- mdx_url, headers={'Accept': SAML_METADATA_CONTENT_TYPE})
+ response = requests.get(mdx_url, headers={
+ 'Accept': SAML_METADATA_CONTENT_TYPE})
if response.status_code == 200:
- node_name = self.node_name \
- or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
- md.EntitiesDescriptor.c_tag)
-
_txt = response.text.encode("utf-8")
if self.parse_and_check_signature(_txt):
@@ -748,6 +763,12 @@ class MetaDataMDX(InMemoryMetaData):
logger.info("Response status: %s", response.status_code)
raise KeyError
+ def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"):
+ if binding is None:
+ binding = BINDING_HTTP_REDIRECT
+ return self.service(entity_id, "idpsso_descriptor",
+ "single_sign_on_service", binding)
+
class MetadataStore(MetaData):
def __init__(self, onts, attrc, config, ca_certs=None,
diff --git a/src/saml2/response.py b/src/saml2/response.py
index 810266cf..a3487155 100644
--- a/src/saml2/response.py
+++ b/src/saml2/response.py
@@ -58,6 +58,7 @@ from saml2.validate import NotValid
logger = logging.getLogger(__name__)
+
# ---------------------------------------------------------------------------
@@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError):
class StatusUnsupportedBinding(StatusError):
pass
+
class StatusResponder(StatusError):
pass
+
STATUSCODE2EXCEPTION = {
STATUS_VERSION_MISMATCH: StatusVersionMismatch,
STATUS_AUTHN_FAILED: StatusAuthnFailed,
@@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = {
STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding,
STATUS_RESPONDER: StatusResponder,
}
+
+
# ---------------------------------------------------------------------------
@@ -206,7 +211,8 @@ def for_me(conditions, myself):
if audience.text.strip() == myself:
return True
else:
- #print("Not for me: %s != %s" % (audience.text.strip(), myself))
+ # print("Not for me: %s != %s" % (audience.text.strip(),
+ # myself))
pass
return False
@@ -336,7 +342,7 @@ class StatusResponse(object):
logger.exception("EXCEPTION: %s", excp)
raise
- #print("<", self.response)
+ # print("<", self.response)
return self._postamble()
@@ -377,7 +383,7 @@ class StatusResponse(object):
if self.request_id and self.in_response_to and \
self.in_response_to != self.request_id:
logger.error("Not the id I expected: %s != %s",
- self.in_response_to, self.request_id)
+ self.in_response_to, self.request_id)
return None
try:
@@ -391,9 +397,9 @@ class StatusResponse(object):
if self.asynchop:
if self.response.destination and \
- self.response.destination not in self.return_addrs:
+ self.response.destination not in self.return_addrs:
logger.error("%s not in %s", self.response.destination,
- self.return_addrs)
+ self.return_addrs)
return None
assert self.issue_instant_ok()
@@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse):
request_id=0, asynchop=True):
StatusResponse.__init__(self, sec_context, return_addrs, timeslack,
request_id, asynchop)
- self.signature_check = self.sec\
+ self.signature_check = self.sec \
.correctly_signed_name_id_mapping_response
@@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse):
if self.asynchop:
if self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to]
- #del self.outstanding_queries[self.in_response_to]
+ # del self.outstanding_queries[self.in_response_to]
try:
if not self.check_subject_confirmation_in_response_to(
self.in_response_to):
@@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse):
def read_attribute_statement(self, attr_statem):
logger.debug("Attribute Statement: %s", attr_statem)
- for aconv in self.attribute_converters:
- logger.debug("Converts name format: %s", aconv.name_format)
+ # for aconv in self.attribute_converters:
+ # logger.debug("Converts name format: %s", aconv.name_format)
self.decrypt_attributes(attr_statem)
return to_local(self.attribute_converters, attr_statem,
- self.allow_unknown_attributes)
+ self.allow_unknown_attributes)
def get_identity(self):
""" The assertion can contain zero or one attributeStatements
@@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in _assertion.advice.assertion:
if tmp_assertion.attribute_statement:
assert len(tmp_assertion.attribute_statement) == 1
- ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
+ ava.update(self.read_attribute_statement(
+ tmp_assertion.attribute_statement[0]))
if _assertion.attribute_statement:
assert len(_assertion.attribute_statement) == 1
_attr_statem = _assertion.attribute_statement[0]
@@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse):
if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[
data.in_response_to]
- #del self.outstanding_queries[data.in_response_to]
+ # del self.outstanding_queries[data.in_response_to]
elif self.allow_unsolicited:
pass
else:
@@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse):
# recognize
logger.debug("in response to: '%s'", data.in_response_to)
logger.info("outstanding queries: %s",
- self.outstanding_queries.keys())
+ self.outstanding_queries.keys())
raise Exception(
"Combination of session id and requestURI I don't "
"recall")
@@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse):
logger.debug("signed")
if not verified and self.do_not_verify is False:
try:
- self.sec.check_signature(assertion, class_name(assertion),self.xmlstr)
+ self.sec.check_signature(assertion, class_name(assertion),
+ self.xmlstr)
except Exception as exc:
logger.error("correctly_signed_response: %s", exc)
raise
@@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse):
logger.debug("assertion keys: %s", assertion.keyswv())
logger.debug("outstanding_queries: %s", self.outstanding_queries)
- #if self.context == "AuthnReq" or self.context == "AttrQuery":
+ # if self.context == "AuthnReq" or self.context == "AttrQuery":
if self.context == "AuthnReq":
self.authn_statement_ok()
- # elif self.context == "AttrQuery":
+ # elif self.context == "AttrQuery":
# self.authn_statement_ok(True)
if not self.condition_ok():
@@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse):
logger.debug("--- Getting Identity ---")
- #if self.context == "AuthnReq" or self.context == "AttrQuery":
+ # if self.context == "AuthnReq" or self.context == "AttrQuery":
# self.ava = self.get_identity()
# logger.debug("--- AVA: %s", self.ava)
@@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse):
logger.exception("get subject")
raise
- def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False):
- """ Moves the decrypted assertion from the encrypted assertion to a list.
+ def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None,
+ verified=False):
+ """ Moves the decrypted assertion from the encrypted assertion to a
+ list.
:param encrypted_assertions: A list of encrypted assertions.
- :param decr_txt: The string representation containing the decrypted data. Used when verifying signatures.
+ :param decr_txt: The string representation containing the decrypted
+ data. Used when verifying signatures.
:param issuer: The issuer of the response.
- :param verified: If True do not verify signatures, otherwise verify the signature if it exists.
+ :param verified: If True do not verify signatures, otherwise verify
+ the signature if it exists.
:return: A list of decrypted assertions.
"""
res = []
@@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse):
if not self.sec.check_signature(
assertion, origdoc=decr_txt,
node_name=class_name(assertion), issuer=issuer):
- logger.error("Failed to verify signature on '%s'", assertion)
+ logger.error("Failed to verify signature on '%s'",
+ assertion)
raise SignatureError()
res.append(assertion)
return res
@@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse):
:return: True encrypted data exists otherwise false.
"""
for _assertion in enc_assertions:
- if _assertion.encrypted_data is not None:
- return True
+ if _assertion.encrypted_data is not None:
+ return True
def find_encrypt_data_assertion_list(self, _assertions):
- """ Verifies if a list of assertions contains encrypted data in the advice element.
+ """ Verifies if a list of assertions contains encrypted data in the
+ advice element.
:param _assertions: A list of assertions.
:return: True encrypted data exists otherwise false.
@@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse):
for _assertion in _assertions:
if _assertion.advice:
if _assertion.advice.encrypted_assertion:
- res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion)
+ res = self.find_encrypt_data_assertion(
+ _assertion.advice.encrypted_assertion)
if res:
return True
def find_encrypt_data(self, resp):
- """ Verifies if a saml response contains encrypted assertions with encrypted data.
+ """ Verifies if a saml response contains encrypted assertions with
+ encrypted data.
:param resp: A saml response.
:return: True encrypted data exists otherwise false.
@@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in resp.assertion:
if tmp_assertion.advice:
if tmp_assertion.advice.encrypted_assertion:
- res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion)
+ res = self.find_encrypt_data_assertion(
+ tmp_assertion.advice.encrypted_assertion)
if res:
return True
return False
@@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse):
def parse_assertion(self, keys=None):
""" Parse the assertions for a saml response.
- :param keys: A string representing a RSA key or a list of strings containing RSA keys.
+ :param keys: A string representing a RSA key or a list of strings
+ containing RSA keys.
:return: True if the assertions are parsed otherwise False.
"""
if self.context == "AuthnQuery":
@@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse):
else: # This is a saml2int limitation
try:
assert len(self.response.assertion) == 1 or \
- len(self.response.encrypted_assertion) == 1
+ len(self.response.encrypted_assertion) == 1
except AssertionError:
raise Exception("No assertion part")
- has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion
- #if not has_encrypted_assertions and self.response.assertion:
+ has_encrypted_assertions = self.find_encrypt_data(self.response) #
+ # self.response.encrypted_assertion
+ # if not has_encrypted_assertions and self.response.assertion:
# for tmp_assertion in self.response.assertion:
# if tmp_assertion.advice:
# if tmp_assertion.advice.encrypted_assertion:
@@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse):
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text)
- _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
+ _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion,
+ decr_text)
decr_text_old = None
- while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \
+ while (self.find_encrypt_data(
+ resp) or self.find_encrypt_data_assertion_list(
+ _enc_assertions)) and \
decr_text_old != decr_text:
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text)
- _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
- #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
+ _enc_assertions = self.decrypt_assertions(
+ resp.encrypted_assertion, decr_text, verified=True)
+ # _enc_assertions = self.decrypt_assertions(
+ # resp.encrypted_assertion, decr_text, verified=True)
all_assertions = _enc_assertions
if resp.assertion:
all_assertions = all_assertions + resp.assertion
@@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse):
for tmp_ass in all_assertions:
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
- advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion,
- decr_text,
- tmp_ass.issuer)
+ advice_res = self.decrypt_assertions(
+ tmp_ass.advice.encrypted_assertion,
+ decr_text,
+ tmp_ass.issuer)
if tmp_ass.advice.assertion:
tmp_ass.advice.assertion.extend(advice_res)
else:
@@ -1211,7 +1236,7 @@ class AssertionIDResponse(object):
logger.exception("EXCEPTION: %s", excp)
raise
- #print("<", self.response)
+ # print("<", self.response)
return self._postamble()
@@ -1233,4 +1258,3 @@ class AssertionIDResponse(object):
logger.debug("response: %s", self.response)
return self
-