diff options
-rw-r--r-- | src/saml2/client_base.py | 19 | ||||
-rw-r--r-- | tests/test_51_client.py | 36 |
2 files changed, 39 insertions, 16 deletions
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 6d8edcfa..93845ff6 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -371,13 +371,6 @@ class Base(Entity): except KeyError: nsprefix = None - force_authn = ( - kwargs.get("force_authn") - or self.config.getattr('force_authn', 'sp') - ) - if str(force_authn).lower() == 'true': - args['force_authn'] = 'true' - conf_sp_type = self.config.getattr('sp_type', 'sp') conf_sp_type_in_md = self.config.getattr('sp_type_in_metadata', 'sp') if conf_sp_type and conf_sp_type_in_md is False: @@ -439,9 +432,17 @@ class Base(Entity): extension_elements=items) extensions.add_extension_element(item) + force_authn = str( + kwargs.pop("force_authn", None) + or self.config.getattr("force_authn", "sp") + ).lower() in ["true", "1"] + if force_authn: + kwargs["force_authn"] = "true" + if kwargs: - _args, extensions = self._filter_args(AuthnRequest(), extensions, - **kwargs) + _args, extensions = self._filter_args( + AuthnRequest(), extensions, **kwargs + ) args.update(_args) args.pop("id", None) diff --git a/tests/test_51_client.py b/tests/test_51_client.py index 75dd8f75..f24fb709 100644 --- a/tests/test_51_client.py +++ b/tests/test_51_client.py @@ -286,16 +286,38 @@ class TestClient: assert c.attributes['FriendlyName'] assert c.attributes['NameFormat'] - def test_create_auth_request_unset_force_authn(self): + def test_create_auth_request_unset_force_authn_by_default(self): req_id, req = self.client.create_authn_request( - "http://www.example.com/sso", sign=False, message_id="id1") - assert bool(req.force_authn) == False + "http://www.example.com/sso", sign=False, message_id="id1" + ) + assert req.force_authn is None - def test_create_auth_request_set_force_authn(self): + def test_create_auth_request_set_force_authn_not_true_or_1(self): req_id, req = self.client.create_authn_request( - "http://www.example.com/sso", sign=False, message_id="id1", - force_authn="true") - assert bool(req.force_authn) == True + "http://www.example.com/sso", + sign=False, + message_id="id1", + force_authn="0", + ) + assert req.force_authn is None + + def test_create_auth_request_set_force_authn_true(self): + req_id, req = self.client.create_authn_request( + "http://www.example.com/sso", + sign=False, + message_id="id1", + force_authn="true", + ) + assert req.force_authn == "true" + + def test_create_auth_request_set_force_authn_1(self): + req_id, req = self.client.create_authn_request( + "http://www.example.com/sso", + sign=False, + message_id="id1", + force_authn="true", + ) + assert req.force_authn == "true" def test_create_auth_request_nameid_policy_allow_create(self): conf = config.SPConfig() |