From 8241a199a48260744cf6cbe8cfa5edcf34f8a4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 28 Aug 2018 14:29:06 +0200 Subject: Convert data values for extensions to text --- openid/extensions/ax.py | 16 ++++++++-------- openid/extensions/sreg.py | 15 ++++++--------- openid/oidutil.py | 15 +++++++++++++++ openid/test/test_ax.py | 10 ++++++++++ openid/test/test_oidutil.py | 22 +++++++++++++++++++++- openid/test/test_sreg.py | 6 ++++++ 6 files changed, 66 insertions(+), 18 deletions(-) (limited to 'openid') diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index faffdb4..79856f7 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -8,7 +8,7 @@ import six from openid import extension from openid.message import OPENID_NS, NamespaceMap -from openid.oidutil import string_to_text +from openid.oidutil import force_text, string_to_text from openid.server.trustroot import TrustRoot __all__ = [ @@ -421,9 +421,9 @@ class AXKeyValueMessage(AXMessage): @param type_uri: The URI for the attribute - @param value: The value to add to the response to the relying - party for this attribute - @type value: six.text_type + @param value: The value to add to the response to the relying party for this attribute. It the value is not + a text, it will be converted. + @type value: Any @returns: None """ @@ -432,7 +432,7 @@ class AXKeyValueMessage(AXMessage): except KeyError: values = self.data[type_uri] = [] - values.append(value) + values.append(force_text(value)) def setValues(self, type_uri, values): """Set the values for the given attribute type. This replaces @@ -440,11 +440,11 @@ class AXKeyValueMessage(AXMessage): @param type_uri: The URI for the attribute - @param values: A list of values to send for this attribute. - @type values: List[six.text_type] + @param values: A list of values to send for this attribute. Values which are not text, will be converted. + @type values: List[Any] """ - self.data[type_uri] = values + self.data[type_uri] = [force_text(v) for v in values] def _getExtensionKVArgs(self, aliases=None): """Get the extension arguments for the key/value pairs diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 7f9828e..a90e4d4 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -42,7 +42,7 @@ import six from openid.extension import Extension from openid.message import NamespaceAliasRegistrationError, registerNamespaceAlias -from openid.oidutil import string_to_text +from openid.oidutil import force_text, string_to_text __all__ = [ 'SRegRequest', @@ -424,12 +424,10 @@ class SRegResponse(Extension): @param request: The simple registration request object @type request: SRegRequest - @param data: The simple registration data for this - response, as a dictionary from unqualified simple - registration field name to string (unicode) value. For - instance, the nickname should be stored under the key - 'nickname'. - @type data: Dict[six.text_type, six.text_type], six.binary_type is deprecated + @param data: The simple registration data for this response, as a mapping of unqualified simple registration + field name to value. For instance, the nickname should be stored under the key 'nickname'. If the value is + missing or None, it will be skipped. If the value is not a text, it will be converted. + @type data: Dict[six.text_type, Any] @returns: a simple registration response object @rtype: SRegResponse @@ -439,8 +437,7 @@ class SRegResponse(Extension): for field in request.allRequestedFields(): value = data.get(field) if value is not None: - value = string_to_text(value, "Binary values for data are deprecated. Use text input instead.") - self.data[field] = value + self.data[field] = force_text(value) return self # Assign getSRegArgs to a static method so that it can be diff --git a/openid/oidutil.py b/openid/oidutil.py index 0f9ff99..884d38f 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -162,3 +162,18 @@ def string_to_text(value, deprecate_msg): warnings.warn(deprecate_msg, DeprecationWarning) value = value.decode('utf-8') return value + + +def force_text(value): + """ + Return a text object representing value in UTF-8 encoding. + """ + if isinstance(value, six.text_type): + # It's already a text, just return it. + return value + elif isinstance(value, bytes): + # It's a byte string, decode it. + return value.decode('utf-8') + else: + # It's not a string, convert it. + return six.text_type(value) diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 0b05961..d6d4c38 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -168,6 +168,16 @@ class ParseAXValuesTest(unittest.TestCase): def test_singletonValue(self): self.assertAXValues({'type.foo': 'urn:foo', 'value.foo': 'Westfall'}, {'urn:foo': ['Westfall']}) + def test_add_value_convert(self): + message = ax.AXKeyValueMessage() + message.addValue('http://example.com/attribute', 1492) + self.assertEqual(message.get('http://example.com/attribute'), ['1492']) + + def test_set_values_convert(self): + message = ax.AXKeyValueMessage() + message.setValues('http://example.com/attribute', [1492, True, None]) + self.assertEqual(message.get('http://example.com/attribute'), ['1492', 'True', 'None']) + class FetchRequestTest(unittest.TestCase): def setUp(self): diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index 8aacb58..9cd9943 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -12,7 +12,7 @@ from mock import sentinel from testfixtures import ShouldWarn from openid import oidutil -from openid.oidutil import string_to_text +from openid.oidutil import force_text, string_to_text class TestBase64(unittest.TestCase): @@ -179,3 +179,23 @@ class TestToText(unittest.TestCase): self.assertIsInstance(result, six.text_type) self.assertEqual(result, 'ěščřž') + + +class TestForceText(unittest.TestCase): + """Test `force_text` utility function.""" + + def test_text(self): + self.assertEqual(force_text(''), '') + self.assertEqual(force_text('ascii'), 'ascii') + self.assertEqual(force_text('ůňíčóďé'), 'ůňíčóďé') + + def test_bytes(self): + self.assertEqual(force_text(b''), '') + self.assertEqual(force_text(b'ascii'), 'ascii') + self.assertEqual(force_text('ůňíčóďé'.encode('utf-8')), 'ůňíčóďé') + + def test_objects(self): + self.assertEqual(force_text(None), 'None') + self.assertEqual(force_text(14), '14') + self.assertEqual(force_text(True), 'True') + self.assertEqual(force_text(False), 'False') diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 1224cbd..d4989b5 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import unittest +from datetime import date from openid.extensions import sreg from openid.message import Message, NamespaceMap @@ -461,6 +462,11 @@ class SendFieldsTest(unittest.TestCase): sent_data = {'nickname': 'linusaur', 'email': 'president@whitehouse.gov', 'fullname': 'Leonhard Euler'} self.assertEqual(sreg_data_resp, sent_data) + def test_extract_response_conversion(self): + sreg_request = sreg.SRegRequest(required=['dob']) + sreg_response = sreg.SRegResponse.extractResponse(sreg_request, {'dob': date(1989, 11, 17)}) + self.assertEqual(sreg_response['dob'], '1989-11-17') + if __name__ == '__main__': unittest.main() -- cgit v1.2.1