summaryrefslogtreecommitdiff
path: root/src/saml2/entity.py
diff options
context:
space:
mode:
authorHans Hörberg <hans.horberg@umu.se>2015-11-06 12:17:34 +0100
committerHans Hörberg <hans.horberg@umu.se>2015-11-06 12:17:34 +0100
commit3b84f65d849c2035636e1691ae0d35fe664640b9 (patch)
treece3477e5f6943edd84d646deb67705fcfff76ccd /src/saml2/entity.py
parent0f209eb5498ca948f654c2128d3b037d18b8cb17 (diff)
downloadpysaml2-3b84f65d849c2035636e1691ae0d35fe664640b9.tar.gz
Added the possibility to set signature and digest algorithm on all the functions I identified.
pysaml2 has a default value for sign and digest. To make it possible to always use the same algorithm this default value has been replaced with a singleton class. The first time the singleton class is instantiated the sign and digest algorithm will be set. After that it cannot be changed. A good place to setup this single class is in the server setup. Example: ds.DefaultSignature(ds.SIG_RSA_SHA512, ds.DIGEST_SHA512)
Diffstat (limited to 'src/saml2/entity.py')
-rw-r--r--src/saml2/entity.py60
1 files changed, 33 insertions, 27 deletions
diff --git a/src/saml2/entity.py b/src/saml2/entity.py
index 60de964f..f8166771 100644
--- a/src/saml2/entity.py
+++ b/src/saml2/entity.py
@@ -409,9 +409,9 @@ class Entity(HTTPBase):
# --------------------------------------------------------------------------
- def sign(self, msg, mid=None, to_sign=None, sign_prepare=False):
+ def sign(self, msg, mid=None, to_sign=None, sign_prepare=False, sign_alg=None, digest_alg=None):
if msg.signature is None:
- msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1)
+ msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1, sign_alg=sign_alg)
if sign_prepare:
return msg
@@ -429,7 +429,7 @@ class Entity(HTTPBase):
def _message(self, request_cls, destination=None, message_id=0,
consent=None, extensions=None, sign=False, sign_prepare=False,
- nsprefix=None, **kwargs):
+ nsprefix=None, sign_alg=None, digest_alg=None, **kwargs):
"""
Some parameters appear in all requests so simplify by doing
it in one place
@@ -468,7 +468,7 @@ class Entity(HTTPBase):
req.register_prefix(nsprefix)
if sign:
- return reqid, self.sign(req, sign_prepare=sign_prepare)
+ return reqid, self.sign(req, sign_prepare=sign_prepare, sign_alg=sign_alg)
else:
logger.info("REQUEST: %s" % req)
return reqid, req
@@ -559,8 +559,10 @@ class Entity(HTTPBase):
def _response(self, in_response_to, consumer_url=None, status=None,
issuer=None, sign=False, to_sign=None, sp_entity_id=None,
- encrypt_assertion=False, encrypt_assertion_self_contained=False, encrypted_advice_attributes=False,
- encrypt_cert_advice=None, encrypt_cert_assertion=None,sign_assertion=None, pefim=False, **kwargs):
+ encrypt_assertion=False, encrypt_assertion_self_contained=False,
+ encrypted_advice_attributes=False,
+ encrypt_cert_advice=None, encrypt_cert_assertion=None,sign_assertion=None,
+ pefim=False, sign_alg=None, digest_alg=None, **kwargs):
""" Create a Response.
Encryption:
encrypt_assertion must be true for encryption to be performed. If encrypted_advice_attributes also is
@@ -596,7 +598,7 @@ class Entity(HTTPBase):
response = response_factory(issuer=_issuer,
in_response_to=in_response_to,
- status=status)
+ status=status, sign_alg=sign_alg)
if consumer_url:
response.destination = consumer_url
@@ -616,7 +618,7 @@ class Entity(HTTPBase):
len(response.assertion.advice.assertion) == 1):
if sign:
response.signature = pre_signature_part(response.id,
- self.sec.my_cert, 1)
+ self.sec.my_cert, 1, sign_alg=sign_alg)
sign_class = [(class_name(response), response.id)]
cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
encrypt_advice = False
@@ -635,7 +637,9 @@ class Entity(HTTPBase):
for tmp_assertion in _advice_assertions:
to_sign_advice = []
if sign_assertion and not pefim:
- tmp_assertion.signature = pre_signature_part(tmp_assertion.id, self.sec.my_cert, 1)
+ tmp_assertion.signature = pre_signature_part(tmp_assertion.id,
+ self.sec.my_cert, 1,
+ sign_alg=sign_alg)
to_sign_advice.append((class_name(tmp_assertion), tmp_assertion.id))
#tmp_assertion = response.assertion.advice.assertion[0]
_assertion.advice.encrypted_assertion[0].add_extension_element(tmp_assertion)
@@ -661,7 +665,9 @@ class Entity(HTTPBase):
if not isinstance(_assertions, list):
_assertions = [_assertions]
for _assertion in _assertions:
- _assertion.signature = pre_signature_part(_assertion.id, self.sec.my_cert, 1)
+ _assertion.signature = pre_signature_part(_assertion.id, self.sec.my_cert,
+ 1,
+ sign_alg=sign_alg)
to_sign_assertion.append((class_name(_assertion), _assertion.id))
if encrypt_assertion_self_contained:
try:
@@ -685,11 +691,11 @@ class Entity(HTTPBase):
return response
if sign:
- return self.sign(response, to_sign=to_sign)
+ return self.sign(response, to_sign=to_sign, sign_alg=sign_alg)
else:
return response
- def _status_response(self, response_class, issuer, status, sign=False,
+ def _status_response(self, response_class, issuer, status, sign=False, sign_alg=None, digest_alg=None,
**kwargs):
""" Create a StatusResponse.
@@ -718,7 +724,7 @@ class Entity(HTTPBase):
status=status, **kwargs)
if sign:
- return self.sign(response, mid)
+ return self.sign(response, mid, sign_alg=sign_alg)
else:
return response
@@ -797,7 +803,7 @@ class Entity(HTTPBase):
# ------------------------------------------------------------------------
def create_error_response(self, in_response_to, destination, info,
- sign=False, issuer=None, **kwargs):
+ sign=False, issuer=None, sign_alg=None, digest_alg=None, **kwargs):
""" Create a error response.
:param in_response_to: The identifier of the message this is a response
@@ -813,7 +819,7 @@ class Entity(HTTPBase):
status = error_status_factory(info)
return self._response(in_response_to, destination, status, issuer,
- sign)
+ sign, sign_alg=sign_alg)
# ------------------------------------------------------------------------
@@ -821,7 +827,7 @@ class Entity(HTTPBase):
subject_id=None, name_id=None,
reason=None, expire=None, message_id=0,
consent=None, extensions=None, sign=False,
- session_indexes=None):
+ session_indexes=None, sign_alg=None, digest_alg=None):
""" Constructs a LogoutRequest
:param destination: Destination of the request
@@ -865,10 +871,10 @@ class Entity(HTTPBase):
return self._message(LogoutRequest, destination, message_id,
consent, extensions, sign, name_id=name_id,
reason=reason, not_on_or_after=expire,
- issuer=self._issuer(), **args)
+ issuer=self._issuer(), sign_alg=sign_alg, **args)
def create_logout_response(self, request, bindings=None, status=None,
- sign=False, issuer=None):
+ sign=False, issuer=None, sign_alg=None, digest_alg=None):
""" Create a LogoutResponse.
:param request: The request this is a response to
@@ -886,14 +892,14 @@ class Entity(HTTPBase):
issuer = self._issuer()
response = self._status_response(samlp.LogoutResponse, issuer, status,
- sign, **rinfo)
+ sign, sign_alg=sign_alg, **rinfo)
logger.info("Response: %s" % (response,))
return response
def create_artifact_resolve(self, artifact, destination, sessid,
- consent=None, extensions=None, sign=False):
+ consent=None, extensions=None, sign=False, sign_alg=None, digest_alg=None):
"""
Create a ArtifactResolve request
@@ -909,10 +915,10 @@ class Entity(HTTPBase):
artifact = Artifact(text=artifact)
return self._message(ArtifactResolve, destination, sessid,
- consent, extensions, sign, artifact=artifact)
+ consent, extensions, sign, artifact=artifact, sign_alg=sign_alg)
def create_artifact_response(self, request, artifact, bindings=None,
- status=None, sign=False, issuer=None):
+ status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None):
"""
Create an ArtifactResponse
:return:
@@ -920,7 +926,7 @@ class Entity(HTTPBase):
rinfo = self.response_args(request, bindings)
response = self._status_response(ArtifactResponse, issuer, status,
- sign=sign, **rinfo)
+ sign=sign, sign_alg=sign_alg, **rinfo)
msg = element_to_extension_element(self.artifact[artifact])
response.extension_elements = [msg]
@@ -933,7 +939,7 @@ class Entity(HTTPBase):
consent=None, extensions=None, sign=False,
name_id=None, new_id=None,
encrypted_id=None, new_encrypted_id=None,
- terminate=None):
+ terminate=None, sign_alg=None, digest_alg=None):
"""
:param destination:
@@ -969,7 +975,7 @@ class Entity(HTTPBase):
"One of NewID, NewEncryptedNameID or Terminate has to be provided")
return self._message(ManageNameIDRequest, destination, consent=consent,
- extensions=extensions, sign=sign, **kwargs)
+ extensions=extensions, sign=sign, sign_alg=sign_alg, **kwargs)
def parse_manage_name_id_request(self, xmlstr, binding=BINDING_SOAP):
""" Deal with a LogoutRequest
@@ -985,13 +991,13 @@ class Entity(HTTPBase):
"manage_name_id_service", binding)
def create_manage_name_id_response(self, request, bindings=None,
- status=None, sign=False, issuer=None,
+ status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None,
**kwargs):
rinfo = self.response_args(request, bindings)
response = self._status_response(samlp.ManageNameIDResponse, issuer,
- status, sign, **rinfo)
+ status, sign, sign_alg=sign_alg, **rinfo)
logger.info("Response: %s" % (response,))