summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/saml2/config.py3
-rw-r--r--src/saml2/eptid.py15
-rw-r--r--src/saml2/ident.py20
-rw-r--r--src/saml2/mdstore.py3
-rw-r--r--src/saml2/mongo_store.py8
-rw-r--r--src/saml2/s_utils.py9
6 files changed, 49 insertions, 9 deletions
diff --git a/src/saml2/config.py b/src/saml2/config.py
index d7953a0b..bc934f0a 100644
--- a/src/saml2/config.py
+++ b/src/saml2/config.py
@@ -8,6 +8,7 @@ import os
import re
import logging
import logging.handlers
+import six
from importlib import import_module
@@ -296,7 +297,7 @@ class Config(object):
def unicode_convert(self, item):
try:
- return unicode(item, "utf-8")
+ return six.text_type(item, "utf-8")
except TypeError:
_uc = self.unicode_convert
if isinstance(item, dict):
diff --git a/src/saml2/eptid.py b/src/saml2/eptid.py
index b2551b44..af5e9924 100644
--- a/src/saml2/eptid.py
+++ b/src/saml2/eptid.py
@@ -8,6 +8,7 @@ import hashlib
import shelve
import logging
+import six
logger = logging.getLogger(__name__)
@@ -21,13 +22,23 @@ class Eptid(object):
md5 = hashlib.md5()
for arg in args:
md5.update(arg.encode("utf-8"))
- md5.update(sp)
- md5.update(self.secret)
+ if isinstance(sp, six.binary_type):
+ md5.update(sp)
+ else:
+ md5.update(sp.encode('utf-8'))
+ if isinstance(self.secret, six.binary_type):
+ md5.update(self.secret)
+ else:
+ md5.update(self.secret.encode('utf-8'))
md5.digest()
hashval = md5.hexdigest()
+ if isinstance(hashval, six.binary_type):
+ hashval = hashval.decode('ascii')
return "!".join([idp, sp, hashval])
def __getitem__(self, key):
+ if six.PY3 and isinstance(key, six.binary_type):
+ key = key.decode('utf-8')
return self._db[key]
def __setitem__(self, key, value):
diff --git a/src/saml2/ident.py b/src/saml2/ident.py
index 49f8a632..c99a3bd4 100644
--- a/src/saml2/ident.py
+++ b/src/saml2/ident.py
@@ -7,7 +7,7 @@ from hashlib import sha256
from six.moves.urllib.parse import quote
from six.moves.urllib.parse import unquote
from saml2 import SAMLError
-from saml2.s_utils import rndstr
+from saml2.s_utils import rndbytes
from saml2.s_utils import PolicyError
from saml2.saml import NameID
from saml2.saml import NAMEID_FORMAT_PERSISTENT
@@ -46,6 +46,16 @@ def code(item):
return ",".join(_res)
+def code_binary(item):
+ """
+ Return a binary 'code' suitable for hashing.
+ """
+ code_str = code(item)
+ if isinstance(code_str, six.string_types):
+ return code_str.encode('utf-8')
+ return code_str
+
+
def decode(txt):
"""Turns a coded string by code() into a NameID class instance.
@@ -75,11 +85,17 @@ class IdentDB(object):
self.name_qualifier = name_qualifier
def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""):
- _id = sha256(rndstr(32))
+ _id = sha256(rndbytes(32))
+ if not isinstance(nformat, six.binary_type):
+ nformat = nformat.encode('utf-8')
_id.update(nformat)
if name_qualifier:
+ if not isinstance(name_qualifier, six.binary_type):
+ name_qualifier = name_qualifier.encode('utf-8')
_id.update(name_qualifier)
if sp_name_qualifier:
+ if not isinstance(sp_name_qualifier, six.binary_type):
+ sp_name_qualifier = sp_name_qualifier.encode('utf-8')
_id.update(sp_name_qualifier)
return _id.hexdigest()
diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py
index d6db2bdc..68206dfc 100644
--- a/src/saml2/mdstore.py
+++ b/src/saml2/mdstore.py
@@ -3,6 +3,7 @@ import logging
import os
import sys
import json
+import six
from hashlib import sha1
from os.path import isfile, join
@@ -487,6 +488,8 @@ class InMemoryMetaData(MetaData):
try:
for srv in ent[desc]:
if "artifact_resolution_service" in srv:
+ if isinstance(eid, six.string_types):
+ eid = eid.encode('utf-8')
s = sha1(eid)
res[s.digest()] = ent
except KeyError:
diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py
index 71908825..ad9b0728 100644
--- a/src/saml2/mongo_store.py
+++ b/src/saml2/mongo_store.py
@@ -9,7 +9,7 @@ from saml2.eptid import Eptid
from saml2.mdstore import InMemoryMetaData
from saml2.s_utils import PolicyError
-from saml2.ident import code, IdentDB, Unknown
+from saml2.ident import code_binary, IdentDB, Unknown
from saml2.mdie import to_dict, from_dict
from saml2 import md
@@ -59,7 +59,7 @@ class SessionStorageMDB(object):
def store_assertion(self, assertion, to_sign):
name_id = assertion.subject.name_id
- nkey = sha1(code(name_id)).hexdigest()
+ nkey = sha1(code_binary(name_id)).hexdigest()
doc = {
"name_id_key": nkey,
@@ -94,7 +94,7 @@ class SessionStorageMDB(object):
:return:
"""
result = []
- key = sha1(code(name_id)).hexdigest()
+ key = sha1(code_binary(name_id)).hexdigest()
for item in self.assertion.find({"name_id_key": key}):
assertion = from_dict(item["assertion"], ONTS, True)
if session_index or requested_context:
@@ -114,7 +114,7 @@ class SessionStorageMDB(object):
def remove_authn_statements(self, name_id):
logger.debug("remove authn about: %s" % name_id)
- key = sha1(code(name_id)).hexdigest()
+ key = sha1(code_binary(name_id)).hexdigest()
for item in self.assertion.find({"name_id_key": key}):
self.assertion.remove(item["_id"])
diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py
index 9d2391ef..271dea93 100644
--- a/src/saml2/s_utils.py
+++ b/src/saml2/s_utils.py
@@ -167,6 +167,15 @@ def rndstr(size=16, alphabet=""):
alphabet = string.ascii_letters[0:52] + string.digits
return type(alphabet)().join(rng.choice(alphabet) for _ in range(size))
+def rndbytes(size=16, alphabet=""):
+ """
+ Returns rndstr always as a binary type
+ """
+ x = rndstr(size, alphabet)
+ if isinstance(x, six.string_types):
+ return x.encode('utf-8')
+ return x
+
def sid():
"""creates an unique SID for each session.