summaryrefslogtreecommitdiff
path: root/redis
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2017-08-02 14:06:00 -0700
committerAndy McCurdy <andy@andymccurdy.com>2017-08-02 14:06:00 -0700
commitea4581a6f8386adafde2d6640b50ec7e1aaa673e (patch)
tree472f2f8986627d9dbf7d64f74399c1aad656c2c9 /redis
parent29a7ecebd0f4709fc184c6af57c6eb1a59cb071e (diff)
downloadredis-py-ea4581a6f8386adafde2d6640b50ec7e1aaa673e.tar.gz
add an Encoder object responsible for encoding/decoding bytes and strings
this simplifies multiple places that needs to encode and decode values
Diffstat (limited to 'redis')
-rwxr-xr-xredis/client.py56
-rwxr-xr-xredis/connection.py88
2 files changed, 70 insertions, 74 deletions
diff --git a/redis/client.py b/redis/client.py
index b3fff38..f829656 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -2339,10 +2339,7 @@ class PubSub(object):
self.connection = None
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
- encoding_options = self.connection_pool.get_encoding()
- self.encoding = encoding_options['encoding']
- self.encoding_errors = encoding_options['encoding_errors']
- self.decode_responses = encoding_options['decode_responses']
+ self.encoder = self.connection_pool.get_encoder()
self.reset()
def __del__(self):
@@ -2374,29 +2371,14 @@ class PubSub(object):
if self.channels:
channels = {}
for k, v in iteritems(self.channels):
- if not self.decode_responses:
- k = k.decode(self.encoding, self.encoding_errors)
- channels[k] = v
+ channels[self.encoder.decode(k, force=True)] = v
self.subscribe(**channels)
if self.patterns:
patterns = {}
for k, v in iteritems(self.patterns):
- if not self.decode_responses:
- k = k.decode(self.encoding, self.encoding_errors)
- patterns[k] = v
+ patterns[self.encoder.decode(k, force=True)] = v
self.psubscribe(**patterns)
- def encode(self, value):
- """
- Encode the value so that it's identical to what we'll
- read off the connection
- """
- if self.decode_responses and isinstance(value, bytes):
- value = value.decode(self.encoding, self.encoding_errors)
- elif not self.decode_responses and isinstance(value, unicode):
- value = value.encode(self.encoding, self.encoding_errors)
- return value
-
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
@@ -2446,6 +2428,16 @@ class PubSub(object):
return None
return self._execute(connection, connection.read_response)
+ def _normalize_keys(self, data):
+ """
+ normalize channel/pattern names to be either bytes or strings
+ based on whether responses are automatically decoded. this saves us
+ from coercing the value for each message coming in.
+ """
+ encode = self.encoder.encode
+ decode = self.encoder.decode
+ return dict([(decode(encode(k)), v) for k, v in iteritems(data)])
+
def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
@@ -2456,15 +2448,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- new_patterns = {}
- new_patterns.update(dict.fromkeys(imap(self.encode, args)))
- for pattern, handler in iteritems(kwargs):
- new_patterns[self.encode(pattern)] = handler
+ new_patterns = dict.fromkeys(args)
+ new_patterns.update(kwargs)
ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
# update the patterns dict AFTER we send the command. we don't want to
# subscribe twice to these patterns, once for the command and again
# for the reconnection.
- self.patterns.update(new_patterns)
+ self.patterns.update(self._normalize_keys(new_patterns))
return ret_val
def punsubscribe(self, *args):
@@ -2486,15 +2476,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- new_channels = {}
- new_channels.update(dict.fromkeys(imap(self.encode, args)))
- for channel, handler in iteritems(kwargs):
- new_channels[self.encode(channel)] = handler
+ new_channels = dict.fromkeys(args)
+ new_channels.update(kwargs)
ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
# update the channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
- self.channels.update(new_channels)
+ self.channels.update(self._normalize_keys(new_channels))
return ret_val
def unsubscribe(self, *args):
@@ -2938,10 +2926,8 @@ class Script(object):
if isinstance(script, basestring):
# We need the encoding from the client in order to generate an
# accurate byte representation of the script
- encoding_options = registered_client.connection_pool.get_encoding()
- encoding = encoding_options['encoding']
- encoding_errors = encoding_options['encoding_errors']
- script = script.encode(encoding, encoding_errors)
+ encoder = registered_client.connection_pool.get_encoder()
+ script = encoder.encode(script)
self.sha = hashlib.sha1(script).hexdigest()
def __call__(self, keys=[], args=[], client=None):
diff --git a/redis/connection.py b/redis/connection.py
index 77eb44d..c8f0513 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -94,6 +94,38 @@ class Token(object):
return self.value
+class Encoder(object):
+ "Encode strings to bytes and decode bytes to strings"
+
+ def __init__(self, encoding, encoding_errors, decode_responses):
+ self.encoding = encoding
+ self.encoding_errors = encoding_errors
+ self.decode_responses = decode_responses
+
+ def encode(self, value):
+ "Return a bytestring representation of the value"
+ if isinstance(value, Token):
+ return value.encoded_value
+ elif isinstance(value, bytes):
+ return value
+ elif isinstance(value, (int, long)):
+ value = b(str(value))
+ elif isinstance(value, float):
+ value = b(repr(value))
+ elif not isinstance(value, basestring):
+ # an object we don't know how to deal with. default to unicode()
+ value = unicode(value)
+ if isinstance(value, unicode):
+ value = value.encode(self.encoding, self.encoding_errors)
+ return value
+
+ def decode(self, value, force=False):
+ "Return a unicode string from the byte representation"
+ if (self.decode_responses or force) and isinstance(value, bytes):
+ value = value.decode(self.encoding, self.encoding_errors)
+ return value
+
+
class BaseParser(object):
EXCEPTION_CLASSES = {
'ERR': {
@@ -217,10 +249,9 @@ class SocketBuffer(object):
class PythonParser(BaseParser):
"Plain Python parsing class"
- encoding = None
-
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
+ self.encoder = None
self._sock = None
self._buffer = None
@@ -234,8 +265,7 @@ class PythonParser(BaseParser):
"Called when the socket connects"
self._sock = connection._sock
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
- if connection.decode_responses:
- self.encoding = connection.encoding
+ self.encoder = connection.encoder
def on_disconnect(self):
"Called when the socket disconnects"
@@ -245,7 +275,7 @@ class PythonParser(BaseParser):
if self._buffer is not None:
self._buffer.close()
self._buffer = None
- self.encoding = None
+ self.encoder = None
def can_read(self):
return self._buffer and bool(self._buffer.length)
@@ -292,8 +322,8 @@ class PythonParser(BaseParser):
if length == -1:
return None
response = [self.read_response() for i in xrange(length)]
- if isinstance(response, bytes) and self.encoding:
- response = response.decode(self.encoding)
+ if isinstance(response, bytes):
+ response = self.encoder.decode(response)
return response
@@ -324,8 +354,8 @@ class HiredisParser(BaseParser):
if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
kwargs['replyError'] = ResponseError
- if connection.decode_responses:
- kwargs['encoding'] = connection.encoding
+ if connection.encoder.decode_responses:
+ kwargs['encoding'] = connection.encoder.encoding
self._reader = hiredis.Reader(**kwargs)
self._next_response = False
@@ -421,9 +451,7 @@ class Connection(object):
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self.retry_on_timeout = retry_on_timeout
- self.encoding = encoding
- self.encoding_errors = encoding_errors
- self.decode_responses = decode_responses
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
@@ -601,22 +629,6 @@ class Connection(object):
raise response
return response
- def encode(self, value):
- "Return a bytestring representation of the value"
- if isinstance(value, Token):
- return value.encoded_value
- elif isinstance(value, bytes):
- return value
- elif isinstance(value, (int, long)):
- value = b(str(value))
- elif isinstance(value, float):
- value = b(repr(value))
- elif not isinstance(value, basestring):
- value = unicode(value)
- if isinstance(value, unicode):
- value = value.encode(self.encoding, self.encoding_errors)
- return value
-
def pack_command(self, *args):
"Pack a series of arguments into the Redis protocol"
output = []
@@ -635,7 +647,7 @@ class Connection(object):
buff = SYM_EMPTY.join(
(SYM_STAR, b(str(len(args))), SYM_CRLF))
- for arg in imap(self.encode, args):
+ for arg in imap(self.encoder.encode, args):
# to avoid large string mallocs, chunk the command into the
# output list if we're sending large values
if len(buff) > 6000 or len(arg) > 6000:
@@ -724,9 +736,7 @@ class UnixDomainSocketConnection(Connection):
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
- self.encoding = encoding
- self.encoding_errors = encoding_errors
- self.decode_responses = decode_responses
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
@@ -956,14 +966,14 @@ class ConnectionPool(object):
self._in_use_connections.add(connection)
return connection
- def get_encoding(self):
- "Return information about the encoding settings"
+ def get_encoder(self):
+ "Return an encoder based on encoding settings"
kwargs = self.connection_kwargs
- return {
- 'encoding': kwargs.get('encoding', 'utf-8'),
- 'encoding_errors': kwargs.get('encoding_errors', 'strict'),
- 'decode_responses': kwargs.get('decode_responses', False)
- }
+ return Encoder(
+ encoding=kwargs.get('encoding', 'utf-8'),
+ encoding_errors=kwargs.get('encoding_errors', 'strict'),
+ decode_responses=kwargs.get('decode_responses', False)
+ )
def make_connection(self):
"Create a new connection"