summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRebecka Gulliksson <rebecka.gulliksson@umu.se>2015-12-28 19:43:37 +0100
committerRebecka Gulliksson <rebecka.gulliksson@umu.se>2015-12-28 19:57:02 +0100
commit7044da38abda80d7e493a8ca7552ca8c0443a45f (patch)
treed41d6c419e56b0f8f543254a75a682090a0726fc
parent7a7b02d792a5ad17aa88336165069263fd387d89 (diff)
downloadpysaml2-7044da38abda80d7e493a8ca7552ca8c0443a45f.tar.gz
Filter optional attributes in the exact same way as required attributes.
-rw-r--r--src/saml2/assertion.py53
-rw-r--r--tests/test_20_assertion.py24
2 files changed, 49 insertions, 28 deletions
diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py
index 20480319..c6e24d2f 100644
--- a/src/saml2/assertion.py
+++ b/src/saml2/assertion.py
@@ -78,12 +78,30 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
:return: The modified attribute value assertion
"""
- def _attr_name(attr):
- """Get the friendly name of an attribute name"""
+ def _match_attr_name(attr, ava):
try:
- return attr["friendly_name"]
+ friendly_name = attr["friendly_name"]
except KeyError:
- return get_local_name(acs, attr["name"], attr["name_format"])
+ friendly_name = get_local_name(acs, attr["name"], attr["name_format"])
+
+ _fn = _match(friendly_name, ava)
+ if not _fn: # In the unlikely case that someone has provided us with URIs as attribute names
+ _fn = _match(attr["name"], ava)
+
+ return _fn
+
+ def _apply_attr_value_restrictions(attr, res, must=False):
+ try:
+ values = [av["text"] for av in attr["attribute_value"]]
+ except KeyError:
+ values = []
+
+ try:
+ res[_fn].extend(_filter_values(ava[_fn], values))
+ except KeyError:
+ res[_fn] = _filter_values(ava[_fn], values)
+
+ return _filter_values(ava[_fn], values, must)
res = {}
@@ -91,39 +109,22 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
required = []
for attr in required:
- _name = _attr_name(attr)
- _fn = _match(_name, ava)
- if not _fn: # In the unlikely case that someone has provided us
- # with URIs as attribute names
- _fn = _match(attr["name"], ava)
+ _fn = _match_attr_name(attr, ava)
if _fn:
- try:
- values = [av["text"] for av in attr["attribute_value"]]
- except KeyError:
- values = []
- res[_fn] = _filter_values(ava[_fn], values, True)
- continue
+ _apply_attr_value_restrictions(attr, res, True)
elif fail_on_unfulfilled_requirements:
desc = "Required attribute missing: '%s' (%s)" % (attr["name"],
- _name)
+ _fn)
raise MissingValue(desc)
if optional is None:
optional = []
for attr in optional:
- _name = _attr_name(attr)
- _fn = _match(_name, ava)
+ _fn = _match_attr_name(attr, ava)
if _fn:
- try:
- values = [av["text"] for av in attr["attribute_value"]]
- except KeyError:
- values = []
- try:
- res[_fn].extend(_filter_values(ava[_fn], values))
- except KeyError:
- res[_fn] = _filter_values(ava[_fn], values)
+ _apply_attr_value_restrictions(attr, res, False)
return res
diff --git a/tests/test_20_assertion.py b/tests/test_20_assertion.py
index 4b3c0ea9..8043c911 100644
--- a/tests/test_20_assertion.py
+++ b/tests/test_20_assertion.py
@@ -1,4 +1,6 @@
# coding=utf-8
+import pytest
+
from saml2.authn_context import pword
from saml2.mdie import to_dict
from saml2 import md, assertion
@@ -81,9 +83,10 @@ def test_filter_on_attributes_1():
def test_filter_on_attributes_without_friendly_name():
ava = {"eduPersonTargetedID": "test@example.com", "eduPersonAffiliation": "test",
"extra": "foo"}
- eptid = to_dict(Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10", name_format=NAME_FORMAT_URI), ONTS)
+ eptid = to_dict(
+ Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10", name_format=NAME_FORMAT_URI), ONTS)
ep_affiliation = to_dict(
- Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", name_format=NAME_FORMAT_URI), ONTS)
+ Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", name_format=NAME_FORMAT_URI), ONTS)
restricted_ava = filter_on_attributes(ava, required=[eptid], optional=[ep_affiliation],
acs=ac_factory())
@@ -91,6 +94,23 @@ def test_filter_on_attributes_without_friendly_name():
"eduPersonAffiliation": "test"}
+def test_filter_on_attributes_with_missing_required_attribute():
+ ava = {"extra": "foo"}
+ eptid = to_dict(Attribute(
+ friendly_name="eduPersonTargetedID", name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10",
+ name_format=NAME_FORMAT_URI), ONTS)
+ with pytest.raises(MissingValue):
+ filter_on_attributes(ava, required=[eptid])
+
+
+def test_filter_on_attributes_with_missing_optional_attribute():
+ ava = {"extra": "foo"}
+ eptid = to_dict(Attribute(
+ friendly_name="eduPersonTargetedID", name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10",
+ name_format=NAME_FORMAT_URI), ONTS)
+ assert filter_on_attributes(ava, optional=[eptid]) == {}
+
+
# ----------------------------------------------------------------------
def test_lifetime_1():