summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandy <andy@andymccurdy.com>2012-06-10 23:29:41 -0700
committerandy <andy@andymccurdy.com>2012-06-10 23:29:41 -0700
commit17ade97a064248f309b771c8ccf3a9ea96e79356 (patch)
treeaed1cccf4784ca9f944a5c77d09dcdf392bb68cb
parentd2a7156f7fd108a24d8aa321e2ecd67b3f1f34c6 (diff)
parenta99b650071ff83368d1b3b1d75d8870a682a2390 (diff)
downloadredis-py-17ade97a064248f309b771c8ccf3a9ea96e79356.tar.gz
Merge remote-tracking branch 'encoding/2.4.11-fix' into encoding
Conflicts: redis/connection.py
-rw-r--r--redis/client.py8
-rw-r--r--redis/connection.py35
-rw-r--r--tests/__init__.py4
-rw-r--r--tests/encoding.py44
-rw-r--r--tests/server_commands.py4
5 files changed, 80 insertions, 15 deletions
diff --git a/redis/client.py b/redis/client.py
index 839b156..84f43b1 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -186,15 +186,17 @@ class StrictRedis(object):
def __init__(self, host='localhost', port=6379,
db=0, password=None, socket_timeout=None,
- connection_pool=None,
- charset='utf-8', errors='strict', unix_socket_path=None):
+ connection_pool=None, charset='utf-8',
+ errors='strict', decode_responses=False,
+ unix_socket_path=None):
if not connection_pool:
kwargs = {
'db': db,
'password': password,
'socket_timeout': socket_timeout,
'encoding': charset,
- 'encoding_errors': errors
+ 'encoding_errors': errors,
+ 'decode_responses': decode_responses,
}
# based on input, setup appropriate connection args
if unix_socket_path:
diff --git a/redis/connection.py b/redis/connection.py
index 6df60d8..8036784 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -23,6 +23,7 @@ except ImportError:
class PythonParser(object):
"Plain Python parsing class"
MAX_READ_LENGTH = 1000000
+ encoding = None
def __init__(self):
self._fp = None
@@ -36,6 +37,8 @@ class PythonParser(object):
def on_connect(self, connection):
"Called when the socket connects"
self._fp = connection._sock.makefile('rb')
+ if connection.decode_responses:
+ self.encoding = connection.encoding
def on_disconnect(self):
"Called when the socket disconnects"
@@ -81,6 +84,9 @@ class PythonParser(object):
byte, response = response[0], response[1:]
+ if byte not in ('-', '+', ':', '$', '*'):
+ raise InvalidResponse("Protocol Error")
+
# server returned an error
if byte == '-':
if response.startswith('ERR '):
@@ -92,24 +98,25 @@ class PythonParser(object):
raise ConnectionError("Redis is loading data into memory")
# single value
elif byte == '+':
- return response
+ pass
# int value
elif byte == ':':
- return long(response)
+ response = long(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self.read(length)
- return response
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
- return [self.read_response() for i in xrange(length)]
- raise InvalidResponse("Protocol Error")
+ response = [self.read_response() for i in xrange(length)]
+ if isinstance(response, str) and self.encoding:
+ response = response.decode(self.encoding)
+ return response
class HiredisParser(object):
"Parser class for connections using Hiredis"
@@ -125,9 +132,13 @@ class HiredisParser(object):
def on_connect(self, connection):
self._sock = connection._sock
- self._reader = hiredis.Reader(
- protocolError=InvalidResponse,
- replyError=ResponseError)
+ kwargs = {
+ 'protocolError': InvalidResponse,
+ 'replyError': ResponseError,
+ }
+ if connection.decode_responses:
+ kwargs['encoding'] = connection.encoding
+ self._reader = hiredis.Reader(**kwargs)
def on_disconnect(self):
self._sock = None
@@ -163,7 +174,8 @@ class Connection(object):
"Manages TCP communication to and from a Redis server"
def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_timeout=None, encoding='utf-8',
- encoding_errors='strict', parser_class=DefaultParser):
+ encoding_errors='strict', decode_responses=False,
+ parser_class=DefaultParser):
self.pid = os.getpid()
self.host = host
self.port = port
@@ -172,6 +184,7 @@ class Connection(object):
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
+ self.decode_responses = decode_responses
self._sock = None
self._parser = parser_class()
@@ -285,7 +298,8 @@ class Connection(object):
class UnixDomainSocketConnection(Connection):
def __init__(self, path='', db=0, password=None,
socket_timeout=None, encoding='utf-8',
- encoding_errors='strict', parser_class=DefaultParser):
+ encoding_errors='strict', decode_responses=False,
+ parser_class=DefaultParser):
self.pid = os.getpid()
self.path = path
self.db = db
@@ -293,6 +307,7 @@ class UnixDomainSocketConnection(Connection):
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
+ self.decode_responses = decode_responses
self._sock = None
self._parser = parser_class()
diff --git a/tests/__init__.py b/tests/__init__.py
index f7db7b1..56227ad 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -4,6 +4,7 @@ from connection_pool import ConnectionPoolTestCase
from pipeline import PipelineTestCase
from lock import LockTestCase
from pubsub import PubSubTestCase, PubSubRedisDownTestCase
+from encoding import PythonParserEncodingTestCase, HiredisEncodingTestCase
use_hiredis = False
try:
@@ -20,4 +21,7 @@ def all_tests():
suite.addTest(unittest.makeSuite(LockTestCase))
suite.addTest(unittest.makeSuite(PubSubTestCase))
suite.addTest(unittest.makeSuite(PubSubRedisDownTestCase))
+ suite.addTest(unittest.makeSuite(PythonParserEncodingTestCase))
+ if use_hiredis:
+ suite.addTest(unittest.makeSuite(HiredisEncodingTestCase))
return suite
diff --git a/tests/encoding.py b/tests/encoding.py
new file mode 100644
index 0000000..245a6df
--- /dev/null
+++ b/tests/encoding.py
@@ -0,0 +1,44 @@
+from __future__ import with_statement
+import redis
+from redis.connection import ConnectionPool, PythonParser, HiredisParser
+import unittest
+
+class EncodingTestCase(unittest.TestCase):
+ def setUp(self):
+ self.client = redis.Redis(host='localhost', port=6379, db=9,
+ charset='utf-8')
+ self.client.flushdb()
+
+ def tearDown(self):
+ self.client.flushdb()
+
+ def test_simple_encoding(self):
+ unicode_string = unichr(3456) + u'abcd' + unichr(3421)
+ self.client.set('unicode-string', unicode_string)
+ cached_val = self.client.get('unicode-string')
+ self.assertEquals('unicode', type(cached_val).__name__,
+ 'Cache returned value with type "%s", expected "unicode"' \
+ % type(cached_val).__name__
+ )
+ self.assertEqual(unicode_string, cached_val)
+
+ def test_list_encoding(self):
+ unicode_string = unichr(3456) + u'abcd' + unichr(3421)
+ result = [unicode_string, unicode_string, unicode_string]
+ for i in range(len(result)):
+ self.client.rpush('a', unicode_string)
+ self.assertEquals(self.client.lrange('a', 0, -1), result)
+
+class PythonParserEncodingTestCase(EncodingTestCase):
+ def setUp(self):
+ pool = ConnectionPool(host='localhost', port=6379, db=9,
+ encoding='utf-8', decode_responses=True, parser_class=PythonParser)
+ self.client = redis.Redis(connection_pool=pool)
+ self.client.flushdb()
+
+class HiredisEncodingTestCase(EncodingTestCase):
+ def setUp(self):
+ pool = ConnectionPool(host='localhost', port=6379, db=9,
+ encoding='utf-8', decode_responses=True, parser_class=HiredisParser)
+ self.client = redis.Redis(connection_pool=pool)
+ self.client.flushdb()
diff --git a/tests/server_commands.py b/tests/server_commands.py
index 3302641..eaae2c2 100644
--- a/tests/server_commands.py
+++ b/tests/server_commands.py
@@ -48,8 +48,8 @@ class ServerCommandsTestCase(unittest.TestCase):
self.assertEquals(self.client.get('byte_string'), byte_string)
self.assertEquals(self.client.get('integer'), str(integer))
self.assertEquals(
- self.client.get('unicode_string').decode('utf-8'),
- unicode_string)
+ self.client.get('unicode_string').decode('utf-8'),
+ unicode_string)
def test_getitem_and_setitem(self):
self.client['a'] = 'bar'