diff options
author | Ib Lundgren <ib.lundgren@gmail.com> | 2013-01-07 23:59:33 +0100 |
---|---|---|
committer | Ib Lundgren <ib.lundgren@gmail.com> | 2013-01-07 23:59:33 +0100 |
commit | 69ec1ab5698e1f8530615f7900f74ddf21aee183 (patch) | |
tree | 1fc63f7a47905772866eff752b343facad459aa1 | |
parent | 3e62dd062b5ef03d268134bda7c49960e0d8ea25 (diff) | |
download | oauthlib-69ec1ab5698e1f8530615f7900f74ddf21aee183.tar.gz |
Experimental default unicode conversion (#53, #68, #86)
-rw-r--r-- | oauthlib/common.py | 55 | ||||
-rw-r--r-- | oauthlib/oauth1/rfc5849/__init__.py | 65 | ||||
-rw-r--r-- | tests/oauth1/rfc5849/test_client.py | 9 | ||||
-rw-r--r-- | tests/test_common.py | 3 |
4 files changed, 60 insertions, 72 deletions
diff --git a/oauthlib/common.py b/oauthlib/common.py index a540f79..145f897 100644 --- a/oauthlib/common.py +++ b/oauthlib/common.py @@ -237,6 +237,32 @@ def safe_string_equals(a, b): return result == 0 +def to_unicode(data, encoding): + """Convert a number of different types of objects to unicode.""" + if isinstance(data, unicode_type): + return data + + if isinstance(data, bytes_type): + return unicode_type(data, encoding=encoding) + + if hasattr(data, '__iter__'): + try: + dict(data) + except TypeError: + pass + except ValueError: + # Assume it's a one dimensional data structure + return (to_unicode(i, encoding) for i in data) + else: + # We support 2.6 which lacks dict comprehensions + return dict(((to_unicode(k, encoding), to_unicode(v, encoding)) + for k, v in ( + data.items() if isinstance(data, dict) else data + ))) + + return data + + class Request(object): """A malleable representation of a signable HTTP request. @@ -252,26 +278,15 @@ class Request(object): """ def __init__(self, uri, http_method='GET', body=None, headers=None, - convert_to_unicode=False, encoding='utf-8'): - if convert_to_unicode: - if isinstance(uri, bytes_type): - uri = uri.decode(encoding) - if isinstance(http_method, bytes_type): - http_method = http_method.decode(encoding) - if isinstance(body, bytes_type): - body = body.decode(encoding) - unicode_headers = {} - for k, v in headers.items(): - k = k.decode(encoding) if isinstance(k, bytes_type) else k - v = v.decode(encoding) if isinstance(v, bytes_type) else v - unicode_headers[k] = v - headers = unicode_headers - - self.uri = uri - self.http_method = http_method - self.headers = headers or {} - self.body = body - self.decoded_body = extract_params(body) + encoding='utf-8'): + # Convert to unicode using encoding if given, else assume unicode + encode = lambda x: to_unicode(x, encoding) if encoding else x + + self.uri = encode(uri) + self.http_method = encode(http_method) + self.headers = encode(headers or {}) + self.body = encode(body) + self.decoded_body = extract_params(encode(body)) self.oauth_params = [] self._params = {} diff --git a/oauthlib/oauth1/rfc5849/__init__.py b/oauthlib/oauth1/rfc5849/__init__.py index 6099174..8e09cc3 100644 --- a/oauthlib/oauth1/rfc5849/__init__.py +++ b/oauthlib/oauth1/rfc5849/__init__.py @@ -23,7 +23,7 @@ else: bytes_type = str from oauthlib.common import Request, urlencode, generate_nonce -from oauthlib.common import generate_timestamp +from oauthlib.common import generate_timestamp, to_unicode from . import parameters, signature, utils SIGNATURE_HMAC = "HMAC-SHA1" @@ -48,48 +48,24 @@ class Client(object): signature_method=SIGNATURE_HMAC, signature_type=SIGNATURE_TYPE_AUTH_HEADER, rsa_key=None, verifier=None, realm=None, - convert_to_unicode=False, encoding='utf-8', - nonce=None, timestamp=None): - if convert_to_unicode: - if isinstance(client_key, bytes_type): - client_key = client_key.decode(encoding) - if isinstance(client_secret, bytes_type): - client_secret = client_secret.decode(encoding) - if isinstance(resource_owner_key, bytes_type): - resource_owner_key = resource_owner_key.decode(encoding) - if isinstance(resource_owner_secret, bytes_type): - resource_owner_secret = resource_owner_secret.decode(encoding) - if isinstance(callback_uri, bytes_type): - callback_uri = callback_uri.decode(encoding) - if isinstance(signature_method, bytes_type): - signature_method = signature_method.decode(encoding) - if isinstance(signature_type, bytes_type): - signature_type = signature_type.decode(encoding) - if isinstance(rsa_key, bytes_type): - rsa_key = rsa_key.decode(encoding) - if isinstance(verifier, bytes_type): - verifier = verifier.decode(encoding) - if isinstance(realm, bytes_type): - realm = realm.decode(encoding) - if isinstance(nonce, bytes_type): - nonce = nonce.decode(encoding) - if isinstance(timestamp, bytes_type): - timestamp = timestamp.decode(encoding) - - self.client_key = client_key - self.client_secret = client_secret - self.resource_owner_key = resource_owner_key - self.resource_owner_secret = resource_owner_secret - self.signature_method = signature_method - self.signature_type = signature_type - self.callback_uri = callback_uri - self.rsa_key = rsa_key - self.verifier = verifier - self.realm = realm - self.convert_to_unicode = convert_to_unicode - self.encoding = encoding - self.nonce = nonce - self.timestamp = timestamp + encoding='utf-8', nonce=None, timestamp=None): + + # Convert to unicode using encoding if given, else assume unicode + encode = lambda x: to_unicode(x, encoding) if encoding else x + + self.client_key = encode(client_key) + self.client_secret = encode(client_secret) + self.resource_owner_key = encode(resource_owner_key) + self.resource_owner_secret = encode(resource_owner_secret) + self.signature_method = encode(signature_method) + self.signature_type = encode(signature_type) + self.callback_uri = encode(callback_uri) + self.rsa_key = encode(rsa_key) + self.verifier = encode(verifier) + self.realm = encode(realm) + self.encoding = encode(encoding) + self.nonce = encode(nonce) + self.timestamp = encode(timestamp) if self.signature_method == SIGNATURE_RSA and self.rsa_key is None: raise ValueError('rsa_key is required when using RSA signature method.') @@ -137,7 +113,7 @@ class Client(object): """ nonce = (generate_nonce() if self.nonce is None else self.nonce) - timestamp = (generate_timestamp() + timestamp = (generate_timestamp() if self.timestamp is None else self.timestamp) params = [ ('oauth_nonce', nonce), @@ -215,7 +191,6 @@ class Client(object): """ # normalize request data request = Request(uri, http_method, body, headers, - convert_to_unicode=self.convert_to_unicode, encoding=self.encoding) # sanity check diff --git a/tests/oauth1/rfc5849/test_client.py b/tests/oauth1/rfc5849/test_client.py index c84ae5a..ff94f10 100644 --- a/tests/oauth1/rfc5849/test_client.py +++ b/tests/oauth1/rfc5849/test_client.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import, unicode_literals +from __future__ import absolute_import, unicode_literals from oauthlib.oauth1.rfc5849 import Client, bytes_type from ...unittest import TestCase @@ -32,9 +32,8 @@ class ClientRealmTests(TestCase): class ClientConstructorTests(TestCase): def test_convert_to_unicode_resource_owner(self): - client = Client('client-key', - resource_owner_key=b'owner key', - convert_to_unicode=True) + client = Client('client-key', + resource_owner_key=b'owner key') self.assertFalse(isinstance(client.resource_owner_key, bytes_type)) self.assertEqual(client.resource_owner_key, 'owner key') @@ -42,7 +41,7 @@ class ClientConstructorTests(TestCase): client = Client('client-key', timestamp='1') params = dict(client.get_oauth_params()) self.assertEqual(params['oauth_timestamp'], '1') - + def test_give_explicit_nonce(self): client = Client('client-key', nonce='1') params = dict(client.get_oauth_params()) diff --git a/tests/test_common.py b/tests/test_common.py index 8ae89a8..08fdf04 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -56,8 +56,7 @@ class CommonTests(TestCase): r = Request(bytes_type('http://a.b/path?query', 'utf-8'), http_method=bytes_type('GET', 'utf-8'), body=bytes_type('you=shall+pass', 'utf-8'), - headers={bytes_type('a', 'utf-8'): bytes_type('b', 'utf-8')}, - convert_to_unicode=True) + headers={bytes_type('a', 'utf-8'): bytes_type('b', 'utf-8')}) self.assertEqual(r.uri, 'http://a.b/path?query') self.assertEqual(r.http_method, 'GET') self.assertEqual(r.body, 'you=shall+pass') |