diff options
author | Roland Hedberg <roland.hedberg@adm.umu.se> | 2015-11-18 10:38:39 +0100 |
---|---|---|
committer | Roland Hedberg <roland.hedberg@adm.umu.se> | 2015-11-18 10:38:39 +0100 |
commit | 2ce425c84ceb2f37e8cf9f2eb383abdbc1a2f087 (patch) | |
tree | 7239bd458d27939bb705d6fe68d61ae5b95069c0 /src/saml2 | |
parent | 0218b0b064e4d088e7771200b96ecd3454f04154 (diff) | |
download | pysaml2-2ce425c84ceb2f37e8cf9f2eb383abdbc1a2f087.tar.gz |
Fixed a problem in parsing metadata extensions.
Diffstat (limited to 'src/saml2')
-rw-r--r-- | src/saml2/__init__.py | 2 | ||||
-rw-r--r-- | src/saml2/mdstore.py | 73 | ||||
-rw-r--r-- | src/saml2/response.py | 106 |
3 files changed, 113 insertions, 68 deletions
diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index 4afd69ca..16555e32 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -979,7 +979,7 @@ def extension_elements_to_elements(extension_elements, schemas): if isinstance(schemas, list): pass elif isinstance(schemas, dict): - schemas = schemas.values() + schemas = list(schemas.values()) else: return res diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index 06940d96..5d701428 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -1,19 +1,20 @@ from __future__ import print_function + +import hashlib import logging import os import sys import json -import six +import requests +import six from hashlib import sha1 from os.path import isfile, join from saml2.httpbase import HTTPBase from saml2.extension.idpdisc import BINDING_DISCO from saml2.extension.idpdisc import DiscoveryResponse from saml2.md import EntitiesDescriptor - from saml2.mdie import to_dict - from saml2 import md from saml2 import samlp from saml2 import SAMLError @@ -67,6 +68,20 @@ ENTITY_CATEGORY_SUPPORT = "http://macedir.org/entity-category-support" # --------------------------------------------------- +def load_extensions(): + from saml2 import extension + import pkgutil + + package = extension + prefix = package.__name__ + "." + ext_map = {} + for importer, modname, ispkg in pkgutil.iter_modules(package.__path__, + prefix): + module = __import__(modname, fromlist="dummy") + ext_map[module.NAMESPACE] = module + + return ext_map + def destinations(srvs): return [s["location"] for s in srvs] @@ -564,8 +579,8 @@ class InMemoryMetaData(MetaData): return True node_name = self.node_name \ - or "%s:%s" % (md.EntitiesDescriptor.c_namespace, - md.EntitiesDescriptor.c_tag) + or "%s:%s" % (md.EntitiesDescriptor.c_namespace, + md.EntitiesDescriptor.c_tag) if self.security.verify_signature( txt, node_name=node_name, cert_file=self.cert): @@ -705,27 +720,31 @@ class MetaDataMDX(InMemoryMetaData): """ Uses the md protocol to fetch entity information """ - def __init__(self, entity_transform, onts, attrc, url, security, cert, - http, **kwargs): + @staticmethod + def sha1_entity_transform(entity_id): + return "{{sha1}}{}".format( + hashlib.sha1(entity_id.encode("utf-8")).hexdigest()) + + def __init__(self, url, entity_transform=None): """ - :params entity_transform: function transforming (e.g. base64 or sha1 + :params url: mdx service url + :params entity_transform: function transforming (e.g. base64, + sha1 hash or URL quote hash) the entity id. It is applied to the entity id before it is - concatenated with the request URL sent to the MDX server. - :params onts: - :params attrc: - :params url: - :params security: SecurityContext() - :params cert: - :params http: + concatenated with the request URL sent to the MDX server. Defaults to + sha1 transformation. """ - super(MetaDataMDX, self).__init__(onts, attrc, **kwargs) + super(MetaDataMDX, self).__init__(None, None) self.url = url - self.security = security - self.cert = cert - self.http = http - self.entity_transform = entity_transform + + if entity_transform: + self.entity_transform = entity_transform + else: + + self.entity_transform = MetaDataMDX.sha1_entity_transform def load(self): + # Do nothing pass def __getitem__(self, item): @@ -733,13 +752,9 @@ class MetaDataMDX(InMemoryMetaData): return self.entity[item] except KeyError: mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item)) - response = self.http.send( - mdx_url, headers={'Accept': SAML_METADATA_CONTENT_TYPE}) + response = requests.get(mdx_url, headers={ + 'Accept': SAML_METADATA_CONTENT_TYPE}) if response.status_code == 200: - node_name = self.node_name \ - or "%s:%s" % (md.EntitiesDescriptor.c_namespace, - md.EntitiesDescriptor.c_tag) - _txt = response.text.encode("utf-8") if self.parse_and_check_signature(_txt): @@ -748,6 +763,12 @@ class MetaDataMDX(InMemoryMetaData): logger.info("Response status: %s", response.status_code) raise KeyError + def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"): + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self.service(entity_id, "idpsso_descriptor", + "single_sign_on_service", binding) + class MetadataStore(MetaData): def __init__(self, onts, attrc, config, ca_certs=None, diff --git a/src/saml2/response.py b/src/saml2/response.py index 810266cf..a3487155 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -58,6 +58,7 @@ from saml2.validate import NotValid logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- @@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError): class StatusUnsupportedBinding(StatusError): pass + class StatusResponder(StatusError): pass + STATUSCODE2EXCEPTION = { STATUS_VERSION_MISMATCH: StatusVersionMismatch, STATUS_AUTHN_FAILED: StatusAuthnFailed, @@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = { STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding, STATUS_RESPONDER: StatusResponder, } + + # --------------------------------------------------------------------------- @@ -206,7 +211,8 @@ def for_me(conditions, myself): if audience.text.strip() == myself: return True else: - #print("Not for me: %s != %s" % (audience.text.strip(), myself)) + # print("Not for me: %s != %s" % (audience.text.strip(), + # myself)) pass return False @@ -336,7 +342,7 @@ class StatusResponse(object): logger.exception("EXCEPTION: %s", excp) raise - #print("<", self.response) + # print("<", self.response) return self._postamble() @@ -377,7 +383,7 @@ class StatusResponse(object): if self.request_id and self.in_response_to and \ self.in_response_to != self.request_id: logger.error("Not the id I expected: %s != %s", - self.in_response_to, self.request_id) + self.in_response_to, self.request_id) return None try: @@ -391,9 +397,9 @@ class StatusResponse(object): if self.asynchop: if self.response.destination and \ - self.response.destination not in self.return_addrs: + self.response.destination not in self.return_addrs: logger.error("%s not in %s", self.response.destination, - self.return_addrs) + self.return_addrs) return None assert self.issue_instant_ok() @@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse): request_id=0, asynchop=True): StatusResponse.__init__(self, sec_context, return_addrs, timeslack, request_id, asynchop) - self.signature_check = self.sec\ + self.signature_check = self.sec \ .correctly_signed_name_id_mapping_response @@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse): if self.asynchop: if self.in_response_to in self.outstanding_queries: self.came_from = self.outstanding_queries[self.in_response_to] - #del self.outstanding_queries[self.in_response_to] + # del self.outstanding_queries[self.in_response_to] try: if not self.check_subject_confirmation_in_response_to( self.in_response_to): @@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse): def read_attribute_statement(self, attr_statem): logger.debug("Attribute Statement: %s", attr_statem) - for aconv in self.attribute_converters: - logger.debug("Converts name format: %s", aconv.name_format) + # for aconv in self.attribute_converters: + # logger.debug("Converts name format: %s", aconv.name_format) self.decrypt_attributes(attr_statem) return to_local(self.attribute_converters, attr_statem, - self.allow_unknown_attributes) + self.allow_unknown_attributes) def get_identity(self): """ The assertion can contain zero or one attributeStatements @@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse): for tmp_assertion in _assertion.advice.assertion: if tmp_assertion.attribute_statement: assert len(tmp_assertion.attribute_statement) == 1 - ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0])) + ava.update(self.read_attribute_statement( + tmp_assertion.attribute_statement[0])) if _assertion.attribute_statement: assert len(_assertion.attribute_statement) == 1 _attr_statem = _assertion.attribute_statement[0] @@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse): if data.in_response_to in self.outstanding_queries: self.came_from = self.outstanding_queries[ data.in_response_to] - #del self.outstanding_queries[data.in_response_to] + # del self.outstanding_queries[data.in_response_to] elif self.allow_unsolicited: pass else: @@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse): # recognize logger.debug("in response to: '%s'", data.in_response_to) logger.info("outstanding queries: %s", - self.outstanding_queries.keys()) + self.outstanding_queries.keys()) raise Exception( "Combination of session id and requestURI I don't " "recall") @@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse): logger.debug("signed") if not verified and self.do_not_verify is False: try: - self.sec.check_signature(assertion, class_name(assertion),self.xmlstr) + self.sec.check_signature(assertion, class_name(assertion), + self.xmlstr) except Exception as exc: logger.error("correctly_signed_response: %s", exc) raise @@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse): logger.debug("assertion keys: %s", assertion.keyswv()) logger.debug("outstanding_queries: %s", self.outstanding_queries) - #if self.context == "AuthnReq" or self.context == "AttrQuery": + # if self.context == "AuthnReq" or self.context == "AttrQuery": if self.context == "AuthnReq": self.authn_statement_ok() - # elif self.context == "AttrQuery": + # elif self.context == "AttrQuery": # self.authn_statement_ok(True) if not self.condition_ok(): @@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse): logger.debug("--- Getting Identity ---") - #if self.context == "AuthnReq" or self.context == "AttrQuery": + # if self.context == "AuthnReq" or self.context == "AttrQuery": # self.ava = self.get_identity() # logger.debug("--- AVA: %s", self.ava) @@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse): logger.exception("get subject") raise - def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False): - """ Moves the decrypted assertion from the encrypted assertion to a list. + def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, + verified=False): + """ Moves the decrypted assertion from the encrypted assertion to a + list. :param encrypted_assertions: A list of encrypted assertions. - :param decr_txt: The string representation containing the decrypted data. Used when verifying signatures. + :param decr_txt: The string representation containing the decrypted + data. Used when verifying signatures. :param issuer: The issuer of the response. - :param verified: If True do not verify signatures, otherwise verify the signature if it exists. + :param verified: If True do not verify signatures, otherwise verify + the signature if it exists. :return: A list of decrypted assertions. """ res = [] @@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse): if not self.sec.check_signature( assertion, origdoc=decr_txt, node_name=class_name(assertion), issuer=issuer): - logger.error("Failed to verify signature on '%s'", assertion) + logger.error("Failed to verify signature on '%s'", + assertion) raise SignatureError() res.append(assertion) return res @@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse): :return: True encrypted data exists otherwise false. """ for _assertion in enc_assertions: - if _assertion.encrypted_data is not None: - return True + if _assertion.encrypted_data is not None: + return True def find_encrypt_data_assertion_list(self, _assertions): - """ Verifies if a list of assertions contains encrypted data in the advice element. + """ Verifies if a list of assertions contains encrypted data in the + advice element. :param _assertions: A list of assertions. :return: True encrypted data exists otherwise false. @@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse): for _assertion in _assertions: if _assertion.advice: if _assertion.advice.encrypted_assertion: - res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion) + res = self.find_encrypt_data_assertion( + _assertion.advice.encrypted_assertion) if res: return True def find_encrypt_data(self, resp): - """ Verifies if a saml response contains encrypted assertions with encrypted data. + """ Verifies if a saml response contains encrypted assertions with + encrypted data. :param resp: A saml response. :return: True encrypted data exists otherwise false. @@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse): for tmp_assertion in resp.assertion: if tmp_assertion.advice: if tmp_assertion.advice.encrypted_assertion: - res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion) + res = self.find_encrypt_data_assertion( + tmp_assertion.advice.encrypted_assertion) if res: return True return False @@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse): def parse_assertion(self, keys=None): """ Parse the assertions for a saml response. - :param keys: A string representing a RSA key or a list of strings containing RSA keys. + :param keys: A string representing a RSA key or a list of strings + containing RSA keys. :return: True if the assertions are parsed otherwise False. """ if self.context == "AuthnQuery": @@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse): else: # This is a saml2int limitation try: assert len(self.response.assertion) == 1 or \ - len(self.response.encrypted_assertion) == 1 + len(self.response.encrypted_assertion) == 1 except AssertionError: raise Exception("No assertion part") - has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion - #if not has_encrypted_assertions and self.response.assertion: + has_encrypted_assertions = self.find_encrypt_data(self.response) # + # self.response.encrypted_assertion + # if not has_encrypted_assertions and self.response.assertion: # for tmp_assertion in self.response.assertion: # if tmp_assertion.advice: # if tmp_assertion.advice.encrypted_assertion: @@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse): decr_text_old = decr_text decr_text = self.sec.decrypt_keys(decr_text, keys) resp = samlp.response_from_string(decr_text) - _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text) + _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, + decr_text) decr_text_old = None - while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \ + while (self.find_encrypt_data( + resp) or self.find_encrypt_data_assertion_list( + _enc_assertions)) and \ decr_text_old != decr_text: decr_text_old = decr_text decr_text = self.sec.decrypt_keys(decr_text, keys) resp = samlp.response_from_string(decr_text) - _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) - #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) + _enc_assertions = self.decrypt_assertions( + resp.encrypted_assertion, decr_text, verified=True) + # _enc_assertions = self.decrypt_assertions( + # resp.encrypted_assertion, decr_text, verified=True) all_assertions = _enc_assertions if resp.assertion: all_assertions = all_assertions + resp.assertion @@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse): for tmp_ass in all_assertions: if tmp_ass.advice and tmp_ass.advice.encrypted_assertion: - advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion, - decr_text, - tmp_ass.issuer) + advice_res = self.decrypt_assertions( + tmp_ass.advice.encrypted_assertion, + decr_text, + tmp_ass.issuer) if tmp_ass.advice.assertion: tmp_ass.advice.assertion.extend(advice_res) else: @@ -1211,7 +1236,7 @@ class AssertionIDResponse(object): logger.exception("EXCEPTION: %s", excp) raise - #print("<", self.response) + # print("<", self.response) return self._postamble() @@ -1233,4 +1258,3 @@ class AssertionIDResponse(object): logger.debug("response: %s", self.response) return self - |