diff options
Diffstat (limited to 'src/saml2/client_base.py')
-rw-r--r-- | src/saml2/client_base.py | 136 |
1 files changed, 114 insertions, 22 deletions
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 55e5b1fc..531ddea5 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -10,6 +10,8 @@ import six from saml2.entity import Entity +import saml2.attributemaps as attributemaps + from saml2.mdstore import destinations from saml2.profile import paos, ecp from saml2.saml import NAMEID_FORMAT_TRANSIENT @@ -18,6 +20,9 @@ from saml2.samlp import NameIDMappingRequest from saml2.samlp import AttributeQuery from saml2.samlp import AuthzDecisionQuery from saml2.samlp import AuthnRequest +from saml2.samlp import Extensions +from saml2.extension import sp_type +from saml2.extension import requested_attributes import saml2 import time @@ -207,7 +212,7 @@ class Base(Entity): nameid_format=None, service_url_binding=None, message_id=0, consent=None, extensions=None, sign=None, - allow_create=False, sign_prepare=False, sign_alg=None, + allow_create=None, sign_prepare=False, sign_alg=None, digest_alg=None, **kwargs): """ Creates an authentication request. @@ -235,26 +240,30 @@ class Base(Entity): args = {} - try: - args["assertion_consumer_service_url"] = kwargs[ - "assertion_consumer_service_urls"][0] - del kwargs["assertion_consumer_service_urls"] - except KeyError: + if self.config.getattr('hide_assertion_consumer_service', 'sp'): + args["assertion_consumer_service_url"] = None + binding = None + else: try: args["assertion_consumer_service_url"] = kwargs[ - "assertion_consumer_service_url"] - del kwargs["assertion_consumer_service_url"] + "assertion_consumer_service_urls"][0] + del kwargs["assertion_consumer_service_urls"] except KeyError: try: - args["assertion_consumer_service_index"] = str(kwargs[ - "assertion_consumer_service_index"]) - del kwargs["assertion_consumer_service_index"] + args["assertion_consumer_service_url"] = kwargs[ + "assertion_consumer_service_url"] + del kwargs["assertion_consumer_service_url"] except KeyError: - if service_url_binding is None: - service_urls = self.service_urls(binding) - else: - service_urls = self.service_urls(service_url_binding) - args["assertion_consumer_service_url"] = service_urls[0] + try: + args["assertion_consumer_service_index"] = str( + kwargs["assertion_consumer_service_index"]) + del kwargs["assertion_consumer_service_index"] + except KeyError: + if service_url_binding is None: + service_urls = self.service_urls(binding) + else: + service_urls = self.service_urls(service_url_binding) + args["assertion_consumer_service_url"] = service_urls[0] try: args["provider_name"] = kwargs["provider_name"] @@ -268,7 +277,7 @@ class Base(Entity): # all of these have cardinality 0..1 _msg = AuthnRequest() for param in ["scoping", "requested_authn_context", "conditions", - "subject", "scoping"]: + "subject"]: try: _item = kwargs[param] except KeyError: @@ -288,10 +297,15 @@ class Base(Entity): args["name_id_policy"] = kwargs["name_id_policy"] del kwargs["name_id_policy"] except KeyError: - if allow_create: - allow_create = "true" - else: - allow_create = "false" + if allow_create is None: + allow_create = self.config.getattr("name_id_format_allow_create", "sp") + if allow_create is None: + allow_create = "false" + else: + if allow_create is True: + allow_create = "true" + else: + allow_create = "false" if nameid_format == "": name_id_policy = None @@ -299,12 +313,21 @@ class Base(Entity): if nameid_format is None: nameid_format = self.config.getattr("name_id_format", "sp") + # If no nameid_format has been set in the configuration + # or passed in then transient is the default. if nameid_format is None: nameid_format = NAMEID_FORMAT_TRANSIENT + + # If a list has been configured or passed in choose the + # first since NameIDPolicy can only have one format specified. elif isinstance(nameid_format, list): - # NameIDPolicy can only have one format specified nameid_format = nameid_format[0] + # Allow a deployer to signal that no format should be specified + # in the NameIDPolicy by passing in or configuring the string 'None'. + elif nameid_format == 'None': + nameid_format = None + name_id_policy = samlp.NameIDPolicy(allow_create=allow_create, format=nameid_format) @@ -321,6 +344,75 @@ class Base(Entity): except KeyError: nsprefix = None + try: + force_authn = kwargs['force_authn'] + except KeyError: + force_authn = self.config.getattr('force_authn', 'sp') + finally: + if force_authn: + args['force_authn'] = 'true' + + conf_sp_type = self.config.getattr('sp_type', 'sp') + conf_sp_type_in_md = self.config.getattr('sp_type_in_metadata', 'sp') + if conf_sp_type and conf_sp_type_in_md is False: + if not extensions: + extensions = Extensions() + item = sp_type.SPType(text=conf_sp_type) + extensions.add_extension_element(item) + + requested_attrs = self.config.getattr('requested_attributes', 'sp') + if requested_attrs: + if not extensions: + extensions = Extensions() + + attributemapsmods = [] + for modname in attributemaps.__all__: + attributemapsmods.append(getattr(attributemaps, modname)) + + items = [] + for attr in requested_attrs: + friendly_name = attr.get('friendly_name') + name = attr.get('name') + name_format = attr.get('name_format') + is_required = str(attr.get('required', False)).lower() + + if not name and not friendly_name: + raise ValueError( + "Missing required attribute: '{}' or '{}'".format( + 'name', 'friendly_name')) + + if not name: + for mod in attributemapsmods: + try: + name = mod.MAP['to'][friendly_name] + except KeyError: + continue + else: + if not name_format: + name_format = mod.MAP['identifier'] + break + + if not friendly_name: + for mod in attributemapsmods: + try: + friendly_name = mod.MAP['fro'][name] + except KeyError: + continue + else: + if not name_format: + name_format = mod.MAP['identifier'] + break + + items.append(requested_attributes.RequestedAttribute( + is_required=is_required, + name_format=name_format, + friendly_name=friendly_name, + name=name)) + + item = requested_attributes.RequestedAttributes( + extension_elements=items) + extensions.add_extension_element(item) + if kwargs: _args, extensions = self._filter_args(AuthnRequest(), extensions, **kwargs) |