summaryrefslogtreecommitdiff
path: root/src/saml2/cache.py
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2013-02-09 18:57:26 +0100
committerRoland Hedberg <roland.hedberg@adm.umu.se>2013-02-09 18:57:26 +0100
commitf295e06ab790c5454c164255328ebba4a0c4f2da (patch)
treef7b29dd6acb2515cd0afbec4a33a82bc812c2b37 /src/saml2/cache.py
parent71246a38292f2f598163ce4b69a3c298f95d04c2 (diff)
downloadpysaml2-f295e06ab790c5454c164255328ebba4a0c4f2da.tar.gz
Rewrote to use NameID instances every where where I previously used just the text part of the instance.
Diffstat (limited to 'src/saml2/cache.py')
-rw-r--r--src/saml2/cache.py77
1 files changed, 48 insertions, 29 deletions
diff --git a/src/saml2/cache.py b/src/saml2/cache.py
index fc7b57c4..a128da3f 100644
--- a/src/saml2/cache.py
+++ b/src/saml2/cache.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import shelve
+from saml2.ident import code, decode
from saml2 import time_util
import logging
@@ -10,12 +11,15 @@ logger = logging.getLogger(__name__)
# gathered from several different sources, all with their own
# timeout time.
+
class ToOld(Exception):
pass
+
class CacheError(Exception):
pass
+
class Cache(object):
def __init__(self, filename=None):
if filename:
@@ -25,18 +29,25 @@ class Cache(object):
self._db = {}
self._sync = False
- def delete(self, subject_id):
- del self._db[subject_id]
+ def delete(self, name_id):
+ """
+
+ :param name_id: The subject identifier, a NameID instance
+ """
+ del self._db[code(name_id)]
if self._sync:
- self._db.sync()
+ try:
+ self._db.sync()
+ except AttributeError:
+ pass
- def get_identity(self, subject_id, entities=None,
+ def get_identity(self, name_id, entities=None,
check_not_on_or_after=True):
""" Get all the identity information that has been received and
are still valid about the subject.
- :param subject_id: The identifier of the subject
+ :param name_id: The subject identifier, a NameID instance
:param entities: The identifiers of the entities whoes assertions are
interesting. If the list is empty all entities are interesting.
:return: A 2-tuple consisting of the identity information (a
@@ -45,7 +56,8 @@ class Cache(object):
"""
if not entities:
try:
- entities = self._db[subject_id].keys()
+ cni = code(name_id)
+ entities = self._db[cni].keys()
except KeyError:
return {}, []
@@ -53,7 +65,7 @@ class Cache(object):
oldees = []
for entity_id in entities:
try:
- info = self.get(subject_id, entity_id, check_not_on_or_after)
+ info = self.get(name_id, entity_id, check_not_on_or_after)
except ToOld:
oldees.append(entity_id)
continue
@@ -70,74 +82,81 @@ class Cache(object):
res[key] = vals
return res, oldees
- def get(self, subject_id, entity_id, check_not_on_or_after=True):
+ def get(self, name_id, entity_id, check_not_on_or_after=True):
""" Get session information about a subject gotten from a
specified IdP/AA.
- :param subject_id: The identifier of the subject
+ :param name_id: The subject identifier, a NameID instance
:param entity_id: The identifier of the entity_id
:param check_not_on_or_after: if True it will check if this
subject is still valid or if it is too old. Otherwise it
will not check this. True by default.
:return: The session information
"""
- (timestamp, info) = self._db[subject_id][entity_id]
+ cni = code(name_id)
+ (timestamp, info) = self._db[cni][entity_id]
if check_not_on_or_after and time_util.after(timestamp):
raise ToOld("past %s" % timestamp)
return info or None
- def set(self, subject_id, entity_id, info, not_on_or_after=0):
- """ Stores session information in the cache. Assumes that the subject_id
+ def set(self, name_id, entity_id, info, not_on_or_after=0):
+ """ Stores session information in the cache. Assumes that the name_id
is unique within the context of the Service Provider.
- :param subject_id: The subject identifier
+ :param name_id: The subject identifier, a NameID instance
:param entity_id: The identifier of the entity_id/receiver of an
assertion
:param info: The session info, the assertion is part of this
:param not_on_or_after: A time after which the assertion is not valid.
"""
- if subject_id not in self._db:
- self._db[subject_id] = {}
+ cni = code(name_id)
+ if cni not in self._db:
+ self._db[cni] = {}
- self._db[subject_id][entity_id] = (not_on_or_after, info)
+ self._db[cni][entity_id] = (not_on_or_after, info)
if self._sync:
- self._db.sync()
+ try:
+ self._db.sync()
+ except AttributeError:
+ pass
- def reset(self, subject_id, entity_id):
+ def reset(self, name_id, entity_id):
""" Scrap the assertions received from a IdP or an AA about a special
subject.
- :param subject_id: The subjects identifier
+ :param name_id: The subject identifier, a NameID instance
:param entity_id: The identifier of the entity_id of the assertion
:return:
"""
- self.set(subject_id, entity_id, {}, 0)
+ self.set(name_id, entity_id, {}, 0)
- def entities(self, subject_id):
+ def entities(self, name_id):
""" Returns all the entities of assertions for a subject, disregarding
whether the assertion still is valid or not.
- :param subject_id: The identifier of the subject
+ :param name_id: The subject identifier, a NameID instance
:return: A possibly empty list of entity identifiers
"""
- return self._db[subject_id].keys()
+ cni = code(name_id)
+ return self._db[cni].keys()
- def receivers(self, subject_id):
+ def receivers(self, name_id):
""" Another name for entities() just to make it more logic in the IdP
scenario """
- return self.entities(subject_id)
+ return self.entities(name_id)
- def active(self, subject_id, entity_id):
+ def active(self, name_id, entity_id):
""" Returns the status of assertions from a specific entity_id.
- :param subject_id: The ID of the subject
+ :param name_id: The ID of the subject
:param entity_id: The entity ID of the entity_id of the assertion
:return: True or False depending on if the assertion is still
valid or not.
"""
try:
- (timestamp, info) = self._db[subject_id][entity_id]
+ cni = code(name_id)
+ (timestamp, info) = self._db[cni][entity_id]
except KeyError:
return False
@@ -151,4 +170,4 @@ class Cache(object):
:return: list of subject identifiers
"""
- return self._db.keys()
+ return [decode(c) for c in self._db.keys()]