summaryrefslogtreecommitdiff
path: root/src/saml2/entity.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2/entity.py')
-rw-r--r--src/saml2/entity.py84
1 files changed, 60 insertions, 24 deletions
diff --git a/src/saml2/entity.py b/src/saml2/entity.py
index 672ad6f7..af2d8ba4 100644
--- a/src/saml2/entity.py
+++ b/src/saml2/entity.py
@@ -190,8 +190,16 @@ class Entity(HTTPBase):
return Issuer(text=self.config.entityid,
format=NAMEID_FORMAT_ENTITY)
- def apply_binding(self, binding, msg_str, destination="", relay_state="",
- response=False, sign=False, **kwargs):
+ def apply_binding(
+ self,
+ binding,
+ msg_str,
+ destination="",
+ relay_state="",
+ response=False,
+ sign=None,
+ **kwargs,
+ ):
"""
Construct the necessary HTTP arguments dependent on Binding
@@ -218,19 +226,25 @@ class Entity(HTTPBase):
# info["url"] = destination
# info["method"] = "POST"
# else:
- info = self.use_http_form_post(msg_str, destination,
- relay_state, typ)
+ info = self.use_http_form_post(msg_str, destination, relay_state, typ)
info["url"] = destination
info["method"] = "POST"
elif binding == BINDING_HTTP_REDIRECT:
logger.info("HTTP REDIRECT")
sigalg = kwargs.get("sigalg")
- if sign and sigalg:
- signer = self.sec.sec_backend.get_signer(sigalg)
- else:
- signer = None
- info = self.use_http_get(msg_str, destination, relay_state, typ,
- signer=signer, **kwargs)
+ signer = (
+ self.sec.sec_backend.get_signer(sigalg)
+ if sign and sigalg
+ else None
+ )
+ info = self.use_http_get(
+ msg_str,
+ destination,
+ relay_state,
+ typ,
+ signer=signer,
+ **kwargs,
+ )
info["url"] = str(destination)
info["method"] = "GET"
elif binding == BINDING_SOAP or binding == BINDING_PAOS:
@@ -416,12 +430,19 @@ class Entity(HTTPBase):
# --------------------------------------------------------------------------
- def sign(self, msg, mid=None, to_sign=None, sign_prepare=False,
- sign_alg=None, digest_alg=None):
+ def sign(
+ self,
+ msg,
+ mid=None,
+ to_sign=None,
+ sign_prepare=False,
+ sign_alg=None,
+ digest_alg=None,
+ ):
if msg.signature is None:
- msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1,
- sign_alg=sign_alg,
- digest_alg=digest_alg)
+ msg.signature = pre_signature_part(
+ msg.id, self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg
+ )
if sign_prepare:
return msg
@@ -437,9 +458,20 @@ class Entity(HTTPBase):
logger.info("REQUEST: %s", msg)
return signed_instance_factory(msg, self.sec, to_sign)
- def _message(self, request_cls, destination=None, message_id=0,
- consent=None, extensions=None, sign=False, sign_prepare=False,
- nsprefix=None, sign_alg=None, digest_alg=None, **kwargs):
+ def _message(
+ self,
+ request_cls,
+ destination=None,
+ message_id=0,
+ consent=None,
+ extensions=None,
+ sign=False,
+ sign_prepare=False,
+ nsprefix=None,
+ sign_alg=None,
+ digest_alg=None,
+ **kwargs,
+ ):
"""
Some parameters appear in all requests so simplify by doing
it in one place
@@ -480,13 +512,17 @@ class Entity(HTTPBase):
req = self.msg_cb(req)
reqid = req.id
-
if sign:
- return reqid, self.sign(req, sign_prepare=sign_prepare,
- sign_alg=sign_alg, digest_alg=digest_alg)
- else:
- logger.info("REQUEST: %s", req)
- return reqid, req
+ signed_req = self.sign(
+ req,
+ sign_prepare=sign_prepare,
+ sign_alg=sign_alg,
+ digest_alg=digest_alg,
+ )
+ req = signed_req
+
+ logger.info("REQUEST: %s", req)
+ return reqid, req
@staticmethod
def _filter_args(instance, extensions=None, **kwargs):