summaryrefslogtreecommitdiff
path: root/src/saml2/assertion.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2/assertion.py')
-rw-r--r--src/saml2/assertion.py422
1 files changed, 209 insertions, 213 deletions
diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py
index dd18affb..30cd90ab 100644
--- a/src/saml2/assertion.py
+++ b/src/saml2/assertion.py
@@ -5,6 +5,7 @@ import importlib
import logging
import re
import six
+from warnings import warn as _warn
from saml2 import saml
from saml2 import xmlenc
@@ -18,6 +19,7 @@ from saml2.saml import NAME_FORMAT_URI
from saml2.time_util import instant
from saml2.time_util import in_a_while
+
logger = logging.getLogger(__name__)
@@ -276,52 +278,54 @@ def restriction_from_attribute_spec(attributes):
return restr
-def post_entity_categories(maps, **kwargs):
- restrictions = {}
- try:
- required = [d['friendly_name'].lower() for d in kwargs['required']]
- except (KeyError, TypeError):
- required = []
+def compile(restrictions):
+ """ This is only for IdPs or AAs, and it's about limiting what
+ is returned to the SP.
+ In the configuration file, restrictions on which values that
+ can be returned are specified with the help of regular expressions.
+ This function goes through and pre-compiles the regular expressions.
- if kwargs["mds"]:
- if "sp_entity_id" in kwargs:
- ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"])
- for ec_map in maps:
- for key, (atlist, only_required) in ec_map.items():
- if key == "": # always released
- attrs = atlist
- elif isinstance(key, tuple):
- if only_required:
- attrs = [a for a in atlist if a in required]
- else:
- attrs = atlist
- for _key in key:
- if _key not in ecs:
- attrs = []
- break
- elif key in ecs:
- if only_required:
- attrs = [a for a in atlist if a in required]
- else:
- attrs = atlist
- else:
- attrs = []
+ :param restrictions: policy configuration
+ :return: The assertion with the string specification replaced with
+ a compiled regular expression.
+ """
+ for who, spec in restrictions.items():
+ spec = spec or {}
- for attr in attrs:
- restrictions[attr] = None
- else:
- for ec_map in maps:
- for attr in ec_map[""]:
- restrictions[attr] = None
+ entity_categories = spec.get("entity_categories", [])
+ ecs = []
+ for cat in entity_categories:
+ try:
+ _mod = importlib.import_module(cat)
+ except ImportError:
+ _mod = importlib.import_module("saml2.entity_category.%s" % cat)
+
+ _ec = {}
+ for key, items in _mod.RELEASE.items():
+ alist = [k.lower() for k in items]
+ _only_required = getattr(_mod, "ONLY_REQUIRED", {}).get(key, False)
+ _ec[key] = (alist, _only_required)
+ ecs.append(_ec)
+ spec["entity_categories"] = ecs or None
+
+ attribute_restrictions = spec.get("attribute_restrictions") or {}
+ _attribute_restrictions = {}
+ for key, values in attribute_restrictions.items():
+ lkey = key.lower()
+ values = [] if not values else values
+ _attribute_restrictions[lkey] = (
+ [re.compile(value) for value in values] or None
+ )
+ spec["attribute_restrictions"] = _attribute_restrictions or None
return restrictions
class Policy(object):
- """ handles restrictions on assertions """
+ """Handles restrictions on assertions."""
- def __init__(self, restrictions=None, config=None):
- self._config = config
+ def __init__(self, restrictions=None, mds=None):
+ self.metadata_store = mds
self._restrictions = self.setup_restrictions(restrictions)
logger.debug("policy restrictions: %s", self._restrictions)
self.acs = []
@@ -331,116 +335,50 @@ class Policy(object):
return None
restrictions = copy.deepcopy(restrictions)
- # TODO: Split policy config in service_providers and registration_authorities
- # "policy": {
- # "service_providers": {
- # "default": ...,
- # "urn:mace:example.com:saml:roland:sp": ...,
- # },
- # "registration_authorities": {
- # "default": ...,
- # "http://www.swamid.se": ...,
- # },
- # },
- registration_authorities = restrictions.pop('registration_authorities', None)
- restrictions = self.compile(restrictions)
- if registration_authorities:
- restrictions['registration_authorities'] = self.compile(registration_authorities)
+ restrictions = compile(restrictions)
return restrictions
- @staticmethod
- def compile(restrictions):
- """ This is only for IdPs or AAs, and it's about limiting what
- is returned to the SP.
- In the configuration file, restrictions on which values that
- can be returned are specified with the help of regular expressions.
- This function goes through and pre-compiles the regular expressions.
-
- :param restrictions: policy configuration
- :return: The assertion with the string specification replaced with
- a compiled regular expression.
- """
- for who, spec in restrictions.items():
- if spec is None:
- continue
-
- entity_categories = spec.get("entity_categories")
- if entity_categories is not None:
- ecs = []
- for cat in entity_categories:
- try:
- _mod = importlib.import_module(cat)
- except ImportError:
- _mod = importlib.import_module(
- "saml2.entity_category.%s" % cat)
- _ec = {}
- for key, items in _mod.RELEASE.items():
- alist = [k.lower() for k in items]
- try:
- _only_required = _mod.ONLY_REQUIRED[key]
- except (AttributeError, KeyError):
- _only_required = False
- _ec[key] = (alist, _only_required)
- ecs.append(_ec)
- spec["entity_categories"] = ecs
-
- attribute_restrictions = spec.get("attribute_restrictions")
- if attribute_restrictions is None:
- continue
-
- _attribute_restrictions = {}
- for key, values in attribute_restrictions.items():
- if not values:
- _attribute_restrictions[key.lower()] = None
- continue
- _attribute_restrictions[key.lower()] = [re.compile(value) for value in values]
-
- spec["attribute_restrictions"] = _attribute_restrictions
-
- return restrictions
-
- def _lookup_registry_authority(self, sp_entity_id):
- if self._config and self._config.metadata:
- registration_info = self._config.metadata.registration_info(sp_entity_id)
- return registration_info.get('registration_authority')
- return None
-
- def get(self, attribute, sp_entity_id, default=None, post_func=None,
- **kwargs):
+ def get(self, attribute, sp_entity_id, default=None):
"""
:param attribute:
:param sp_entity_id:
:param default:
- :param post_func:
:return:
"""
if not self._restrictions:
return default
- registration_authority_name = self._lookup_registry_authority(sp_entity_id)
- registration_authorities = self._restrictions.get("registration_authorities")
-
- val = None
- # Specific SP takes precedence
- if sp_entity_id in self._restrictions:
- val = self._restrictions[sp_entity_id].get(attribute)
- # Second choice is if the SP is part of a configured registration authority
- elif registration_authorities and registration_authority_name in registration_authorities:
- val = registration_authorities[registration_authority_name].get(attribute)
- # Third is to try default for registration authorities
- elif registration_authorities and 'default' in registration_authorities:
- val = registration_authorities['default'].get(attribute)
- # Lastly we try default for SPs
- elif 'default' in self._restrictions:
- val = self._restrictions.get('default').get(attribute)
-
- if val is None:
- return default
- elif post_func:
- return post_func(val, sp_entity_id=sp_entity_id, **kwargs)
- else:
- return val
+ ra_info = (
+ self.metadata_store.registration_info(sp_entity_id) or {}
+ if self.metadata_store is not None
+ else {}
+ )
+ ra_entity_id = ra_info.get("registration_authority")
+
+ sp_restrictions = self._restrictions.get(sp_entity_id)
+ ra_restrictions = self._restrictions.get(ra_entity_id)
+ default_restrictions = (
+ self._restrictions.get("default")
+ or self._restrictions.get("")
+ )
+ restrictions = (
+ sp_restrictions
+ if sp_restrictions is not None
+ else ra_restrictions
+ if ra_restrictions is not None
+ else default_restrictions
+ if default_restrictions is not None
+ else {}
+ )
+
+ attribute_restriction = restrictions.get(attribute)
+ restriction = (
+ attribute_restriction
+ if attribute_restriction is not None
+ else default
+ )
+ return restriction
def get_nameid_format(self, sp_entity_id):
""" Get the NameIDFormat to used for the entity id
@@ -484,18 +422,78 @@ class Policy(object):
return self.get("fail_on_missing_requested", sp_entity_id, default=True)
- def get_entity_categories(self, sp_entity_id, mds, required):
+ def get_sign(self, sp_entity_id):
+ """
+ Possible choices
+ "sign": ["response", "assertion", "on_demand"]
+
+ :param sp_entity_id:
+ :return:
+ """
+
+ return self.get("sign", sp_entity_id, default=[])
+
+ def get_entity_categories(self, sp_entity_id, mds=None, required=None):
"""
:param sp_entity_id:
- :param mds: MetadataStore instance
:param required: required attributes
:return: A dictionary with restrictions
"""
- kwargs = {"mds": mds, 'required': required}
+ if mds is not None:
+ warn_msg = (
+ "The mds parameter for saml2.assertion.Policy.get_entity_categories "
+ "is deprecated; "
+ "instead, initialize the Policy object setting the mds param."
+ )
+ logger.warning(warn_msg)
+ _warn(warn_msg, DeprecationWarning)
+
+ def post_entity_categories(maps, sp_entity_id=None, mds=None, required=None):
+ restrictions = {}
+ required = [d['friendly_name'].lower() for d in (required or [])]
+
+ if mds:
+ ecs = mds.entity_categories(sp_entity_id)
+ for ec_map in maps:
+ for key, (atlist, only_required) in ec_map.items():
+ if key == "": # always released
+ attrs = atlist
+ elif isinstance(key, tuple):
+ if only_required:
+ attrs = [a for a in atlist if a in required]
+ else:
+ attrs = atlist
+ for _key in key:
+ if _key not in ecs:
+ attrs = []
+ break
+ elif key in ecs:
+ if only_required:
+ attrs = [a for a in atlist if a in required]
+ else:
+ attrs = atlist
+ else:
+ attrs = []
+
+ for attr in attrs:
+ restrictions[attr] = None
- return self.get("entity_categories", sp_entity_id, default={}, post_func=post_entity_categories, **kwargs)
+ return restrictions
+
+ sentinel = object()
+ result1 = self.get("entity_categories", sp_entity_id, default=sentinel)
+ if result1 is sentinel:
+ return {}
+
+ result2 = post_entity_categories(
+ result1,
+ sp_entity_id=sp_entity_id,
+ mds=(mds or self.metadata_store),
+ required=required,
+ )
+ return result2
def not_on_or_after(self, sp_entity_id):
""" When the assertion stops being valid, should not be
@@ -507,74 +505,84 @@ class Policy(object):
return in_a_while(**self.get_lifetime(sp_entity_id))
- def get_sign(self, sp_entity_id):
- """
- Possible choices
- "sign": ["response", "assertion", "on_demand"]
-
- :param sp_entity_id:
- :return:
- """
-
- return self.get("sign", sp_entity_id, default=[])
-
- def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
+ def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None):
""" What attribute and attribute values returns depends on what
- the SP has said it wants in the request or in the metadata file and
- what the IdP/AA wants to release. An assumption is that what the SP
+ the SP or the registration authority has said it wants in the request
+ or in the metadata file and what the IdP/AA wants to release.
+ An assumption is that what the SP or the registration authority
asks for overrides whatever is in the metadata. But of course the
IdP never releases anything it doesn't want to.
:param ava: The information about the subject as a dictionary
:param sp_entity_id: The entity ID of the SP
- :param mdstore: A Metadata store
:param required: Attributes that the SP requires in the assertion
:param optional: Attributes that the SP regards as optional
:return: A possibly modified AVA
"""
- _ava = None
-
- if not self.acs: # acs MUST have a value, fall back to default.
+ if mdstore is not None:
+ warn_msg = (
+ "The mdstore parameter for saml2.assertion.Policy.filter "
+ "is deprecated; "
+ "instead, initialize the Policy object setting the mds param."
+ )
+ logger.warning(warn_msg)
+ _warn(warn_msg, DeprecationWarning)
+
+ # acs MUST have a value, fall back to default.
+ if not self.acs:
self.acs = ac_factory()
- _rest = self.get_entity_categories(sp_entity_id, mdstore, required)
- if _rest:
- _ava = filter_attribute_value_assertions(ava.copy(), _rest)
+ subject_ava = ava.copy()
+
+ # entity category restrictions
+ _ent_rest = self.get_entity_categories(sp_entity_id, mds=mdstore, required=required)
+ if _ent_rest:
+ subject_ava = filter_attribute_value_assertions(subject_ava, _ent_rest)
elif required or optional:
logger.debug("required: %s, optional: %s", required, optional)
- _ava = filter_on_attributes(
- ava.copy(), required, optional, self.acs,
- self.get_fail_on_missing_requested(sp_entity_id))
-
- _rest = self.get_attribute_restrictions(sp_entity_id)
- if _rest:
- if _ava is None:
- _ava = ava.copy()
- _ava = filter_attribute_value_assertions(_ava, _rest)
- elif _ava is None:
- _ava = ava.copy()
-
- if _ava is None:
- return {}
- else:
- return _ava
+ subject_ava = filter_on_attributes(
+ subject_ava,
+ required,
+ optional,
+ self.acs,
+ self.get_fail_on_missing_requested(sp_entity_id),
+ )
+
+ # attribute restrictions
+ _attr_rest = self.get_attribute_restrictions(sp_entity_id)
+ subject_ava = filter_attribute_value_assertions(subject_ava, _attr_rest)
+
+ return subject_ava or {}
def restrict(self, ava, sp_entity_id, metadata=None):
- """ Identity attribute names are expected to be expressed in
- the local lingo (== friendlyName)
+ """ Identity attribute names are expected to be expressed as FriendlyNames
:return: A filtered ava according to the IdPs/AAs rules and
the list of required/optional attributes according to the SP.
If the requirements can't be met an exception is raised.
"""
- if metadata:
- spec = metadata.attribute_requirement(sp_entity_id)
- if spec:
- return self.filter(ava, sp_entity_id, metadata,
- spec["required"], spec["optional"])
-
- return self.filter(ava, sp_entity_id, metadata, [], [])
+ if metadata is not None:
+ warn_msg = (
+ "The metadata parameter for saml2.assertion.Policy.restrict "
+ "is deprecated and ignored; "
+ "instead, initialize the Policy object setting the mds param."
+ )
+ logger.warning(warn_msg)
+ _warn(warn_msg, DeprecationWarning)
+
+ metadata_store = metadata or self.metadata_store
+ spec = (
+ metadata_store.attribute_requirement(sp_entity_id) or {}
+ if metadata_store
+ else {}
+ )
+ return self.filter(
+ ava,
+ sp_entity_id,
+ required=spec.get("required"),
+ optional=spec.get("optional"),
+ )
def conditions(self, sp_entity_id):
""" Return a saml.Condition instance
@@ -582,27 +590,18 @@ class Policy(object):
:param sp_entity_id: The SP entity ID
:return: A saml.Condition instance
"""
- return factory(saml.Conditions,
- not_before=instant(),
- # How long might depend on who's getting it
- not_on_or_after=self.not_on_or_after(sp_entity_id),
- audience_restriction=[factory(
- saml.AudienceRestriction,
- audience=[factory(saml.Audience,
- text=sp_entity_id)])])
-
- def entity_category_attributes(self, ec):
- # TODO: Not used. Remove?
- if not self._restrictions:
- return None
-
- ec_maps = self._restrictions["default"]["entity_categories"]
- for ec_map in ec_maps:
- try:
- return ec_map[ec]
- except KeyError:
- pass
- return []
+ return factory(
+ saml.Conditions,
+ not_before=instant(),
+ # How long might depend on who's getting it
+ not_on_or_after=self.not_on_or_after(sp_entity_id),
+ audience_restriction=[
+ factory(
+ saml.AudienceRestriction,
+ audience=[factory(saml.Audience, text=sp_entity_id)],
+ ),
+ ],
+ )
class EntityCategories(object):
@@ -740,12 +739,10 @@ def do_subject_confirmation(policy, sp_entity_id, key_info=None, **treeargs):
def do_subject(policy, sp_entity_id, name_id, **farg):
- #
specs = farg['subject_confirmation']
if isinstance(specs, list):
- res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in
- specs]
+ res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in specs]
else:
res = [do_subject_confirmation(policy, sp_entity_id, **specs)]
@@ -833,17 +830,16 @@ class Assertion(dict):
return _ass
- def apply_policy(self, sp_entity_id, policy, metadata=None):
+ def apply_policy(self, sp_entity_id, policy):
""" Apply policy to the assertion I'm representing
:param sp_entity_id: The SP entity ID
:param policy: The policy
- :param metadata: Metadata to use
:return: The resulting AVA after the policy is applied
"""
policy.acs = self.acs
- ava = policy.restrict(self, sp_entity_id, metadata)
+ ava = policy.restrict(self, sp_entity_id)
for key, val in list(self.items()):
if key in ava: