diff options
author | Roland Hedberg <roland.hedberg@adm.umu.se> | 2015-12-11 13:02:49 +0100 |
---|---|---|
committer | Roland Hedberg <roland.hedberg@adm.umu.se> | 2015-12-11 13:02:49 +0100 |
commit | 6200f158dbad1acf9bf6982a738c58620452f813 (patch) | |
tree | 3fd0a53efa2cc70cae8b72289fa5cb7f39bdea7f /src | |
parent | 82d3b4da6ebd19f556d2f4d377236a05bb64cd75 (diff) | |
download | pysaml2-6200f158dbad1acf9bf6982a738c58620452f813.tar.gz |
Reworked the security backend so you should now be able to use a HSM again for XML security. Support for non-XML crypto using HSMs are on the way.
Diffstat (limited to 'src')
-rw-r--r-- | src/saml2/client.py | 27 | ||||
-rw-r--r-- | src/saml2/config.py | 3 | ||||
-rw-r--r-- | src/saml2/entity.py | 118 | ||||
-rw-r--r-- | src/saml2/entity_category/swamid.py | 3 | ||||
-rw-r--r-- | src/saml2/httpbase.py | 9 | ||||
-rw-r--r-- | src/saml2/mdstore.py | 5 | ||||
-rw-r--r-- | src/saml2/pack.py | 38 | ||||
-rw-r--r-- | src/saml2/server.py | 302 | ||||
-rw-r--r-- | src/saml2/sigver.py | 207 |
9 files changed, 408 insertions, 304 deletions
diff --git a/src/saml2/client.py b/src/saml2/client.py index fca9859b..65c2bcad 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -66,7 +66,8 @@ class Saml2Client(Base): :return: session id and AuthnRequest info """ - reqid, negotiated_binding, info = self.prepare_for_negotiated_authenticate( + reqid, negotiated_binding, info = \ + self.prepare_for_negotiated_authenticate( entityid=entityid, relay_state=relay_state, binding=binding, @@ -137,7 +138,8 @@ class Saml2Client(Base): raise SignOnError( "No supported bindings available for authentication") - def global_logout(self, name_id, reason="", expire=None, sign=None, sign_alg=None, digest_alg=None): + def global_logout(self, name_id, reason="", expire=None, sign=None, + sign_alg=None, digest_alg=None): """ More or less a layer of indirection :-/ Bootstrapping the whole thing by finding all the IdPs that should be notified. @@ -162,10 +164,12 @@ class Saml2Client(Base): # find out which IdPs/AAs I should notify entity_ids = self.users.issuers_of_info(name_id) - return self.do_logout(name_id, entity_ids, reason, expire, sign, sign_alg=sign_alg, digest_alg=digest_alg) + return self.do_logout(name_id, entity_ids, reason, expire, sign, + sign_alg=sign_alg, digest_alg=digest_alg) def do_logout(self, name_id, entity_ids, reason, expire, sign=None, - expected_binding=None, sign_alg=None, digest_alg=None, **kwargs): + expected_binding=None, sign_alg=None, digest_alg=None, + **kwargs): """ :param name_id: Identifier of the Subject (a NameID instance) @@ -227,22 +231,22 @@ class Saml2Client(Base): sign = self.logout_requests_signed sigalg = None - key = None if sign: if binding == BINDING_HTTP_REDIRECT: - sigalg = kwargs.get("sigalg", ds.DefaultSignature().get_sign_alg()) - key = kwargs.get("key", self.signkey) + sigalg = kwargs.get( + "sigalg", ds.DefaultSignature().get_sign_alg()) + #key = kwargs.get("key", self.signkey) srequest = str(request) else: - srequest = self.sign(request, sign_alg=sign_alg, digest_alg=digest_alg) + srequest = self.sign(request, sign_alg=sign_alg, + digest_alg=digest_alg) else: srequest = str(request) relay_state = self._relay_state(req_id) http_info = self.apply_binding(binding, srequest, destination, - relay_state, sigalg=sigalg, - key=key) + relay_state, sigalg=sigalg) if binding == BINDING_SOAP: response = self.send(**http_info) @@ -318,7 +322,8 @@ class Saml2Client(Base): return self.do_logout(decode(status["name_id"]), status["entity_ids"], status["reason"], status["not_on_or_after"], - status["sign"], sign_alg=sign_alg, digest_alg=digest_alg) + status["sign"], sign_alg=sign_alg, + digest_alg=digest_alg) def _use_soap(self, destination, query_type, **kwargs): _create_func = getattr(self, "create_%s" % query_type) diff --git a/src/saml2/config.py b/src/saml2/config.py index 952526f9..24e891f4 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -53,7 +53,8 @@ COMMON_ARGS = [ "tmp_key_file", "validate_certificate", "extensions", - "allow_unknown_attributes" + "allow_unknown_attributes", + "crypto_backend" ] SP_ARGS = [ diff --git a/src/saml2/entity.py b/src/saml2/entity.py index b5fcb3d2..174d2e83 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -1,5 +1,6 @@ import base64 # from binascii import hexlify +from binascii import hexlify import copy import logging from hashlib import sha1 @@ -148,12 +149,6 @@ class Entity(HTTPBase): raise Exception( "Could not fetch certificate from %s" % _val) - try: - self.signkey = RSA.importKey( - open(self.config.getattr("key_file", ""), 'r').read()) - except (KeyError, TypeError): - self.signkey = None - HTTPBase.__init__(self, self.config.verify_ssl_cert, self.config.ca_certs, self.config.key_file, self.config.cert_file) @@ -227,8 +222,12 @@ class Entity(HTTPBase): info["method"] = "POST" elif binding == BINDING_HTTP_REDIRECT: logger.info("HTTP REDIRECT") + if 'sigalg' in kwargs: + signer = self.sec.sec_backend.get_signer(kwargs['sigalg']) + else: + signer = None info = self.use_http_get(msg_str, destination, relay_state, typ, - **kwargs) + signer=signer, **kwargs) info["url"] = str(destination) info["method"] = "GET" elif binding == BINDING_SOAP or binding == BINDING_PAOS: @@ -419,9 +418,12 @@ class Entity(HTTPBase): # -------------------------------------------------------------------------- - def sign(self, msg, mid=None, to_sign=None, sign_prepare=False, sign_alg=None, digest_alg=None): + def sign(self, msg, mid=None, to_sign=None, sign_prepare=False, + sign_alg=None, digest_alg=None): if msg.signature is None: - msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg) + msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1, + sign_alg=sign_alg, + digest_alg=digest_alg) if sign_prepare: return msg @@ -478,7 +480,8 @@ class Entity(HTTPBase): req.register_prefix(nsprefix) if sign: - return reqid, self.sign(req, sign_prepare=sign_prepare, sign_alg=sign_alg, digest_alg=digest_alg) + return reqid, self.sign(req, sign_prepare=sign_prepare, + sign_alg=sign_alg, digest_alg=digest_alg) else: logger.info("REQUEST: %s", req) return reqid, req @@ -542,7 +545,7 @@ class Entity(HTTPBase): :return: A new samlp.Resonse with the designated assertion encrypted. """ _certs = [] - cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) + if encrypt_cert: _certs = [] _certs.append(encrypt_cert) @@ -558,9 +561,9 @@ class Entity(HTTPBase): if end_cert not in _cert: _cert = "%s%s" % (_cert, end_cert) _, cert_file = make_temp(_cert.encode('ascii'), decode=False) - response = cbxs.encrypt_assertion(response, cert_file, - pre_encryption_part(), - node_xpath=node_xpath) + response = self.sec.encrypt_assertion(response, cert_file, + pre_encryption_part(), + node_xpath=node_xpath) return response except Exception as ex: exception = ex @@ -575,7 +578,8 @@ class Entity(HTTPBase): encrypt_assertion_self_contained=False, encrypted_advice_attributes=False, encrypt_cert_advice=None, encrypt_cert_assertion=None, - sign_assertion=None, pefim=False, sign_alg=None, digest_alg=None, **kwargs): + sign_assertion=None, pefim=False, sign_alg=None, + digest_alg=None, **kwargs): """ Create a Response. Encryption: encrypt_assertion must be true for encryption to be @@ -619,7 +623,8 @@ class Entity(HTTPBase): response = response_factory(issuer=_issuer, in_response_to=in_response_to, - status=status, sign_alg=sign_alg, digest_alg=digest_alg) + status=status, sign_alg=sign_alg, + digest_alg=digest_alg) if consumer_url: response.destination = consumer_url @@ -636,15 +641,19 @@ class Entity(HTTPBase): encrypt_assertion = False if encrypt_assertion or ( - encrypted_advice_attributes and response.assertion.advice is + encrypted_advice_attributes and + response.assertion.advice is not None and - len(response.assertion.advice.assertion) == 1): + len(response.assertion.advice.assertion) == 1): if sign: response.signature = pre_signature_part(response.id, - self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg) + self.sec.my_cert, 1, + sign_alg=sign_alg, + digest_alg=digest_alg) sign_class = [(class_name(response), response.id)] - cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) - encrypt_advice = False + else: + sign_class = [] + if encrypted_advice_attributes and response.assertion.advice is \ not None \ and len(response.assertion.advice.assertion) > 0: @@ -664,7 +673,8 @@ class Entity(HTTPBase): to_sign_advice = [] if sign_assertion and not pefim: tmp_assertion.signature = pre_signature_part( - tmp_assertion.id, self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg) + tmp_assertion.id, self.sec.my_cert, 1, + sign_alg=sign_alg, digest_alg=digest_alg) to_sign_advice.append( (class_name(tmp_assertion), tmp_assertion.id)) @@ -687,10 +697,9 @@ class Entity(HTTPBase): response = signed_instance_factory(response, self.sec, to_sign_advice) - response = self._encrypt_assertion(encrypt_cert_advice, - sp_entity_id, - response, - node_xpath=node_xpath) + response = self._encrypt_assertion( + encrypt_cert_advice, sp_entity_id, response, + node_xpath=node_xpath) response = response_from_string(response) if encrypt_assertion: @@ -700,10 +709,9 @@ class Entity(HTTPBase): if not isinstance(_assertions, list): _assertions = [_assertions] for _assertion in _assertions: - _assertion.signature = pre_signature_part(_assertion.id, - self.sec.my_cert, - 1, - sign_alg=sign_alg, digest_alg=digest_alg) + _assertion.signature = pre_signature_part( + _assertion.id, self.sec.my_cert, 1, + sign_alg=sign_alg, digest_alg=digest_alg) to_sign_assertion.append( (class_name(_assertion), _assertion.id)) if encrypt_assertion_self_contained: @@ -717,7 +725,7 @@ class Entity(HTTPBase): response = pre_encrypt_assertion(response) response = \ response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion( - assertion_tag) + assertion_tag) else: response = pre_encrypt_assertion(response) if to_sign_assertion: @@ -735,11 +743,13 @@ class Entity(HTTPBase): return response if sign: - return self.sign(response, to_sign=to_sign, sign_alg=sign_alg, digest_alg=digest_alg) + return self.sign(response, to_sign=to_sign, sign_alg=sign_alg, + digest_alg=digest_alg) else: return response - def _status_response(self, response_class, issuer, status, sign=False, sign_alg=None, digest_alg=None, + def _status_response(self, response_class, issuer, status, sign=False, + sign_alg=None, digest_alg=None, **kwargs): """ Create a StatusResponse. @@ -768,7 +778,8 @@ class Entity(HTTPBase): status=status, **kwargs) if sign: - return self.sign(response, mid, sign_alg=sign_alg, digest_alg=digest_alg) + return self.sign(response, mid, sign_alg=sign_alg, + digest_alg=digest_alg) else: return response @@ -847,7 +858,8 @@ class Entity(HTTPBase): # ------------------------------------------------------------------------ def create_error_response(self, in_response_to, destination, info, - sign=False, issuer=None, sign_alg=None, digest_alg=None, **kwargs): + sign=False, issuer=None, sign_alg=None, + digest_alg=None, **kwargs): """ Create a error response. :param in_response_to: The identifier of the message this is a response @@ -871,7 +883,8 @@ class Entity(HTTPBase): subject_id=None, name_id=None, reason=None, expire=None, message_id=0, consent=None, extensions=None, sign=False, - session_indexes=None, sign_alg=None, digest_alg=None): + session_indexes=None, sign_alg=None, + digest_alg=None): """ Constructs a LogoutRequest :param destination: Destination of the request @@ -915,10 +928,12 @@ class Entity(HTTPBase): return self._message(LogoutRequest, destination, message_id, consent, extensions, sign, name_id=name_id, reason=reason, not_on_or_after=expire, - issuer=self._issuer(), sign_alg=sign_alg, digest_alg=digest_alg, **args) + issuer=self._issuer(), sign_alg=sign_alg, + digest_alg=digest_alg, **args) def create_logout_response(self, request, bindings=None, status=None, - sign=False, issuer=None, sign_alg=None, digest_alg=None): + sign=False, issuer=None, sign_alg=None, + digest_alg=None): """ Create a LogoutResponse. :param request: The request this is a response to @@ -936,14 +951,16 @@ class Entity(HTTPBase): issuer = self._issuer() response = self._status_response(samlp.LogoutResponse, issuer, status, - sign, sign_alg=sign_alg, digest_alg=digest_alg, **rinfo) + sign, sign_alg=sign_alg, + digest_alg=digest_alg, **rinfo) logger.info("Response: %s", response) return response def create_artifact_resolve(self, artifact, destination, sessid, - consent=None, extensions=None, sign=False, sign_alg=None, digest_alg=None): + consent=None, extensions=None, sign=False, + sign_alg=None, digest_alg=None): """ Create a ArtifactResolve request @@ -959,10 +976,12 @@ class Entity(HTTPBase): artifact = Artifact(text=artifact) return self._message(ArtifactResolve, destination, sessid, - consent, extensions, sign, artifact=artifact, sign_alg=sign_alg, digest_alg=digest_alg) + consent, extensions, sign, artifact=artifact, + sign_alg=sign_alg, digest_alg=digest_alg) def create_artifact_response(self, request, artifact, bindings=None, - status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None): + status=None, sign=False, issuer=None, + sign_alg=None, digest_alg=None): """ Create an ArtifactResponse :return: @@ -970,7 +989,8 @@ class Entity(HTTPBase): rinfo = self.response_args(request, bindings) response = self._status_response(ArtifactResponse, issuer, status, - sign=sign, sign_alg=sign_alg, digest_alg=digest_alg, **rinfo) + sign=sign, sign_alg=sign_alg, + digest_alg=digest_alg, **rinfo) msg = element_to_extension_element(self.artifact[artifact]) response.extension_elements = [msg] @@ -983,7 +1003,8 @@ class Entity(HTTPBase): consent=None, extensions=None, sign=False, name_id=None, new_id=None, encrypted_id=None, new_encrypted_id=None, - terminate=None, sign_alg=None, digest_alg=None): + terminate=None, sign_alg=None, + digest_alg=None): """ :param destination: @@ -1020,7 +1041,8 @@ class Entity(HTTPBase): "provided") return self._message(ManageNameIDRequest, destination, consent=consent, - extensions=extensions, sign=sign, sign_alg=sign_alg, digest_alg=digest_alg, **kwargs) + extensions=extensions, sign=sign, + sign_alg=sign_alg, digest_alg=digest_alg, **kwargs) def parse_manage_name_id_request(self, xmlstr, binding=BINDING_SOAP): """ Deal with a LogoutRequest @@ -1036,13 +1058,15 @@ class Entity(HTTPBase): "manage_name_id_service", binding) def create_manage_name_id_response(self, request, bindings=None, - status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None, + status=None, sign=False, issuer=None, + sign_alg=None, digest_alg=None, **kwargs): rinfo = self.response_args(request, bindings) response = self._status_response(samlp.ManageNameIDResponse, issuer, - status, sign, sign_alg=sign_alg, digest_alg=digest_alg, **rinfo) + status, sign, sign_alg=sign_alg, + digest_alg=digest_alg, **rinfo) logger.info("Response: %s", response) diff --git a/src/saml2/entity_category/swamid.py b/src/saml2/entity_category/swamid.py index 94ebe62e..5dd717ec 100644 --- a/src/saml2/entity_category/swamid.py +++ b/src/saml2/entity_category/swamid.py @@ -2,7 +2,8 @@ __author__ = 'rolandh' NAME = ["givenName", "displayName", "sn"] -STATIC_ORG_INFO = ["c", "o", "co", "norEduOrgAcronym", "schacHomeOrganization"] +STATIC_ORG_INFO = ["c", "o", "co", "norEduOrgAcronym", "schacHomeOrganization", + 'schacHomeOrganizationType'] OTHER = ["eduPersonPrincipalName", "eduPersonScopedAffiliation", "mail"] # These give you access to information diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 1dd49d19..b6e5fe78 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -386,7 +386,7 @@ class HTTPBase(object): @staticmethod def use_http_get(message, destination, relay_state, - typ="SAMLRequest", sigalg="", key=None, **kwargs): + typ="SAMLRequest", sigalg="", signer=None, **kwargs): """ Send a message using GET, this is the HTTP-Redirect case so no direct response is expected to this request. @@ -395,12 +395,13 @@ class HTTPBase(object): :param destination: :param relay_state: :param typ: Whether a Request, Response or Artifact - :param sigalg: The signature algorithm to use. - :param key: Key to use for signing + :param sigalg: Which algorithm the signature function will use to sign + the message + :param signer: A signing function that can be used to sign the message :return: dictionary """ if not isinstance(message, six.string_types): message = "%s" % (message,) return http_redirect_message(message, destination, relay_state, typ, - sigalg, key) + sigalg, signer) diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index c6bb1ecb..07f199e2 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -43,6 +43,9 @@ class ToOld(Exception): pass +class SourceNotFound(Exception): + pass + REQ2SRV = { # IDP "authn_request": "single_sign_on_service", @@ -718,7 +721,7 @@ class MetaDataExtern(InMemoryMetaData): return self.parse_and_check_signature(_txt) else: logger.info("Response status: %s", response.status_code) - return False + raise SourceNotFound(self.url) class MetaDataMD(InMemoryMetaData): diff --git a/src/saml2/pack.py b/src/saml2/pack.py index ed4142a0..ff7bd0ad 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -20,11 +20,13 @@ from saml2.sigver import REQ_ORDER from saml2.sigver import RESP_ORDER from saml2.sigver import SIGNER_ALGS import six +from saml2.xmldsig import SIG_ALLOWED_ALG logger = logging.getLogger(__name__) try: from xml.etree import cElementTree as ElementTree + if ElementTree.VERSION < '1.3.0': # cElementTree has no support for register_namespace # neither _namespace_map, thus we sacrify performance @@ -106,7 +108,7 @@ def http_post_message(message, relay_state="", typ="SAMLRequest", **kwargs): def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", - sigalg=None, key=None, **kwargs): + sigalg='', signer=None, **kwargs): """The HTTP Redirect binding defines a mechanism by which SAML protocol messages can be transmitted within URL parameters. Messages are encoded for use with this binding using a URL encoding @@ -118,7 +120,9 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", :param location: Where the message should be posted to :param relay_state: for preserving and conveying state information :param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart - :param sigalg: The signature algorithm to use. + :param sigalg: Which algorithm the signature function will use to sign + the message + :param signer: A signature function that can be used to sign the message :param key: Key to use for signing :return: A tuple containing header information and a HTML message. """ @@ -141,20 +145,15 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", if relay_state: args["RelayState"] = relay_state - if sigalg: - # sigalgs, one of the ones defined in xmldsig - + if signer: + # sigalgs, should be one defined in xmldsig + assert sigalg in [b for a, b in SIG_ALLOWED_ALG] args["SigAlg"] = sigalg - try: - signer = SIGNER_ALGS[sigalg] - except: - raise Unsupported("Signing algorithm") - else: - string = "&".join([urlencode({k: args[k]}) - for k in _order if k in args]).encode('ascii') - args["Signature"] = base64.b64encode(signer.sign(string, key)) - string = urlencode(args) + string = "&".join([urlencode({k: args[k]}) + for k in _order if k in args]).encode('ascii') + args["Signature"] = base64.b64encode(signer.sign(string)) + string = urlencode(args) else: string = urlencode(args) @@ -240,11 +239,11 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): envelope = ElementTree.fromstring(text) assert envelope.tag == '{%s}Envelope' % NAMESPACE - #print(len(envelope)) + # print(len(envelope)) body = None header = {} for part in envelope: - #print(">",part.tag) + # print(">",part.tag) if part.tag == '{%s}Body' % NAMESPACE: for sub in part: try: @@ -255,11 +254,11 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): elif part.tag == '{%s}Header' % NAMESPACE: if not header_class: raise Exception("Header where I didn't expect one") - #print("--- HEADER ---") + # print("--- HEADER ---") for sub in part: - #print(">>",sub.tag) + # print(">>",sub.tag) for klass in header_class: - #print("?{%s}%s" % (klass.c_namespace,klass.c_tag)) + # print("?{%s}%s" % (klass.c_namespace,klass.c_tag)) if sub.tag == "{%s}%s" % (klass.c_namespace, klass.c_tag): header[sub.tag] = \ saml2.create_class_from_element_tree(klass, sub) @@ -267,6 +266,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): return body, header + # ----------------------------------------------------------------------------- PACKING = { diff --git a/src/saml2/server.py b/src/saml2/server.py index f62b32aa..b77b2c4d 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -338,8 +338,10 @@ class Server(Entity): status=None, authn=None, issuer=None, policy=None, sign_assertion=False, sign_response=False, best_effort=False, encrypt_assertion=False, - encrypt_cert_advice=None, encrypt_cert_assertion=None, authn_statement=None, - encrypt_assertion_self_contained=False, encrypted_advice_attributes=False, + encrypt_cert_advice=None, encrypt_cert_assertion=None, + authn_statement=None, + encrypt_assertion_self_contained=False, + encrypted_advice_attributes=False, pefim=False, sign_alg=None, digest_alg=None): """ Create a response. A layer of indirection. @@ -417,8 +419,10 @@ class Server(Entity): to_sign = [] if not encrypt_assertion: if sign_assertion: - assertion.signature = pre_signature_part(assertion.id, self.sec.my_cert, 1, - sign_alg=sign_alg, digest_alg=digest_alg) + assertion.signature = pre_signature_part(assertion.id, + self.sec.my_cert, 1, + sign_alg=sign_alg, + digest_alg=digest_alg) to_sign.append((class_name(assertion), assertion.id)) args["assertion"] = assertion @@ -434,7 +438,8 @@ class Server(Entity): encrypt_assertion_self_contained=encrypt_assertion_self_contained, encrypted_advice_attributes=encrypted_advice_attributes, sign_assertion=sign_assertion, - pefim=pefim, sign_alg=sign_alg, digest_alg=digest_alg, + pefim=pefim, sign_alg=sign_alg, + digest_alg=digest_alg, **args) # ------------------------------------------------------------------------ @@ -444,7 +449,8 @@ class Server(Entity): sp_entity_id, userid="", name_id=None, status=None, issuer=None, sign_assertion=False, sign_response=False, - attributes=None, sign_alg=None, digest_alg=None, **kwargs): + attributes=None, sign_alg=None, + digest_alg=None, **kwargs): """ Create an attribute assertion response. :param identity: A dictionary with attributes and values that are @@ -494,17 +500,115 @@ class Server(Entity): if sign_assertion: assertion.signature = pre_signature_part(assertion.id, - self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg) + self.sec.my_cert, 1, + sign_alg=sign_alg, + digest_alg=digest_alg) # Just the assertion or the response and the assertion ? to_sign = [(class_name(assertion), assertion.id)] args["assertion"] = assertion return self._response(in_response_to, destination, status, issuer, - sign_response, to_sign, sign_alg=sign_alg, digest_alg=digest_alg, **args) + sign_response, to_sign, sign_alg=sign_alg, + digest_alg=digest_alg, **args) # ------------------------------------------------------------------------ + def gather_authn_response_args(self, sp_entity_id, name_id_policy, userid, + **kwargs): + param_default = { + 'sign_assertion': False, + 'sign_response': False, + 'encrypt_assertion': False, + 'encrypt_assertion_self_contained': True, + 'encrypted_advice_attributes': False, + 'encrypt_cert_advice': None, + 'encrypt_cert_assertion': None + } + + args = {} + + try: + args["policy"] = kwargs["release_policy"] + except KeyError: + args["policy"] = self.config.getattr("policy", "idp") + + try: + args['best_effort'] = kwargs["best_effort"] + except KeyError: + args['best_effort'] = False + + for param in ['sign_assertion', 'sign_response', 'encrypt_assertion', + 'encrypt_assertion_self_contained', + 'encrypted_advice_attributes', 'encrypt_cert_advice', + 'encrypt_cert_assertion']: + try: + _val = kwargs[param] + except: + _val = None + + if _val is None: + _val = self.config.getattr(param, "idp") + + if _val is None: + args[param] = param_default[param] + else: + args[param] = _val + + for arg, attr, eca, pefim in [ + ('encrypted_advice_attributes', 'verify_encrypt_cert_advice', + 'encrypt_cert_advice', kwargs["pefim"]), + ('encrypt_assertion', 'verify_encrypt_cert_assertion', + 'encrypt_cert_assertion', False)]: + + if args[arg] or pefim: + _enc_cert = self.config.getattr(attr, "idp") + + if _enc_cert is not None: + if kwargs[eca] is None: + raise CertificateError( + "No SPCertEncType certificate for encryption " + "contained in authentication " + "request.") + if not _enc_cert(kwargs[eca]): + raise CertificateError( + "Invalid certificate for encryption!") + + if 'name_id' not in kwargs or not kwargs['name_id']: + nid_formats = [] + for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]: + if "name_id_format" in _sp: + nid_formats.extend([n["text"] for n in + _sp["name_id_format"]]) + try: + snq = name_id_policy.sp_name_qualifier + except AttributeError: + snq = sp_entity_id + + if not snq: + snq = sp_entity_id + + kwa = {"sp_name_qualifier": snq} + + try: + kwa["format"] = name_id_policy.format + except AttributeError: + pass + + _nids = self.ident.find_nameid(userid, **kwa) + # either none or one + if _nids: + args['name_id'] = _nids[0] + else: + args['name_id'] = self.ident.construct_nameid( + userid, args['policy'], sp_entity_id, name_id_policy) + logger.debug("construct_nameid: %s => %s", userid, + args['name_id']) + else: + args['name_id'] = kwargs['name_id'] + + return args + def create_authn_response(self, identity, in_response_to, destination, sp_entity_id, name_id_policy=None, userid=None, name_id=None, authn=None, issuer=None, @@ -513,7 +617,8 @@ class Server(Entity): encrypt_cert_assertion=None, encrypt_assertion=None, encrypt_assertion_self_contained=True, - encrypted_advice_attributes=False, pefim=False, sign_alg=None, digest_alg=None, + encrypted_advice_attributes=False, pefim=False, + sign_alg=None, digest_alg=None, **kwargs): """ Constructs an AuthenticationResponse @@ -547,105 +652,22 @@ class Server(Entity): """ try: - policy = kwargs["release_policy"] - except KeyError: - policy = self.config.getattr("policy", "idp") - - try: - best_effort = kwargs["best_effort"] - except KeyError: - best_effort = False - - if sign_assertion is None: - sign_assertion = self.config.getattr("sign_assertion", "idp") - if sign_assertion is None: - sign_assertion = False - - if sign_response is None: - sign_response = self.config.getattr("sign_response", "idp") - if sign_response is None: - sign_response = False - - if encrypt_assertion is None: - encrypt_assertion = self.config.getattr("encrypt_assertion", "idp") - if encrypt_assertion is None: - encrypt_assertion = False - - if encrypt_assertion_self_contained is None: - encrypt_assertion_self_contained = self.config.getattr( - "encrypt_assertion_self_contained", "idp") - if encrypt_assertion_self_contained is None: - encrypt_assertion_self_contained = True - - if encrypted_advice_attributes is None: - encrypted_advice_attributes = self.config.getattr( - "encrypted_advice_attributes", "idp") - if encrypted_advice_attributes is None: - encrypted_advice_attributes = False - - if encrypted_advice_attributes or pefim: - verify_encrypt_cert = self.config.getattr( - "verify_encrypt_cert_advice", "idp") - if verify_encrypt_cert is not None: - if encrypt_cert_advice is None: - raise CertificateError( - "No SPCertEncType certificate for encryption " - "contained in authentication " - "request.") - if not verify_encrypt_cert(encrypt_cert_advice): - raise CertificateError( - "Invalid certificate for encryption!") - - if encrypt_assertion: - verify_encrypt_cert = self.config.getattr( - "verify_encrypt_cert_assertion", "idp") - if verify_encrypt_cert is not None: - if encrypt_cert_assertion is None: - raise CertificateError( - "No SPCertEncType certificate for encryption " - "contained in authentication " - "request.") - if not verify_encrypt_cert(encrypt_cert_assertion): - raise CertificateError( - "Invalid certificate for encryption!") - - if not name_id: - try: - nid_formats = [] - for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]: - if "name_id_format" in _sp: - nid_formats.extend([n["text"] for n in - _sp["name_id_format"]]) - try: - snq = name_id_policy.sp_name_qualifier - except AttributeError: - snq = sp_entity_id - - if not snq: - snq = sp_entity_id - - kwa = {"sp_name_qualifier": snq} - - try: - kwa["format"] = name_id_policy.format - except AttributeError: - pass - - _nids = self.ident.find_nameid(userid, **kwa) - # either none or one - if _nids: - name_id = _nids[0] - else: - name_id = self.ident.construct_nameid(userid, policy, - sp_entity_id, - name_id_policy) - logger.debug("construct_nameid: %s => %s", userid, name_id) - except IOError as exc: - response = self.create_error_response(in_response_to, - destination, - sp_entity_id, - exc, name_id) - return ("%s" % response).split("\n") + args = self.gather_authn_response_args( + sp_entity_id, name_id_policy=name_id_policy, userid=userid, + name_id=name_id, sign_response=sign_response, + sign_assertion=sign_assertion, + encrypt_cert_advice=encrypt_cert_advice, + encrypt_cert_assertion=encrypt_cert_assertion, + encrypt_assertion=encrypt_assertion, + encrypt_assertion_self_contained=encrypt_assertion_self_contained, + encrypted_advice_attributes=encrypted_advice_attributes, + pefim=pefim, **kwargs) + except IOError as exc: + response = self.create_error_response(in_response_to, + destination, + sp_entity_id, + exc, name_id) + return ("%s" % response).split("\n") try: _authn = authn @@ -654,44 +676,14 @@ class Server(Entity): self.sec.cert_handler.generate_cert(): with self.lock: self.sec.cert_handler.update_cert(True) - return self._authn_response(in_response_to, - # in_response_to - destination, # consumer_url - sp_entity_id, # sp_entity_id - identity, - # identity as dictionary - name_id, - authn=_authn, - issuer=issuer, - policy=policy, - sign_assertion=sign_assertion, - sign_response=sign_response, - best_effort=best_effort, - encrypt_assertion=encrypt_assertion, - encrypt_assertion_self_contained=encrypt_assertion_self_contained, - encrypted_advice_attributes=encrypted_advice_attributes, - encrypt_cert_advice=encrypt_cert_advice, - encrypt_cert_assertion=encrypt_cert_assertion, - pefim=pefim, - sign_alg=sign_alg, digest_alg=digest_alg) - return self._authn_response(in_response_to, # in_response_to - destination, # consumer_url - sp_entity_id, # sp_entity_id - identity, # identity as dictionary - name_id, - authn=_authn, - issuer=issuer, - policy=policy, - sign_assertion=sign_assertion, - sign_response=sign_response, - best_effort=best_effort, - encrypt_assertion=encrypt_assertion, - encrypt_assertion_self_contained=encrypt_assertion_self_contained, - encrypted_advice_attributes=encrypted_advice_attributes, - encrypt_cert_advice=encrypt_cert_advice, - encrypt_cert_assertion=encrypt_cert_assertion, - pefim=pefim, - sign_alg=sign_alg, digest_alg=digest_alg) + return self._authn_response( + in_response_to, destination, sp_entity_id, identity, + authn=_authn, issuer=issuer, pefim=pefim, + sign_alg=sign_alg, digest_alg=digest_alg, **args) + return self._authn_response( + in_response_to, destination, sp_entity_id, identity, + authn=_authn, issuer=issuer, pefim=pefim, sign_alg=sign_alg, + digest_alg=digest_alg, **args) except MissingValue as exc: return self.create_error_response(in_response_to, destination, @@ -710,8 +702,9 @@ class Server(Entity): sign_response, sign_assertion, authn_decl=authn_decl) - #noinspection PyUnusedLocal - def create_assertion_id_request_response(self, assertion_id, sign=False, sign_alg=None, + # noinspection PyUnusedLocal + def create_assertion_id_request_response(self, assertion_id, sign=False, + sign_alg=None, digest_alg=None, **kwargs): """ @@ -728,7 +721,9 @@ class Server(Entity): if to_sign: if assertion.signature is None: assertion.signature = pre_signature_part(assertion.id, - self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg) + self.sec.my_cert, 1, + sign_alg=sign_alg, + digest_alg=digest_alg) return signed_instance_factory(assertion, self.sec, to_sign) else: @@ -738,7 +733,8 @@ class Server(Entity): def create_name_id_mapping_response(self, name_id=None, encrypted_id=None, in_response_to=None, issuer=None, sign_response=False, - status=None, sign_alg=None, digest_alg=None, **kwargs): + status=None, sign_alg=None, + digest_alg=None, **kwargs): """ protocol for mapping a principal's name identifier into a different name identifier for the same principal. @@ -768,7 +764,8 @@ class Server(Entity): def create_authn_query_response(self, subject, session_index=None, requested_context=None, in_response_to=None, issuer=None, sign_response=False, - status=None, sign_alg=None, digest_alg=None, **kwargs): + status=None, sign_alg=None, digest_alg=None, + **kwargs): """ A successful <Response> will contain one or more assertions containing authentication statements. @@ -789,7 +786,8 @@ class Server(Entity): args = {} return self._response(in_response_to, "", status, issuer, - sign_response, to_sign=[], sign_alg=sign_alg, digest_alg=digest_alg, **args) + sign_response, to_sign=[], sign_alg=sign_alg, + digest_alg=digest_alg, **args) # --------- diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index ec48f4ea..f89f347c 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -235,7 +235,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, """ cinst = None - #print("make_vals(%s, %s)" % (val, klass)) + # print("make_vals(%s, %s)" % (val, klass)) if isinstance(val, dict): cinst = _instance(klass, val, seccont, base64encode=base64encode, @@ -264,7 +264,7 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): instance = klass() for prop in instance.c_attributes.values(): - #print("# %s" % (prop)) + # print("# %s" % (prop)) if prop in ava: if isinstance(ava[prop], bool): setattr(instance, prop, str(ava[prop]).encode('utf-8')) @@ -277,9 +277,9 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): instance.set_text(ava["text"], base64encode) for prop, klassdef in instance.c_children.values(): - #print("## %s, %s" % (prop, klassdef)) + # print("## %s, %s" % (prop, klassdef)) if prop in ava: - #print("### %s" % ava[prop]) + # print("### %s" % ava[prop]) if isinstance(klassdef, list): # means there can be a list of values _make_vals(ava[prop], klassdef[0], seccont, instance, prop, @@ -319,9 +319,9 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): signed_xml = seccont.sign_statement( signed_xml, node_name=node_name, node_id=nodeid) - #print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") - #print("%s" % signed_xml) - #print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + # print("%s" % signed_xml) + # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") return signed_xml else: return instance @@ -367,6 +367,7 @@ def make_temp(string, suffix="", decode=True, delete=True): def split_len(seq, length): return [seq[i:i + length] for i in range(0, len(seq), length)] + # -------------------------------------------------------------------------- M2_TIME_FORMAT = "%b %d %H:%M:%S %Y" @@ -396,15 +397,16 @@ def active_cert(key): assert not_after > utc_now() return True except: - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str) - assert cert.has_expired() == 0 - assert not OpenSSLWrapper().certificate_not_valid_yet(cert) - return True + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str) + assert cert.has_expired() == 0 + assert not OpenSSLWrapper().certificate_not_valid_yet(cert) + return True except AssertionError: return False except AttributeError: return False + def cert_from_key_info(key_info, ignore_age=False): """ Get all X509 certs from a KeyInfo instance. Care is taken to make sure that the certs are continues sequences of bytes. @@ -418,7 +420,7 @@ def cert_from_key_info(key_info, ignore_age=False): """ res = [] for x509_data in key_info.x509_data: - #print("X509Data",x509_data) + # print("X509Data",x509_data) x509_certificate = x509_data.x509_certificate cert = x509_certificate.text.strip() cert = "\n".join(split_len("".join([s.strip() for s in @@ -515,13 +517,13 @@ def key_from_key_value_dict(key_info): # ============================================================================= -#def rsa_load(filename): +# def rsa_load(filename): # """Read a PEM-encoded RSA key pair from a file.""" # return M2Crypto.RSA.load_key(filename, M2Crypto.util # .no_passphrase_callback) # # -#def rsa_loads(key): +# def rsa_loads(key): # """Read a PEM-encoded RSA key pair from a string.""" # return M2Crypto.RSA.load_key_string(key, # M2Crypto.util.no_passphrase_callback) @@ -582,6 +584,9 @@ def sha1_digest(msg): class Signer(object): """Abstract base class for signing algorithms.""" + def __init__(self, key): + self.key = key + def sign(self, msg, key): """Sign ``msg`` with ``key`` and return the signature.""" raise NotImplementedError @@ -592,15 +597,22 @@ class Signer(object): class RSASigner(Signer): - def __init__(self, digest): + def __init__(self, digest, key=None): + Signer.__init__(self, key) self.digest = digest - def sign(self, msg, key): + def sign(self, msg, key=None): + if key is None: + key = self.key + h = self.digest.new(msg) signer = PKCS1_v1_5.new(key) return signer.sign(h) - def verify(self, msg, sig, key): + def verify(self, msg, sig, key=None): + if key is None: + key = self.key + h = self.digest.new(msg) verifier = PKCS1_v1_5.new(key) return verifier.verify(h, sig) @@ -618,7 +630,25 @@ REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"] -def verify_redirect_signature(saml_msg, cert=None, sigkey=None): +class RSACrypto(object): + def __init__(self, key): + self.key = key + + def get_signer(self, sigalg, sigkey=None): + try: + signer = SIGNER_ALGS[sigalg] + except KeyError: + return None + else: + if sigkey: + signer.key = sigkey + else: + signer.key = self.key + + return signer + + +def verify_redirect_signature(saml_msg, crypto, cert=None, sigkey=None): """ :param saml_msg: A dictionary with strings as values, *NOT* lists as @@ -628,7 +658,7 @@ def verify_redirect_signature(saml_msg, cert=None, sigkey=None): """ try: - signer = SIGNER_ALGS[saml_msg["SigAlg"]] + signer = crypto.get_signer(saml_msg["SigAlg"], sigkey) except KeyError: raise Unsupported("Signature algorithm: %s" % saml_msg["SigAlg"]) else: @@ -646,10 +676,12 @@ def verify_redirect_signature(saml_msg, cert=None, sigkey=None): string = "&".join( [urlencode({k: _args[k]}) for k in _order if k in _args]).encode('ascii') + if cert: _key = extract_rsa_key_from_x509_cert(pem_format(cert)) else: _key = sigkey + _sign = base64.b64decode(saml_msg["Signature"]) return bool(signer.verify(string, _sign, _key)) @@ -665,6 +697,7 @@ def make_str(txt): else: return txt.decode("utf8") + # --------------------------------------------------------------------------- @@ -677,7 +710,6 @@ def read_cert_from_file(cert_file, cert_type): :return: A base64 encoded certificate as a string or the empty string """ - if not cert_file: return "" @@ -689,7 +721,7 @@ def read_cert_from_file(cert_file, cert_type): for pattern in ("-----BEGIN CERTIFICATE-----", "-----BEGIN PUBLIC KEY-----"): if pattern in lines: - lines = lines[lines.index(pattern)+1:] + lines = lines[lines.index(pattern) + 1:] break else: raise CertificateError("Strange beginning of PEM file") @@ -703,7 +735,6 @@ def read_cert_from_file(cert_file, cert_type): raise CertificateError("Strange end of PEM file") return make_str("".join(lines).encode("utf8")) - if cert_type in ["der", "cer", "crt"]: data = read_file(cert_file, 'rb') _cert = base64.b64encode(data) @@ -757,6 +788,11 @@ class CryptoBackendXmlSec1(CryptoBackend): else: self._xmlsec_delete_tmpfiles = True + try: + self.non_xml_crypto = RSACrypto(kwargs['rsa_key']) + except KeyError: + pass + def version(self): com_list = [self.xmlsec, "--version"] pof = Popen(com_list, stderr=PIPE, stdout=PIPE) @@ -808,7 +844,8 @@ class CryptoBackendXmlSec1(CryptoBackend): if isinstance(statement, SamlBase): statement = pre_encrypt_assertion(statement) - _, fil = make_temp(str(statement).encode('utf-8'), decode=False, delete=False) + _, fil = make_temp(str(statement).encode('utf-8'), decode=False, + delete=False) _, tmpl = make_temp(str(template).encode('utf-8'), decode=False) if not node_xpath: @@ -882,7 +919,7 @@ class CryptoBackendXmlSec1(CryptoBackend): return signed_statement.decode('utf-8') logger.error( "Signing operation failed :\nstdout : %s\nstderr : %s", - stdout, stderr) + stdout, stderr) raise SigverError(stderr) except DecryptError: raise SigverError("Signing failed") @@ -917,7 +954,7 @@ class CryptoBackendXmlSec1(CryptoBackend): if self.__DEBUG: try: - print( " ".join(com_list)) + print(" ".join(com_list)) except TypeError: print("cert_type", cert_type) print("cert_file", cert_file) @@ -1040,7 +1077,7 @@ class CryptoBackendXMLSecurity(CryptoBackend): def security_context(conf, debug=None): """ Creates a security context based on the configuration - :param conf: The configuration + :param conf: The configuration, this is a Config instance :return: A SecurityContext instance """ if not conf: @@ -1061,6 +1098,8 @@ def security_context(conf, debug=None): if _only_md is None: _only_md = False + sec_backend = None + if conf.crypto_backend == 'xmlsec1': xmlsec_binary = conf.xmlsec_binary if not xmlsec_binary: @@ -1071,10 +1110,20 @@ def security_context(conf, debug=None): xmlsec_binary = get_xmlsec_binary(_path) # verify that xmlsec is where it's supposed to be if not os.path.exists(xmlsec_binary): - #if not os.access(, os.F_OK): + # if not os.access(, os.F_OK): raise SigverError( "xmlsec binary not in '%s' !" % xmlsec_binary) crypto = _get_xmlsec_cryptobackend(xmlsec_binary, debug=debug) + _file_name = conf.getattr("key_file", "") + if _file_name: + try: + rsa_key = import_rsa_key_from_file(_file_name) + except Exception as err: + logger.error("Could not import key from {}: {}".format(_file_name, + err)) + raise + else: + sec_backend = RSACrypto(rsa_key) elif conf.crypto_backend == 'XMLSecurity': # new and somewhat untested pyXMLSecurity crypto backend. crypto = CryptoBackendXMLSecurity(debug=debug) @@ -1082,6 +1131,7 @@ def security_context(conf, debug=None): raise SigverError('Unknown crypto_backend %s' % ( repr(conf.crypto_backend))) + enc_key_files = [] if conf.encryption_keypairs is not None: for _encryption_keypair in conf.encryption_keypairs: @@ -1096,36 +1146,43 @@ def security_context(conf, debug=None): tmp_cert_file=conf.tmp_cert_file, tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate, - enc_key_files=enc_key_files, encryption_keypairs=conf.encryption_keypairs) + enc_key_files=enc_key_files, + encryption_keypairs=conf.encryption_keypairs, + sec_backend=sec_backend) def encrypt_cert_from_item(item): _encrypt_cert = None try: try: - _elem = extension_elements_to_elements(item.extensions.extension_elements,[pefim, ds]) + _elem = extension_elements_to_elements( + item.extensions.extension_elements, [pefim, ds]) except: - _elem = extension_elements_to_elements(item.extension_elements[0].children, - [pefim, ds]) + _elem = extension_elements_to_elements( + item.extension_elements[0].children, + [pefim, ds]) for _tmp_elem in _elem: if isinstance(_tmp_elem, SPCertEnc): for _tmp_key_info in _tmp_elem.key_info: - if _tmp_key_info.x509_data is not None and len(_tmp_key_info.x509_data) > 0: - _encrypt_cert = _tmp_key_info.x509_data[0].x509_certificate.text + if _tmp_key_info.x509_data is not None and len( + _tmp_key_info.x509_data) > 0: + _encrypt_cert = _tmp_key_info.x509_data[ + 0].x509_certificate.text break - #_encrypt_cert = _elem[0].x509_data[0].x509_certificate.text -# else: -# certs = cert_from_instance(item) -# if len(certs) > 0: -# _encrypt_cert = certs[0] + # _encrypt_cert = _elem[0].x509_data[ + # 0].x509_certificate.text + # else: + # certs = cert_from_instance(item) + # if len(certs) > 0: + # _encrypt_cert = certs[0] except Exception as _exception: pass -# if _encrypt_cert is None: -# certs = cert_from_instance(item) -# if len(certs) > 0: -# _encrypt_cert = certs[0] + # if _encrypt_cert is None: + # certs = cert_from_instance(item) + # if len(certs) > 0: + # _encrypt_cert = certs[0] if _encrypt_cert is not None: if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1: @@ -1135,7 +1192,6 @@ def encrypt_cert_from_item(item): return _encrypt_cert - class CertHandlerExtra(object): def __init__(self): pass @@ -1146,14 +1202,14 @@ class CertHandlerExtra(object): def generate_cert(self, generate_cert_info, root_cert_string, root_key_string): raise Exception("generate_cert function must be implemented") - #Excepts to return (cert_string, key_string) + # Excepts to return (cert_string, key_string) def use_validate_cert_func(self): raise Exception("use_validate_cert_func function must be implemented") def validate_cert(self, cert_str, root_cert_string, root_key_string): raise Exception("validate_cert function must be implemented") - #Excepts to return True/False + # Excepts to return True/False class CertHandler(object): @@ -1180,7 +1236,7 @@ class CertHandler(object): self._verify_cert = False self._generate_cert = False - #This cert do not have to be valid, it is just the last cert to be + # This cert do not have to be valid, it is just the last cert to be # validated. self._last_cert_verified = None self._last_validated_cert = None @@ -1206,7 +1262,7 @@ class CertHandler(object): self._cert_info = None self._generate_cert_func_active = False if generate_cert_info is not None and len(self._cert_str) > 0 and \ - len(self._key_str) > 0 and tmp_key_file is not \ + len(self._key_str) > 0 and tmp_key_file is not \ None and tmp_cert_file is not None: self._generate_cert = True self._cert_info = generate_cert_info @@ -1236,7 +1292,7 @@ class CertHandler(object): if (self._generate_cert and active) or client_crt is not None: if client_crt is not None: self._tmp_cert_str = client_crt - #No private key for signing + # No private key for signing self._tmp_key_str = "" elif self._cert_handler_extra_class is not None and \ self._cert_handler_extra_class.use_generate_cert_func(): @@ -1244,9 +1300,9 @@ class CertHandler(object): self._cert_handler_extra_class.generate_cert( self._cert_info, self._cert_str, self._key_str) else: - self._tmp_cert_str, self._tmp_key_str = self._osw\ + self._tmp_cert_str, self._tmp_key_str = self._osw \ .create_certificate( - self._cert_info, request=True) + self._cert_info, request=True) self._tmp_cert_str = self._osw.create_cert_signed_certificate( self._cert_str, self._key_str, self._tmp_cert_str) valid, mess = self._osw.verify(self._cert_str, @@ -1273,12 +1329,18 @@ class SecurityContext(object): debug=False, template="", encrypt_key_type="des-192", only_use_keys_in_metadata=False, cert_handler_extra_class=None, generate_cert_info=None, tmp_cert_file=None, - tmp_key_file=None, validate_certificate=None, enc_key_files=None, enc_key_type="pem", - encryption_keypairs=None, enc_cert_type="pem"): + tmp_key_file=None, validate_certificate=None, + enc_key_files=None, enc_key_type="pem", + encryption_keypairs=None, enc_cert_type="pem", + sec_backend=None): self.crypto = crypto assert (isinstance(self.crypto, CryptoBackend)) + if sec_backend: + assert (isinstance(sec_backend, RSACrypto)) + self.sec_backend = sec_backend + # Your private key for signing self.key_file = key_file self.key_type = key_type @@ -1355,7 +1417,8 @@ class SecurityContext(object): :param key_type: The type of session key to use. :return: The encrypted text """ - raise NotImplemented() + return self.crypto.encrypt_assertion(statement, enc_key, template, + key_type, node_xpath) def decrypt_keys(self, enctext, keys=None): """ Decrypting an encrypted text by the use of a private key. @@ -1394,9 +1457,9 @@ class SecurityContext(object): if _enctext is not None and len(_enctext) > 0: return _enctext if key_file is not None and len(key_file.strip()) > 0: - _enctext = self.crypto.decrypt(enctext, key_file) - if _enctext is not None and len(_enctext) > 0: - return _enctext + _enctext = self.crypto.decrypt(enctext, key_file) + if _enctext is not None and len(_enctext) > 0: + return _enctext return enctext def verify_signature(self, signedtext, cert_file=None, cert_type="pem", @@ -1428,7 +1491,7 @@ class SecurityContext(object): def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None, id_attr="", must=False, only_valid_cert=False, issuer=None): - #print(item) + # print(item) try: _issuer = item.issuer.text.strip() except AttributeError: @@ -1459,16 +1522,17 @@ class SecurityContext(object): if not certs and not self.only_use_keys_in_metadata: logger.debug("==== Certs from instance ====") certs = [make_temp(pem_format(cert), suffix=".pem", - decode=False, delete=self._xmlsec_delete_tmpfiles) - for cert in cert_from_instance(item)] + decode=False, + delete=self._xmlsec_delete_tmpfiles) + for cert in cert_from_instance(item)] else: logger.debug("==== Certs from metadata ==== %s: %s ====", issuer, - certs) + certs) if not certs: raise MissingKey("%s" % issuer) - #print(certs) + # print(certs) verified = False last_pem_file = None @@ -1552,7 +1616,8 @@ class SecurityContext(object): if not msg.signature: if must: - raise SignatureError("Required signature missing on %s" % msgtype) + raise SignatureError( + "Required signature missing on %s" % msgtype) else: return msg @@ -1697,9 +1762,9 @@ class SecurityContext(object): return response - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # SIGNATURE PART - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def sign_statement_using_xmlsec(self, statement, **kwargs): """ Deprecated function. See sign_statement(). """ return self.sign_statement(statement, **kwargs) @@ -1760,7 +1825,8 @@ class SecurityContext(object): return self.sign_statement(statement, class_name( samlp.AttributeQuery()), **kwargs) - def multiple_signatures(self, statement, to_sign, key=None, key_file=None, sign_alg=None, digest_alg=None): + def multiple_signatures(self, statement, to_sign, key=None, key_file=None, + sign_alg=None, digest_alg=None): """ Sign multiple parts of a statement @@ -1779,7 +1845,9 @@ class SecurityContext(object): sid = item.id if not item.signature: - item.signature = pre_signature_part(sid, self.cert_file, sign_alg=sign_alg, digest_alg=digest_alg) + item.signature = pre_signature_part(sid, self.cert_file, + sign_alg=sign_alg, + digest_alg=digest_alg) statement = self.sign_statement(statement, class_name(item), key=key, key_file=key_file, @@ -1917,12 +1985,14 @@ def pre_encrypt_assertion(response): return response -def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, **kwargs): +def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, + **kwargs): response = samlp.Response(id=sid(), version=VERSION, issue_instant=instant()) if sign: - response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg, digest_alg=digest_alg) + response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg, + digest_alg=digest_alg) if encrypt: pass @@ -1931,6 +2001,7 @@ def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, return response + # ---------------------------------------------------------------------------- if __name__ == '__main__': import argparse |