summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2015-12-11 13:02:49 +0100
committerRoland Hedberg <roland.hedberg@adm.umu.se>2015-12-11 13:02:49 +0100
commit6200f158dbad1acf9bf6982a738c58620452f813 (patch)
tree3fd0a53efa2cc70cae8b72289fa5cb7f39bdea7f /src
parent82d3b4da6ebd19f556d2f4d377236a05bb64cd75 (diff)
downloadpysaml2-6200f158dbad1acf9bf6982a738c58620452f813.tar.gz
Reworked the security backend so you should now be able to use a HSM again for XML security. Support for non-XML crypto using HSMs are on the way.
Diffstat (limited to 'src')
-rw-r--r--src/saml2/client.py27
-rw-r--r--src/saml2/config.py3
-rw-r--r--src/saml2/entity.py118
-rw-r--r--src/saml2/entity_category/swamid.py3
-rw-r--r--src/saml2/httpbase.py9
-rw-r--r--src/saml2/mdstore.py5
-rw-r--r--src/saml2/pack.py38
-rw-r--r--src/saml2/server.py302
-rw-r--r--src/saml2/sigver.py207
9 files changed, 408 insertions, 304 deletions
diff --git a/src/saml2/client.py b/src/saml2/client.py
index fca9859b..65c2bcad 100644
--- a/src/saml2/client.py
+++ b/src/saml2/client.py
@@ -66,7 +66,8 @@ class Saml2Client(Base):
:return: session id and AuthnRequest info
"""
- reqid, negotiated_binding, info = self.prepare_for_negotiated_authenticate(
+ reqid, negotiated_binding, info = \
+ self.prepare_for_negotiated_authenticate(
entityid=entityid,
relay_state=relay_state,
binding=binding,
@@ -137,7 +138,8 @@ class Saml2Client(Base):
raise SignOnError(
"No supported bindings available for authentication")
- def global_logout(self, name_id, reason="", expire=None, sign=None, sign_alg=None, digest_alg=None):
+ def global_logout(self, name_id, reason="", expire=None, sign=None,
+ sign_alg=None, digest_alg=None):
""" More or less a layer of indirection :-/
Bootstrapping the whole thing by finding all the IdPs that should
be notified.
@@ -162,10 +164,12 @@ class Saml2Client(Base):
# find out which IdPs/AAs I should notify
entity_ids = self.users.issuers_of_info(name_id)
- return self.do_logout(name_id, entity_ids, reason, expire, sign, sign_alg=sign_alg, digest_alg=digest_alg)
+ return self.do_logout(name_id, entity_ids, reason, expire, sign,
+ sign_alg=sign_alg, digest_alg=digest_alg)
def do_logout(self, name_id, entity_ids, reason, expire, sign=None,
- expected_binding=None, sign_alg=None, digest_alg=None, **kwargs):
+ expected_binding=None, sign_alg=None, digest_alg=None,
+ **kwargs):
"""
:param name_id: Identifier of the Subject (a NameID instance)
@@ -227,22 +231,22 @@ class Saml2Client(Base):
sign = self.logout_requests_signed
sigalg = None
- key = None
if sign:
if binding == BINDING_HTTP_REDIRECT:
- sigalg = kwargs.get("sigalg", ds.DefaultSignature().get_sign_alg())
- key = kwargs.get("key", self.signkey)
+ sigalg = kwargs.get(
+ "sigalg", ds.DefaultSignature().get_sign_alg())
+ #key = kwargs.get("key", self.signkey)
srequest = str(request)
else:
- srequest = self.sign(request, sign_alg=sign_alg, digest_alg=digest_alg)
+ srequest = self.sign(request, sign_alg=sign_alg,
+ digest_alg=digest_alg)
else:
srequest = str(request)
relay_state = self._relay_state(req_id)
http_info = self.apply_binding(binding, srequest, destination,
- relay_state, sigalg=sigalg,
- key=key)
+ relay_state, sigalg=sigalg)
if binding == BINDING_SOAP:
response = self.send(**http_info)
@@ -318,7 +322,8 @@ class Saml2Client(Base):
return self.do_logout(decode(status["name_id"]),
status["entity_ids"],
status["reason"], status["not_on_or_after"],
- status["sign"], sign_alg=sign_alg, digest_alg=digest_alg)
+ status["sign"], sign_alg=sign_alg,
+ digest_alg=digest_alg)
def _use_soap(self, destination, query_type, **kwargs):
_create_func = getattr(self, "create_%s" % query_type)
diff --git a/src/saml2/config.py b/src/saml2/config.py
index 952526f9..24e891f4 100644
--- a/src/saml2/config.py
+++ b/src/saml2/config.py
@@ -53,7 +53,8 @@ COMMON_ARGS = [
"tmp_key_file",
"validate_certificate",
"extensions",
- "allow_unknown_attributes"
+ "allow_unknown_attributes",
+ "crypto_backend"
]
SP_ARGS = [
diff --git a/src/saml2/entity.py b/src/saml2/entity.py
index b5fcb3d2..174d2e83 100644
--- a/src/saml2/entity.py
+++ b/src/saml2/entity.py
@@ -1,5 +1,6 @@
import base64
# from binascii import hexlify
+from binascii import hexlify
import copy
import logging
from hashlib import sha1
@@ -148,12 +149,6 @@ class Entity(HTTPBase):
raise Exception(
"Could not fetch certificate from %s" % _val)
- try:
- self.signkey = RSA.importKey(
- open(self.config.getattr("key_file", ""), 'r').read())
- except (KeyError, TypeError):
- self.signkey = None
-
HTTPBase.__init__(self, self.config.verify_ssl_cert,
self.config.ca_certs, self.config.key_file,
self.config.cert_file)
@@ -227,8 +222,12 @@ class Entity(HTTPBase):
info["method"] = "POST"
elif binding == BINDING_HTTP_REDIRECT:
logger.info("HTTP REDIRECT")
+ if 'sigalg' in kwargs:
+ signer = self.sec.sec_backend.get_signer(kwargs['sigalg'])
+ else:
+ signer = None
info = self.use_http_get(msg_str, destination, relay_state, typ,
- **kwargs)
+ signer=signer, **kwargs)
info["url"] = str(destination)
info["method"] = "GET"
elif binding == BINDING_SOAP or binding == BINDING_PAOS:
@@ -419,9 +418,12 @@ class Entity(HTTPBase):
# --------------------------------------------------------------------------
- def sign(self, msg, mid=None, to_sign=None, sign_prepare=False, sign_alg=None, digest_alg=None):
+ def sign(self, msg, mid=None, to_sign=None, sign_prepare=False,
+ sign_alg=None, digest_alg=None):
if msg.signature is None:
- msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
+ msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg)
if sign_prepare:
return msg
@@ -478,7 +480,8 @@ class Entity(HTTPBase):
req.register_prefix(nsprefix)
if sign:
- return reqid, self.sign(req, sign_prepare=sign_prepare, sign_alg=sign_alg, digest_alg=digest_alg)
+ return reqid, self.sign(req, sign_prepare=sign_prepare,
+ sign_alg=sign_alg, digest_alg=digest_alg)
else:
logger.info("REQUEST: %s", req)
return reqid, req
@@ -542,7 +545,7 @@ class Entity(HTTPBase):
:return: A new samlp.Resonse with the designated assertion encrypted.
"""
_certs = []
- cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
+
if encrypt_cert:
_certs = []
_certs.append(encrypt_cert)
@@ -558,9 +561,9 @@ class Entity(HTTPBase):
if end_cert not in _cert:
_cert = "%s%s" % (_cert, end_cert)
_, cert_file = make_temp(_cert.encode('ascii'), decode=False)
- response = cbxs.encrypt_assertion(response, cert_file,
- pre_encryption_part(),
- node_xpath=node_xpath)
+ response = self.sec.encrypt_assertion(response, cert_file,
+ pre_encryption_part(),
+ node_xpath=node_xpath)
return response
except Exception as ex:
exception = ex
@@ -575,7 +578,8 @@ class Entity(HTTPBase):
encrypt_assertion_self_contained=False,
encrypted_advice_attributes=False,
encrypt_cert_advice=None, encrypt_cert_assertion=None,
- sign_assertion=None, pefim=False, sign_alg=None, digest_alg=None, **kwargs):
+ sign_assertion=None, pefim=False, sign_alg=None,
+ digest_alg=None, **kwargs):
""" Create a Response.
Encryption:
encrypt_assertion must be true for encryption to be
@@ -619,7 +623,8 @@ class Entity(HTTPBase):
response = response_factory(issuer=_issuer,
in_response_to=in_response_to,
- status=status, sign_alg=sign_alg, digest_alg=digest_alg)
+ status=status, sign_alg=sign_alg,
+ digest_alg=digest_alg)
if consumer_url:
response.destination = consumer_url
@@ -636,15 +641,19 @@ class Entity(HTTPBase):
encrypt_assertion = False
if encrypt_assertion or (
- encrypted_advice_attributes and response.assertion.advice is
+ encrypted_advice_attributes and
+ response.assertion.advice is
not None and
- len(response.assertion.advice.assertion) == 1):
+ len(response.assertion.advice.assertion) == 1):
if sign:
response.signature = pre_signature_part(response.id,
- self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
+ self.sec.my_cert, 1,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg)
sign_class = [(class_name(response), response.id)]
- cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
- encrypt_advice = False
+ else:
+ sign_class = []
+
if encrypted_advice_attributes and response.assertion.advice is \
not None \
and len(response.assertion.advice.assertion) > 0:
@@ -664,7 +673,8 @@ class Entity(HTTPBase):
to_sign_advice = []
if sign_assertion and not pefim:
tmp_assertion.signature = pre_signature_part(
- tmp_assertion.id, self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
+ tmp_assertion.id, self.sec.my_cert, 1,
+ sign_alg=sign_alg, digest_alg=digest_alg)
to_sign_advice.append(
(class_name(tmp_assertion), tmp_assertion.id))
@@ -687,10 +697,9 @@ class Entity(HTTPBase):
response = signed_instance_factory(response,
self.sec,
to_sign_advice)
- response = self._encrypt_assertion(encrypt_cert_advice,
- sp_entity_id,
- response,
- node_xpath=node_xpath)
+ response = self._encrypt_assertion(
+ encrypt_cert_advice, sp_entity_id, response,
+ node_xpath=node_xpath)
response = response_from_string(response)
if encrypt_assertion:
@@ -700,10 +709,9 @@ class Entity(HTTPBase):
if not isinstance(_assertions, list):
_assertions = [_assertions]
for _assertion in _assertions:
- _assertion.signature = pre_signature_part(_assertion.id,
- self.sec.my_cert,
- 1,
- sign_alg=sign_alg, digest_alg=digest_alg)
+ _assertion.signature = pre_signature_part(
+ _assertion.id, self.sec.my_cert, 1,
+ sign_alg=sign_alg, digest_alg=digest_alg)
to_sign_assertion.append(
(class_name(_assertion), _assertion.id))
if encrypt_assertion_self_contained:
@@ -717,7 +725,7 @@ class Entity(HTTPBase):
response = pre_encrypt_assertion(response)
response = \
response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion(
- assertion_tag)
+ assertion_tag)
else:
response = pre_encrypt_assertion(response)
if to_sign_assertion:
@@ -735,11 +743,13 @@ class Entity(HTTPBase):
return response
if sign:
- return self.sign(response, to_sign=to_sign, sign_alg=sign_alg, digest_alg=digest_alg)
+ return self.sign(response, to_sign=to_sign, sign_alg=sign_alg,
+ digest_alg=digest_alg)
else:
return response
- def _status_response(self, response_class, issuer, status, sign=False, sign_alg=None, digest_alg=None,
+ def _status_response(self, response_class, issuer, status, sign=False,
+ sign_alg=None, digest_alg=None,
**kwargs):
""" Create a StatusResponse.
@@ -768,7 +778,8 @@ class Entity(HTTPBase):
status=status, **kwargs)
if sign:
- return self.sign(response, mid, sign_alg=sign_alg, digest_alg=digest_alg)
+ return self.sign(response, mid, sign_alg=sign_alg,
+ digest_alg=digest_alg)
else:
return response
@@ -847,7 +858,8 @@ class Entity(HTTPBase):
# ------------------------------------------------------------------------
def create_error_response(self, in_response_to, destination, info,
- sign=False, issuer=None, sign_alg=None, digest_alg=None, **kwargs):
+ sign=False, issuer=None, sign_alg=None,
+ digest_alg=None, **kwargs):
""" Create a error response.
:param in_response_to: The identifier of the message this is a response
@@ -871,7 +883,8 @@ class Entity(HTTPBase):
subject_id=None, name_id=None,
reason=None, expire=None, message_id=0,
consent=None, extensions=None, sign=False,
- session_indexes=None, sign_alg=None, digest_alg=None):
+ session_indexes=None, sign_alg=None,
+ digest_alg=None):
""" Constructs a LogoutRequest
:param destination: Destination of the request
@@ -915,10 +928,12 @@ class Entity(HTTPBase):
return self._message(LogoutRequest, destination, message_id,
consent, extensions, sign, name_id=name_id,
reason=reason, not_on_or_after=expire,
- issuer=self._issuer(), sign_alg=sign_alg, digest_alg=digest_alg, **args)
+ issuer=self._issuer(), sign_alg=sign_alg,
+ digest_alg=digest_alg, **args)
def create_logout_response(self, request, bindings=None, status=None,
- sign=False, issuer=None, sign_alg=None, digest_alg=None):
+ sign=False, issuer=None, sign_alg=None,
+ digest_alg=None):
""" Create a LogoutResponse.
:param request: The request this is a response to
@@ -936,14 +951,16 @@ class Entity(HTTPBase):
issuer = self._issuer()
response = self._status_response(samlp.LogoutResponse, issuer, status,
- sign, sign_alg=sign_alg, digest_alg=digest_alg, **rinfo)
+ sign, sign_alg=sign_alg,
+ digest_alg=digest_alg, **rinfo)
logger.info("Response: %s", response)
return response
def create_artifact_resolve(self, artifact, destination, sessid,
- consent=None, extensions=None, sign=False, sign_alg=None, digest_alg=None):
+ consent=None, extensions=None, sign=False,
+ sign_alg=None, digest_alg=None):
"""
Create a ArtifactResolve request
@@ -959,10 +976,12 @@ class Entity(HTTPBase):
artifact = Artifact(text=artifact)
return self._message(ArtifactResolve, destination, sessid,
- consent, extensions, sign, artifact=artifact, sign_alg=sign_alg, digest_alg=digest_alg)
+ consent, extensions, sign, artifact=artifact,
+ sign_alg=sign_alg, digest_alg=digest_alg)
def create_artifact_response(self, request, artifact, bindings=None,
- status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None):
+ status=None, sign=False, issuer=None,
+ sign_alg=None, digest_alg=None):
"""
Create an ArtifactResponse
:return:
@@ -970,7 +989,8 @@ class Entity(HTTPBase):
rinfo = self.response_args(request, bindings)
response = self._status_response(ArtifactResponse, issuer, status,
- sign=sign, sign_alg=sign_alg, digest_alg=digest_alg, **rinfo)
+ sign=sign, sign_alg=sign_alg,
+ digest_alg=digest_alg, **rinfo)
msg = element_to_extension_element(self.artifact[artifact])
response.extension_elements = [msg]
@@ -983,7 +1003,8 @@ class Entity(HTTPBase):
consent=None, extensions=None, sign=False,
name_id=None, new_id=None,
encrypted_id=None, new_encrypted_id=None,
- terminate=None, sign_alg=None, digest_alg=None):
+ terminate=None, sign_alg=None,
+ digest_alg=None):
"""
:param destination:
@@ -1020,7 +1041,8 @@ class Entity(HTTPBase):
"provided")
return self._message(ManageNameIDRequest, destination, consent=consent,
- extensions=extensions, sign=sign, sign_alg=sign_alg, digest_alg=digest_alg, **kwargs)
+ extensions=extensions, sign=sign,
+ sign_alg=sign_alg, digest_alg=digest_alg, **kwargs)
def parse_manage_name_id_request(self, xmlstr, binding=BINDING_SOAP):
""" Deal with a LogoutRequest
@@ -1036,13 +1058,15 @@ class Entity(HTTPBase):
"manage_name_id_service", binding)
def create_manage_name_id_response(self, request, bindings=None,
- status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None,
+ status=None, sign=False, issuer=None,
+ sign_alg=None, digest_alg=None,
**kwargs):
rinfo = self.response_args(request, bindings)
response = self._status_response(samlp.ManageNameIDResponse, issuer,
- status, sign, sign_alg=sign_alg, digest_alg=digest_alg, **rinfo)
+ status, sign, sign_alg=sign_alg,
+ digest_alg=digest_alg, **rinfo)
logger.info("Response: %s", response)
diff --git a/src/saml2/entity_category/swamid.py b/src/saml2/entity_category/swamid.py
index 94ebe62e..5dd717ec 100644
--- a/src/saml2/entity_category/swamid.py
+++ b/src/saml2/entity_category/swamid.py
@@ -2,7 +2,8 @@ __author__ = 'rolandh'
NAME = ["givenName", "displayName", "sn"]
-STATIC_ORG_INFO = ["c", "o", "co", "norEduOrgAcronym", "schacHomeOrganization"]
+STATIC_ORG_INFO = ["c", "o", "co", "norEduOrgAcronym", "schacHomeOrganization",
+ 'schacHomeOrganizationType']
OTHER = ["eduPersonPrincipalName", "eduPersonScopedAffiliation", "mail"]
# These give you access to information
diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py
index 1dd49d19..b6e5fe78 100644
--- a/src/saml2/httpbase.py
+++ b/src/saml2/httpbase.py
@@ -386,7 +386,7 @@ class HTTPBase(object):
@staticmethod
def use_http_get(message, destination, relay_state,
- typ="SAMLRequest", sigalg="", key=None, **kwargs):
+ typ="SAMLRequest", sigalg="", signer=None, **kwargs):
"""
Send a message using GET, this is the HTTP-Redirect case so
no direct response is expected to this request.
@@ -395,12 +395,13 @@ class HTTPBase(object):
:param destination:
:param relay_state:
:param typ: Whether a Request, Response or Artifact
- :param sigalg: The signature algorithm to use.
- :param key: Key to use for signing
+ :param sigalg: Which algorithm the signature function will use to sign
+ the message
+ :param signer: A signing function that can be used to sign the message
:return: dictionary
"""
if not isinstance(message, six.string_types):
message = "%s" % (message,)
return http_redirect_message(message, destination, relay_state, typ,
- sigalg, key)
+ sigalg, signer)
diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py
index c6bb1ecb..07f199e2 100644
--- a/src/saml2/mdstore.py
+++ b/src/saml2/mdstore.py
@@ -43,6 +43,9 @@ class ToOld(Exception):
pass
+class SourceNotFound(Exception):
+ pass
+
REQ2SRV = {
# IDP
"authn_request": "single_sign_on_service",
@@ -718,7 +721,7 @@ class MetaDataExtern(InMemoryMetaData):
return self.parse_and_check_signature(_txt)
else:
logger.info("Response status: %s", response.status_code)
- return False
+ raise SourceNotFound(self.url)
class MetaDataMD(InMemoryMetaData):
diff --git a/src/saml2/pack.py b/src/saml2/pack.py
index ed4142a0..ff7bd0ad 100644
--- a/src/saml2/pack.py
+++ b/src/saml2/pack.py
@@ -20,11 +20,13 @@ from saml2.sigver import REQ_ORDER
from saml2.sigver import RESP_ORDER
from saml2.sigver import SIGNER_ALGS
import six
+from saml2.xmldsig import SIG_ALLOWED_ALG
logger = logging.getLogger(__name__)
try:
from xml.etree import cElementTree as ElementTree
+
if ElementTree.VERSION < '1.3.0':
# cElementTree has no support for register_namespace
# neither _namespace_map, thus we sacrify performance
@@ -106,7 +108,7 @@ def http_post_message(message, relay_state="", typ="SAMLRequest", **kwargs):
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
- sigalg=None, key=None, **kwargs):
+ sigalg='', signer=None, **kwargs):
"""The HTTP Redirect binding defines a mechanism by which SAML protocol
messages can be transmitted within URL parameters.
Messages are encoded for use with this binding using a URL encoding
@@ -118,7 +120,9 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
:param location: Where the message should be posted to
:param relay_state: for preserving and conveying state information
:param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart
- :param sigalg: The signature algorithm to use.
+ :param sigalg: Which algorithm the signature function will use to sign
+ the message
+ :param signer: A signature function that can be used to sign the message
:param key: Key to use for signing
:return: A tuple containing header information and a HTML message.
"""
@@ -141,20 +145,15 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
if relay_state:
args["RelayState"] = relay_state
- if sigalg:
- # sigalgs, one of the ones defined in xmldsig
-
+ if signer:
+ # sigalgs, should be one defined in xmldsig
+ assert sigalg in [b for a, b in SIG_ALLOWED_ALG]
args["SigAlg"] = sigalg
- try:
- signer = SIGNER_ALGS[sigalg]
- except:
- raise Unsupported("Signing algorithm")
- else:
- string = "&".join([urlencode({k: args[k]})
- for k in _order if k in args]).encode('ascii')
- args["Signature"] = base64.b64encode(signer.sign(string, key))
- string = urlencode(args)
+ string = "&".join([urlencode({k: args[k]})
+ for k in _order if k in args]).encode('ascii')
+ args["Signature"] = base64.b64encode(signer.sign(string))
+ string = urlencode(args)
else:
string = urlencode(args)
@@ -240,11 +239,11 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
envelope = ElementTree.fromstring(text)
assert envelope.tag == '{%s}Envelope' % NAMESPACE
- #print(len(envelope))
+ # print(len(envelope))
body = None
header = {}
for part in envelope:
- #print(">",part.tag)
+ # print(">",part.tag)
if part.tag == '{%s}Body' % NAMESPACE:
for sub in part:
try:
@@ -255,11 +254,11 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
elif part.tag == '{%s}Header' % NAMESPACE:
if not header_class:
raise Exception("Header where I didn't expect one")
- #print("--- HEADER ---")
+ # print("--- HEADER ---")
for sub in part:
- #print(">>",sub.tag)
+ # print(">>",sub.tag)
for klass in header_class:
- #print("?{%s}%s" % (klass.c_namespace,klass.c_tag))
+ # print("?{%s}%s" % (klass.c_namespace,klass.c_tag))
if sub.tag == "{%s}%s" % (klass.c_namespace, klass.c_tag):
header[sub.tag] = \
saml2.create_class_from_element_tree(klass, sub)
@@ -267,6 +266,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
return body, header
+
# -----------------------------------------------------------------------------
PACKING = {
diff --git a/src/saml2/server.py b/src/saml2/server.py
index f62b32aa..b77b2c4d 100644
--- a/src/saml2/server.py
+++ b/src/saml2/server.py
@@ -338,8 +338,10 @@ class Server(Entity):
status=None, authn=None, issuer=None, policy=None,
sign_assertion=False, sign_response=False,
best_effort=False, encrypt_assertion=False,
- encrypt_cert_advice=None, encrypt_cert_assertion=None, authn_statement=None,
- encrypt_assertion_self_contained=False, encrypted_advice_attributes=False,
+ encrypt_cert_advice=None, encrypt_cert_assertion=None,
+ authn_statement=None,
+ encrypt_assertion_self_contained=False,
+ encrypted_advice_attributes=False,
pefim=False, sign_alg=None, digest_alg=None):
""" Create a response. A layer of indirection.
@@ -417,8 +419,10 @@ class Server(Entity):
to_sign = []
if not encrypt_assertion:
if sign_assertion:
- assertion.signature = pre_signature_part(assertion.id, self.sec.my_cert, 1,
- sign_alg=sign_alg, digest_alg=digest_alg)
+ assertion.signature = pre_signature_part(assertion.id,
+ self.sec.my_cert, 1,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg)
to_sign.append((class_name(assertion), assertion.id))
args["assertion"] = assertion
@@ -434,7 +438,8 @@ class Server(Entity):
encrypt_assertion_self_contained=encrypt_assertion_self_contained,
encrypted_advice_attributes=encrypted_advice_attributes,
sign_assertion=sign_assertion,
- pefim=pefim, sign_alg=sign_alg, digest_alg=digest_alg,
+ pefim=pefim, sign_alg=sign_alg,
+ digest_alg=digest_alg,
**args)
# ------------------------------------------------------------------------
@@ -444,7 +449,8 @@ class Server(Entity):
sp_entity_id, userid="", name_id=None,
status=None, issuer=None,
sign_assertion=False, sign_response=False,
- attributes=None, sign_alg=None, digest_alg=None, **kwargs):
+ attributes=None, sign_alg=None,
+ digest_alg=None, **kwargs):
""" Create an attribute assertion response.
:param identity: A dictionary with attributes and values that are
@@ -494,17 +500,115 @@ class Server(Entity):
if sign_assertion:
assertion.signature = pre_signature_part(assertion.id,
- self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
+ self.sec.my_cert, 1,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg)
# Just the assertion or the response and the assertion ?
to_sign = [(class_name(assertion), assertion.id)]
args["assertion"] = assertion
return self._response(in_response_to, destination, status, issuer,
- sign_response, to_sign, sign_alg=sign_alg, digest_alg=digest_alg, **args)
+ sign_response, to_sign, sign_alg=sign_alg,
+ digest_alg=digest_alg, **args)
# ------------------------------------------------------------------------
+ def gather_authn_response_args(self, sp_entity_id, name_id_policy, userid,
+ **kwargs):
+ param_default = {
+ 'sign_assertion': False,
+ 'sign_response': False,
+ 'encrypt_assertion': False,
+ 'encrypt_assertion_self_contained': True,
+ 'encrypted_advice_attributes': False,
+ 'encrypt_cert_advice': None,
+ 'encrypt_cert_assertion': None
+ }
+
+ args = {}
+
+ try:
+ args["policy"] = kwargs["release_policy"]
+ except KeyError:
+ args["policy"] = self.config.getattr("policy", "idp")
+
+ try:
+ args['best_effort'] = kwargs["best_effort"]
+ except KeyError:
+ args['best_effort'] = False
+
+ for param in ['sign_assertion', 'sign_response', 'encrypt_assertion',
+ 'encrypt_assertion_self_contained',
+ 'encrypted_advice_attributes', 'encrypt_cert_advice',
+ 'encrypt_cert_assertion']:
+ try:
+ _val = kwargs[param]
+ except:
+ _val = None
+
+ if _val is None:
+ _val = self.config.getattr(param, "idp")
+
+ if _val is None:
+ args[param] = param_default[param]
+ else:
+ args[param] = _val
+
+ for arg, attr, eca, pefim in [
+ ('encrypted_advice_attributes', 'verify_encrypt_cert_advice',
+ 'encrypt_cert_advice', kwargs["pefim"]),
+ ('encrypt_assertion', 'verify_encrypt_cert_assertion',
+ 'encrypt_cert_assertion', False)]:
+
+ if args[arg] or pefim:
+ _enc_cert = self.config.getattr(attr, "idp")
+
+ if _enc_cert is not None:
+ if kwargs[eca] is None:
+ raise CertificateError(
+ "No SPCertEncType certificate for encryption "
+ "contained in authentication "
+ "request.")
+ if not _enc_cert(kwargs[eca]):
+ raise CertificateError(
+ "Invalid certificate for encryption!")
+
+ if 'name_id' not in kwargs or not kwargs['name_id']:
+ nid_formats = []
+ for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]:
+ if "name_id_format" in _sp:
+ nid_formats.extend([n["text"] for n in
+ _sp["name_id_format"]])
+ try:
+ snq = name_id_policy.sp_name_qualifier
+ except AttributeError:
+ snq = sp_entity_id
+
+ if not snq:
+ snq = sp_entity_id
+
+ kwa = {"sp_name_qualifier": snq}
+
+ try:
+ kwa["format"] = name_id_policy.format
+ except AttributeError:
+ pass
+
+ _nids = self.ident.find_nameid(userid, **kwa)
+ # either none or one
+ if _nids:
+ args['name_id'] = _nids[0]
+ else:
+ args['name_id'] = self.ident.construct_nameid(
+ userid, args['policy'], sp_entity_id, name_id_policy)
+ logger.debug("construct_nameid: %s => %s", userid,
+ args['name_id'])
+ else:
+ args['name_id'] = kwargs['name_id']
+
+ return args
+
def create_authn_response(self, identity, in_response_to, destination,
sp_entity_id, name_id_policy=None, userid=None,
name_id=None, authn=None, issuer=None,
@@ -513,7 +617,8 @@ class Server(Entity):
encrypt_cert_assertion=None,
encrypt_assertion=None,
encrypt_assertion_self_contained=True,
- encrypted_advice_attributes=False, pefim=False, sign_alg=None, digest_alg=None,
+ encrypted_advice_attributes=False, pefim=False,
+ sign_alg=None, digest_alg=None,
**kwargs):
""" Constructs an AuthenticationResponse
@@ -547,105 +652,22 @@ class Server(Entity):
"""
try:
- policy = kwargs["release_policy"]
- except KeyError:
- policy = self.config.getattr("policy", "idp")
-
- try:
- best_effort = kwargs["best_effort"]
- except KeyError:
- best_effort = False
-
- if sign_assertion is None:
- sign_assertion = self.config.getattr("sign_assertion", "idp")
- if sign_assertion is None:
- sign_assertion = False
-
- if sign_response is None:
- sign_response = self.config.getattr("sign_response", "idp")
- if sign_response is None:
- sign_response = False
-
- if encrypt_assertion is None:
- encrypt_assertion = self.config.getattr("encrypt_assertion", "idp")
- if encrypt_assertion is None:
- encrypt_assertion = False
-
- if encrypt_assertion_self_contained is None:
- encrypt_assertion_self_contained = self.config.getattr(
- "encrypt_assertion_self_contained", "idp")
- if encrypt_assertion_self_contained is None:
- encrypt_assertion_self_contained = True
-
- if encrypted_advice_attributes is None:
- encrypted_advice_attributes = self.config.getattr(
- "encrypted_advice_attributes", "idp")
- if encrypted_advice_attributes is None:
- encrypted_advice_attributes = False
-
- if encrypted_advice_attributes or pefim:
- verify_encrypt_cert = self.config.getattr(
- "verify_encrypt_cert_advice", "idp")
- if verify_encrypt_cert is not None:
- if encrypt_cert_advice is None:
- raise CertificateError(
- "No SPCertEncType certificate for encryption "
- "contained in authentication "
- "request.")
- if not verify_encrypt_cert(encrypt_cert_advice):
- raise CertificateError(
- "Invalid certificate for encryption!")
-
- if encrypt_assertion:
- verify_encrypt_cert = self.config.getattr(
- "verify_encrypt_cert_assertion", "idp")
- if verify_encrypt_cert is not None:
- if encrypt_cert_assertion is None:
- raise CertificateError(
- "No SPCertEncType certificate for encryption "
- "contained in authentication "
- "request.")
- if not verify_encrypt_cert(encrypt_cert_assertion):
- raise CertificateError(
- "Invalid certificate for encryption!")
-
- if not name_id:
- try:
- nid_formats = []
- for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]:
- if "name_id_format" in _sp:
- nid_formats.extend([n["text"] for n in
- _sp["name_id_format"]])
- try:
- snq = name_id_policy.sp_name_qualifier
- except AttributeError:
- snq = sp_entity_id
-
- if not snq:
- snq = sp_entity_id
-
- kwa = {"sp_name_qualifier": snq}
-
- try:
- kwa["format"] = name_id_policy.format
- except AttributeError:
- pass
-
- _nids = self.ident.find_nameid(userid, **kwa)
- # either none or one
- if _nids:
- name_id = _nids[0]
- else:
- name_id = self.ident.construct_nameid(userid, policy,
- sp_entity_id,
- name_id_policy)
- logger.debug("construct_nameid: %s => %s", userid, name_id)
- except IOError as exc:
- response = self.create_error_response(in_response_to,
- destination,
- sp_entity_id,
- exc, name_id)
- return ("%s" % response).split("\n")
+ args = self.gather_authn_response_args(
+ sp_entity_id, name_id_policy=name_id_policy, userid=userid,
+ name_id=name_id, sign_response=sign_response,
+ sign_assertion=sign_assertion,
+ encrypt_cert_advice=encrypt_cert_advice,
+ encrypt_cert_assertion=encrypt_cert_assertion,
+ encrypt_assertion=encrypt_assertion,
+ encrypt_assertion_self_contained=encrypt_assertion_self_contained,
+ encrypted_advice_attributes=encrypted_advice_attributes,
+ pefim=pefim, **kwargs)
+ except IOError as exc:
+ response = self.create_error_response(in_response_to,
+ destination,
+ sp_entity_id,
+ exc, name_id)
+ return ("%s" % response).split("\n")
try:
_authn = authn
@@ -654,44 +676,14 @@ class Server(Entity):
self.sec.cert_handler.generate_cert():
with self.lock:
self.sec.cert_handler.update_cert(True)
- return self._authn_response(in_response_to,
- # in_response_to
- destination, # consumer_url
- sp_entity_id, # sp_entity_id
- identity,
- # identity as dictionary
- name_id,
- authn=_authn,
- issuer=issuer,
- policy=policy,
- sign_assertion=sign_assertion,
- sign_response=sign_response,
- best_effort=best_effort,
- encrypt_assertion=encrypt_assertion,
- encrypt_assertion_self_contained=encrypt_assertion_self_contained,
- encrypted_advice_attributes=encrypted_advice_attributes,
- encrypt_cert_advice=encrypt_cert_advice,
- encrypt_cert_assertion=encrypt_cert_assertion,
- pefim=pefim,
- sign_alg=sign_alg, digest_alg=digest_alg)
- return self._authn_response(in_response_to, # in_response_to
- destination, # consumer_url
- sp_entity_id, # sp_entity_id
- identity, # identity as dictionary
- name_id,
- authn=_authn,
- issuer=issuer,
- policy=policy,
- sign_assertion=sign_assertion,
- sign_response=sign_response,
- best_effort=best_effort,
- encrypt_assertion=encrypt_assertion,
- encrypt_assertion_self_contained=encrypt_assertion_self_contained,
- encrypted_advice_attributes=encrypted_advice_attributes,
- encrypt_cert_advice=encrypt_cert_advice,
- encrypt_cert_assertion=encrypt_cert_assertion,
- pefim=pefim,
- sign_alg=sign_alg, digest_alg=digest_alg)
+ return self._authn_response(
+ in_response_to, destination, sp_entity_id, identity,
+ authn=_authn, issuer=issuer, pefim=pefim,
+ sign_alg=sign_alg, digest_alg=digest_alg, **args)
+ return self._authn_response(
+ in_response_to, destination, sp_entity_id, identity,
+ authn=_authn, issuer=issuer, pefim=pefim, sign_alg=sign_alg,
+ digest_alg=digest_alg, **args)
except MissingValue as exc:
return self.create_error_response(in_response_to, destination,
@@ -710,8 +702,9 @@ class Server(Entity):
sign_response, sign_assertion,
authn_decl=authn_decl)
- #noinspection PyUnusedLocal
- def create_assertion_id_request_response(self, assertion_id, sign=False, sign_alg=None,
+ # noinspection PyUnusedLocal
+ def create_assertion_id_request_response(self, assertion_id, sign=False,
+ sign_alg=None,
digest_alg=None, **kwargs):
"""
@@ -728,7 +721,9 @@ class Server(Entity):
if to_sign:
if assertion.signature is None:
assertion.signature = pre_signature_part(assertion.id,
- self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
+ self.sec.my_cert, 1,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg)
return signed_instance_factory(assertion, self.sec, to_sign)
else:
@@ -738,7 +733,8 @@ class Server(Entity):
def create_name_id_mapping_response(self, name_id=None, encrypted_id=None,
in_response_to=None,
issuer=None, sign_response=False,
- status=None, sign_alg=None, digest_alg=None, **kwargs):
+ status=None, sign_alg=None,
+ digest_alg=None, **kwargs):
"""
protocol for mapping a principal's name identifier into a
different name identifier for the same principal.
@@ -768,7 +764,8 @@ class Server(Entity):
def create_authn_query_response(self, subject, session_index=None,
requested_context=None, in_response_to=None,
issuer=None, sign_response=False,
- status=None, sign_alg=None, digest_alg=None, **kwargs):
+ status=None, sign_alg=None, digest_alg=None,
+ **kwargs):
"""
A successful <Response> will contain one or more assertions containing
authentication statements.
@@ -789,7 +786,8 @@ class Server(Entity):
args = {}
return self._response(in_response_to, "", status, issuer,
- sign_response, to_sign=[], sign_alg=sign_alg, digest_alg=digest_alg, **args)
+ sign_response, to_sign=[], sign_alg=sign_alg,
+ digest_alg=digest_alg, **args)
# ---------
diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py
index ec48f4ea..f89f347c 100644
--- a/src/saml2/sigver.py
+++ b/src/saml2/sigver.py
@@ -235,7 +235,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
"""
cinst = None
- #print("make_vals(%s, %s)" % (val, klass))
+ # print("make_vals(%s, %s)" % (val, klass))
if isinstance(val, dict):
cinst = _instance(klass, val, seccont, base64encode=base64encode,
@@ -264,7 +264,7 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
instance = klass()
for prop in instance.c_attributes.values():
- #print("# %s" % (prop))
+ # print("# %s" % (prop))
if prop in ava:
if isinstance(ava[prop], bool):
setattr(instance, prop, str(ava[prop]).encode('utf-8'))
@@ -277,9 +277,9 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
instance.set_text(ava["text"], base64encode)
for prop, klassdef in instance.c_children.values():
- #print("## %s, %s" % (prop, klassdef))
+ # print("## %s, %s" % (prop, klassdef))
if prop in ava:
- #print("### %s" % ava[prop])
+ # print("### %s" % ava[prop])
if isinstance(klassdef, list):
# means there can be a list of values
_make_vals(ava[prop], klassdef[0], seccont, instance, prop,
@@ -319,9 +319,9 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
signed_xml = seccont.sign_statement(
signed_xml, node_name=node_name, node_id=nodeid)
- #print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
- #print("%s" % signed_xml)
- #print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
+ # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
+ # print("%s" % signed_xml)
+ # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
return signed_xml
else:
return instance
@@ -367,6 +367,7 @@ def make_temp(string, suffix="", decode=True, delete=True):
def split_len(seq, length):
return [seq[i:i + length] for i in range(0, len(seq), length)]
+
# --------------------------------------------------------------------------
M2_TIME_FORMAT = "%b %d %H:%M:%S %Y"
@@ -396,15 +397,16 @@ def active_cert(key):
assert not_after > utc_now()
return True
except:
- cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str)
- assert cert.has_expired() == 0
- assert not OpenSSLWrapper().certificate_not_valid_yet(cert)
- return True
+ cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str)
+ assert cert.has_expired() == 0
+ assert not OpenSSLWrapper().certificate_not_valid_yet(cert)
+ return True
except AssertionError:
return False
except AttributeError:
return False
+
def cert_from_key_info(key_info, ignore_age=False):
""" Get all X509 certs from a KeyInfo instance. Care is taken to make sure
that the certs are continues sequences of bytes.
@@ -418,7 +420,7 @@ def cert_from_key_info(key_info, ignore_age=False):
"""
res = []
for x509_data in key_info.x509_data:
- #print("X509Data",x509_data)
+ # print("X509Data",x509_data)
x509_certificate = x509_data.x509_certificate
cert = x509_certificate.text.strip()
cert = "\n".join(split_len("".join([s.strip() for s in
@@ -515,13 +517,13 @@ def key_from_key_value_dict(key_info):
# =============================================================================
-#def rsa_load(filename):
+# def rsa_load(filename):
# """Read a PEM-encoded RSA key pair from a file."""
# return M2Crypto.RSA.load_key(filename, M2Crypto.util
# .no_passphrase_callback)
#
#
-#def rsa_loads(key):
+# def rsa_loads(key):
# """Read a PEM-encoded RSA key pair from a string."""
# return M2Crypto.RSA.load_key_string(key,
# M2Crypto.util.no_passphrase_callback)
@@ -582,6 +584,9 @@ def sha1_digest(msg):
class Signer(object):
"""Abstract base class for signing algorithms."""
+ def __init__(self, key):
+ self.key = key
+
def sign(self, msg, key):
"""Sign ``msg`` with ``key`` and return the signature."""
raise NotImplementedError
@@ -592,15 +597,22 @@ class Signer(object):
class RSASigner(Signer):
- def __init__(self, digest):
+ def __init__(self, digest, key=None):
+ Signer.__init__(self, key)
self.digest = digest
- def sign(self, msg, key):
+ def sign(self, msg, key=None):
+ if key is None:
+ key = self.key
+
h = self.digest.new(msg)
signer = PKCS1_v1_5.new(key)
return signer.sign(h)
- def verify(self, msg, sig, key):
+ def verify(self, msg, sig, key=None):
+ if key is None:
+ key = self.key
+
h = self.digest.new(msg)
verifier = PKCS1_v1_5.new(key)
return verifier.verify(h, sig)
@@ -618,7 +630,25 @@ REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"]
RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"]
-def verify_redirect_signature(saml_msg, cert=None, sigkey=None):
+class RSACrypto(object):
+ def __init__(self, key):
+ self.key = key
+
+ def get_signer(self, sigalg, sigkey=None):
+ try:
+ signer = SIGNER_ALGS[sigalg]
+ except KeyError:
+ return None
+ else:
+ if sigkey:
+ signer.key = sigkey
+ else:
+ signer.key = self.key
+
+ return signer
+
+
+def verify_redirect_signature(saml_msg, crypto, cert=None, sigkey=None):
"""
:param saml_msg: A dictionary with strings as values, *NOT* lists as
@@ -628,7 +658,7 @@ def verify_redirect_signature(saml_msg, cert=None, sigkey=None):
"""
try:
- signer = SIGNER_ALGS[saml_msg["SigAlg"]]
+ signer = crypto.get_signer(saml_msg["SigAlg"], sigkey)
except KeyError:
raise Unsupported("Signature algorithm: %s" % saml_msg["SigAlg"])
else:
@@ -646,10 +676,12 @@ def verify_redirect_signature(saml_msg, cert=None, sigkey=None):
string = "&".join(
[urlencode({k: _args[k]}) for k in _order if k in
_args]).encode('ascii')
+
if cert:
_key = extract_rsa_key_from_x509_cert(pem_format(cert))
else:
_key = sigkey
+
_sign = base64.b64decode(saml_msg["Signature"])
return bool(signer.verify(string, _sign, _key))
@@ -665,6 +697,7 @@ def make_str(txt):
else:
return txt.decode("utf8")
+
# ---------------------------------------------------------------------------
@@ -677,7 +710,6 @@ def read_cert_from_file(cert_file, cert_type):
:return: A base64 encoded certificate as a string or the empty string
"""
-
if not cert_file:
return ""
@@ -689,7 +721,7 @@ def read_cert_from_file(cert_file, cert_type):
for pattern in ("-----BEGIN CERTIFICATE-----",
"-----BEGIN PUBLIC KEY-----"):
if pattern in lines:
- lines = lines[lines.index(pattern)+1:]
+ lines = lines[lines.index(pattern) + 1:]
break
else:
raise CertificateError("Strange beginning of PEM file")
@@ -703,7 +735,6 @@ def read_cert_from_file(cert_file, cert_type):
raise CertificateError("Strange end of PEM file")
return make_str("".join(lines).encode("utf8"))
-
if cert_type in ["der", "cer", "crt"]:
data = read_file(cert_file, 'rb')
_cert = base64.b64encode(data)
@@ -757,6 +788,11 @@ class CryptoBackendXmlSec1(CryptoBackend):
else:
self._xmlsec_delete_tmpfiles = True
+ try:
+ self.non_xml_crypto = RSACrypto(kwargs['rsa_key'])
+ except KeyError:
+ pass
+
def version(self):
com_list = [self.xmlsec, "--version"]
pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
@@ -808,7 +844,8 @@ class CryptoBackendXmlSec1(CryptoBackend):
if isinstance(statement, SamlBase):
statement = pre_encrypt_assertion(statement)
- _, fil = make_temp(str(statement).encode('utf-8'), decode=False, delete=False)
+ _, fil = make_temp(str(statement).encode('utf-8'), decode=False,
+ delete=False)
_, tmpl = make_temp(str(template).encode('utf-8'), decode=False)
if not node_xpath:
@@ -882,7 +919,7 @@ class CryptoBackendXmlSec1(CryptoBackend):
return signed_statement.decode('utf-8')
logger.error(
"Signing operation failed :\nstdout : %s\nstderr : %s",
- stdout, stderr)
+ stdout, stderr)
raise SigverError(stderr)
except DecryptError:
raise SigverError("Signing failed")
@@ -917,7 +954,7 @@ class CryptoBackendXmlSec1(CryptoBackend):
if self.__DEBUG:
try:
- print( " ".join(com_list))
+ print(" ".join(com_list))
except TypeError:
print("cert_type", cert_type)
print("cert_file", cert_file)
@@ -1040,7 +1077,7 @@ class CryptoBackendXMLSecurity(CryptoBackend):
def security_context(conf, debug=None):
""" Creates a security context based on the configuration
- :param conf: The configuration
+ :param conf: The configuration, this is a Config instance
:return: A SecurityContext instance
"""
if not conf:
@@ -1061,6 +1098,8 @@ def security_context(conf, debug=None):
if _only_md is None:
_only_md = False
+ sec_backend = None
+
if conf.crypto_backend == 'xmlsec1':
xmlsec_binary = conf.xmlsec_binary
if not xmlsec_binary:
@@ -1071,10 +1110,20 @@ def security_context(conf, debug=None):
xmlsec_binary = get_xmlsec_binary(_path)
# verify that xmlsec is where it's supposed to be
if not os.path.exists(xmlsec_binary):
- #if not os.access(, os.F_OK):
+ # if not os.access(, os.F_OK):
raise SigverError(
"xmlsec binary not in '%s' !" % xmlsec_binary)
crypto = _get_xmlsec_cryptobackend(xmlsec_binary, debug=debug)
+ _file_name = conf.getattr("key_file", "")
+ if _file_name:
+ try:
+ rsa_key = import_rsa_key_from_file(_file_name)
+ except Exception as err:
+ logger.error("Could not import key from {}: {}".format(_file_name,
+ err))
+ raise
+ else:
+ sec_backend = RSACrypto(rsa_key)
elif conf.crypto_backend == 'XMLSecurity':
# new and somewhat untested pyXMLSecurity crypto backend.
crypto = CryptoBackendXMLSecurity(debug=debug)
@@ -1082,6 +1131,7 @@ def security_context(conf, debug=None):
raise SigverError('Unknown crypto_backend %s' % (
repr(conf.crypto_backend)))
+
enc_key_files = []
if conf.encryption_keypairs is not None:
for _encryption_keypair in conf.encryption_keypairs:
@@ -1096,36 +1146,43 @@ def security_context(conf, debug=None):
tmp_cert_file=conf.tmp_cert_file,
tmp_key_file=conf.tmp_key_file,
validate_certificate=conf.validate_certificate,
- enc_key_files=enc_key_files, encryption_keypairs=conf.encryption_keypairs)
+ enc_key_files=enc_key_files,
+ encryption_keypairs=conf.encryption_keypairs,
+ sec_backend=sec_backend)
def encrypt_cert_from_item(item):
_encrypt_cert = None
try:
try:
- _elem = extension_elements_to_elements(item.extensions.extension_elements,[pefim, ds])
+ _elem = extension_elements_to_elements(
+ item.extensions.extension_elements, [pefim, ds])
except:
- _elem = extension_elements_to_elements(item.extension_elements[0].children,
- [pefim, ds])
+ _elem = extension_elements_to_elements(
+ item.extension_elements[0].children,
+ [pefim, ds])
for _tmp_elem in _elem:
if isinstance(_tmp_elem, SPCertEnc):
for _tmp_key_info in _tmp_elem.key_info:
- if _tmp_key_info.x509_data is not None and len(_tmp_key_info.x509_data) > 0:
- _encrypt_cert = _tmp_key_info.x509_data[0].x509_certificate.text
+ if _tmp_key_info.x509_data is not None and len(
+ _tmp_key_info.x509_data) > 0:
+ _encrypt_cert = _tmp_key_info.x509_data[
+ 0].x509_certificate.text
break
- #_encrypt_cert = _elem[0].x509_data[0].x509_certificate.text
-# else:
-# certs = cert_from_instance(item)
-# if len(certs) > 0:
-# _encrypt_cert = certs[0]
+ # _encrypt_cert = _elem[0].x509_data[
+ # 0].x509_certificate.text
+ # else:
+ # certs = cert_from_instance(item)
+ # if len(certs) > 0:
+ # _encrypt_cert = certs[0]
except Exception as _exception:
pass
-# if _encrypt_cert is None:
-# certs = cert_from_instance(item)
-# if len(certs) > 0:
-# _encrypt_cert = certs[0]
+ # if _encrypt_cert is None:
+ # certs = cert_from_instance(item)
+ # if len(certs) > 0:
+ # _encrypt_cert = certs[0]
if _encrypt_cert is not None:
if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1:
@@ -1135,7 +1192,6 @@ def encrypt_cert_from_item(item):
return _encrypt_cert
-
class CertHandlerExtra(object):
def __init__(self):
pass
@@ -1146,14 +1202,14 @@ class CertHandlerExtra(object):
def generate_cert(self, generate_cert_info, root_cert_string,
root_key_string):
raise Exception("generate_cert function must be implemented")
- #Excepts to return (cert_string, key_string)
+ # Excepts to return (cert_string, key_string)
def use_validate_cert_func(self):
raise Exception("use_validate_cert_func function must be implemented")
def validate_cert(self, cert_str, root_cert_string, root_key_string):
raise Exception("validate_cert function must be implemented")
- #Excepts to return True/False
+ # Excepts to return True/False
class CertHandler(object):
@@ -1180,7 +1236,7 @@ class CertHandler(object):
self._verify_cert = False
self._generate_cert = False
- #This cert do not have to be valid, it is just the last cert to be
+ # This cert do not have to be valid, it is just the last cert to be
# validated.
self._last_cert_verified = None
self._last_validated_cert = None
@@ -1206,7 +1262,7 @@ class CertHandler(object):
self._cert_info = None
self._generate_cert_func_active = False
if generate_cert_info is not None and len(self._cert_str) > 0 and \
- len(self._key_str) > 0 and tmp_key_file is not \
+ len(self._key_str) > 0 and tmp_key_file is not \
None and tmp_cert_file is not None:
self._generate_cert = True
self._cert_info = generate_cert_info
@@ -1236,7 +1292,7 @@ class CertHandler(object):
if (self._generate_cert and active) or client_crt is not None:
if client_crt is not None:
self._tmp_cert_str = client_crt
- #No private key for signing
+ # No private key for signing
self._tmp_key_str = ""
elif self._cert_handler_extra_class is not None and \
self._cert_handler_extra_class.use_generate_cert_func():
@@ -1244,9 +1300,9 @@ class CertHandler(object):
self._cert_handler_extra_class.generate_cert(
self._cert_info, self._cert_str, self._key_str)
else:
- self._tmp_cert_str, self._tmp_key_str = self._osw\
+ self._tmp_cert_str, self._tmp_key_str = self._osw \
.create_certificate(
- self._cert_info, request=True)
+ self._cert_info, request=True)
self._tmp_cert_str = self._osw.create_cert_signed_certificate(
self._cert_str, self._key_str, self._tmp_cert_str)
valid, mess = self._osw.verify(self._cert_str,
@@ -1273,12 +1329,18 @@ class SecurityContext(object):
debug=False, template="", encrypt_key_type="des-192",
only_use_keys_in_metadata=False, cert_handler_extra_class=None,
generate_cert_info=None, tmp_cert_file=None,
- tmp_key_file=None, validate_certificate=None, enc_key_files=None, enc_key_type="pem",
- encryption_keypairs=None, enc_cert_type="pem"):
+ tmp_key_file=None, validate_certificate=None,
+ enc_key_files=None, enc_key_type="pem",
+ encryption_keypairs=None, enc_cert_type="pem",
+ sec_backend=None):
self.crypto = crypto
assert (isinstance(self.crypto, CryptoBackend))
+ if sec_backend:
+ assert (isinstance(sec_backend, RSACrypto))
+ self.sec_backend = sec_backend
+
# Your private key for signing
self.key_file = key_file
self.key_type = key_type
@@ -1355,7 +1417,8 @@ class SecurityContext(object):
:param key_type: The type of session key to use.
:return: The encrypted text
"""
- raise NotImplemented()
+ return self.crypto.encrypt_assertion(statement, enc_key, template,
+ key_type, node_xpath)
def decrypt_keys(self, enctext, keys=None):
""" Decrypting an encrypted text by the use of a private key.
@@ -1394,9 +1457,9 @@ class SecurityContext(object):
if _enctext is not None and len(_enctext) > 0:
return _enctext
if key_file is not None and len(key_file.strip()) > 0:
- _enctext = self.crypto.decrypt(enctext, key_file)
- if _enctext is not None and len(_enctext) > 0:
- return _enctext
+ _enctext = self.crypto.decrypt(enctext, key_file)
+ if _enctext is not None and len(_enctext) > 0:
+ return _enctext
return enctext
def verify_signature(self, signedtext, cert_file=None, cert_type="pem",
@@ -1428,7 +1491,7 @@ class SecurityContext(object):
def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
origdoc=None, id_attr="", must=False,
only_valid_cert=False, issuer=None):
- #print(item)
+ # print(item)
try:
_issuer = item.issuer.text.strip()
except AttributeError:
@@ -1459,16 +1522,17 @@ class SecurityContext(object):
if not certs and not self.only_use_keys_in_metadata:
logger.debug("==== Certs from instance ====")
certs = [make_temp(pem_format(cert), suffix=".pem",
- decode=False, delete=self._xmlsec_delete_tmpfiles)
- for cert in cert_from_instance(item)]
+ decode=False,
+ delete=self._xmlsec_delete_tmpfiles)
+ for cert in cert_from_instance(item)]
else:
logger.debug("==== Certs from metadata ==== %s: %s ====", issuer,
- certs)
+ certs)
if not certs:
raise MissingKey("%s" % issuer)
- #print(certs)
+ # print(certs)
verified = False
last_pem_file = None
@@ -1552,7 +1616,8 @@ class SecurityContext(object):
if not msg.signature:
if must:
- raise SignatureError("Required signature missing on %s" % msgtype)
+ raise SignatureError(
+ "Required signature missing on %s" % msgtype)
else:
return msg
@@ -1697,9 +1762,9 @@ class SecurityContext(object):
return response
- #--------------------------------------------------------------------------
+ # --------------------------------------------------------------------------
# SIGNATURE PART
- #--------------------------------------------------------------------------
+ # --------------------------------------------------------------------------
def sign_statement_using_xmlsec(self, statement, **kwargs):
""" Deprecated function. See sign_statement(). """
return self.sign_statement(statement, **kwargs)
@@ -1760,7 +1825,8 @@ class SecurityContext(object):
return self.sign_statement(statement, class_name(
samlp.AttributeQuery()), **kwargs)
- def multiple_signatures(self, statement, to_sign, key=None, key_file=None, sign_alg=None, digest_alg=None):
+ def multiple_signatures(self, statement, to_sign, key=None, key_file=None,
+ sign_alg=None, digest_alg=None):
"""
Sign multiple parts of a statement
@@ -1779,7 +1845,9 @@ class SecurityContext(object):
sid = item.id
if not item.signature:
- item.signature = pre_signature_part(sid, self.cert_file, sign_alg=sign_alg, digest_alg=digest_alg)
+ item.signature = pre_signature_part(sid, self.cert_file,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg)
statement = self.sign_statement(statement, class_name(item),
key=key, key_file=key_file,
@@ -1917,12 +1985,14 @@ def pre_encrypt_assertion(response):
return response
-def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, **kwargs):
+def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None,
+ **kwargs):
response = samlp.Response(id=sid(), version=VERSION,
issue_instant=instant())
if sign:
- response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg, digest_alg=digest_alg)
+ response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg,
+ digest_alg=digest_alg)
if encrypt:
pass
@@ -1931,6 +2001,7 @@ def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None,
return response
+
# ----------------------------------------------------------------------------
if __name__ == '__main__':
import argparse