diff options
author | Alex Grönholm <alex.gronholm+git@nextday.fi> | 2012-08-07 00:15:53 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm+git@nextday.fi> | 2012-08-07 00:15:53 +0300 |
commit | 0aaa404f3525641fe5bb3b50ca303d05a326cdce (patch) | |
tree | cd9bb0d8c37b01a98a5a3891891061605d4c30e2 | |
parent | 826364f3bb19abca0e1e6a92266abb017f5481f4 (diff) | |
download | redis-py-0aaa404f3525641fe5bb3b50ca303d05a326cdce.tar.gz |
Fixed Python 3.2+ compatibility
-rw-r--r-- | redis/_compat.py | 48 | ||||
-rw-r--r-- | redis/client.py | 29 | ||||
-rw-r--r-- | redis/connection.py | 59 | ||||
-rw-r--r-- | redis/utils.py | 5 | ||||
-rw-r--r-- | tests/__init__.py | 12 | ||||
-rw-r--r-- | tests/encoding.py | 11 | ||||
-rw-r--r-- | tests/lock.py | 4 | ||||
-rw-r--r-- | tests/pipeline.py | 2 | ||||
-rw-r--r-- | tests/pubsub.py | 17 | ||||
-rw-r--r-- | tests/server_commands.py | 45 |
10 files changed, 148 insertions, 84 deletions
diff --git a/redis/_compat.py b/redis/_compat.py new file mode 100644 index 0000000..2dbf62e --- /dev/null +++ b/redis/_compat.py @@ -0,0 +1,48 @@ +"""Internal module for Python 2 backwards compatibility.""" +import sys + + +if sys.version_info[0] < 3: + from urlparse import urlparse + from itertools import imap, izip + from string import letters as ascii_letters + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + + iteritems = lambda x: x.iteritems() + dictkeys = lambda x: x.keys() + dictvalues = lambda x: x.values() + nativestr = lambda x: x if isinstance(x, str) else x.encode(errors='replace') + u = lambda x: x.decode() + b = lambda x: x + next = lambda x: x.next() + byte_to_chr = lambda x: x + unichr = unichr + xrange = xrange + basestring = basestring + unicode = unicode + bytes = str + long = long +else: + from urllib.parse import urlparse + from io import BytesIO + from string import ascii_letters + + iteritems = lambda x: x.items() + dictkeys = lambda x: list(x.keys()) + dictvalues = lambda x: list(x.values()) + byte_to_chr = lambda x: chr(x) + nativestr = lambda x: x if isinstance(x, str) else x.decode(errors='replace') + u = lambda x: x + b = lambda x: x.encode('iso-8859-1') + next = next + unichr = chr + imap = map + izip = zip + xrange = range + basestring = str + unicode = str + bytes = bytes + long = int diff --git a/redis/client.py b/redis/client.py index a40fb93..a4e6854 100644 --- a/redis/client.py +++ b/redis/client.py @@ -1,9 +1,10 @@ from __future__ import with_statement +from itertools import starmap import datetime import time import warnings -from itertools import imap, izip, starmap +from redis._compat import b, izip, imap, iteritems, dictkeys, dictvalues, basestring, long, nativestr from redis.connection import ConnectionPool, UnixDomainSocketConnection from redis.exceptions import ( ConnectionError, @@ -119,7 +120,7 @@ def zset_score_pairs(response, **options): return response score_cast_func = options.get('score_cast_func', float) it = iter(response) - return zip(it, imap(score_cast_func, it)) + return list(izip(it, imap(score_cast_func, it))) def int_or_none(response): @@ -166,13 +167,13 @@ class StrictRedis(object): string_keys_to_dict( # these return OK, or int if redis-server is >=1.3.4 'LPUSH RPUSH', - lambda r: isinstance(r, long) and r or r == 'OK' + lambda r: isinstance(r, long) and r or nativestr(r) == 'OK' ), string_keys_to_dict('ZSCORE ZINCRBY', float_or_none), string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET RENAME ' 'SAVE SELECT SET SHUTDOWN SLAVEOF WATCH UNWATCH', - lambda r: r == 'OK' + lambda r: nativestr(r) == 'OK' ), string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), string_keys_to_dict( @@ -196,7 +197,7 @@ class StrictRedis(object): 'INFO': parse_info, 'LASTSAVE': timestamp_to_datetime, 'OBJECT': parse_object, - 'PING': lambda r: r == 'PONG', + 'PING': lambda r: nativestr(r) == 'PONG', 'RANDOMKEY': lambda r: r and r or None, } ) @@ -495,7 +496,7 @@ class StrictRedis(object): def mset(self, mapping): "Sets each key in the ``mapping`` dict to its corresponding value" items = [] - for pair in mapping.iteritems(): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSET', *items) @@ -505,7 +506,7 @@ class StrictRedis(object): none of the keys are already set """ items = [] - for pair in mapping.iteritems(): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('MSETNX', *items) @@ -891,7 +892,7 @@ class StrictRedis(object): raise RedisError("ZADD requires an equal number of " "values and scores") pieces.extend(args) - for pair in kwargs.iteritems(): + for pair in iteritems(kwargs): pieces.append(pair[1]) pieces.append(pair[0]) return self.execute_command('ZADD', name, *pieces) @@ -1061,7 +1062,7 @@ class StrictRedis(object): def _zaggregate(self, command, dest, keys, aggregate=None): pieces = [command, dest, len(keys)] if isinstance(keys, dict): - keys, weights = keys.keys(), keys.values() + keys, weights = dictkeys(keys), dictvalues(keys) else: weights = None pieces.extend(keys) @@ -1124,7 +1125,7 @@ class StrictRedis(object): if not mapping: raise DataError("'hmset' with 'mapping' of length 0") items = [] - for pair in mapping.iteritems(): + for pair in iteritems(mapping): items.extend(pair) return self.execute_command('HMSET', name, *items) @@ -1222,7 +1223,7 @@ class Redis(StrictRedis): raise RedisError("ZADD requires an equal number of " "values and scores") pieces.extend(reversed(args)) - for pair in kwargs.iteritems(): + for pair in iteritems(kwargs): pieces.append(pair[1]) pieces.append(pair[0]) return self.execute_command('ZADD', name, *pieces) @@ -1498,7 +1499,7 @@ class BasePipeline(object): return self def _execute_transaction(self, connection, commands): - all_cmds = ''.join(starmap(connection.pack_command, + all_cmds = b('').join(starmap(connection.pack_command, [args for args, options in commands])) connection.send_packed_command(all_cmds) # we don't care about the multi/exec any longer @@ -1530,8 +1531,8 @@ class BasePipeline(object): def _execute_pipeline(self, connection, commands): # build up all commands into a single request to increase network perf - all_cmds = ''.join(starmap(connection.pack_command, - [args for args, options in commands])) + all_cmds = b('').join(starmap(connection.pack_command, + [args for args, options in commands])) connection.send_packed_command(all_cmds) return [self.parse_response(connection, args[0], **options) for args, options in commands] diff --git a/redis/connection.py b/redis/connection.py index 4bb82f3..8aecc54 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,7 +1,9 @@ +from itertools import chain import os import socket -from itertools import chain, imap +import sys +from redis._compat import b, xrange, imap, byte_to_chr, unicode, bytes, long, BytesIO, nativestr from redis.exceptions import ( RedisError, ConnectionError, @@ -11,11 +13,6 @@ from redis.exceptions import ( ) try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO - -try: import hiredis hiredis_available = True except ImportError: @@ -62,7 +59,7 @@ class PythonParser(object): # https://github.com/andymccurdy/redis-py/issues/205 # read smaller chunks at a time to work around this try: - buf = StringIO() + buf = BytesIO() while bytes_left > 0: read_len = min(bytes_left, self.MAX_READ_LENGTH) buf.write(self._fp.read(read_len)) @@ -75,7 +72,8 @@ class PythonParser(object): # no length, read a full line return self._fp.readline()[:-2] - except (socket.error, socket.timeout), e: + except (socket.error, socket.timeout): + e = sys.exc_info()[1] raise ConnectionError("Error while reading from socket: %s" % (e.args,)) @@ -84,17 +82,17 @@ class PythonParser(object): if not response: raise ConnectionError("Socket closed on remote end") - byte, response = response[0], response[1:] + byte, response = byte_to_chr(response[0]), response[1:] if byte not in ('-', '+', ':', '$', '*'): raise InvalidResponse("Protocol Error") # server returned an error if byte == '-': - if response.startswith('ERR '): + if nativestr(response).startswith('ERR '): response = response[4:] return ResponseError(response) - if response.startswith('LOADING '): + if nativestr(response).startswith('LOADING '): # If we're loading the dataset into memory, kill the socket # so we re-initialize (and re-SELECT) next time. raise ConnectionError("Redis is loading data into memory") @@ -116,7 +114,7 @@ class PythonParser(object): if length == -1: return None response = [self.read_response() for i in xrange(length)] - if isinstance(response, str) and self.encoding: + if isinstance(response, bytes) and self.encoding: response = response.decode(self.encoding) return response @@ -154,7 +152,8 @@ class HiredisParser(object): while response is False: try: buffer = self._sock.recv(4096) - except (socket.error, socket.timeout), e: + except (socket.error, socket.timeout): + e = sys.exc_info()[1] raise ConnectionError("Error while reading from socket: %s" % (e.args,)) if not buffer: @@ -162,7 +161,7 @@ class HiredisParser(object): self._reader.feed(buffer) # proactively, but not conclusively, check if more data is in the # buffer. if the data received doesn't end with \n, there's more. - if not buffer.endswith('\n'): + if not buffer.endswith(b('\n')): continue response = self._reader.gets() return response @@ -203,7 +202,8 @@ class Connection(object): return try: sock = self._connect() - except socket.error, e: + except socket.error: + e = sys.exc_info()[1] raise ConnectionError(self._error_message(e)) self._sock = sock @@ -233,13 +233,13 @@ class Connection(object): # if a password is specified, authenticate if self.password: self.send_command('AUTH', self.password) - if self.read_response() != 'OK': + if nativestr(self.read_response()) != 'OK': raise AuthenticationError('Invalid Password') # if a database is specified, switch to it if self.db: self.send_command('SELECT', self.db) - if self.read_response() != 'OK': + if nativestr(self.read_response()) != 'OK': raise ConnectionError('Invalid Database') def disconnect(self): @@ -259,7 +259,8 @@ class Connection(object): self.connect() try: self._sock.sendall(command) - except socket.error, e: + except socket.error: + e = sys.exc_info()[1] self.disconnect() if len(e.args) == 1: _errno, errmsg = 'UNKNOWN', e.args[0] @@ -288,15 +289,27 @@ class Connection(object): def encode(self, value): "Return a bytestring representation of the value" + if isinstance(value, bytes): + return value + if not isinstance(value, unicode): + value = str(value) if isinstance(value, unicode): - return value.encode(self.encoding, self.encoding_errors) - return str(value) + value = value.encode(self.encoding, self.encoding_errors) + return value def pack_command(self, *args): "Pack a series of arguments into a value Redis command" - command = ['$%s\r\n%s\r\n' % (len(enc_value), enc_value) - for enc_value in imap(self.encode, args)] - return '*%s\r\n%s' % (len(command), ''.join(command)) + output = BytesIO() + output.write(b('*')) + output.write(b(str(len(args)))) + output.write(b('\r\n')) + for enc_value in imap(self.encode, args): + output.write(b('$')) + output.write(b(str(len(enc_value)))) + output.write(b('\r\n')) + output.write(enc_value) + output.write(b('\r\n')) + return output.getvalue() class UnixDomainSocketConnection(Connection): diff --git a/redis/utils.py b/redis/utils.py index 466509b..a9fe83b 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,6 +1,5 @@ -from urlparse import urlparse - -from .client import Redis +from redis.client import Redis +from redis._compat import urlparse DEFAULT_DATABASE_ID = 0 diff --git a/tests/__init__.py b/tests/__init__.py index 777e77a..fb73cde 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,11 +1,11 @@ import unittest -from server_commands import ServerCommandsTestCase -from connection_pool import ConnectionPoolTestCase -from pipeline import PipelineTestCase -from lock import LockTestCase -from pubsub import PubSubTestCase, PubSubRedisDownTestCase -from encoding import PythonParserEncodingTestCase, HiredisEncodingTestCase +from tests.server_commands import ServerCommandsTestCase +from tests.connection_pool import ConnectionPoolTestCase +from tests.pipeline import PipelineTestCase +from tests.lock import LockTestCase +from tests.pubsub import PubSubTestCase, PubSubRedisDownTestCase +from tests.encoding import PythonParserEncodingTestCase, HiredisEncodingTestCase try: import hiredis diff --git a/tests/encoding.py b/tests/encoding.py index 248e05e..8923f68 100644 --- a/tests/encoding.py +++ b/tests/encoding.py @@ -1,6 +1,7 @@ from __future__ import with_statement import unittest +from redis._compat import unichr, u, unicode from redis.connection import ConnectionPool, PythonParser, HiredisParser import redis @@ -15,17 +16,17 @@ class EncodingTestCase(unittest.TestCase): self.client.flushdb() def test_simple_encoding(self): - unicode_string = unichr(3456) + u'abcd' + unichr(3421) + unicode_string = unichr(3456) + u('abcd') + unichr(3421) self.client.set('unicode-string', unicode_string) cached_val = self.client.get('unicode-string') self.assertEquals( - 'unicode', type(cached_val).__name__, - 'Cache returned value with type "%s", expected "unicode"' % - type(cached_val).__name__) + unicode.__name__, type(cached_val).__name__, + 'Cache returned value with type "%s", expected "%s"' % + (type(cached_val).__name__, unicode.__name__)) self.assertEqual(unicode_string, cached_val) def test_list_encoding(self): - unicode_string = unichr(3456) + u'abcd' + unichr(3421) + unicode_string = unichr(3456) + u('abcd') + unichr(3421) result = [unicode_string, unicode_string, unicode_string] for i in range(len(result)): self.client.rpush('a', unicode_string) diff --git a/tests/lock.py b/tests/lock.py index 9a9a397..fc2d336 100644 --- a/tests/lock.py +++ b/tests/lock.py @@ -17,7 +17,7 @@ class LockTestCase(unittest.TestCase): def test_lock(self): lock = self.client.lock('foo') self.assert_(lock.acquire()) - self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER)) + self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER).encode()) lock.release() self.assertEquals(self.client.get('foo'), None) @@ -51,7 +51,7 @@ class LockTestCase(unittest.TestCase): def test_context_manager(self): with self.client.lock('foo'): - self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER)) + self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER).encode()) self.assertEquals(self.client.get('foo'), None) def test_float_timeout(self): diff --git a/tests/pipeline.py b/tests/pipeline.py index e20e48a..42ddad6 100644 --- a/tests/pipeline.py +++ b/tests/pipeline.py @@ -6,7 +6,7 @@ import redis class PipelineTestCase(unittest.TestCase): def setUp(self): - self.client = redis.Redis(host='localhost', port=6379, db=9) + self.client = redis.Redis(host='localhost', port=6379, db=9, decode_responses=True) self.client.flushdb() def tearDown(self): diff --git a/tests/pubsub.py b/tests/pubsub.py index 090297f..90bf619 100644 --- a/tests/pubsub.py +++ b/tests/pubsub.py @@ -1,12 +1,13 @@ import unittest +from redis._compat import next from redis.exceptions import ConnectionError import redis class PubSubTestCase(unittest.TestCase): def setUp(self): - self.connection_pool = redis.ConnectionPool() + self.connection_pool = redis.ConnectionPool(decode_responses=True) self.client = redis.Redis(connection_pool=self.connection_pool) self.pubsub = self.client.pubsub() @@ -24,7 +25,7 @@ class PubSubTestCase(unittest.TestCase): # there should be now 2 messages in the buffer, a subscribe and the # one we just published self.assertEquals( - self.pubsub.listen().next(), + next(self.pubsub.listen()), { 'type': 'subscribe', 'pattern': None, @@ -33,7 +34,7 @@ class PubSubTestCase(unittest.TestCase): } ) self.assertEquals( - self.pubsub.listen().next(), + next(self.pubsub.listen()), { 'type': 'message', 'pattern': None, @@ -49,7 +50,7 @@ class PubSubTestCase(unittest.TestCase): ) # unsubscribe message should be in the buffer self.assertEquals( - self.pubsub.listen().next(), + next(self.pubsub.listen()), { 'type': 'unsubscribe', 'pattern': None, @@ -69,7 +70,7 @@ class PubSubTestCase(unittest.TestCase): # there should be now 2 messages in the buffer, a subscribe and the # one we just published self.assertEquals( - self.pubsub.listen().next(), + next(self.pubsub.listen()), { 'type': 'psubscribe', 'pattern': None, @@ -78,7 +79,7 @@ class PubSubTestCase(unittest.TestCase): } ) self.assertEquals( - self.pubsub.listen().next(), + next(self.pubsub.listen()), { 'type': 'pmessage', 'pattern': 'f*', @@ -94,7 +95,7 @@ class PubSubTestCase(unittest.TestCase): ) # unsubscribe message should be in the buffer self.assertEquals( - self.pubsub.listen().next(), + next(self.pubsub.listen()), { 'type': 'punsubscribe', 'pattern': None, @@ -107,7 +108,7 @@ class PubSubTestCase(unittest.TestCase): class PubSubRedisDownTestCase(unittest.TestCase): def setUp(self): self.connection_pool = redis.ConnectionPool(port=6390) - self.client = redis.Redis(connection_pool=self.connection_pool) + self.client = redis.Redis(connection_pool=self.connection_pool, decode_responses=True) self.pubsub = self.client.pubsub() def tearDown(self): diff --git a/tests/server_commands.py b/tests/server_commands.py index bd935b9..cbc802b 100644 --- a/tests/server_commands.py +++ b/tests/server_commands.py @@ -1,16 +1,16 @@ -from string import letters from distutils.version import StrictVersion import unittest import datetime import time +from redis._compat import unichr, u, b, ascii_letters, iteritems, dictkeys, dictvalues from redis.client import parse_info import redis class ServerCommandsTestCase(unittest.TestCase): def get_client(self, cls=redis.Redis): - return cls(host='localhost', port=6379, db=9) + return cls(host='localhost', port=6379, db=9, decode_responses=True) def setUp(self): self.client = self.get_client() @@ -39,17 +39,18 @@ class ServerCommandsTestCase(unittest.TestCase): def test_get_and_set(self): # get and set can't be tested independently of each other - self.assertEquals(self.client.get('a'), None) - byte_string = 'value' + client = redis.Redis(host='localhost', port=6379, db=9) + self.assertEquals(client.get('a'), None) + byte_string = b('value') integer = 5 - unicode_string = unichr(3456) + u'abcd' + unichr(3421) - self.assert_(self.client.set('byte_string', byte_string)) - self.assert_(self.client.set('integer', 5)) - self.assert_(self.client.set('unicode_string', unicode_string)) - self.assertEquals(self.client.get('byte_string'), byte_string) - self.assertEquals(self.client.get('integer'), str(integer)) + unicode_string = unichr(3456) + u('abcd') + unichr(3421) + self.assert_(client.set('byte_string', byte_string)) + self.assert_(client.set('integer', 5)) + self.assert_(client.set('unicode_string', unicode_string)) + self.assertEquals(client.get('byte_string'), byte_string) + self.assertEquals(client.get('integer'), b(str(integer))) self.assertEquals( - self.client.get('unicode_string').decode('utf-8'), + client.get('unicode_string').decode('utf-8'), unicode_string) def test_getitem_and_setitem(self): @@ -211,7 +212,7 @@ class ServerCommandsTestCase(unittest.TestCase): def test_mset(self): d = {'a': '1', 'b': '2', 'c': '3'} self.assert_(self.client.mset(d)) - for k, v in d.iteritems(): + for k, v in iteritems(d): self.assertEquals(self.client[k], v) def test_msetnx(self): @@ -219,7 +220,7 @@ class ServerCommandsTestCase(unittest.TestCase): self.assert_(self.client.msetnx(d)) d2 = {'a': 'x', 'd': '4'} self.assert_(not self.client.msetnx(d2)) - for k, v in d.iteritems(): + for k, v in iteritems(d): self.assertEquals(self.client[k], v) self.assertEquals(self.client.get('d'), None) @@ -990,7 +991,7 @@ class ServerCommandsTestCase(unittest.TestCase): # HASHES def make_hash(self, key, d): - for k, v in d.iteritems(): + for k, v in iteritems(d): self.client.hset(key, k, v) def test_hget_and_hset(self): @@ -1097,7 +1098,7 @@ class ServerCommandsTestCase(unittest.TestCase): self.assertEquals(self.client.hincrby('a', 'a1'), 2) self.assertEquals(self.client.hincrby('a', 'a1', amount=2), 4) # negative values decrement - self.assertEquals(self.client.hincrby('a', 'a1', amount=-3), 1) + self.assertEquals(self.client.hincrby('a', 'a1', amount= -3), 1) # hash that exists, but key that doesn't self.assertEquals(self.client.hincrby('a', 'a2', amount=3), 3) # finally a key that's not an int @@ -1114,7 +1115,7 @@ class ServerCommandsTestCase(unittest.TestCase): # real logic h = {'a1': '1', 'a2': '2', 'a3': '3'} self.make_hash('a', h) - keys = h.keys() + keys = dictkeys(h) keys.sort() remote_keys = self.client.hkeys('a') remote_keys.sort() @@ -1143,7 +1144,7 @@ class ServerCommandsTestCase(unittest.TestCase): # real logic h = {'a1': '1', 'a2': '2', 'a3': '3'} self.make_hash('a', h) - vals = h.values() + vals = dictvalues(h) vals.sort() remote_vals = self.client.hvals('a') remote_vals.sort() @@ -1296,15 +1297,15 @@ class ServerCommandsTestCase(unittest.TestCase): 'foo\tbar\x07': '789', } # fill in lists - for key, value in mapping.iteritems(): + for key, value in iteritems(mapping): for c in value: self.assertTrue(self.client.rpush(key, c)) # check that KEYS returns all the keys as they are - self.assertEqual(sorted(self.client.keys('*')), sorted(mapping.keys())) + self.assertEqual(sorted(self.client.keys('*')), sorted(dictkeys(mapping))) # check that it is possible to get list content by key name - for key in mapping.keys(): + for key in dictkeys(mapping): self.assertEqual(self.client.lrange(key, 0, -1), list(mapping[key])) @@ -1346,8 +1347,8 @@ class ServerCommandsTestCase(unittest.TestCase): "The PythonParser has some special cases for return values > 1MB" # load up 5MB of data into a key data = [] - for i in range(5000000 / len(letters)): - data.append(letters) + for i in range(5000000 // len(ascii_letters)): + data.append(ascii_letters) data = ''.join(data) self.client.set('a', data) self.assertEquals(self.client.get('a'), data) |