summaryrefslogtreecommitdiff
path: root/src/saml2/assertion.py
diff options
context:
space:
mode:
authorRoland Hedberg <roland@catalogix.se>2017-10-11 15:37:14 +0200
committerRoland Hedberg <roland@catalogix.se>2017-10-11 15:37:14 +0200
commitf6a47025f034fc030a6cbb8b7ec901399b58118d (patch)
tree60b96e7a2b9326b16caa494f63e3772a99e0d645 /src/saml2/assertion.py
parent3360ee2734ea2fd32d131a2d908e38aa2919f076 (diff)
downloadpysaml2-f6a47025f034fc030a6cbb8b7ec901399b58118d.tar.gz
Ordered way to find a local name of an attribute.
Diffstat (limited to 'src/saml2/assertion.py')
-rw-r--r--src/saml2/assertion.py38
1 files changed, 25 insertions, 13 deletions
diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py
index 0db4b723..8984db59 100644
--- a/src/saml2/assertion.py
+++ b/src/saml2/assertion.py
@@ -8,12 +8,15 @@ import six
from saml2 import saml
from saml2 import xmlenc
-from saml2.attribute_converter import from_local, get_local_name
+from saml2.attribute_converter import from_local, ac_factory
+from saml2.attribute_converter import get_local_name
from saml2.s_utils import assertion_factory
from saml2.s_utils import factory
-from saml2.s_utils import sid, MissingValue
+from saml2.s_utils import sid
+from saml2.s_utils import MissingValue
from saml2.saml import NAME_FORMAT_URI
-from saml2.time_util import instant, in_a_while
+from saml2.time_util import instant
+from saml2.time_util import in_a_while
logger = logging.getLogger(__name__)
@@ -78,15 +81,22 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
"""
def _match_attr_name(attr, ava):
-
- local_name = get_local_name(acs, attr["name"], attr["name_format"])
- if not local_name:
- try:
- local_name = attr["friendly_name"]
- except KeyError:
- pass
+ local_name = None
+
+ for a in ['name_format', 'friendly_name']:
+ _val = attr.get(a)
+ if _val:
+ if a == 'name_format':
+ local_name = get_local_name(acs, attr['name'], _val)
+ else:
+ local_name = _val
+ break
+
+ if local_name:
+ _fn = _match(local_name, ava)
+ else:
+ _fn = None
- _fn = _match(local_name, ava)
if not _fn: # In the unlikely case that someone has provided us with
# URIs as attribute names
_fn = _match(attr["name"], ava)
@@ -117,8 +127,7 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
if _fn:
_apply_attr_value_restrictions(attr, res, True)
elif fail_on_unfulfilled_requirements:
- desc = "Required attribute missing: '%s' (%s)" % (attr["name"],
- _fn)
+ desc = "Required attribute missing: '%s'" % (attr["name"])
raise MissingValue(desc)
if optional is None:
@@ -502,6 +511,9 @@ class Policy(object):
_ava = None
+ if not self.acs: # acs MUST have a value, fall back to default.
+ self.acs = ac_factory()
+
_rest = self.get_entity_categories(sp_entity_id, mdstore, required)
if _rest:
_ava = filter_attribute_value_assertions(ava.copy(), _rest)