summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2016-05-16 20:48:56 +0200
committerRoland Hedberg <roland.hedberg@adm.umu.se>2016-05-16 20:48:56 +0200
commit9ef92af7d8d444be4269cd9a35342deb58850a18 (patch)
tree7723cb3a289e0160c03e6b22644ab4d8009b5e8d
parent9c04dc7ebbae1791b4234bc0f4f2949f1fe9f3c4 (diff)
downloadpysaml2-9ef92af7d8d444be4269cd9a35342deb58850a18.tar.gz
Deal with entity category (CoCo) that have more complex evaluation rules.
-rwxr-xr-xexample/sp-wsgi/sp.py7
-rw-r--r--src/saml2/assertion.py111
-rw-r--r--src/saml2/attributemaps/saml_uri.py2
-rw-r--r--src/saml2/entity_category/edugain.py1
4 files changed, 71 insertions, 50 deletions
diff --git a/example/sp-wsgi/sp.py b/example/sp-wsgi/sp.py
index d1c182e6..38a737ee 100755
--- a/example/sp-wsgi/sp.py
+++ b/example/sp-wsgi/sp.py
@@ -356,9 +356,14 @@ class ACS(Service):
return resp(self.environ, self.start_response)
try:
+ conv_info = {'remote_addr': self.environ['REMOTE_ADDR'],
+ 'request_uri': self.environ['REQUEST_URI'],
+ 'entity_id': self.sp.config.entityid,
+ 'endpoints': self.sp.config.getattr('endpoints', 'sp')}
+
self.response = self.sp.parse_authn_request_response(
response, binding, self.outstanding_queries,
- self.cache.outstanding_certs)
+ self.cache.outstanding_certs, conv_info=conv_info)
except UnknownPrincipal as excp:
logger.error("UnknownPrincipal: %s", excp)
resp = ServiceError("UnknownPrincipal: %s" % (excp,))
diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py
index 3a5220ec..1b678b63 100644
--- a/src/saml2/assertion.py
+++ b/src/saml2/assertion.py
@@ -82,10 +82,12 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
try:
friendly_name = attr["friendly_name"]
except KeyError:
- friendly_name = get_local_name(acs, attr["name"], attr["name_format"])
+ friendly_name = get_local_name(acs, attr["name"],
+ attr["name_format"])
_fn = _match(friendly_name, ava)
- if not _fn: # In the unlikely case that someone has provided us with URIs as attribute names
+ if not _fn: # In the unlikely case that someone has provided us with
+ # URIs as attribute names
_fn = _match(attr["name"], ava)
return _fn
@@ -152,8 +154,8 @@ def filter_on_demands(ava, required=None, optional=None):
for val in vals:
if val not in ava[lava[attr]]:
raise MissingValue(
- "Required attribute value missing: %s,%s" % (attr,
- val))
+ "Required attribute value missing: %s,%s" % (attr,
+ val))
else:
raise MissingValue("Required attribute missing: %s" % (attr,))
@@ -266,6 +268,11 @@ def restriction_from_attribute_spec(attributes):
def post_entity_categories(maps, **kwargs):
restrictions = {}
+ try:
+ required = [d['friendly_name'].lower() for d in kwargs['required']]
+ except KeyError:
+ required = []
+
if kwargs["mds"]:
try:
ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"])
@@ -275,11 +282,14 @@ def post_entity_categories(maps, **kwargs):
restrictions[attr] = None
else:
for ec_map in maps:
- for key, val in ec_map.items():
+ for key, (atlist, only_required) in ec_map.items():
if key == "": # always released
- attrs = val
+ attrs = atlist
elif isinstance(key, tuple):
- attrs = val
+ if only_required:
+ attrs = [a for a in atlist if a in required]
+ else:
+ attrs = atlist
for _key in key:
try:
assert _key in ecs
@@ -287,7 +297,10 @@ def post_entity_categories(maps, **kwargs):
attrs = []
break
elif key in ecs:
- attrs = val
+ if only_required:
+ attrs = [a for a in atlist if a in required]
+ else:
+ attrs = atlist
else:
attrs = []
@@ -332,10 +345,15 @@ class Policy(object):
ecs = []
for cat in items:
_mod = importlib.import_module(
- "saml2.entity_category.%s" % cat)
+ "saml2.entity_category.%s" % cat)
_ec = {}
for key, items in _mod.RELEASE.items():
- _ec[key] = [k.lower() for k in items]
+ alist = [k.lower() for k in items]
+ try:
+ _only_required = _mod.ONLY_REQUIRED[key]
+ except (AttributeError, KeyError):
+ _only_required = False
+ _ec[key] = (alist, _only_required)
ecs.append(_ec)
spec["entity_categories"] = ecs
try:
@@ -444,7 +462,7 @@ class Policy(object):
pass
return []
- def get_entity_categories(self, sp_entity_id, mds):
+ def get_entity_categories(self, sp_entity_id, mds, required):
"""
:param sp_entity_id:
@@ -452,7 +470,7 @@ class Policy(object):
:return: A dictionary with restrictions
"""
- kwargs = {"mds": mds}
+ kwargs = {"mds": mds, 'required': required}
return self.get("entity_categories", sp_entity_id, default={},
post_func=post_entity_categories, **kwargs)
@@ -483,19 +501,15 @@ class Policy(object):
"""
_ava = None
- if required or optional:
- logger.debug("required: %s, optional: %s", required, optional)
- _ava = filter_on_attributes(
- ava.copy(), required, optional, self.acs,
- self.get_fail_on_missing_requested(sp_entity_id))
- _rest = self.get_entity_categories(sp_entity_id, mdstore)
+ _rest = self.get_entity_categories(sp_entity_id, mdstore, required)
if _rest:
- ava_ec = filter_attribute_value_assertions(ava.copy(), _rest)
- if _ava is None:
- _ava = ava_ec
- else:
- _ava.update(ava_ec)
+ _ava = filter_attribute_value_assertions(ava.copy(), _rest)
+ elif required or optional:
+ logger.debug("required: %s, optional: %s", required, optional)
+ _ava = filter_on_attributes(
+ ava.copy(), required, optional, self.acs,
+ self.get_fail_on_missing_requested(sp_entity_id))
_rest = self.get_attribute_restrictions(sp_entity_id)
if _rest:
@@ -537,9 +551,9 @@ class Policy(object):
# How long might depend on who's getting it
not_on_or_after=self.not_on_or_after(sp_entity_id),
audience_restriction=[factory(
- saml.AudienceRestriction,
- audience=[factory(saml.Audience,
- text=sp_entity_id)])])
+ saml.AudienceRestriction,
+ audience=[factory(saml.Audience,
+ text=sp_entity_id)])])
def get_sign(self, sp_entity_id):
"""
@@ -569,7 +583,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
return factory(saml.AuthnContext,
authn_context_class_ref=cntx_class,
authenticating_authority=factory(
- saml.AuthenticatingAuthority, text=authn_auth))
+ saml.AuthenticatingAuthority, text=authn_auth))
else:
return factory(saml.AuthnContext,
authn_context_class_ref=cntx_class)
@@ -585,7 +599,7 @@ def _authn_context_decl(decl, authn_auth=None):
return factory(saml.AuthnContext,
authn_context_decl=decl,
authenticating_authority=factory(
- saml.AuthenticatingAuthority, text=authn_auth))
+ saml.AuthenticatingAuthority, text=authn_auth))
def _authn_context_decl_ref(decl_ref, authn_auth=None):
@@ -598,7 +612,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
return factory(saml.AuthnContext,
authn_context_decl_ref=decl_ref,
authenticating_authority=factory(
- saml.AuthenticatingAuthority, text=authn_auth))
+ saml.AuthenticatingAuthority, text=authn_auth))
def authn_statement(authn_class=None, authn_auth=None,
@@ -624,29 +638,29 @@ def authn_statement(authn_class=None, authn_auth=None,
if authn_class:
res = factory(
- saml.AuthnStatement,
- authn_instant=_instant,
- session_index=sid(),
- authn_context=_authn_context_class_ref(
- authn_class, authn_auth))
+ saml.AuthnStatement,
+ authn_instant=_instant,
+ session_index=sid(),
+ authn_context=_authn_context_class_ref(
+ authn_class, authn_auth))
elif authn_decl:
res = factory(
- saml.AuthnStatement,
- authn_instant=_instant,
- session_index=sid(),
- authn_context=_authn_context_decl(authn_decl, authn_auth))
+ saml.AuthnStatement,
+ authn_instant=_instant,
+ session_index=sid(),
+ authn_context=_authn_context_decl(authn_decl, authn_auth))
elif authn_decl_ref:
res = factory(
- saml.AuthnStatement,
- authn_instant=_instant,
- session_index=sid(),
- authn_context=_authn_context_decl_ref(authn_decl_ref,
- authn_auth))
+ saml.AuthnStatement,
+ authn_instant=_instant,
+ session_index=sid(),
+ authn_context=_authn_context_decl_ref(authn_decl_ref,
+ authn_auth))
else:
res = factory(
- saml.AuthnStatement,
- authn_instant=_instant,
- session_index=sid())
+ saml.AuthnStatement,
+ authn_instant=_instant,
+ session_index=sid())
if subject_locality:
res.subject_locality = saml.SubjectLocality(text=subject_locality)
@@ -688,7 +702,8 @@ def do_subject(policy, sp_entity_id, name_id, **farg):
specs = farg['subject_confirmation']
if isinstance(specs, list):
- res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in specs]
+ res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in
+ specs]
else:
res = [do_subject_confirmation(policy, sp_entity_id, **specs)]
@@ -736,7 +751,7 @@ class Assertion(dict):
_name_format = NAME_FORMAT_URI
attr_statement = saml.AttributeStatement(attribute=from_local(
- attrconvs, self, _name_format))
+ attrconvs, self, _name_format))
if encrypt == "attributes":
for attr in attr_statement.attribute:
diff --git a/src/saml2/attributemaps/saml_uri.py b/src/saml2/attributemaps/saml_uri.py
index e905be4b..afdfee25 100644
--- a/src/saml2/attributemaps/saml_uri.py
+++ b/src/saml2/attributemaps/saml_uri.py
@@ -26,7 +26,7 @@ MAP = {
EDUPERSON_OID+'6': 'eduPersonPrimaryAffiliation',
EDUPERSON_OID+'7': 'eduPersonPrimaryOrgUnitDN',
EDUPERSON_OID+'8': 'eduPersonPrincipalName',
- EDUPERSON_OID+'9': 'eduPersonPrincipalNamePrior',
+ EDUPERSON_OID+'9': 'eduPersonPrincipalName',
EDUPERSON_OID+'10': 'eduPersonScopedAffiliation',
EDUPERSON_OID+'11': 'eduPersonTargetedID',
EDUPERSON_OID+'12': 'eduPersonAssurance',
diff --git a/src/saml2/entity_category/edugain.py b/src/saml2/entity_category/edugain.py
index 62a50941..533af426 100644
--- a/src/saml2/entity_category/edugain.py
+++ b/src/saml2/entity_category/edugain.py
@@ -12,3 +12,4 @@ RELEASE = {
"schacHomeOrganization"]
}
+ONLY_REQUIRED = {COCO: True}