summaryrefslogtreecommitdiff
path: root/src/saml2/client_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2/client_base.py')
-rw-r--r--src/saml2/client_base.py136
1 files changed, 114 insertions, 22 deletions
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py
index 55e5b1fc..531ddea5 100644
--- a/src/saml2/client_base.py
+++ b/src/saml2/client_base.py
@@ -10,6 +10,8 @@ import six
from saml2.entity import Entity
+import saml2.attributemaps as attributemaps
+
from saml2.mdstore import destinations
from saml2.profile import paos, ecp
from saml2.saml import NAMEID_FORMAT_TRANSIENT
@@ -18,6 +20,9 @@ from saml2.samlp import NameIDMappingRequest
from saml2.samlp import AttributeQuery
from saml2.samlp import AuthzDecisionQuery
from saml2.samlp import AuthnRequest
+from saml2.samlp import Extensions
+from saml2.extension import sp_type
+from saml2.extension import requested_attributes
import saml2
import time
@@ -207,7 +212,7 @@ class Base(Entity):
nameid_format=None,
service_url_binding=None, message_id=0,
consent=None, extensions=None, sign=None,
- allow_create=False, sign_prepare=False, sign_alg=None,
+ allow_create=None, sign_prepare=False, sign_alg=None,
digest_alg=None, **kwargs):
""" Creates an authentication request.
@@ -235,26 +240,30 @@ class Base(Entity):
args = {}
- try:
- args["assertion_consumer_service_url"] = kwargs[
- "assertion_consumer_service_urls"][0]
- del kwargs["assertion_consumer_service_urls"]
- except KeyError:
+ if self.config.getattr('hide_assertion_consumer_service', 'sp'):
+ args["assertion_consumer_service_url"] = None
+ binding = None
+ else:
try:
args["assertion_consumer_service_url"] = kwargs[
- "assertion_consumer_service_url"]
- del kwargs["assertion_consumer_service_url"]
+ "assertion_consumer_service_urls"][0]
+ del kwargs["assertion_consumer_service_urls"]
except KeyError:
try:
- args["assertion_consumer_service_index"] = str(kwargs[
- "assertion_consumer_service_index"])
- del kwargs["assertion_consumer_service_index"]
+ args["assertion_consumer_service_url"] = kwargs[
+ "assertion_consumer_service_url"]
+ del kwargs["assertion_consumer_service_url"]
except KeyError:
- if service_url_binding is None:
- service_urls = self.service_urls(binding)
- else:
- service_urls = self.service_urls(service_url_binding)
- args["assertion_consumer_service_url"] = service_urls[0]
+ try:
+ args["assertion_consumer_service_index"] = str(
+ kwargs["assertion_consumer_service_index"])
+ del kwargs["assertion_consumer_service_index"]
+ except KeyError:
+ if service_url_binding is None:
+ service_urls = self.service_urls(binding)
+ else:
+ service_urls = self.service_urls(service_url_binding)
+ args["assertion_consumer_service_url"] = service_urls[0]
try:
args["provider_name"] = kwargs["provider_name"]
@@ -268,7 +277,7 @@ class Base(Entity):
# all of these have cardinality 0..1
_msg = AuthnRequest()
for param in ["scoping", "requested_authn_context", "conditions",
- "subject", "scoping"]:
+ "subject"]:
try:
_item = kwargs[param]
except KeyError:
@@ -288,10 +297,15 @@ class Base(Entity):
args["name_id_policy"] = kwargs["name_id_policy"]
del kwargs["name_id_policy"]
except KeyError:
- if allow_create:
- allow_create = "true"
- else:
- allow_create = "false"
+ if allow_create is None:
+ allow_create = self.config.getattr("name_id_format_allow_create", "sp")
+ if allow_create is None:
+ allow_create = "false"
+ else:
+ if allow_create is True:
+ allow_create = "true"
+ else:
+ allow_create = "false"
if nameid_format == "":
name_id_policy = None
@@ -299,12 +313,21 @@ class Base(Entity):
if nameid_format is None:
nameid_format = self.config.getattr("name_id_format", "sp")
+ # If no nameid_format has been set in the configuration
+ # or passed in then transient is the default.
if nameid_format is None:
nameid_format = NAMEID_FORMAT_TRANSIENT
+
+ # If a list has been configured or passed in choose the
+ # first since NameIDPolicy can only have one format specified.
elif isinstance(nameid_format, list):
- # NameIDPolicy can only have one format specified
nameid_format = nameid_format[0]
+ # Allow a deployer to signal that no format should be specified
+ # in the NameIDPolicy by passing in or configuring the string 'None'.
+ elif nameid_format == 'None':
+ nameid_format = None
+
name_id_policy = samlp.NameIDPolicy(allow_create=allow_create,
format=nameid_format)
@@ -321,6 +344,75 @@ class Base(Entity):
except KeyError:
nsprefix = None
+ try:
+ force_authn = kwargs['force_authn']
+ except KeyError:
+ force_authn = self.config.getattr('force_authn', 'sp')
+ finally:
+ if force_authn:
+ args['force_authn'] = 'true'
+
+ conf_sp_type = self.config.getattr('sp_type', 'sp')
+ conf_sp_type_in_md = self.config.getattr('sp_type_in_metadata', 'sp')
+ if conf_sp_type and conf_sp_type_in_md is False:
+ if not extensions:
+ extensions = Extensions()
+ item = sp_type.SPType(text=conf_sp_type)
+ extensions.add_extension_element(item)
+
+ requested_attrs = self.config.getattr('requested_attributes', 'sp')
+ if requested_attrs:
+ if not extensions:
+ extensions = Extensions()
+
+ attributemapsmods = []
+ for modname in attributemaps.__all__:
+ attributemapsmods.append(getattr(attributemaps, modname))
+
+ items = []
+ for attr in requested_attrs:
+ friendly_name = attr.get('friendly_name')
+ name = attr.get('name')
+ name_format = attr.get('name_format')
+ is_required = str(attr.get('required', False)).lower()
+
+ if not name and not friendly_name:
+ raise ValueError(
+ "Missing required attribute: '{}' or '{}'".format(
+ 'name', 'friendly_name'))
+
+ if not name:
+ for mod in attributemapsmods:
+ try:
+ name = mod.MAP['to'][friendly_name]
+ except KeyError:
+ continue
+ else:
+ if not name_format:
+ name_format = mod.MAP['identifier']
+ break
+
+ if not friendly_name:
+ for mod in attributemapsmods:
+ try:
+ friendly_name = mod.MAP['fro'][name]
+ except KeyError:
+ continue
+ else:
+ if not name_format:
+ name_format = mod.MAP['identifier']
+ break
+
+ items.append(requested_attributes.RequestedAttribute(
+ is_required=is_required,
+ name_format=name_format,
+ friendly_name=friendly_name,
+ name=name))
+
+ item = requested_attributes.RequestedAttributes(
+ extension_elements=items)
+ extensions.add_extension_element(item)
+
if kwargs:
_args, extensions = self._filter_args(AuthnRequest(), extensions,
**kwargs)