diff options
author | Rebecka Gulliksson <rebecka.gulliksson@umu.se> | 2015-12-28 19:25:08 +0100 |
---|---|---|
committer | Rebecka Gulliksson <rebecka.gulliksson@umu.se> | 2015-12-28 19:25:08 +0100 |
commit | 7a7b02d792a5ad17aa88336165069263fd387d89 (patch) | |
tree | 35001524083386fd5d4be75db81ffcdaf06c2b7d | |
parent | 344ba4a66badec2513701b4da6c8309c7d18cfe7 (diff) | |
download | pysaml2-7a7b02d792a5ad17aa88336165069263fd387d89.tar.gz |
Match the attribute name of optional attributes in the same way as for required attributes.
-rw-r--r-- | src/saml2/assertion.py | 139 | ||||
-rw-r--r-- | tests/test_20_assertion.py | 222 |
2 files changed, 187 insertions, 174 deletions
diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index 920083c1..20480319 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -16,7 +16,6 @@ from saml2.s_utils import sid, MissingValue from saml2.s_utils import factory from saml2.s_utils import assertion_factory - logger = logging.getLogger(__name__) @@ -78,25 +77,24 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None, are missing fail or fail not depending on this parameter. :return: The modified attribute value assertion """ + + def _attr_name(attr): + """Get the friendly name of an attribute name""" + try: + return attr["friendly_name"] + except KeyError: + return get_local_name(acs, attr["name"], attr["name_format"]) + res = {} if required is None: required = [] - nform = "friendly_name" for attr in required: - try: - _name = attr[nform] - except KeyError: - if nform == "friendly_name": - _name = get_local_name(acs, attr["name"], - attr["name_format"]) - else: - continue - + _name = _attr_name(attr) _fn = _match(_name, ava) if not _fn: # In the unlikely case that someone has provided us - # with URIs as attribute names + # with URIs as attribute names _fn = _match(attr["name"], ava) if _fn: @@ -115,18 +113,17 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None, optional = [] for attr in optional: - for nform in ["friendly_name", "name"]: - if nform in attr: - _fn = _match(attr[nform], ava) - if _fn: - try: - values = [av["text"] for av in attr["attribute_value"]] - except KeyError: - values = [] - try: - res[_fn].extend(_filter_values(ava[_fn], values)) - except KeyError: - res[_fn] = _filter_values(ava[_fn], values) + _name = _attr_name(attr) + _fn = _match(_name, ava) + if _fn: + try: + values = [av["text"] for av in attr["attribute_value"]] + except KeyError: + values = [] + try: + res[_fn].extend(_filter_values(ava[_fn], values)) + except KeyError: + res[_fn] = _filter_values(ava[_fn], values) return res @@ -154,8 +151,8 @@ def filter_on_demands(ava, required=None, optional=None): for val in vals: if val not in ava[lava[attr]]: raise MissingValue( - "Required attribute value missing: %s,%s" % (attr, - val)) + "Required attribute value missing: %s,%s" % (attr, + val)) else: raise MissingValue("Required attribute missing: %s" % (attr,)) @@ -334,7 +331,7 @@ class Policy(object): ecs = [] for cat in items: _mod = importlib.import_module( - "saml2.entity_category.%s" % cat) + "saml2.entity_category.%s" % cat) _ec = {} for key, items in _mod.RELEASE.items(): _ec[key] = [k.lower() for k in items] @@ -488,8 +485,8 @@ class Policy(object): if required or optional: logger.debug("required: %s, optional: %s", required, optional) _ava = filter_on_attributes( - ava.copy(), required, optional, self.acs, - self.get_fail_on_missing_requested(sp_entity_id)) + ava.copy(), required, optional, self.acs, + self.get_fail_on_missing_requested(sp_entity_id)) _rest = self.get_entity_categories(sp_entity_id, mdstore) if _rest: @@ -539,9 +536,9 @@ class Policy(object): # How long might depend on who's getting it not_on_or_after=self.not_on_or_after(sp_entity_id), audience_restriction=[factory( - saml.AudienceRestriction, - audience=[factory(saml.Audience, - text=sp_entity_id)])]) + saml.AudienceRestriction, + audience=[factory(saml.Audience, + text=sp_entity_id)])]) def get_sign(self, sp_entity_id): """ @@ -571,7 +568,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None): return factory(saml.AuthnContext, authn_context_class_ref=cntx_class, authenticating_authority=factory( - saml.AuthenticatingAuthority, text=authn_auth)) + saml.AuthenticatingAuthority, text=authn_auth)) else: return factory(saml.AuthnContext, authn_context_class_ref=cntx_class) @@ -587,7 +584,7 @@ def _authn_context_decl(decl, authn_auth=None): return factory(saml.AuthnContext, authn_context_decl=decl, authenticating_authority=factory( - saml.AuthenticatingAuthority, text=authn_auth)) + saml.AuthenticatingAuthority, text=authn_auth)) def _authn_context_decl_ref(decl_ref, authn_auth=None): @@ -600,7 +597,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None): return factory(saml.AuthnContext, authn_context_decl_ref=decl_ref, authenticating_authority=factory( - saml.AuthenticatingAuthority, text=authn_auth)) + saml.AuthenticatingAuthority, text=authn_auth)) def authn_statement(authn_class=None, authn_auth=None, @@ -626,29 +623,29 @@ def authn_statement(authn_class=None, authn_auth=None, if authn_class: res = factory( - saml.AuthnStatement, - authn_instant=_instant, - session_index=sid(), - authn_context=_authn_context_class_ref( - authn_class, authn_auth)) + saml.AuthnStatement, + authn_instant=_instant, + session_index=sid(), + authn_context=_authn_context_class_ref( + authn_class, authn_auth)) elif authn_decl: res = factory( - saml.AuthnStatement, - authn_instant=_instant, - session_index=sid(), - authn_context=_authn_context_decl(authn_decl, authn_auth)) + saml.AuthnStatement, + authn_instant=_instant, + session_index=sid(), + authn_context=_authn_context_decl(authn_decl, authn_auth)) elif authn_decl_ref: res = factory( - saml.AuthnStatement, - authn_instant=_instant, - session_index=sid(), - authn_context=_authn_context_decl_ref(authn_decl_ref, - authn_auth)) + saml.AuthnStatement, + authn_instant=_instant, + session_index=sid(), + authn_context=_authn_context_decl_ref(authn_decl_ref, + authn_auth)) else: res = factory( - saml.AuthnStatement, - authn_instant=_instant, - session_index=sid()) + saml.AuthnStatement, + authn_instant=_instant, + session_index=sid()) if subject_locality: res.subject_locality = saml.SubjectLocality(text=subject_locality) @@ -698,7 +695,7 @@ class Assertion(dict): _name_format = NAME_FORMAT_URI attr_statement = saml.AttributeStatement(attribute=from_local( - attrconvs, self, _name_format)) + attrconvs, self, _name_format)) if encrypt == "attributes": for attr in attr_statement.attribute: @@ -725,33 +722,33 @@ class Assertion(dict): if not add_subject: _ass = assertion_factory( - issuer=issuer, - conditions=conds, - subject=None + issuer=issuer, + conditions=conds, + subject=None ) else: _ass = assertion_factory( - issuer=issuer, - conditions=conds, - subject=factory( - saml.Subject, - name_id=name_id, - subject_confirmation=[factory( - saml.SubjectConfirmation, - method=saml.SCM_BEARER, - subject_confirmation_data=factory( - saml.SubjectConfirmationData, - in_response_to=in_response_to, - recipient=consumer_url, - not_on_or_after=policy.not_on_or_after(sp_entity_id)))] - ), + issuer=issuer, + conditions=conds, + subject=factory( + saml.Subject, + name_id=name_id, + subject_confirmation=[factory( + saml.SubjectConfirmation, + method=saml.SCM_BEARER, + subject_confirmation_data=factory( + saml.SubjectConfirmationData, + in_response_to=in_response_to, + recipient=consumer_url, + not_on_or_after=policy.not_on_or_after(sp_entity_id)))] + ), ) if _authn_statement: _ass.authn_statement = [_authn_statement] if not attr_statement.empty(): - _ass.attribute_statement=[attr_statement] + _ass.attribute_statement = [attr_statement] return _ass diff --git a/tests/test_20_assertion.py b/tests/test_20_assertion.py index 67e22a54..4b3c0ea9 100644 --- a/tests/test_20_assertion.py +++ b/tests/test_20_assertion.py @@ -50,6 +50,7 @@ mail = to_dict(md.RequestedAttribute(name="urn:oid:0.9.2342.19200300.100.1.3", friendly_name="mail", name_format=NAME_FORMAT_URI), ONTS) + # --------------------------------------------------------------------------- @@ -76,6 +77,20 @@ def test_filter_on_attributes_1(): assert list(ava.keys()) == ["serialNumber"] assert ava["serialNumber"] == ["12345"] + +def test_filter_on_attributes_without_friendly_name(): + ava = {"eduPersonTargetedID": "test@example.com", "eduPersonAffiliation": "test", + "extra": "foo"} + eptid = to_dict(Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10", name_format=NAME_FORMAT_URI), ONTS) + ep_affiliation = to_dict( + Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", name_format=NAME_FORMAT_URI), ONTS) + + restricted_ava = filter_on_attributes(ava, required=[eptid], optional=[ep_affiliation], + acs=ac_factory()) + assert restricted_ava == {"eduPersonTargetedID": "test@example.com", + "eduPersonAffiliation": "test"} + + # ---------------------------------------------------------------------- def test_lifetime_1(): @@ -173,8 +188,8 @@ def test_ava_filter_2(): "mail": "derek@example.com"} # mail removed because it doesn't match the regular expression - _ava =policy.filter(ava, 'urn:mace:umu.se:saml:roland:sp', None, [mail], - [gn, sn]) + _ava = policy.filter(ava, 'urn:mace:umu.se:saml:roland:sp', None, [mail], + [gn, sn]) assert _eq(sorted(list(_ava.keys())), ["givenName", "surName"]) @@ -214,7 +229,7 @@ def test_ava_filter_dont_fail(): # mail removed because it doesn't match the regular expression # So it should fail if the 'fail_on_ ...' flag wasn't set - _ava = policy.filter(ava,'urn:mace:umu.se:saml:roland:sp', None, + _ava = policy.filter(ava, 'urn:mace:umu.se:saml:roland:sp', None, [mail], [gn, sn]) assert _ava @@ -228,6 +243,7 @@ def test_ava_filter_dont_fail(): assert _ava + def test_filter_attribute_value_assertions_0(AVA): p = Policy({ "default": { @@ -382,9 +398,9 @@ def test_filter_values_req_2(): def test_filter_values_req_3(): a = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="12345")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="12345")]), ONTS) required = [a] ava = {"serialNumber": ["12345"]} @@ -396,9 +412,9 @@ def test_filter_values_req_3(): def test_filter_values_req_4(): a = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="54321")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="54321")]), ONTS) required = [a] ava = {"serialNumber": ["12345"]} @@ -408,9 +424,9 @@ def test_filter_values_req_4(): def test_filter_values_req_5(): a = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="12345")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="12345")]), ONTS) required = [a] ava = {"serialNumber": ["12345", "54321"]} @@ -422,9 +438,9 @@ def test_filter_values_req_5(): def test_filter_values_req_6(): a = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="54321")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="54321")]), ONTS) required = [a] ava = {"serialNumber": ["12345", "54321"]} @@ -436,13 +452,13 @@ def test_filter_values_req_6(): def test_filter_values_req_opt_0(): r = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="54321")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="54321")]), ONTS) o = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="12345")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="12345")]), ONTS) ava = {"serialNumber": ["12345", "54321"]} @@ -453,14 +469,14 @@ def test_filter_values_req_opt_0(): def test_filter_values_req_opt_1(): r = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="54321")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="54321")]), ONTS) o = to_dict( - Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, - friendly_name="serialNumber", - attribute_value=[AttributeValue(text="12345"), - AttributeValue(text="abcd0")]), ONTS) + Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", + attribute_value=[AttributeValue(text="12345"), + AttributeValue(text="abcd0")]), ONTS) ava = {"serialNumber": ["12345", "54321"]} @@ -472,30 +488,30 @@ def test_filter_values_req_opt_1(): def test_filter_values_req_opt_2(): r = [ to_dict( - Attribute( - friendly_name="surName", - name="urn:oid:2.5.4.4", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), - ONTS), + Attribute( + friendly_name="surName", + name="urn:oid:2.5.4.4", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + ONTS), to_dict( - Attribute( - friendly_name="givenName", - name="urn:oid:2.5.4.42", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), - ONTS), + Attribute( + friendly_name="givenName", + name="urn:oid:2.5.4.42", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + ONTS), to_dict( - Attribute( - friendly_name="mail", - name="urn:oid:0.9.2342.19200300.100.1.3", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), - ONTS)] + Attribute( + friendly_name="mail", + name="urn:oid:0.9.2342.19200300.100.1.3", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + ONTS)] o = [ to_dict( - Attribute( - friendly_name="title", - name="urn:oid:2.5.4.12", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), - ONTS)] + Attribute( + friendly_name="title", + name="urn:oid:2.5.4.12", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + ONTS)] ava = {"surname": ["Hedberg"], "givenName": ["Roland"], "eduPersonAffiliation": ["staff"], "uid": ["rohe0002"]} @@ -509,18 +525,18 @@ def test_filter_values_req_opt_2(): def test_filter_values_req_opt_4(): r = [ Attribute( - friendly_name="surName", - name="urn:oid:2.5.4.4", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + friendly_name="surName", + name="urn:oid:2.5.4.4", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), Attribute( - friendly_name="givenName", - name="urn:oid:2.5.4.42", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] + friendly_name="givenName", + name="urn:oid:2.5.4.42", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] o = [ Attribute( - friendly_name="title", - name="urn:oid:2.5.4.12", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] + friendly_name="title", + name="urn:oid:2.5.4.12", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] acs = attribute_converter.ac_factory(full_path("attributemaps")) @@ -541,15 +557,15 @@ def test_filter_values_req_opt_4(): def test_filter_ava_0(): policy = Policy( - { - "default": { - "lifetime": {"minutes": 15}, - "attribute_restrictions": None # means all I have - }, - "urn:mace:example.com:saml:roland:sp": { - "lifetime": {"minutes": 5}, + { + "default": { + "lifetime": {"minutes": 15}, + "attribute_restrictions": None # means all I have + }, + "urn:mace:example.com:saml:roland:sp": { + "lifetime": {"minutes": 5}, + } } - } ) ava = {"givenName": ["Derek"], "surName": ["Jeter"], @@ -665,30 +681,30 @@ def test_filter_ava_4(): def test_req_opt(): req = [ to_dict( - md.RequestedAttribute( - friendly_name="surname", name="urn:oid:2.5.4.4", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - is_required="true"), ONTS), + md.RequestedAttribute( + friendly_name="surname", name="urn:oid:2.5.4.4", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + is_required="true"), ONTS), to_dict( - md.RequestedAttribute( - friendly_name="givenname", - name="urn:oid:2.5.4.42", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - is_required="true"), ONTS), + md.RequestedAttribute( + friendly_name="givenname", + name="urn:oid:2.5.4.42", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + is_required="true"), ONTS), to_dict( - md.RequestedAttribute( - friendly_name="edupersonaffiliation", - name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - is_required="true"), ONTS)] + md.RequestedAttribute( + friendly_name="edupersonaffiliation", + name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + is_required="true"), ONTS)] opt = [ to_dict( - md.RequestedAttribute( - friendly_name="title", - name="urn:oid:2.5.4.12", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - is_required="false"), ONTS)] + md.RequestedAttribute( + friendly_name="title", + name="urn:oid:2.5.4.12", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + is_required="false"), ONTS)] policy = Policy() ava = {'givenname': 'Roland', 'surname': 'Hedberg', @@ -702,18 +718,18 @@ def test_req_opt(): def test_filter_on_wire_representation_1(): r = [ Attribute( - friendly_name="surName", - name="urn:oid:2.5.4.4", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + friendly_name="surName", + name="urn:oid:2.5.4.4", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), Attribute( - friendly_name="givenName", - name="urn:oid:2.5.4.42", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] + friendly_name="givenName", + name="urn:oid:2.5.4.42", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] o = [ Attribute( - friendly_name="title", - name="urn:oid:2.5.4.12", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] + friendly_name="title", + name="urn:oid:2.5.4.12", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] acs = attribute_converter.ac_factory(full_path("attributemaps")) @@ -727,18 +743,18 @@ def test_filter_on_wire_representation_1(): def test_filter_on_wire_representation_2(): r = [ Attribute( - friendly_name="surName", - name="urn:oid:2.5.4.4", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), + friendly_name="surName", + name="urn:oid:2.5.4.4", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), Attribute( - friendly_name="givenName", - name="urn:oid:2.5.4.42", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] + friendly_name="givenName", + name="urn:oid:2.5.4.42", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] o = [ Attribute( - friendly_name="title", - name="urn:oid:2.5.4.12", - name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] + friendly_name="title", + name="urn:oid:2.5.4.12", + name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] acs = attribute_converter.ac_factory(full_path("attributemaps")) @@ -784,7 +800,7 @@ def test_assertion_with_noop_attribute_conv(): # THis test doesn't work without a MetadataStore instance -#def test_filter_ava_5(): +# def test_filter_ava_5(): # policy = Policy({ # "default": { # "lifetime": {"minutes": 15}, |