diff options
author | Roland Hedberg <roland.hedberg@adm.umu.se> | 2013-02-09 18:57:26 +0100 |
---|---|---|
committer | Roland Hedberg <roland.hedberg@adm.umu.se> | 2013-02-09 18:57:26 +0100 |
commit | f295e06ab790c5454c164255328ebba4a0c4f2da (patch) | |
tree | f7b29dd6acb2515cd0afbec4a33a82bc812c2b37 /src/saml2 | |
parent | 71246a38292f2f598163ce4b69a3c298f95d04c2 (diff) | |
download | pysaml2-f295e06ab790c5454c164255328ebba4a0c4f2da.tar.gz |
Rewrote to use NameID instances every where where I previously used just the text part of the instance.
Diffstat (limited to 'src/saml2')
-rw-r--r-- | src/saml2/attribute_resolver.py | 17 | ||||
-rw-r--r-- | src/saml2/cache.py | 77 | ||||
-rw-r--r-- | src/saml2/client.py | 89 | ||||
-rw-r--r-- | src/saml2/client_base.py | 112 | ||||
-rw-r--r-- | src/saml2/httpbase.py | 4 | ||||
-rw-r--r-- | src/saml2/ident.py | 46 | ||||
-rw-r--r-- | src/saml2/pack.py | 8 | ||||
-rw-r--r-- | src/saml2/population.py | 54 | ||||
-rw-r--r-- | src/saml2/response.py | 115 | ||||
-rw-r--r-- | src/saml2/virtual_org.py | 40 |
10 files changed, 297 insertions, 265 deletions
diff --git a/src/saml2/attribute_resolver.py b/src/saml2/attribute_resolver.py index 06dbf125..dab809ce 100644 --- a/src/saml2/attribute_resolver.py +++ b/src/saml2/attribute_resolver.py @@ -35,16 +35,13 @@ class AttributeResolver(object): self.saml2client = saml2client self.metadata = saml2client.config.metadata - def extend(self, subject_id, issuer, vo_members, name_id_format=None, - sp_name_qualifier=None, real_id=None): + def extend(self, name_id, issuer, vo_members): """ - :param subject_id: The identifier by which the subject is know + :param name_id: The identifier by which the subject is know among all the participents of the VO :param issuer: Who am I the poses the query :param vo_members: The entity IDs of the IdP who I'm going to ask for extra attributes - :param name_id_format: Used to make the IdPs aware of what's going - on here :return: A dictionary with all the collected information about the subject """ @@ -53,17 +50,13 @@ class AttributeResolver(object): for ass in self.metadata.attribute_consuming_service(member): for attr_serv in ass.attribute_service: logger.info( - "Send attribute request to %s" % attr_serv.location) + "Send attribute request to %s" % attr_serv.location) if attr_serv.binding != BINDING_SOAP: continue # attribute query assumes SOAP binding session_info = self.saml2client.attribute_query( - subject_id, - attr_serv.location, - issuer_id=issuer, - sp_name_qualifier=sp_name_qualifier, - nameid_format=name_id_format, - real_id=real_id) + name_id, attr_serv.location, issuer_id=issuer, +) if session_info: result.append(session_info) return result diff --git a/src/saml2/cache.py b/src/saml2/cache.py index fc7b57c4..a128da3f 100644 --- a/src/saml2/cache.py +++ b/src/saml2/cache.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import shelve +from saml2.ident import code, decode from saml2 import time_util import logging @@ -10,12 +11,15 @@ logger = logging.getLogger(__name__) # gathered from several different sources, all with their own # timeout time. + class ToOld(Exception): pass + class CacheError(Exception): pass + class Cache(object): def __init__(self, filename=None): if filename: @@ -25,18 +29,25 @@ class Cache(object): self._db = {} self._sync = False - def delete(self, subject_id): - del self._db[subject_id] + def delete(self, name_id): + """ + + :param name_id: The subject identifier, a NameID instance + """ + del self._db[code(name_id)] if self._sync: - self._db.sync() + try: + self._db.sync() + except AttributeError: + pass - def get_identity(self, subject_id, entities=None, + def get_identity(self, name_id, entities=None, check_not_on_or_after=True): """ Get all the identity information that has been received and are still valid about the subject. - :param subject_id: The identifier of the subject + :param name_id: The subject identifier, a NameID instance :param entities: The identifiers of the entities whoes assertions are interesting. If the list is empty all entities are interesting. :return: A 2-tuple consisting of the identity information (a @@ -45,7 +56,8 @@ class Cache(object): """ if not entities: try: - entities = self._db[subject_id].keys() + cni = code(name_id) + entities = self._db[cni].keys() except KeyError: return {}, [] @@ -53,7 +65,7 @@ class Cache(object): oldees = [] for entity_id in entities: try: - info = self.get(subject_id, entity_id, check_not_on_or_after) + info = self.get(name_id, entity_id, check_not_on_or_after) except ToOld: oldees.append(entity_id) continue @@ -70,74 +82,81 @@ class Cache(object): res[key] = vals return res, oldees - def get(self, subject_id, entity_id, check_not_on_or_after=True): + def get(self, name_id, entity_id, check_not_on_or_after=True): """ Get session information about a subject gotten from a specified IdP/AA. - :param subject_id: The identifier of the subject + :param name_id: The subject identifier, a NameID instance :param entity_id: The identifier of the entity_id :param check_not_on_or_after: if True it will check if this subject is still valid or if it is too old. Otherwise it will not check this. True by default. :return: The session information """ - (timestamp, info) = self._db[subject_id][entity_id] + cni = code(name_id) + (timestamp, info) = self._db[cni][entity_id] if check_not_on_or_after and time_util.after(timestamp): raise ToOld("past %s" % timestamp) return info or None - def set(self, subject_id, entity_id, info, not_on_or_after=0): - """ Stores session information in the cache. Assumes that the subject_id + def set(self, name_id, entity_id, info, not_on_or_after=0): + """ Stores session information in the cache. Assumes that the name_id is unique within the context of the Service Provider. - :param subject_id: The subject identifier + :param name_id: The subject identifier, a NameID instance :param entity_id: The identifier of the entity_id/receiver of an assertion :param info: The session info, the assertion is part of this :param not_on_or_after: A time after which the assertion is not valid. """ - if subject_id not in self._db: - self._db[subject_id] = {} + cni = code(name_id) + if cni not in self._db: + self._db[cni] = {} - self._db[subject_id][entity_id] = (not_on_or_after, info) + self._db[cni][entity_id] = (not_on_or_after, info) if self._sync: - self._db.sync() + try: + self._db.sync() + except AttributeError: + pass - def reset(self, subject_id, entity_id): + def reset(self, name_id, entity_id): """ Scrap the assertions received from a IdP or an AA about a special subject. - :param subject_id: The subjects identifier + :param name_id: The subject identifier, a NameID instance :param entity_id: The identifier of the entity_id of the assertion :return: """ - self.set(subject_id, entity_id, {}, 0) + self.set(name_id, entity_id, {}, 0) - def entities(self, subject_id): + def entities(self, name_id): """ Returns all the entities of assertions for a subject, disregarding whether the assertion still is valid or not. - :param subject_id: The identifier of the subject + :param name_id: The subject identifier, a NameID instance :return: A possibly empty list of entity identifiers """ - return self._db[subject_id].keys() + cni = code(name_id) + return self._db[cni].keys() - def receivers(self, subject_id): + def receivers(self, name_id): """ Another name for entities() just to make it more logic in the IdP scenario """ - return self.entities(subject_id) + return self.entities(name_id) - def active(self, subject_id, entity_id): + def active(self, name_id, entity_id): """ Returns the status of assertions from a specific entity_id. - :param subject_id: The ID of the subject + :param name_id: The ID of the subject :param entity_id: The entity ID of the entity_id of the assertion :return: True or False depending on if the assertion is still valid or not. """ try: - (timestamp, info) = self._db[subject_id][entity_id] + cni = code(name_id) + (timestamp, info) = self._db[cni][entity_id] except KeyError: return False @@ -151,4 +170,4 @@ class Cache(object): :return: list of subject identifiers """ - return self._db.keys() + return [decode(c) for c in self._db.keys()] diff --git a/src/saml2/client.py b/src/saml2/client.py index 0e23fef2..f7a79395 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -20,7 +20,6 @@ to conclude its tasks. """ from saml2.httpbase import HTTPError from saml2.s_utils import sid -from saml2.samlp import logout_response_from_string import saml2 try: @@ -46,6 +45,7 @@ from saml2 import BINDING_SOAP import logging logger = logging.getLogger(__name__) + class Saml2Client(Base): """ The basic pySAML2 service provider class """ @@ -81,12 +81,12 @@ class Saml2Client(Base): return req.id, info - def global_logout(self, subject_id, reason="", expire=None, sign=None): + def global_logout(self, name_id, reason="", expire=None, sign=None): """ More or less a layer of indirection :-/ Bootstrapping the whole thing by finding all the IdPs that should be notified. - :param subject_id: The identifier of the subject that wants to be + :param name_id: The identifier of the subject that wants to be logged out. :param reason: Why the subject wants to log out :param expire: The latest the log out should happen. @@ -99,17 +99,17 @@ class Saml2Client(Base): conversation. """ - logger.info("logout request for: %s" % subject_id) + logger.info("logout request for: %s" % name_id) # find out which IdPs/AAs I should notify - entity_ids = self.users.issuers_of_info(subject_id) + entity_ids = self.users.issuers_of_info(name_id) - return self.do_logout(subject_id, entity_ids, reason, expire, sign) + return self.do_logout(name_id, entity_ids, reason, expire, sign) - def do_logout(self, subject_id, entity_ids, reason, expire, sign=None): + def do_logout(self, name_id, entity_ids, reason, expire, sign=None): """ - :param subject_id: Identifier of the Subject + :param name_id: Identifier of the Subject a NameID instance :param entity_ids: List of entity ids for the IdPs that have provided information concerning the subject :param reason: The reason for doing the logout @@ -118,34 +118,33 @@ class Saml2Client(Base): :return: """ # check time - if not not_on_or_after(expire): # I've run out of time + if not not_on_or_after(expire): # I've run out of time # Do the local logout anyway - self.local_logout(subject_id) + self.local_logout(name_id) return 0, "504 Gateway Timeout", [], [] - # for all where I can use the SOAP binding, do those first not_done = entity_ids[:] responses = {} for entity_id in entity_ids: - response = False - - for binding in [BINDING_SOAP, - BINDING_HTTP_POST, + logger.debug("Logout from '%s'" % entity_id) + # for all where I can use the SOAP binding, do those first + for binding in [BINDING_SOAP, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT]: srvs = self.metadata.single_logout_service(entity_id, binding, "idpsso") if not srvs: + logger.debug("No SLO '%s' service" % binding) continue destination = destinations(srvs)[0] - logger.info("destination to provider: %s" % destination) request = self.create_logout_request(destination, entity_id, - subject_id, reason=reason, + name_id=name_id, + reason=reason, expire=expire) - to_sign = [] + #to_sign = [] if binding.startswith("http://"): sign = True @@ -160,28 +159,28 @@ class Saml2Client(Base): relay_state = self._relay_state(request.id) http_info = self.apply_binding(binding, srequest, destination, - relay_state) + relay_state) if binding == BINDING_SOAP: - if response: - logger.info("Verifying response") - response = self.send(**http_info) + response = self.send(**http_info) - if response: + if response and response.status_code == 200: not_done.remove(entity_id) - logger.info("OK response from %s" % destination) - responses[entity_id] = logout_response_from_string(response) + response = response.text + logger.info("Response: %s" % response) + res = self.parse_logout_request_response(response) + responses[entity_id] = res else: logger.info("NOT OK response from %s" % destination) else: self.state[request.id] = {"entity_id": entity_id, - "operation": "SLO", - "entity_ids": entity_ids, - "subject_id": subject_id, - "reason": reason, - "not_on_of_after": expire, - "sign": sign} + "operation": "SLO", + "entity_ids": entity_ids, + "name_id": name_id, + "reason": reason, + "not_on_of_after": expire, + "sign": sign} responses[entity_id] = (binding, http_info) not_done.remove(entity_id) @@ -217,9 +216,9 @@ class Saml2Client(Base): issuer = response.issuer() logger.info("issuer: %s" % issuer) del self.state[response.in_response_to] - if status["entity_ids"] == [issuer]: # done + if status["entity_ids"] == [issuer]: # done self.local_logout(status["subject_id"]) - return 0, "200 Ok", [("Content-type","text/html")], [] + return 0, "200 Ok", [("Content-type", "text/html")], [] else: status["entity_ids"].remove(issuer) return self.do_logout(status["subject_id"], status["entity_ids"], @@ -277,16 +276,15 @@ class Saml2Client(Base): consent=None, extensions=None, sign=False): subject = saml.Subject( - name_id = saml.NameID(text=subject_id, - format=nameid_format, - sp_name_qualifier=sp_name_qualifier, - name_qualifier=name_qualifier)) + name_id=saml.NameID(text=subject_id, format=nameid_format, + sp_name_qualifier=sp_name_qualifier, + name_qualifier=name_qualifier)) srvs = self.metadata.authz_service(entity_id, BINDING_SOAP) for dest in destinations(srvs): resp = self._use_soap(dest, "authz_decision_query", - action=action, evidence=evidence, - resource=resource, subject=subject) + action=action, evidence=evidence, + resource=resource, subject=subject) if resp: return resp @@ -308,8 +306,8 @@ class Saml2Client(Base): for destination in destinations(srvs): res = self._use_soap(destination, "assertion_id_request", - assertion_id_refs=_id_refs, consent=consent, - extensions=extensions, sign=sign) + assertion_id_refs=_id_refs, consent=consent, + extensions=extensions, sign=sign) if res: return res @@ -321,9 +319,8 @@ class Saml2Client(Base): srvs = self.metadata.authn_request_service(entity_id, BINDING_SOAP) for destination in destinations(srvs): - resp = self._use_soap(destination, "authn_query", - consent=consent, extensions=extensions, - sign=sign) + resp = self._use_soap(destination, "authn_query", consent=consent, + extensions=extensions, sign=sign) if resp: return resp @@ -339,7 +336,8 @@ class Saml2Client(Base): :param entityid: To whom the query should be sent :param subject_id: The identifier of the subject - :param attribute: A dictionary of attributes and values that is asked for + :param attribute: A dictionary of attributes and values that is + asked for :param sp_name_qualifier: The unique identifier of the service provider or affiliation of providers for whom the identifier was generated. @@ -353,7 +351,6 @@ class Saml2Client(Base): HTTP args if BINDING_HTT_POST was used. """ - if real_id: response_args = {"real_id": real_id} else: diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 70e7a0c6..b00abaa1 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -78,23 +78,28 @@ ECP_SERVICE = "urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp" ACTOR = "http://schemas.xmlsoap.org/soap/actor/next" MIME_PAOS = "application/vnd.paos+xml" + class IdpUnspecified(Exception): pass + class VerifyError(Exception): pass + class LogoutError(Exception): pass + class NoServiceDefined(Exception): pass + class Base(Entity): """ The basic pySAML2 service provider class """ def __init__(self, config=None, identity_cache=None, state_cache=None, - virtual_organization="",config_file=""): + virtual_organization="", config_file=""): """ :param config: A saml2.config.Config instance :param identity_cache: Where the class should store identity information @@ -108,12 +113,12 @@ class Base(Entity): # for server state storage if state_cache is None: - self.state = {} # in memory storage + self.state = {} # in memory storage else: self.state = state_cache for foo in ["allow_unsolicited", "authn_requests_signed", - "logout_requests_signed"]: + "logout_requests_signed"]: if self.config.getattr("sp", foo) == 'true': setattr(self, foo, True) else: @@ -166,25 +171,25 @@ class Base(Entity): # Public API # - def add_vo_information_about_user(self, subject_id): + def add_vo_information_about_user(self, name_id): """ Add information to the knowledge I have about the user. This is for Virtual organizations. - :param subject_id: The subject identifier + :param name_id: The subject identifier :return: A possibly extended knowledge. """ ava = {} try: - (ava, _) = self.users.get_identity(subject_id) + (ava, _) = self.users.get_identity(name_id) except KeyError: pass # is this a Virtual Organization situation if self.vorg: - if self.vorg.do_aggregation(subject_id): + if self.vorg.do_aggregation(name_id): # Get the extended identity - ava = self.users.get_identity(subject_id)[0] + ava = self.users.get_identity(name_id)[0] return ava #noinspection PyUnusedLocal @@ -228,7 +233,8 @@ class Base(Entity): args = {} try: - args["assertion_consumer_service_url"] = kwargs["assertion_consumer_service_url"] + args["assertion_consumer_service_url"] = kwargs[ + "assertion_consumer_service_url"] except KeyError: if service_url_binding is None: service_url = self.service_url(binding) @@ -247,16 +253,17 @@ class Base(Entity): try: args["name_id_policy"] = kwargs["name_id_policy"] del kwargs["name_id_policy"] - except: + except KeyError: if allow_create: - allow_create="true" + allow_create = "true" else: - allow_create="false" + allow_create = "false" # Profile stuff, should be configurable - if nameid_format is None or nameid_format == NAMEID_FORMAT_TRANSIENT: - name_id_policy = samlp.NameIDPolicy(allow_create=allow_create, - format=NAMEID_FORMAT_TRANSIENT) + if nameid_format is None or \ + nameid_format == NAMEID_FORMAT_TRANSIENT: + name_id_policy = samlp.NameIDPolicy( + allow_create=allow_create, format=NAMEID_FORMAT_TRANSIENT) else: name_id_policy = samlp.NameIDPolicy(allow_create=allow_create, format=nameid_format) @@ -272,28 +279,31 @@ class Base(Entity): if kwargs: if extensions is None: extensions = [] - fargs = [p for p,c,r in AuthnRequest.c_attributes.values()] - fargs.extend([p for p,c in AuthnRequest.c_children.values()]) - for key,val in kwargs.items(): + fargs = [p for p, c, r in AuthnRequest.c_attributes.values()] + fargs.extend([p for p, c in AuthnRequest.c_children.values()]) + for key, val in kwargs.items(): if key not in fargs: # extension elements allowed extensions.append(saml2.element_to_extension_element(val)) else: args[key] = val + try: + del args["id"] + except KeyError: + pass return self._message(AuthnRequest, destination, sid, consent, extensions, sign, protocol_binding=binding, scoping=scoping, **args) - def create_attribute_query(self, destination, name_id=None, attribute=None, sid=0, consent=None, extensions=None, sign=False, **kwargs): """ Constructs an AttributeQuery :param destination: To whom the query should be sent - :param subject_id: The identifier of the subject + :param name_id: The identifier of the subject :param attribute: A dictionary of attributes and values that is asked for. The key are one of 4 variants: 3-tuple of name_format,name and friendly_name, @@ -333,7 +343,7 @@ class Base(Entity): except KeyError: pass - subject = saml.Subject(name_id = name_id) + subject = saml.Subject(name_id=name_id) if attribute: attribute = do_attributes(attribute) @@ -342,11 +352,9 @@ class Base(Entity): extensions, sign, subject=subject, attribute=attribute) - # MUST use SOAP for # AssertionIDRequest, SubjectQuery, # AuthnQuery, AttributeQuery, or AuthzDecisionQuery - def create_authz_decision_query(self, destination, action, evidence=None, resource=None, subject=None, sid=0, consent=None, extensions=None, @@ -369,8 +377,9 @@ class Base(Entity): extensions, sign, action=action, evidence=evidence, resource=resource, subject=subject) - def create_authz_decision_query_using_assertion(self, destination, assertion, - action=None, resource=None, + def create_authz_decision_query_using_assertion(self, destination, + assertion, action=None, + resource=None, subject=None, sid=0, consent=None, extensions=None, @@ -397,14 +406,10 @@ class Base(Entity): else: _action = None - return self.create_authz_decision_query(destination, - _action, - saml.Evidence(assertion=assertion), - resource, subject, - sid=sid, - consent=consent, - extensions=extensions, - sign=sign) + return self.create_authz_decision_query( + destination, _action, saml.Evidence(assertion=assertion), + resource, subject, sid=sid, consent=consent, extensions=extensions, + sign=sign) def create_assertion_id_request(self, assertion_id_refs, **kwargs): """ @@ -442,10 +447,10 @@ class Base(Entity): requested_authn_context=authn_context) def create_name_id_mapping_request(self, name_id_policy, - name_id=None, base_id=None, - encrypted_id=None, destination=None, - sid=0, consent=None, extensions=None, - sign=False): + name_id=None, base_id=None, + encrypted_id=None, destination=None, + sid=0, consent=None, extensions=None, + sign=False): """ :param name_id_policy: @@ -464,16 +469,17 @@ class Base(Entity): assert name_id or base_id or encrypted_id if name_id: - return self._message(NameIDMappingRequest, destination, sid, consent, - extensions, sign, name_id_policy=name_id_policy, - name_id=name_id) + return self._message(NameIDMappingRequest, destination, sid, + consent, extensions, sign, + name_id_policy=name_id_policy, name_id=name_id) elif base_id: - return self._message(NameIDMappingRequest, destination, sid, consent, - extensions, sign, name_id_policy=name_id_policy, - base_id=base_id) + return self._message(NameIDMappingRequest, destination, sid, + consent, extensions, sign, + name_id_policy=name_id_policy, base_id=base_id) else: - return self._message(NameIDMappingRequest, destination, sid, consent, - extensions, sign, name_id_policy=name_id_policy, + return self._message(NameIDMappingRequest, destination, sid, + consent, extensions, sign, + name_id_policy=name_id_policy, encrypted_id=encrypted_id) # ======== response handling =========== @@ -549,7 +555,7 @@ class Base(Entity): "attribute_converters": self.config.attribute_converters} res = self._parse_response(response, AssertionIDResponse, "", binding, - **kwargs) + **kwargs) return res # ------------------------------------------------------------------------ @@ -594,7 +600,7 @@ class Base(Entity): # paos_request = paos.Request(must_understand="1", actor=ACTOR, response_consumer_url=my_url, - service = ECP_SERVICE) + service=ECP_SERVICE) # ---------------------------------------- # <ecp:RelayState> @@ -622,16 +628,16 @@ class Base(Entity): # SingleSignOnService _, location = self.pick_binding("single_sign_on_service", [_binding], entity_id=entityid) - authn_req = self.create_authn_request(location, - service_url_binding=BINDING_PAOS, - **kwargs) + authn_req = self.create_authn_request( + location, service_url_binding=BINDING_PAOS, **kwargs) # ---------------------------------------- # The SOAP envelope # ---------------------------------------- - soap_envelope = make_soap_enveloped_saml_thingy(authn_req,[paos_request, - relay_state]) + soap_envelope = make_soap_enveloped_saml_thingy(authn_req, + [paos_request, + relay_state]) return authn_req.id, "%s" % soap_envelope @@ -644,7 +650,7 @@ class Base(Entity): _relay_state = None for item in rdict["header"]: if item.c_tag == "RelayState" and\ - item.c_namespace == ecp.NAMESPACE: + item.c_namespace == ecp.NAMESPACE: _relay_state = item response = self.parse_authn_request_response(rdict["body"], diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 0eb66661..9643e839 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -140,9 +140,9 @@ class HTTPBase(object): if morsel["max-age"]: std_attr["expires"] = _since_epoch(morsel["max-age"]) - for att, set in PAIRS.items(): + for att, item in PAIRS.items(): if std_attr[att]: - std_attr[set] = True + std_attr[item] = True if std_attr["domain"] and std_attr["domain"].startswith("."): std_attr["domain_initial_dot"] = True diff --git a/src/saml2/ident.py b/src/saml2/ident.py index c091f286..6fabe0b8 100644 --- a/src/saml2/ident.py +++ b/src/saml2/ident.py @@ -19,9 +19,11 @@ logger = logging.getLogger(__name__) ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id", "text"] + class Unknown(Exception): pass + def code(item): _res = [] i = 0 @@ -32,13 +34,15 @@ def code(item): i += 1 return ",".join(_res) -def decode(str): + +def decode(txt): _nid = NameID() - for part in str.split(","): + for part in txt.split(","): i, val = part.split("=") setattr(_nid, ATTR[int(i)], unquote(val)) return _nid + class IdentDB(object): """ A class that handles identifiers of entities Keeps a list of all nameIDs returned per SP @@ -51,19 +55,19 @@ class IdentDB(object): self.domain = domain self.name_qualifier = name_qualifier - def _create_id(self, format, name_qualifier="", sp_name_qualifier=""): + def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): _id = sha256(rndstr(32)) - _id.update(format) + _id.update(nformat) if name_qualifier: _id.update(name_qualifier) if sp_name_qualifier: _id.update(sp_name_qualifier) return _id.hexdigest() - def create_id(self, format, name_qualifier="", sp_name_qualifier=""): - _id = self._create_id(format, name_qualifier, sp_name_qualifier) + def create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): + _id = self._create_id(nformat, name_qualifier, sp_name_qualifier) while _id in self.db: - _id = self._create_id(format, name_qualifier, sp_name_qualifier) + _id = self._create_id(nformat, name_qualifier, sp_name_qualifier) return _id def store(self, ident, name_id): @@ -92,30 +96,30 @@ class IdentDB(object): del self.db[_cn] - def remove_local(self, id): - if isinstance(id, unicode): - id = id.encode("utf-8") + def remove_local(self, sid): + if isinstance(sid, unicode): + sid = sid.encode("utf-8") try: - for val in self.db[id].split(" "): + for val in self.db[sid].split(" "): try: del self.db[val] except KeyError: pass - del self.db[id] + del self.db[sid] except KeyError: pass - def get_nameid(self, userid, format, sp_name_qualifier, name_qualifier): - _id = self.create_id(format, name_qualifier, sp_name_qualifier) + def get_nameid(self, userid, nformat, sp_name_qualifier, name_qualifier): + _id = self.create_id(nformat, name_qualifier, sp_name_qualifier) - if format == NAMEID_FORMAT_EMAILADDRESS: + if nformat == NAMEID_FORMAT_EMAILADDRESS: if not self.domain: raise Exception("Can't issue email nameids, unknown domain") _id = "%s@%s" % (_id, self.domain) - nameid = NameID(format=format, sp_name_qualifier=sp_name_qualifier, + nameid = NameID(format=nformat, sp_name_qualifier=sp_name_qualifier, name_qualifier=name_qualifier, text=_id) self.store(userid, nameid) @@ -150,8 +154,9 @@ class IdentDB(object): if not name_qualifier: name_qualifier = self.name_qualifier - return {"format":nameid_format, "sp_name_qualifier": sp_name_qualifier, - "name_qualifier":name_qualifier} + return {"nformat": nameid_format, + "sp_name_qualifier": sp_name_qualifier, + "name_qualifier": name_qualifier} def construct_nameid(self, userid, local_policy=None, sp_name_qualifier=None, name_id_policy=None, @@ -175,7 +180,8 @@ class IdentDB(object): return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT, sp_name_qualifier, name_qualifier) - def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): + def persistent_nameid(self, userid, sp_name_qualifier="", + name_qualifier=""): nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier) if nameid: return nameid @@ -194,6 +200,8 @@ class IdentDB(object): try: return self.db[code(name_id)] except KeyError: + logger.debug("name: %s" % code(name_id)) + logger.debug("id keys: %s" % self.db.keys()) return None def match_local_id(self, userid, sp_name_qualifier, name_qualifier): diff --git a/src/saml2/pack.py b/src/saml2/pack.py index d758ce9e..a985e7a4 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -154,9 +154,11 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None): if isinstance(thingy, basestring): # remove the first XML version/encoding line + logger.debug("thingy0: %s" % thingy) _part = thingy.split("\n") - thingy = _part[1] + thingy = "".join(_part[1:]) thingy = thingy.replace(PREFIX, "") + logger.debug("thingy: %s" % thingy) _child = ElementTree.Element('') _child.tag = '{%s}FuddleMuddle' % DUMMY_NAMESPACE body.append(_child) @@ -165,12 +167,12 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None): # find an remove the namespace definition i = _str.find(DUMMY_NAMESPACE) j = _str.rfind("xmlns:", 0, i) - cut1 = _str[j:i+len(DUMMY_NAMESPACE)+1] + cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1] _str = _str.replace(cut1, "") first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],)) last = _str.find(">", first+14) cut2 = _str[first:last+1] - return _str.replace(cut2,thingy) + return _str.replace(cut2, thingy) else: thingy.become_child_element_of(body) return ElementTree.tostring(envelope, encoding="UTF-8") diff --git a/src/saml2/population.py b/src/saml2/population.py index 58ee1cd4..503375a5 100644 --- a/src/saml2/population.py +++ b/src/saml2/population.py @@ -3,6 +3,7 @@ from saml2.cache import Cache logger = logging.getLogger(__name__) + class Population(object): def __init__(self, cache=None): if cache: @@ -17,45 +18,48 @@ class Population(object): """If there already are information from this source in the cache this function will overwrite that information""" - subject_id = session_info["name_id"] + name_id = session_info["name_id"] issuer = session_info["issuer"] del session_info["issuer"] - self.cache.set(subject_id, issuer, session_info, - session_info["not_on_or_after"]) - return subject_id + self.cache.set(name_id, issuer, session_info, + session_info["not_on_or_after"]) + return name_id - def stale_sources_for_person(self, subject_id, sources=None): - if not sources: # assume that all the members has be asked - # once before, hence they are represented in the cache - sources = self.cache.entities(subject_id) - sources = [m for m in sources \ - if not self.cache.active(subject_id, m)] + def stale_sources_for_person(self, name_id, sources=None): + """ + + :param name_id: Identifier of the subject, a NameID instance + :param sources: Sources for information about the subject + :return: + """ + if not sources: # assume that all the members has be asked + # once before, hence they are represented in the cache + sources = self.cache.entities(name_id) + sources = [m for m in sources if not self.cache.active(name_id, m)] return sources - def issuers_of_info(self, subject_id): - return self.cache.entities(subject_id) + def issuers_of_info(self, name_id): + return self.cache.entities(name_id) - def get_identity(self, subject_id, entities=None, - check_not_on_or_after=True): - return self.cache.get_identity(subject_id, entities, - check_not_on_or_after) + def get_identity(self, name_id, entities=None, check_not_on_or_after=True): + return self.cache.get_identity(name_id, entities, check_not_on_or_after) - def get_info_from(self, subject_id, entity_id): - return self.cache.get(subject_id, entity_id) + def get_info_from(self, name_id, entity_id): + return self.cache.get(name_id, entity_id) def subjects(self): """Returns the name id's for all the persons in the cache""" return self.cache.subjects() - def remove_person(self, subject_id): - self.cache.delete(subject_id) + def remove_person(self, name_id): + self.cache.delete(name_id) - def get_entityid(self, subject_id, source_id, check_not_on_or_after=True): + def get_entityid(self, name_id, source_id, check_not_on_or_after=True): try: - return self.cache.get(subject_id, source_id, - check_not_on_or_after)["name_id"] + return self.cache.get(name_id, source_id, check_not_on_or_after)[ + "name_id"] except (KeyError, ValueError): return "" - def sources(self, subject_id): - return self.cache.entities(subject_id) + def sources(self, name_id): + return self.cache.entities(name_id) diff --git a/src/saml2/response.py b/src/saml2/response.py index b82a2820..b83c97da 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -49,17 +49,21 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- + class IncorrectlySigned(Exception): pass + class VerificationError(Exception): pass # --------------------------------------------------------------------------- + def _dummy(_): return None + def for_me(condition, myself): # Am I among the intended audiences for restriction in condition.audience_restriction: @@ -72,6 +76,7 @@ def for_me(condition, myself): return False + def authn_response(conf, return_addr, outstanding_queries=None, timeslack=0, asynchop=True, allow_unsolicited=False): sec = security_context(conf) @@ -82,8 +87,9 @@ def authn_response(conf, return_addr, outstanding_queries=None, timeslack=0, timeslack = 0 return AuthnResponse(sec, conf.attribute_converters, conf.entityid, - return_addr, outstanding_queries, timeslack, - asynchop=asynchop, allow_unsolicited=allow_unsolicited) + return_addr, outstanding_queries, timeslack, + asynchop=asynchop, allow_unsolicited=allow_unsolicited) + # comes in over SOAP so synchronous def attribute_response(conf, return_addr, timeslack=0, asynchop=False, @@ -96,8 +102,9 @@ def attribute_response(conf, return_addr, timeslack=0, asynchop=False, timeslack = 0 return AttributeResponse(sec, conf.attribute_converters, conf.entityid, - return_addr, timeslack, asynchop=asynchop, - test=test) + return_addr, timeslack, asynchop=asynchop, + test=test) + class StatusResponse(object): msgtype = "status_response" @@ -111,7 +118,7 @@ class StatusResponse(object): self.request_id = request_id self.xmlstr = "" - self.name_id = "" + self.name_id = None self.response = None self.not_on_or_after = 0 self.in_response_to = None @@ -121,7 +128,7 @@ class StatusResponse(object): def _clear(self): self.xmlstr = "" - self.name_id = "" + self.name_id = None self.response = None self.not_on_or_after = 0 @@ -149,9 +156,10 @@ class StatusResponse(object): # This will check signature on Assertion which is the default try: self.response = self.sec.check_signature(instance) - except SignatureError: # The response as a whole might be signed or not - self.response = self.sec.check_signature(instance, - samlp.NAMESPACE+":Response") + except SignatureError: + # The response as a whole might be signed or not + self.response = self.sec.check_signature( + instance, samlp.NAMESPACE + ":Response") else: self.not_signed = True self.response = instance @@ -190,9 +198,9 @@ class StatusResponse(object): def issue_instant_ok(self): """ Check that the response was issued at a reasonable time """ upper = time_util.shift_time(time_util.time_in_a_while(days=1), - self.timeslack).timetuple() + self.timeslack).timetuple() lower = time_util.shift_time(time_util.time_a_while_ago(days=1), - -self.timeslack).timetuple() + -self.timeslack).timetuple() # print "issue_instant: %s" % self.response.issue_instant # print "%s < x < %s" % (lower, upper) issued_at = str_to_time(self.response.issue_instant) @@ -200,10 +208,9 @@ class StatusResponse(object): def _verify(self): if self.request_id and self.in_response_to and \ - self.in_response_to != self.request_id: + self.in_response_to != self.request_id: logger.error("Not the id I expected: %s != %s" % ( - self.in_response_to, - self.request_id)) + self.in_response_to, self.request_id)) return None try: @@ -217,9 +224,9 @@ class StatusResponse(object): if self.asynchop: if self.response.destination and \ - self.response.destination != self.return_addr: + self.response.destination != self.return_addr: logger.error("%s != %s" % (self.response.destination, - self.return_addr)) + self.return_addr)) return None assert self.issue_instant_ok() @@ -244,14 +251,17 @@ class StatusResponse(object): def issuer(self): return self.response.issuer.text.strip() + class LogoutResponse(StatusResponse): msgtype = "logout_response" + def __init__(self, sec_context, return_addr=None, timeslack=0, asynchop=True): StatusResponse.__init__(self, sec_context, return_addr, timeslack, asynchop=asynchop) self.signature_check = self.sec.correctly_signed_logout_response + class NameIDMappingResponse(StatusResponse): msgtype = "name_id_mapping_response" @@ -261,6 +271,7 @@ class NameIDMappingResponse(StatusResponse): request_id, asynchop) self.signature_check = self.sec.correctly_signed_name_id_mapping_response + class ManageNameIDResponse(StatusResponse): msgtype = "manage_name_id_response" @@ -273,15 +284,16 @@ class ManageNameIDResponse(StatusResponse): # ---------------------------------------------------------------------------- + class AuthnResponse(StatusResponse): """ This is where all the profile compliance is checked. This one does saml2int compliance. """ msgtype = "authn_response" - def __init__(self, sec_context, attribute_converters, entity_id, - return_addr=None, outstanding_queries=None, - timeslack=0, asynchop=True, allow_unsolicited=False, - test=False): + def __init__(self, sec_context, attribute_converters, entity_id, + return_addr=None, outstanding_queries=None, + timeslack=0, asynchop=True, allow_unsolicited=False, + test=False): StatusResponse.__init__(self, sec_context, return_addr, timeslack, asynchop=asynchop) @@ -335,7 +347,8 @@ class AuthnResponse(StatusResponse): if validate_on_or_after(authn_statement.session_not_on_or_after, self.timeslack): self.session_not_on_or_after = calendar.timegm( - time_util.str_to_time(authn_statement.session_not_on_or_after)) + time_util.str_to_time( + authn_statement.session_not_on_or_after)) else: return False return True @@ -364,8 +377,7 @@ class AuthnResponse(StatusResponse): try: if condition.not_on_or_after: self.not_on_or_after = validate_on_or_after( - condition.not_on_or_after, - self.timeslack) + condition.not_on_or_after, self.timeslack) if condition.not_before: validate_before(condition.not_before, self.timeslack) except Exception, excp: @@ -375,7 +387,6 @@ class AuthnResponse(StatusResponse): else: self.not_on_or_after = 0 - if not for_me(condition, self.entity_id): if not lax: #print condition @@ -490,7 +501,7 @@ class AuthnResponse(StatusResponse): pass else: raise ValueError("Unknown subject confirmation method: %s" % ( - subject_confirmation.method,)) + subject_confirmation.method,)) subjconf.append(subject_confirmation) @@ -501,7 +512,7 @@ class AuthnResponse(StatusResponse): # The subject must contain a name_id assert subject.name_id - self.name_id = subject.name_id.text.strip() + self.name_id = subject.name_id return self.name_id def _assertion(self, assertion): @@ -561,7 +572,7 @@ class AuthnResponse(StatusResponse): def parse_assertion(self): try: assert len(self.response.assertion) == 1 or \ - len(self.response.encrypted_assertion) == 1 + len(self.response.encrypted_assertion) == 1 except AssertionError: raise Exception("No assertion part") @@ -571,8 +582,7 @@ class AuthnResponse(StatusResponse): else: logger.debug("***Encrypted response***") return self._encrypted_assertion( - self.response.encrypted_assertion[0]) - + self.response.encrypted_assertion[0]) def verify(self): """ Verify that the assertion is syntactically correct and @@ -615,7 +625,7 @@ class AuthnResponse(StatusResponse): return res def authz_decision_info(self): - res = {"permit":[], "deny": [], "indeterminate":[] } + res = {"permit": [], "deny": [], "indeterminate": []} for adstat in self.assertion.authz_decision_statement: # one of 'Permit', 'Deny', 'Indeterminate' res[adstat.decision.text.lower()] = adstat @@ -632,19 +642,18 @@ class AuthnResponse(StatusResponse): nooa = self.not_on_or_after if self.context == "AuthzQuery": - return {"name_id": self.name_id, - "came_from": self.came_from, "issuer": self.issuer(), - "not_on_or_after": nooa, + return {"name_id": self.name_id, "came_from": self.came_from, + "issuer": self.issuer(), "not_on_or_after": nooa, "authz_decision_info": self.authz_decision_info() } else: - return { "ava": self.ava, "name_id": self.name_id, + return {"ava": self.ava, "name_id": self.name_id, "came_from": self.came_from, "issuer": self.issuer(), - "not_on_or_after": nooa, - "authn_info": self.authn_info() } + "not_on_or_after": nooa, "authn_info": self.authn_info()} def __str__(self): return "%s" % self.xmlstr + class AuthnQueryResponse(AuthnResponse): msgtype = "authn_query_response" @@ -659,39 +668,44 @@ class AuthnQueryResponse(AuthnResponse): self.assertion = None self.context = "AuthnQueryResponse" - def condition_ok(self, lax=False): # Should I care about conditions ? + def condition_ok(self, lax=False): # Should I care about conditions ? return True + class AttributeResponse(AuthnResponse): msgtype = "attribute_response" def __init__(self, sec_context, attribute_converters, entity_id, - return_addr=None, timeslack=0, asynchop=False, test=False): + return_addr=None, timeslack=0, asynchop=False, test=False): AuthnResponse.__init__(self, sec_context, attribute_converters, - entity_id, return_addr, timeslack=timeslack, - asynchop=asynchop, test=test) + entity_id, return_addr, timeslack=timeslack, + asynchop=asynchop, test=test) self.entity_id = entity_id self.attribute_converters = attribute_converters self.assertion = None self.context = "AttrQuery" + class AuthzResponse(AuthnResponse): """ A successful response will be in the form of assertions containing authorization decision statements.""" msgtype = "authz_decision_response" + def __init__(self, sec_context, attribute_converters, entity_id, - return_addr=None, timeslack=0, asynchop=False): + return_addr=None, timeslack=0, asynchop=False): AuthnResponse.__init__(self, sec_context, attribute_converters, - entity_id, return_addr, - timeslack=timeslack, asynchop=asynchop) + entity_id, return_addr, timeslack=timeslack, + asynchop=asynchop) self.entity_id = entity_id self.attribute_converters = attribute_converters self.assertion = None self.context = "AuthzQuery" + class ArtifactResponse(AuthnResponse): msgtype = "artifact_response" + def __init__(self, sec_context, attribute_converters, entity_id, return_addr=None, timeslack=0, asynchop=False, test=False): @@ -704,10 +718,9 @@ class ArtifactResponse(AuthnResponse): self.context = "ArtifactResolve" -def response_factory(xmlstr, conf, return_addr=None, - outstanding_queries=None, - timeslack=0, decode=True, request_id=0, - origxml=None, asynchop=True, allow_unsolicited=False): +def response_factory(xmlstr, conf, return_addr=None, outstanding_queries=None, + timeslack=0, decode=True, request_id=0, origxml=None, + asynchop=True, allow_unsolicited=False): sec_context = security_context(conf) if not timeslack: try: @@ -723,9 +736,10 @@ def response_factory(xmlstr, conf, return_addr=None, try: response.loads(xmlstr, decode, origxml) if response.response.assertion or response.response.encrypted_assertion: - authnresp = AuthnResponse(sec_context, attribute_converters, - entity_id, return_addr, outstanding_queries, - timeslack, asynchop, allow_unsolicited) + authnresp = AuthnResponse(sec_context, attribute_converters, + entity_id, return_addr, + outstanding_queries, timeslack, asynchop, + allow_unsolicited) authnresp.update(response) return authnresp except TypeError: @@ -741,6 +755,7 @@ def response_factory(xmlstr, conf, return_addr=None, # =========================================================================== # A class of it's own + class AssertionIDResponse(object): msgtype = "assertion_id_response" diff --git a/src/saml2/virtual_org.py b/src/saml2/virtual_org.py index 9369e602..f2ffaec1 100644 --- a/src/saml2/virtual_org.py +++ b/src/saml2/virtual_org.py @@ -4,9 +4,10 @@ from saml2.saml import NAMEID_FORMAT_PERSISTENT logger = logging.getLogger(__name__) + class VirtualOrg(object): def __init__(self, sp, vorg, cnf): - self.sp = sp # The parent SP client instance + self.sp = sp # The parent SP client instance self._name = vorg self.common_identifier = cnf["common_identifier"] try: @@ -28,7 +29,7 @@ class VirtualOrg(object): """ return self.sp.config.metadata.vo_members(self._name) - def members_to_ask(self, subject_id): + def members_to_ask(self, name_id): """Find the member of the Virtual Organization that I haven't already spoken too """ @@ -40,12 +41,12 @@ class VirtualOrg(object): # Remove the ones I have cached data from about this subject vo_members = [m for m in vo_members if not self.sp.users.cache.active( - subject_id, m)] + name_id, m)] logger.info("VO members (not cached): %s" % vo_members) return vo_members - def get_common_identifier(self, subject_id): - (ava, _) = self.sp.users.get_identity(subject_id) + def get_common_identifier(self, name_id): + (ava, _) = self.sp.users.get_identity(name_id) if ava == {}: return None @@ -56,36 +57,23 @@ class VirtualOrg(object): except KeyError: return None - def do_aggregation(self, subject_id): + def do_aggregation(self, name_id): logger.info("** Do VO aggregation **\nSubjectID: %s, VO:%s" % ( - subject_id, self._name)) + name_id, self._name)) - to_ask = self.members_to_ask(subject_id) + to_ask = self.members_to_ask(name_id) if to_ask: - # Find the NameIDFormat and the SPNameQualifier - if self.nameid_format: - name_id_format = self.nameid_format - sp_name_qualifier = "" - else: - sp_name_qualifier = self._name - name_id_format = "" - - com_identifier = self.get_common_identifier(subject_id) + com_identifier = self.get_common_identifier(name_id) resolver = AttributeResolver(self.sp) # extends returns a list of session_infos - for session_info in resolver.extend(com_identifier, - self.sp.config.entityid, - to_ask, - name_id_format=name_id_format, - sp_name_qualifier=sp_name_qualifier, - real_id=subject_id): + for session_info in resolver.extend( + com_identifier, self.sp.config.entityid, to_ask): _ = self._cache_session(session_info) - logger.info(">Issuers: %s" % self.sp.users.issuers_of_info( - subject_id)) - logger.info("AVA: %s" % (self.sp.users.get_identity(subject_id),)) + logger.info(">Issuers: %s" % self.sp.users.issuers_of_info(name_id)) + logger.info("AVA: %s" % (self.sp.users.get_identity(name_id),)) return True else: |