diff options
Diffstat (limited to 'src/saml2/entity.py')
-rw-r--r-- | src/saml2/entity.py | 84 |
1 files changed, 60 insertions, 24 deletions
diff --git a/src/saml2/entity.py b/src/saml2/entity.py index 672ad6f7..af2d8ba4 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -190,8 +190,16 @@ class Entity(HTTPBase): return Issuer(text=self.config.entityid, format=NAMEID_FORMAT_ENTITY) - def apply_binding(self, binding, msg_str, destination="", relay_state="", - response=False, sign=False, **kwargs): + def apply_binding( + self, + binding, + msg_str, + destination="", + relay_state="", + response=False, + sign=None, + **kwargs, + ): """ Construct the necessary HTTP arguments dependent on Binding @@ -218,19 +226,25 @@ class Entity(HTTPBase): # info["url"] = destination # info["method"] = "POST" # else: - info = self.use_http_form_post(msg_str, destination, - relay_state, typ) + info = self.use_http_form_post(msg_str, destination, relay_state, typ) info["url"] = destination info["method"] = "POST" elif binding == BINDING_HTTP_REDIRECT: logger.info("HTTP REDIRECT") sigalg = kwargs.get("sigalg") - if sign and sigalg: - signer = self.sec.sec_backend.get_signer(sigalg) - else: - signer = None - info = self.use_http_get(msg_str, destination, relay_state, typ, - signer=signer, **kwargs) + signer = ( + self.sec.sec_backend.get_signer(sigalg) + if sign and sigalg + else None + ) + info = self.use_http_get( + msg_str, + destination, + relay_state, + typ, + signer=signer, + **kwargs, + ) info["url"] = str(destination) info["method"] = "GET" elif binding == BINDING_SOAP or binding == BINDING_PAOS: @@ -416,12 +430,19 @@ class Entity(HTTPBase): # -------------------------------------------------------------------------- - def sign(self, msg, mid=None, to_sign=None, sign_prepare=False, - sign_alg=None, digest_alg=None): + def sign( + self, + msg, + mid=None, + to_sign=None, + sign_prepare=False, + sign_alg=None, + digest_alg=None, + ): if msg.signature is None: - msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1, - sign_alg=sign_alg, - digest_alg=digest_alg) + msg.signature = pre_signature_part( + msg.id, self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg + ) if sign_prepare: return msg @@ -437,9 +458,20 @@ class Entity(HTTPBase): logger.info("REQUEST: %s", msg) return signed_instance_factory(msg, self.sec, to_sign) - def _message(self, request_cls, destination=None, message_id=0, - consent=None, extensions=None, sign=False, sign_prepare=False, - nsprefix=None, sign_alg=None, digest_alg=None, **kwargs): + def _message( + self, + request_cls, + destination=None, + message_id=0, + consent=None, + extensions=None, + sign=False, + sign_prepare=False, + nsprefix=None, + sign_alg=None, + digest_alg=None, + **kwargs, + ): """ Some parameters appear in all requests so simplify by doing it in one place @@ -480,13 +512,17 @@ class Entity(HTTPBase): req = self.msg_cb(req) reqid = req.id - if sign: - return reqid, self.sign(req, sign_prepare=sign_prepare, - sign_alg=sign_alg, digest_alg=digest_alg) - else: - logger.info("REQUEST: %s", req) - return reqid, req + signed_req = self.sign( + req, + sign_prepare=sign_prepare, + sign_alg=sign_alg, + digest_alg=digest_alg, + ) + req = signed_req + + logger.info("REQUEST: %s", req) + return reqid, req @staticmethod def _filter_args(instance, extensions=None, **kwargs): |