diff options
author | andy <andy@andymccurdy.com> | 2012-06-10 23:29:41 -0700 |
---|---|---|
committer | andy <andy@andymccurdy.com> | 2012-06-10 23:29:41 -0700 |
commit | 17ade97a064248f309b771c8ccf3a9ea96e79356 (patch) | |
tree | aed1cccf4784ca9f944a5c77d09dcdf392bb68cb | |
parent | d2a7156f7fd108a24d8aa321e2ecd67b3f1f34c6 (diff) | |
parent | a99b650071ff83368d1b3b1d75d8870a682a2390 (diff) | |
download | redis-py-17ade97a064248f309b771c8ccf3a9ea96e79356.tar.gz |
Merge remote-tracking branch 'encoding/2.4.11-fix' into encoding
Conflicts:
redis/connection.py
-rw-r--r-- | redis/client.py | 8 | ||||
-rw-r--r-- | redis/connection.py | 35 | ||||
-rw-r--r-- | tests/__init__.py | 4 | ||||
-rw-r--r-- | tests/encoding.py | 44 | ||||
-rw-r--r-- | tests/server_commands.py | 4 |
5 files changed, 80 insertions, 15 deletions
diff --git a/redis/client.py b/redis/client.py index 839b156..84f43b1 100644 --- a/redis/client.py +++ b/redis/client.py @@ -186,15 +186,17 @@ class StrictRedis(object): def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, - connection_pool=None, - charset='utf-8', errors='strict', unix_socket_path=None): + connection_pool=None, charset='utf-8', + errors='strict', decode_responses=False, + unix_socket_path=None): if not connection_pool: kwargs = { 'db': db, 'password': password, 'socket_timeout': socket_timeout, 'encoding': charset, - 'encoding_errors': errors + 'encoding_errors': errors, + 'decode_responses': decode_responses, } # based on input, setup appropriate connection args if unix_socket_path: diff --git a/redis/connection.py b/redis/connection.py index 6df60d8..8036784 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -23,6 +23,7 @@ except ImportError: class PythonParser(object): "Plain Python parsing class" MAX_READ_LENGTH = 1000000 + encoding = None def __init__(self): self._fp = None @@ -36,6 +37,8 @@ class PythonParser(object): def on_connect(self, connection): "Called when the socket connects" self._fp = connection._sock.makefile('rb') + if connection.decode_responses: + self.encoding = connection.encoding def on_disconnect(self): "Called when the socket disconnects" @@ -81,6 +84,9 @@ class PythonParser(object): byte, response = response[0], response[1:] + if byte not in ('-', '+', ':', '$', '*'): + raise InvalidResponse("Protocol Error") + # server returned an error if byte == '-': if response.startswith('ERR '): @@ -92,24 +98,25 @@ class PythonParser(object): raise ConnectionError("Redis is loading data into memory") # single value elif byte == '+': - return response + pass # int value elif byte == ':': - return long(response) + response = long(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = self.read(length) - return response # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None - return [self.read_response() for i in xrange(length)] - raise InvalidResponse("Protocol Error") + response = [self.read_response() for i in xrange(length)] + if isinstance(response, str) and self.encoding: + response = response.decode(self.encoding) + return response class HiredisParser(object): "Parser class for connections using Hiredis" @@ -125,9 +132,13 @@ class HiredisParser(object): def on_connect(self, connection): self._sock = connection._sock - self._reader = hiredis.Reader( - protocolError=InvalidResponse, - replyError=ResponseError) + kwargs = { + 'protocolError': InvalidResponse, + 'replyError': ResponseError, + } + if connection.decode_responses: + kwargs['encoding'] = connection.encoding + self._reader = hiredis.Reader(**kwargs) def on_disconnect(self): self._sock = None @@ -163,7 +174,8 @@ class Connection(object): "Manages TCP communication to and from a Redis server" def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, encoding='utf-8', - encoding_errors='strict', parser_class=DefaultParser): + encoding_errors='strict', decode_responses=False, + parser_class=DefaultParser): self.pid = os.getpid() self.host = host self.port = port @@ -172,6 +184,7 @@ class Connection(object): self.socket_timeout = socket_timeout self.encoding = encoding self.encoding_errors = encoding_errors + self.decode_responses = decode_responses self._sock = None self._parser = parser_class() @@ -285,7 +298,8 @@ class Connection(object): class UnixDomainSocketConnection(Connection): def __init__(self, path='', db=0, password=None, socket_timeout=None, encoding='utf-8', - encoding_errors='strict', parser_class=DefaultParser): + encoding_errors='strict', decode_responses=False, + parser_class=DefaultParser): self.pid = os.getpid() self.path = path self.db = db @@ -293,6 +307,7 @@ class UnixDomainSocketConnection(Connection): self.socket_timeout = socket_timeout self.encoding = encoding self.encoding_errors = encoding_errors + self.decode_responses = decode_responses self._sock = None self._parser = parser_class() diff --git a/tests/__init__.py b/tests/__init__.py index f7db7b1..56227ad 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,6 +4,7 @@ from connection_pool import ConnectionPoolTestCase from pipeline import PipelineTestCase from lock import LockTestCase from pubsub import PubSubTestCase, PubSubRedisDownTestCase +from encoding import PythonParserEncodingTestCase, HiredisEncodingTestCase use_hiredis = False try: @@ -20,4 +21,7 @@ def all_tests(): suite.addTest(unittest.makeSuite(LockTestCase)) suite.addTest(unittest.makeSuite(PubSubTestCase)) suite.addTest(unittest.makeSuite(PubSubRedisDownTestCase)) + suite.addTest(unittest.makeSuite(PythonParserEncodingTestCase)) + if use_hiredis: + suite.addTest(unittest.makeSuite(HiredisEncodingTestCase)) return suite diff --git a/tests/encoding.py b/tests/encoding.py new file mode 100644 index 0000000..245a6df --- /dev/null +++ b/tests/encoding.py @@ -0,0 +1,44 @@ +from __future__ import with_statement +import redis +from redis.connection import ConnectionPool, PythonParser, HiredisParser +import unittest + +class EncodingTestCase(unittest.TestCase): + def setUp(self): + self.client = redis.Redis(host='localhost', port=6379, db=9, + charset='utf-8') + self.client.flushdb() + + def tearDown(self): + self.client.flushdb() + + def test_simple_encoding(self): + unicode_string = unichr(3456) + u'abcd' + unichr(3421) + self.client.set('unicode-string', unicode_string) + cached_val = self.client.get('unicode-string') + self.assertEquals('unicode', type(cached_val).__name__, + 'Cache returned value with type "%s", expected "unicode"' \ + % type(cached_val).__name__ + ) + self.assertEqual(unicode_string, cached_val) + + def test_list_encoding(self): + unicode_string = unichr(3456) + u'abcd' + unichr(3421) + result = [unicode_string, unicode_string, unicode_string] + for i in range(len(result)): + self.client.rpush('a', unicode_string) + self.assertEquals(self.client.lrange('a', 0, -1), result) + +class PythonParserEncodingTestCase(EncodingTestCase): + def setUp(self): + pool = ConnectionPool(host='localhost', port=6379, db=9, + encoding='utf-8', decode_responses=True, parser_class=PythonParser) + self.client = redis.Redis(connection_pool=pool) + self.client.flushdb() + +class HiredisEncodingTestCase(EncodingTestCase): + def setUp(self): + pool = ConnectionPool(host='localhost', port=6379, db=9, + encoding='utf-8', decode_responses=True, parser_class=HiredisParser) + self.client = redis.Redis(connection_pool=pool) + self.client.flushdb() diff --git a/tests/server_commands.py b/tests/server_commands.py index 3302641..eaae2c2 100644 --- a/tests/server_commands.py +++ b/tests/server_commands.py @@ -48,8 +48,8 @@ class ServerCommandsTestCase(unittest.TestCase): self.assertEquals(self.client.get('byte_string'), byte_string) self.assertEquals(self.client.get('integer'), str(integer)) self.assertEquals( - self.client.get('unicode_string').decode('utf-8'), - unicode_string) + self.client.get('unicode_string').decode('utf-8'), + unicode_string) def test_getitem_and_setitem(self): self.client['a'] = 'bar' |