summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVlastimil Zíma <vlastimil.zima@nic.cz>2018-08-29 15:21:23 +0200
committerVlastimil Zíma <vlastimil.zima@nic.cz>2018-08-29 15:21:23 +0200
commit8bfc9e1a796f13bd80e70bdff25f38ffc761e4cf (patch)
tree4e348a6b13d6b85be2bf795e00e8012f03b66153
parent0693fc2038706a5f32e3393338c2d7d750bef1db (diff)
parent8241a199a48260744cf6cbe8cfa5edcf34f8a4c4 (diff)
downloadopenid-8bfc9e1a796f13bd80e70bdff25f38ffc761e4cf.tar.gz
Merge branch 'convert-extension-values'
-rw-r--r--openid/extensions/ax.py16
-rw-r--r--openid/extensions/sreg.py15
-rw-r--r--openid/oidutil.py15
-rw-r--r--openid/test/test_ax.py10
-rw-r--r--openid/test/test_oidutil.py22
-rw-r--r--openid/test/test_sreg.py6
6 files changed, 66 insertions, 18 deletions
diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py
index faffdb4..79856f7 100644
--- a/openid/extensions/ax.py
+++ b/openid/extensions/ax.py
@@ -8,7 +8,7 @@ import six
from openid import extension
from openid.message import OPENID_NS, NamespaceMap
-from openid.oidutil import string_to_text
+from openid.oidutil import force_text, string_to_text
from openid.server.trustroot import TrustRoot
__all__ = [
@@ -421,9 +421,9 @@ class AXKeyValueMessage(AXMessage):
@param type_uri: The URI for the attribute
- @param value: The value to add to the response to the relying
- party for this attribute
- @type value: six.text_type
+ @param value: The value to add to the response to the relying party for this attribute. It the value is not
+ a text, it will be converted.
+ @type value: Any
@returns: None
"""
@@ -432,7 +432,7 @@ class AXKeyValueMessage(AXMessage):
except KeyError:
values = self.data[type_uri] = []
- values.append(value)
+ values.append(force_text(value))
def setValues(self, type_uri, values):
"""Set the values for the given attribute type. This replaces
@@ -440,11 +440,11 @@ class AXKeyValueMessage(AXMessage):
@param type_uri: The URI for the attribute
- @param values: A list of values to send for this attribute.
- @type values: List[six.text_type]
+ @param values: A list of values to send for this attribute. Values which are not text, will be converted.
+ @type values: List[Any]
"""
- self.data[type_uri] = values
+ self.data[type_uri] = [force_text(v) for v in values]
def _getExtensionKVArgs(self, aliases=None):
"""Get the extension arguments for the key/value pairs
diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py
index 7f9828e..a90e4d4 100644
--- a/openid/extensions/sreg.py
+++ b/openid/extensions/sreg.py
@@ -42,7 +42,7 @@ import six
from openid.extension import Extension
from openid.message import NamespaceAliasRegistrationError, registerNamespaceAlias
-from openid.oidutil import string_to_text
+from openid.oidutil import force_text, string_to_text
__all__ = [
'SRegRequest',
@@ -424,12 +424,10 @@ class SRegResponse(Extension):
@param request: The simple registration request object
@type request: SRegRequest
- @param data: The simple registration data for this
- response, as a dictionary from unqualified simple
- registration field name to string (unicode) value. For
- instance, the nickname should be stored under the key
- 'nickname'.
- @type data: Dict[six.text_type, six.text_type], six.binary_type is deprecated
+ @param data: The simple registration data for this response, as a mapping of unqualified simple registration
+ field name to value. For instance, the nickname should be stored under the key 'nickname'. If the value is
+ missing or None, it will be skipped. If the value is not a text, it will be converted.
+ @type data: Dict[six.text_type, Any]
@returns: a simple registration response object
@rtype: SRegResponse
@@ -439,8 +437,7 @@ class SRegResponse(Extension):
for field in request.allRequestedFields():
value = data.get(field)
if value is not None:
- value = string_to_text(value, "Binary values for data are deprecated. Use text input instead.")
- self.data[field] = value
+ self.data[field] = force_text(value)
return self
# Assign getSRegArgs to a static method so that it can be
diff --git a/openid/oidutil.py b/openid/oidutil.py
index 0f9ff99..884d38f 100644
--- a/openid/oidutil.py
+++ b/openid/oidutil.py
@@ -162,3 +162,18 @@ def string_to_text(value, deprecate_msg):
warnings.warn(deprecate_msg, DeprecationWarning)
value = value.decode('utf-8')
return value
+
+
+def force_text(value):
+ """
+ Return a text object representing value in UTF-8 encoding.
+ """
+ if isinstance(value, six.text_type):
+ # It's already a text, just return it.
+ return value
+ elif isinstance(value, bytes):
+ # It's a byte string, decode it.
+ return value.decode('utf-8')
+ else:
+ # It's not a string, convert it.
+ return six.text_type(value)
diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py
index 0b05961..d6d4c38 100644
--- a/openid/test/test_ax.py
+++ b/openid/test/test_ax.py
@@ -168,6 +168,16 @@ class ParseAXValuesTest(unittest.TestCase):
def test_singletonValue(self):
self.assertAXValues({'type.foo': 'urn:foo', 'value.foo': 'Westfall'}, {'urn:foo': ['Westfall']})
+ def test_add_value_convert(self):
+ message = ax.AXKeyValueMessage()
+ message.addValue('http://example.com/attribute', 1492)
+ self.assertEqual(message.get('http://example.com/attribute'), ['1492'])
+
+ def test_set_values_convert(self):
+ message = ax.AXKeyValueMessage()
+ message.setValues('http://example.com/attribute', [1492, True, None])
+ self.assertEqual(message.get('http://example.com/attribute'), ['1492', 'True', 'None'])
+
class FetchRequestTest(unittest.TestCase):
def setUp(self):
diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py
index 8aacb58..9cd9943 100644
--- a/openid/test/test_oidutil.py
+++ b/openid/test/test_oidutil.py
@@ -12,7 +12,7 @@ from mock import sentinel
from testfixtures import ShouldWarn
from openid import oidutil
-from openid.oidutil import string_to_text
+from openid.oidutil import force_text, string_to_text
class TestBase64(unittest.TestCase):
@@ -179,3 +179,23 @@ class TestToText(unittest.TestCase):
self.assertIsInstance(result, six.text_type)
self.assertEqual(result, 'ěščřž')
+
+
+class TestForceText(unittest.TestCase):
+ """Test `force_text` utility function."""
+
+ def test_text(self):
+ self.assertEqual(force_text(''), '')
+ self.assertEqual(force_text('ascii'), 'ascii')
+ self.assertEqual(force_text('ůňíčóďé'), 'ůňíčóďé')
+
+ def test_bytes(self):
+ self.assertEqual(force_text(b''), '')
+ self.assertEqual(force_text(b'ascii'), 'ascii')
+ self.assertEqual(force_text('ůňíčóďé'.encode('utf-8')), 'ůňíčóďé')
+
+ def test_objects(self):
+ self.assertEqual(force_text(None), 'None')
+ self.assertEqual(force_text(14), '14')
+ self.assertEqual(force_text(True), 'True')
+ self.assertEqual(force_text(False), 'False')
diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py
index 1224cbd..d4989b5 100644
--- a/openid/test/test_sreg.py
+++ b/openid/test/test_sreg.py
@@ -1,6 +1,7 @@
from __future__ import unicode_literals
import unittest
+from datetime import date
from openid.extensions import sreg
from openid.message import Message, NamespaceMap
@@ -461,6 +462,11 @@ class SendFieldsTest(unittest.TestCase):
sent_data = {'nickname': 'linusaur', 'email': 'president@whitehouse.gov', 'fullname': 'Leonhard Euler'}
self.assertEqual(sreg_data_resp, sent_data)
+ def test_extract_response_conversion(self):
+ sreg_request = sreg.SRegRequest(required=['dob'])
+ sreg_response = sreg.SRegResponse.extractResponse(sreg_request, {'dob': date(1989, 11, 17)})
+ self.assertEqual(sreg_response['dob'], '1989-11-17')
+
if __name__ == '__main__':
unittest.main()