summaryrefslogtreecommitdiff
path: root/src/saml2/client_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2/client_base.py')
-rw-r--r--src/saml2/client_base.py28
1 files changed, 24 insertions, 4 deletions
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py
index 90a4fbd1..4546ef07 100644
--- a/src/saml2/client_base.py
+++ b/src/saml2/client_base.py
@@ -17,7 +17,9 @@ from saml2.mdstore import locations
from saml2.profile import paos, ecp
from saml2.saml import NAMEID_FORMAT_PERSISTENT
from saml2.saml import NAMEID_FORMAT_TRANSIENT
-from saml2.samlp import AuthnQuery, RequestedAuthnContext
+from saml2.saml import AuthnContextClassRef
+from saml2.samlp import AuthnQuery
+from saml2.samlp import RequestedAuthnContext
from saml2.samlp import NameIDMappingRequest
from saml2.samlp import AttributeQuery
from saml2.samlp import AuthzDecisionQuery
@@ -358,18 +360,36 @@ class Base(Entity):
provider_name = self._my_name()
args["provider_name"] = provider_name
+ requested_authn_context = (
+ kwargs.pop("requested_authn_context", None)
+ or self.config.getattr("requested_authn_context", "sp")
+ or {}
+ )
+ requested_authn_context_accrs = requested_authn_context.get(
+ "authn_context_class_ref", []
+ )
+ requested_authn_context_comparison = requested_authn_context.get(
+ "comparison", "exact"
+ )
+ if requested_authn_context_accrs:
+ args["requested_authn_context"] = RequestedAuthnContext(
+ authn_context_class_ref=[
+ AuthnContextClassRef(accr)
+ for accr in requested_authn_context_accrs
+ ],
+ comparison=requested_authn_context_comparison,
+ )
+
# Allow argument values either as class instances or as dictionaries
# all of these have cardinality 0..1
_msg = AuthnRequest()
- for param in ["scoping", "requested_authn_context", "conditions", "subject"]:
+ for param in ["scoping", "conditions", "subject"]:
_item = kwargs.pop(param, None)
if not _item:
continue
if isinstance(_item, _msg.child_class(param)):
args[param] = _item
- elif isinstance(_item, dict):
- args[param] = RequestedAuthnContext(**_item)
else:
raise ValueError("Wrong type for param {name}".format(name=param))