diff options
-rw-r--r-- | src/saml2/config.py | 3 | ||||
-rw-r--r-- | src/saml2/eptid.py | 15 | ||||
-rw-r--r-- | src/saml2/ident.py | 20 | ||||
-rw-r--r-- | src/saml2/mdstore.py | 3 | ||||
-rw-r--r-- | src/saml2/mongo_store.py | 8 | ||||
-rw-r--r-- | src/saml2/s_utils.py | 9 |
6 files changed, 49 insertions, 9 deletions
diff --git a/src/saml2/config.py b/src/saml2/config.py index d7953a0b..bc934f0a 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -8,6 +8,7 @@ import os import re import logging import logging.handlers +import six from importlib import import_module @@ -296,7 +297,7 @@ class Config(object): def unicode_convert(self, item): try: - return unicode(item, "utf-8") + return six.text_type(item, "utf-8") except TypeError: _uc = self.unicode_convert if isinstance(item, dict): diff --git a/src/saml2/eptid.py b/src/saml2/eptid.py index b2551b44..af5e9924 100644 --- a/src/saml2/eptid.py +++ b/src/saml2/eptid.py @@ -8,6 +8,7 @@ import hashlib import shelve import logging +import six logger = logging.getLogger(__name__) @@ -21,13 +22,23 @@ class Eptid(object): md5 = hashlib.md5() for arg in args: md5.update(arg.encode("utf-8")) - md5.update(sp) - md5.update(self.secret) + if isinstance(sp, six.binary_type): + md5.update(sp) + else: + md5.update(sp.encode('utf-8')) + if isinstance(self.secret, six.binary_type): + md5.update(self.secret) + else: + md5.update(self.secret.encode('utf-8')) md5.digest() hashval = md5.hexdigest() + if isinstance(hashval, six.binary_type): + hashval = hashval.decode('ascii') return "!".join([idp, sp, hashval]) def __getitem__(self, key): + if six.PY3 and isinstance(key, six.binary_type): + key = key.decode('utf-8') return self._db[key] def __setitem__(self, key, value): diff --git a/src/saml2/ident.py b/src/saml2/ident.py index 49f8a632..c99a3bd4 100644 --- a/src/saml2/ident.py +++ b/src/saml2/ident.py @@ -7,7 +7,7 @@ from hashlib import sha256 from six.moves.urllib.parse import quote from six.moves.urllib.parse import unquote from saml2 import SAMLError -from saml2.s_utils import rndstr +from saml2.s_utils import rndbytes from saml2.s_utils import PolicyError from saml2.saml import NameID from saml2.saml import NAMEID_FORMAT_PERSISTENT @@ -46,6 +46,16 @@ def code(item): return ",".join(_res) +def code_binary(item): + """ + Return a binary 'code' suitable for hashing. + """ + code_str = code(item) + if isinstance(code_str, six.string_types): + return code_str.encode('utf-8') + return code_str + + def decode(txt): """Turns a coded string by code() into a NameID class instance. @@ -75,11 +85,17 @@ class IdentDB(object): self.name_qualifier = name_qualifier def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""): - _id = sha256(rndstr(32)) + _id = sha256(rndbytes(32)) + if not isinstance(nformat, six.binary_type): + nformat = nformat.encode('utf-8') _id.update(nformat) if name_qualifier: + if not isinstance(name_qualifier, six.binary_type): + name_qualifier = name_qualifier.encode('utf-8') _id.update(name_qualifier) if sp_name_qualifier: + if not isinstance(sp_name_qualifier, six.binary_type): + sp_name_qualifier = sp_name_qualifier.encode('utf-8') _id.update(sp_name_qualifier) return _id.hexdigest() diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index d6db2bdc..68206dfc 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -3,6 +3,7 @@ import logging import os import sys import json +import six from hashlib import sha1 from os.path import isfile, join @@ -487,6 +488,8 @@ class InMemoryMetaData(MetaData): try: for srv in ent[desc]: if "artifact_resolution_service" in srv: + if isinstance(eid, six.string_types): + eid = eid.encode('utf-8') s = sha1(eid) res[s.digest()] = ent except KeyError: diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py index 71908825..ad9b0728 100644 --- a/src/saml2/mongo_store.py +++ b/src/saml2/mongo_store.py @@ -9,7 +9,7 @@ from saml2.eptid import Eptid from saml2.mdstore import InMemoryMetaData from saml2.s_utils import PolicyError -from saml2.ident import code, IdentDB, Unknown +from saml2.ident import code_binary, IdentDB, Unknown from saml2.mdie import to_dict, from_dict from saml2 import md @@ -59,7 +59,7 @@ class SessionStorageMDB(object): def store_assertion(self, assertion, to_sign): name_id = assertion.subject.name_id - nkey = sha1(code(name_id)).hexdigest() + nkey = sha1(code_binary(name_id)).hexdigest() doc = { "name_id_key": nkey, @@ -94,7 +94,7 @@ class SessionStorageMDB(object): :return: """ result = [] - key = sha1(code(name_id)).hexdigest() + key = sha1(code_binary(name_id)).hexdigest() for item in self.assertion.find({"name_id_key": key}): assertion = from_dict(item["assertion"], ONTS, True) if session_index or requested_context: @@ -114,7 +114,7 @@ class SessionStorageMDB(object): def remove_authn_statements(self, name_id): logger.debug("remove authn about: %s" % name_id) - key = sha1(code(name_id)).hexdigest() + key = sha1(code_binary(name_id)).hexdigest() for item in self.assertion.find({"name_id_key": key}): self.assertion.remove(item["_id"]) diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index 9d2391ef..271dea93 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -167,6 +167,15 @@ def rndstr(size=16, alphabet=""): alphabet = string.ascii_letters[0:52] + string.digits return type(alphabet)().join(rng.choice(alphabet) for _ in range(size)) +def rndbytes(size=16, alphabet=""): + """ + Returns rndstr always as a binary type + """ + x = rndstr(size, alphabet) + if isinstance(x, six.string_types): + return x.encode('utf-8') + return x + def sid(): """creates an unique SID for each session. |