diff options
Diffstat (limited to 'src/saml2/assertion.py')
-rw-r--r-- | src/saml2/assertion.py | 422 |
1 files changed, 209 insertions, 213 deletions
diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index dd18affb..30cd90ab 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -5,6 +5,7 @@ import importlib import logging import re import six +from warnings import warn as _warn from saml2 import saml from saml2 import xmlenc @@ -18,6 +19,7 @@ from saml2.saml import NAME_FORMAT_URI from saml2.time_util import instant from saml2.time_util import in_a_while + logger = logging.getLogger(__name__) @@ -276,52 +278,54 @@ def restriction_from_attribute_spec(attributes): return restr -def post_entity_categories(maps, **kwargs): - restrictions = {} - try: - required = [d['friendly_name'].lower() for d in kwargs['required']] - except (KeyError, TypeError): - required = [] +def compile(restrictions): + """ This is only for IdPs or AAs, and it's about limiting what + is returned to the SP. + In the configuration file, restrictions on which values that + can be returned are specified with the help of regular expressions. + This function goes through and pre-compiles the regular expressions. - if kwargs["mds"]: - if "sp_entity_id" in kwargs: - ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"]) - for ec_map in maps: - for key, (atlist, only_required) in ec_map.items(): - if key == "": # always released - attrs = atlist - elif isinstance(key, tuple): - if only_required: - attrs = [a for a in atlist if a in required] - else: - attrs = atlist - for _key in key: - if _key not in ecs: - attrs = [] - break - elif key in ecs: - if only_required: - attrs = [a for a in atlist if a in required] - else: - attrs = atlist - else: - attrs = [] + :param restrictions: policy configuration + :return: The assertion with the string specification replaced with + a compiled regular expression. + """ + for who, spec in restrictions.items(): + spec = spec or {} - for attr in attrs: - restrictions[attr] = None - else: - for ec_map in maps: - for attr in ec_map[""]: - restrictions[attr] = None + entity_categories = spec.get("entity_categories", []) + ecs = [] + for cat in entity_categories: + try: + _mod = importlib.import_module(cat) + except ImportError: + _mod = importlib.import_module("saml2.entity_category.%s" % cat) + + _ec = {} + for key, items in _mod.RELEASE.items(): + alist = [k.lower() for k in items] + _only_required = getattr(_mod, "ONLY_REQUIRED", {}).get(key, False) + _ec[key] = (alist, _only_required) + ecs.append(_ec) + spec["entity_categories"] = ecs or None + + attribute_restrictions = spec.get("attribute_restrictions") or {} + _attribute_restrictions = {} + for key, values in attribute_restrictions.items(): + lkey = key.lower() + values = [] if not values else values + _attribute_restrictions[lkey] = ( + [re.compile(value) for value in values] or None + ) + spec["attribute_restrictions"] = _attribute_restrictions or None return restrictions class Policy(object): - """ handles restrictions on assertions """ + """Handles restrictions on assertions.""" - def __init__(self, restrictions=None, config=None): - self._config = config + def __init__(self, restrictions=None, mds=None): + self.metadata_store = mds self._restrictions = self.setup_restrictions(restrictions) logger.debug("policy restrictions: %s", self._restrictions) self.acs = [] @@ -331,116 +335,50 @@ class Policy(object): return None restrictions = copy.deepcopy(restrictions) - # TODO: Split policy config in service_providers and registration_authorities - # "policy": { - # "service_providers": { - # "default": ..., - # "urn:mace:example.com:saml:roland:sp": ..., - # }, - # "registration_authorities": { - # "default": ..., - # "http://www.swamid.se": ..., - # }, - # }, - registration_authorities = restrictions.pop('registration_authorities', None) - restrictions = self.compile(restrictions) - if registration_authorities: - restrictions['registration_authorities'] = self.compile(registration_authorities) + restrictions = compile(restrictions) return restrictions - @staticmethod - def compile(restrictions): - """ This is only for IdPs or AAs, and it's about limiting what - is returned to the SP. - In the configuration file, restrictions on which values that - can be returned are specified with the help of regular expressions. - This function goes through and pre-compiles the regular expressions. - - :param restrictions: policy configuration - :return: The assertion with the string specification replaced with - a compiled regular expression. - """ - for who, spec in restrictions.items(): - if spec is None: - continue - - entity_categories = spec.get("entity_categories") - if entity_categories is not None: - ecs = [] - for cat in entity_categories: - try: - _mod = importlib.import_module(cat) - except ImportError: - _mod = importlib.import_module( - "saml2.entity_category.%s" % cat) - _ec = {} - for key, items in _mod.RELEASE.items(): - alist = [k.lower() for k in items] - try: - _only_required = _mod.ONLY_REQUIRED[key] - except (AttributeError, KeyError): - _only_required = False - _ec[key] = (alist, _only_required) - ecs.append(_ec) - spec["entity_categories"] = ecs - - attribute_restrictions = spec.get("attribute_restrictions") - if attribute_restrictions is None: - continue - - _attribute_restrictions = {} - for key, values in attribute_restrictions.items(): - if not values: - _attribute_restrictions[key.lower()] = None - continue - _attribute_restrictions[key.lower()] = [re.compile(value) for value in values] - - spec["attribute_restrictions"] = _attribute_restrictions - - return restrictions - - def _lookup_registry_authority(self, sp_entity_id): - if self._config and self._config.metadata: - registration_info = self._config.metadata.registration_info(sp_entity_id) - return registration_info.get('registration_authority') - return None - - def get(self, attribute, sp_entity_id, default=None, post_func=None, - **kwargs): + def get(self, attribute, sp_entity_id, default=None): """ :param attribute: :param sp_entity_id: :param default: - :param post_func: :return: """ if not self._restrictions: return default - registration_authority_name = self._lookup_registry_authority(sp_entity_id) - registration_authorities = self._restrictions.get("registration_authorities") - - val = None - # Specific SP takes precedence - if sp_entity_id in self._restrictions: - val = self._restrictions[sp_entity_id].get(attribute) - # Second choice is if the SP is part of a configured registration authority - elif registration_authorities and registration_authority_name in registration_authorities: - val = registration_authorities[registration_authority_name].get(attribute) - # Third is to try default for registration authorities - elif registration_authorities and 'default' in registration_authorities: - val = registration_authorities['default'].get(attribute) - # Lastly we try default for SPs - elif 'default' in self._restrictions: - val = self._restrictions.get('default').get(attribute) - - if val is None: - return default - elif post_func: - return post_func(val, sp_entity_id=sp_entity_id, **kwargs) - else: - return val + ra_info = ( + self.metadata_store.registration_info(sp_entity_id) or {} + if self.metadata_store is not None + else {} + ) + ra_entity_id = ra_info.get("registration_authority") + + sp_restrictions = self._restrictions.get(sp_entity_id) + ra_restrictions = self._restrictions.get(ra_entity_id) + default_restrictions = ( + self._restrictions.get("default") + or self._restrictions.get("") + ) + restrictions = ( + sp_restrictions + if sp_restrictions is not None + else ra_restrictions + if ra_restrictions is not None + else default_restrictions + if default_restrictions is not None + else {} + ) + + attribute_restriction = restrictions.get(attribute) + restriction = ( + attribute_restriction + if attribute_restriction is not None + else default + ) + return restriction def get_nameid_format(self, sp_entity_id): """ Get the NameIDFormat to used for the entity id @@ -484,18 +422,78 @@ class Policy(object): return self.get("fail_on_missing_requested", sp_entity_id, default=True) - def get_entity_categories(self, sp_entity_id, mds, required): + def get_sign(self, sp_entity_id): + """ + Possible choices + "sign": ["response", "assertion", "on_demand"] + + :param sp_entity_id: + :return: + """ + + return self.get("sign", sp_entity_id, default=[]) + + def get_entity_categories(self, sp_entity_id, mds=None, required=None): """ :param sp_entity_id: - :param mds: MetadataStore instance :param required: required attributes :return: A dictionary with restrictions """ - kwargs = {"mds": mds, 'required': required} + if mds is not None: + warn_msg = ( + "The mds parameter for saml2.assertion.Policy.get_entity_categories " + "is deprecated; " + "instead, initialize the Policy object setting the mds param." + ) + logger.warning(warn_msg) + _warn(warn_msg, DeprecationWarning) + + def post_entity_categories(maps, sp_entity_id=None, mds=None, required=None): + restrictions = {} + required = [d['friendly_name'].lower() for d in (required or [])] + + if mds: + ecs = mds.entity_categories(sp_entity_id) + for ec_map in maps: + for key, (atlist, only_required) in ec_map.items(): + if key == "": # always released + attrs = atlist + elif isinstance(key, tuple): + if only_required: + attrs = [a for a in atlist if a in required] + else: + attrs = atlist + for _key in key: + if _key not in ecs: + attrs = [] + break + elif key in ecs: + if only_required: + attrs = [a for a in atlist if a in required] + else: + attrs = atlist + else: + attrs = [] + + for attr in attrs: + restrictions[attr] = None - return self.get("entity_categories", sp_entity_id, default={}, post_func=post_entity_categories, **kwargs) + return restrictions + + sentinel = object() + result1 = self.get("entity_categories", sp_entity_id, default=sentinel) + if result1 is sentinel: + return {} + + result2 = post_entity_categories( + result1, + sp_entity_id=sp_entity_id, + mds=(mds or self.metadata_store), + required=required, + ) + return result2 def not_on_or_after(self, sp_entity_id): """ When the assertion stops being valid, should not be @@ -507,74 +505,84 @@ class Policy(object): return in_a_while(**self.get_lifetime(sp_entity_id)) - def get_sign(self, sp_entity_id): - """ - Possible choices - "sign": ["response", "assertion", "on_demand"] - - :param sp_entity_id: - :return: - """ - - return self.get("sign", sp_entity_id, default=[]) - - def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None): + def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None): """ What attribute and attribute values returns depends on what - the SP has said it wants in the request or in the metadata file and - what the IdP/AA wants to release. An assumption is that what the SP + the SP or the registration authority has said it wants in the request + or in the metadata file and what the IdP/AA wants to release. + An assumption is that what the SP or the registration authority asks for overrides whatever is in the metadata. But of course the IdP never releases anything it doesn't want to. :param ava: The information about the subject as a dictionary :param sp_entity_id: The entity ID of the SP - :param mdstore: A Metadata store :param required: Attributes that the SP requires in the assertion :param optional: Attributes that the SP regards as optional :return: A possibly modified AVA """ - _ava = None - - if not self.acs: # acs MUST have a value, fall back to default. + if mdstore is not None: + warn_msg = ( + "The mdstore parameter for saml2.assertion.Policy.filter " + "is deprecated; " + "instead, initialize the Policy object setting the mds param." + ) + logger.warning(warn_msg) + _warn(warn_msg, DeprecationWarning) + + # acs MUST have a value, fall back to default. + if not self.acs: self.acs = ac_factory() - _rest = self.get_entity_categories(sp_entity_id, mdstore, required) - if _rest: - _ava = filter_attribute_value_assertions(ava.copy(), _rest) + subject_ava = ava.copy() + + # entity category restrictions + _ent_rest = self.get_entity_categories(sp_entity_id, mds=mdstore, required=required) + if _ent_rest: + subject_ava = filter_attribute_value_assertions(subject_ava, _ent_rest) elif required or optional: logger.debug("required: %s, optional: %s", required, optional) - _ava = filter_on_attributes( - ava.copy(), required, optional, self.acs, - self.get_fail_on_missing_requested(sp_entity_id)) - - _rest = self.get_attribute_restrictions(sp_entity_id) - if _rest: - if _ava is None: - _ava = ava.copy() - _ava = filter_attribute_value_assertions(_ava, _rest) - elif _ava is None: - _ava = ava.copy() - - if _ava is None: - return {} - else: - return _ava + subject_ava = filter_on_attributes( + subject_ava, + required, + optional, + self.acs, + self.get_fail_on_missing_requested(sp_entity_id), + ) + + # attribute restrictions + _attr_rest = self.get_attribute_restrictions(sp_entity_id) + subject_ava = filter_attribute_value_assertions(subject_ava, _attr_rest) + + return subject_ava or {} def restrict(self, ava, sp_entity_id, metadata=None): - """ Identity attribute names are expected to be expressed in - the local lingo (== friendlyName) + """ Identity attribute names are expected to be expressed as FriendlyNames :return: A filtered ava according to the IdPs/AAs rules and the list of required/optional attributes according to the SP. If the requirements can't be met an exception is raised. """ - if metadata: - spec = metadata.attribute_requirement(sp_entity_id) - if spec: - return self.filter(ava, sp_entity_id, metadata, - spec["required"], spec["optional"]) - - return self.filter(ava, sp_entity_id, metadata, [], []) + if metadata is not None: + warn_msg = ( + "The metadata parameter for saml2.assertion.Policy.restrict " + "is deprecated and ignored; " + "instead, initialize the Policy object setting the mds param." + ) + logger.warning(warn_msg) + _warn(warn_msg, DeprecationWarning) + + metadata_store = metadata or self.metadata_store + spec = ( + metadata_store.attribute_requirement(sp_entity_id) or {} + if metadata_store + else {} + ) + return self.filter( + ava, + sp_entity_id, + required=spec.get("required"), + optional=spec.get("optional"), + ) def conditions(self, sp_entity_id): """ Return a saml.Condition instance @@ -582,27 +590,18 @@ class Policy(object): :param sp_entity_id: The SP entity ID :return: A saml.Condition instance """ - return factory(saml.Conditions, - not_before=instant(), - # How long might depend on who's getting it - not_on_or_after=self.not_on_or_after(sp_entity_id), - audience_restriction=[factory( - saml.AudienceRestriction, - audience=[factory(saml.Audience, - text=sp_entity_id)])]) - - def entity_category_attributes(self, ec): - # TODO: Not used. Remove? - if not self._restrictions: - return None - - ec_maps = self._restrictions["default"]["entity_categories"] - for ec_map in ec_maps: - try: - return ec_map[ec] - except KeyError: - pass - return [] + return factory( + saml.Conditions, + not_before=instant(), + # How long might depend on who's getting it + not_on_or_after=self.not_on_or_after(sp_entity_id), + audience_restriction=[ + factory( + saml.AudienceRestriction, + audience=[factory(saml.Audience, text=sp_entity_id)], + ), + ], + ) class EntityCategories(object): @@ -740,12 +739,10 @@ def do_subject_confirmation(policy, sp_entity_id, key_info=None, **treeargs): def do_subject(policy, sp_entity_id, name_id, **farg): - # specs = farg['subject_confirmation'] if isinstance(specs, list): - res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in - specs] + res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in specs] else: res = [do_subject_confirmation(policy, sp_entity_id, **specs)] @@ -833,17 +830,16 @@ class Assertion(dict): return _ass - def apply_policy(self, sp_entity_id, policy, metadata=None): + def apply_policy(self, sp_entity_id, policy): """ Apply policy to the assertion I'm representing :param sp_entity_id: The SP entity ID :param policy: The policy - :param metadata: Metadata to use :return: The resulting AVA after the policy is applied """ policy.acs = self.acs - ava = policy.restrict(self, sp_entity_id, metadata) + ava = policy.restrict(self, sp_entity_id) for key, val in list(self.items()): if key in ava: |