summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Kanakarakis <ivan.kanak@gmail.com>2020-06-05 23:41:14 +0300
committerGitHub <noreply@github.com>2020-06-05 23:41:14 +0300
commitcc1c8561b0211682397a14e48163219c37aa695b (patch)
treed7a4d00c659c087d1d89437249b50152c84ef8db
parent3e5373d52bae3251fe983ef573217fb63bc23c26 (diff)
parent17b03f3c0468db58e116a0f0b669b50ff4559850 (diff)
downloadpysaml2-cc1c8561b0211682397a14e48163219c37aa695b.tar.gz
Merge pull request #691 from IdentityPython/feat-requested-attributes-per-request
Set eIDAS RequestedAttributes per AuthnRequest
-rw-r--r--src/saml2/attribute_converter.py2
-rw-r--r--src/saml2/client_base.py111
-rw-r--r--tests/server_conf.py6
-rw-r--r--tests/test_51_client.py29
4 files changed, 92 insertions, 56 deletions
diff --git a/src/saml2/attribute_converter.py b/src/saml2/attribute_converter.py
index bc937702..adfe659d 100644
--- a/src/saml2/attribute_converter.py
+++ b/src/saml2/attribute_converter.py
@@ -62,7 +62,7 @@ def ac_factory(path=""):
if path not in sys.path:
sys.path.insert(0, path)
- for fil in os.listdir(path):
+ for fil in sorted(os.listdir(path)):
if fil.endswith(".py"):
mod = import_module(fil[:-3])
for key, item in mod.__dict__.items():
diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py
index ea0c86f9..b396fa16 100644
--- a/src/saml2/client_base.py
+++ b/src/saml2/client_base.py
@@ -12,8 +12,6 @@ import logging
from saml2.entity import Entity
-import saml2.attributemaps as attributemaps
-
from saml2.mdstore import destinations
from saml2.profile import paos, ecp
from saml2.saml import NAMEID_FORMAT_TRANSIENT
@@ -24,7 +22,8 @@ from saml2.samlp import AuthzDecisionQuery
from saml2.samlp import AuthnRequest
from saml2.samlp import Extensions
from saml2.extension import sp_type
-from saml2.extension import requested_attributes
+from saml2.extension.requested_attributes import RequestedAttribute
+from saml2.extension.requested_attributes import RequestedAttributes
import saml2
from saml2.soap import make_soap_enveloped_saml_thingy
@@ -235,7 +234,7 @@ class Base(Entity):
service_url_binding=None, message_id=0,
consent=None, extensions=None, sign=None,
allow_create=None, sign_prepare=False, sign_alg=None,
- digest_alg=None, **kwargs):
+ digest_alg=None, requested_attributes=None, **kwargs):
""" Creates an authentication request.
:param destination: Where the request should be sent.
@@ -253,6 +252,11 @@ class Base(Entity):
:param allow_create: If the identity provider is allowed, in the course
of fulfilling the request, to create a new identifier to represent
the principal.
+ :param requested_attributes: A list of dicts which define attributes to
+ be used as eIDAS Requested Attributes for this request. If not
+ defined the configuration option requested_attributes will be used,
+ if defined. The format is the same as the requested_attributes
+ configuration option.
:param kwargs: Extra key word arguments
:return: either a tuple of request ID and <samlp:AuthnRequest> instance
or a tuple of request ID and str when sign is set to True
@@ -379,58 +383,62 @@ class Base(Entity):
item = sp_type.SPType(text=conf_sp_type)
extensions.add_extension_element(item)
- requested_attrs = self.config.getattr('requested_attributes', 'sp')
- if requested_attrs:
- if not extensions:
- extensions = Extensions()
+ requested_attrs = (
+ requested_attributes
+ or self.config.getattr('requested_attributes', 'sp')
+ or []
+ )
+
+ if not extensions:
+ extensions = Extensions()
+
+ items = []
+ for attr in requested_attrs:
+ friendly_name = attr.get('friendly_name')
+ name = attr.get('name')
+ name_format = attr.get('name_format')
+ is_required = str(attr.get('required', False)).lower()
+
+ if not name and not friendly_name:
+ raise ValueError(
+ "Missing required attribute: '{}' or '{}'".format(
+ 'name', 'friendly_name'
+ )
+ )
+
+ if not name:
+ for converter in self.config.attribute_converters:
+ try:
+ name = converter._to[friendly_name.lower()]
+ except KeyError:
+ continue
+ else:
+ if not name_format:
+ name_format = converter.name_format
+ break
- attributemapsmods = []
- for modname in attributemaps.__all__:
- attributemapsmods.append(getattr(attributemaps, modname))
-
- items = []
- for attr in requested_attrs:
- friendly_name = attr.get('friendly_name')
- name = attr.get('name')
- name_format = attr.get('name_format')
- is_required = str(attr.get('required', False)).lower()
-
- if not name and not friendly_name:
- raise ValueError(
- "Missing required attribute: '{}' or '{}'".format(
- 'name', 'friendly_name'))
-
- if not name:
- for mod in attributemapsmods:
- try:
- name = mod.MAP['to'][friendly_name]
- except KeyError:
- continue
- else:
- if not name_format:
- name_format = mod.MAP['identifier']
- break
-
- if not friendly_name:
- for mod in attributemapsmods:
- try:
- friendly_name = mod.MAP['fro'][name]
- except KeyError:
- continue
- else:
- if not name_format:
- name_format = mod.MAP['identifier']
- break
+ if not friendly_name:
+ for converter in self.config.attribute_converters:
+ try:
+ friendly_name = converter._fro[name.lower()]
+ except KeyError:
+ continue
+ else:
+ if not name_format:
+ name_format = converter.name_format
+ break
- items.append(requested_attributes.RequestedAttribute(
+ items.append(
+ RequestedAttribute(
is_required=is_required,
name_format=name_format,
friendly_name=friendly_name,
- name=name))
+ name=name,
+ )
+ )
- item = requested_attributes.RequestedAttributes(
- extension_elements=items)
- extensions.add_extension_element(item)
+ item = RequestedAttributes(extension_elements=items)
+ extensions.add_extension_element(item)
force_authn = str(
kwargs.pop("force_authn", None)
@@ -450,8 +458,7 @@ class Base(Entity):
if sign is None:
sign = self.authn_requests_signed
- if (sign and self.sec.cert_handler.generate_cert()) or \
- client_crt is not None:
+ if (sign and self.sec.cert_handler.generate_cert()) or client_crt is not None:
with self.lock:
self.sec.cert_handler.update_cert(True, client_crt)
if client_crt is not None:
diff --git a/tests/server_conf.py b/tests/server_conf.py
index 4b528119..2b87b942 100644
--- a/tests/server_conf.py
+++ b/tests/server_conf.py
@@ -16,15 +16,15 @@ CONFIG = {
"idp": ["urn:mace:example.com:saml:roland:idp"],
"requested_attributes": [
{
- "name": "http://eidas.europa.eu/attributes/naturalperson/DateOfBirth",
+ "name": "urn:oid:1.3.6.1.4.1.5923.1.1.1.2",
"required": False,
},
{
- "friendly_name": "PersonIdentifier",
+ "friendly_name": "eduPersonNickname",
"required": True,
},
{
- "friendly_name": "PlaceOfBirth",
+ "friendly_name": "eduPersonScopedAffiliation",
},
],
}
diff --git a/tests/test_51_client.py b/tests/test_51_client.py
index f6fc2759..302b459d 100644
--- a/tests/test_51_client.py
+++ b/tests/test_51_client.py
@@ -286,6 +286,35 @@ class TestClient:
assert c.attributes['FriendlyName']
assert c.attributes['NameFormat']
+ def test_create_auth_request_requested_attributes(self):
+ req_attr = [{"friendly_name": "eduPersonOrgUnitDN", "required": True}]
+ ar_id, ar = self.client.create_authn_request(
+ "http://www.example.com/sso",
+ message_id="id1",
+ requested_attributes=req_attr
+ )
+
+ req_attrs_nodes = (
+ e
+ for e in ar.extensions.extension_elements
+ if e.tag == RequestedAttributes.c_tag
+ )
+ req_attrs_node = next(req_attrs_nodes, None)
+ assert req_attrs_node is not None
+
+ attrs = (
+ child
+ for child in req_attrs_node.children
+ if child.friendly_name == "eduPersonOrgUnitDN"
+ )
+ attr = next(attrs, None)
+ assert attr is not None
+ assert attr.c_tag == RequestedAttribute.c_tag
+ assert attr.is_required == 'true'
+ assert attr.name == 'urn:mace:dir:attribute-def:eduPersonOrgUnitDN'
+ assert attr.friendly_name == 'eduPersonOrgUnitDN'
+ assert attr.name_format == 'urn:oasis:names:tc:SAML:2.0:attrname-format:basic'
+
def test_create_auth_request_unset_force_authn_by_default(self):
req_id, req = self.client.create_authn_request(
"http://www.example.com/sso", sign=False, message_id="id1"