diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2017-08-02 14:06:00 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2017-08-02 14:06:00 -0700 |
commit | ea4581a6f8386adafde2d6640b50ec7e1aaa673e (patch) | |
tree | 472f2f8986627d9dbf7d64f74399c1aad656c2c9 /redis | |
parent | 29a7ecebd0f4709fc184c6af57c6eb1a59cb071e (diff) | |
download | redis-py-ea4581a6f8386adafde2d6640b50ec7e1aaa673e.tar.gz |
add an Encoder object responsible for encoding/decoding bytes and strings
this simplifies multiple places that needs to encode and decode values
Diffstat (limited to 'redis')
-rwxr-xr-x | redis/client.py | 56 | ||||
-rwxr-xr-x | redis/connection.py | 88 |
2 files changed, 70 insertions, 74 deletions
diff --git a/redis/client.py b/redis/client.py index b3fff38..f829656 100755 --- a/redis/client.py +++ b/redis/client.py @@ -2339,10 +2339,7 @@ class PubSub(object): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - encoding_options = self.connection_pool.get_encoding() - self.encoding = encoding_options['encoding'] - self.encoding_errors = encoding_options['encoding_errors'] - self.decode_responses = encoding_options['decode_responses'] + self.encoder = self.connection_pool.get_encoder() self.reset() def __del__(self): @@ -2374,29 +2371,14 @@ class PubSub(object): if self.channels: channels = {} for k, v in iteritems(self.channels): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - channels[k] = v + channels[self.encoder.decode(k, force=True)] = v self.subscribe(**channels) if self.patterns: patterns = {} for k, v in iteritems(self.patterns): - if not self.decode_responses: - k = k.decode(self.encoding, self.encoding_errors) - patterns[k] = v + patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) - def encode(self, value): - """ - Encode the value so that it's identical to what we'll - read off the connection - """ - if self.decode_responses and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - elif not self.decode_responses and isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -2446,6 +2428,16 @@ class PubSub(object): return None return self._execute(connection, connection.read_response) + def _normalize_keys(self, data): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return dict([(decode(encode(k)), v) for k, v in iteritems(data)]) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments @@ -2456,15 +2448,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_patterns = {} - new_patterns.update(dict.fromkeys(imap(self.encode, args))) - for pattern, handler in iteritems(kwargs): - new_patterns[self.encode(pattern)] = handler + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(new_patterns) + self.patterns.update(self._normalize_keys(new_patterns)) return ret_val def punsubscribe(self, *args): @@ -2486,15 +2476,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - new_channels = {} - new_channels.update(dict.fromkeys(imap(self.encode, args))) - for channel, handler in iteritems(kwargs): - new_channels[self.encode(channel)] = handler + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(new_channels) + self.channels.update(self._normalize_keys(new_channels)) return ret_val def unsubscribe(self, *args): @@ -2938,10 +2926,8 @@ class Script(object): if isinstance(script, basestring): # We need the encoding from the client in order to generate an # accurate byte representation of the script - encoding_options = registered_client.connection_pool.get_encoding() - encoding = encoding_options['encoding'] - encoding_errors = encoding_options['encoding_errors'] - script = script.encode(encoding, encoding_errors) + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) self.sha = hashlib.sha1(script).hexdigest() def __call__(self, keys=[], args=[], client=None): diff --git a/redis/connection.py b/redis/connection.py index 77eb44d..c8f0513 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -94,6 +94,38 @@ class Token(object): return self.value +class Encoder(object): + "Encode strings to bytes and decode bytes to strings" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring representation of the value" + if isinstance(value, Token): + return value.encoded_value + elif isinstance(value, bytes): + return value + elif isinstance(value, (int, long)): + value = b(str(value)) + elif isinstance(value, float): + value = b(repr(value)) + elif not isinstance(value, basestring): + # an object we don't know how to deal with. default to unicode() + value = unicode(value) + if isinstance(value, unicode): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the byte representation" + if (self.decode_responses or force) and isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value + + class BaseParser(object): EXCEPTION_CLASSES = { 'ERR': { @@ -217,10 +249,9 @@ class SocketBuffer(object): class PythonParser(BaseParser): "Plain Python parsing class" - encoding = None - def __init__(self, socket_read_size): self.socket_read_size = socket_read_size + self.encoder = None self._sock = None self._buffer = None @@ -234,8 +265,7 @@ class PythonParser(BaseParser): "Called when the socket connects" self._sock = connection._sock self._buffer = SocketBuffer(self._sock, self.socket_read_size) - if connection.decode_responses: - self.encoding = connection.encoding + self.encoder = connection.encoder def on_disconnect(self): "Called when the socket disconnects" @@ -245,7 +275,7 @@ class PythonParser(BaseParser): if self._buffer is not None: self._buffer.close() self._buffer = None - self.encoding = None + self.encoder = None def can_read(self): return self._buffer and bool(self._buffer.length) @@ -292,8 +322,8 @@ class PythonParser(BaseParser): if length == -1: return None response = [self.read_response() for i in xrange(length)] - if isinstance(response, bytes) and self.encoding: - response = response.decode(self.encoding) + if isinstance(response, bytes): + response = self.encoder.decode(response) return response @@ -324,8 +354,8 @@ class HiredisParser(BaseParser): if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: kwargs['replyError'] = ResponseError - if connection.decode_responses: - kwargs['encoding'] = connection.encoding + if connection.encoder.decode_responses: + kwargs['encoding'] = connection.encoder.encoding self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -421,9 +451,7 @@ class Connection(object): self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -601,22 +629,6 @@ class Connection(object): raise response return response - def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, Token): - return value.encoded_value - elif isinstance(value, bytes): - return value - elif isinstance(value, (int, long)): - value = b(str(value)) - elif isinstance(value, float): - value = b(repr(value)) - elif not isinstance(value, basestring): - value = unicode(value) - if isinstance(value, unicode): - value = value.encode(self.encoding, self.encoding_errors) - return value - def pack_command(self, *args): "Pack a series of arguments into the Redis protocol" output = [] @@ -635,7 +647,7 @@ class Connection(object): buff = SYM_EMPTY.join( (SYM_STAR, b(str(len(args))), SYM_CRLF)) - for arg in imap(self.encode, args): + for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values if len(buff) > 6000 or len(arg) > 6000: @@ -724,9 +736,7 @@ class UnixDomainSocketConnection(Connection): self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._description_args = { @@ -956,14 +966,14 @@ class ConnectionPool(object): self._in_use_connections.add(connection) return connection - def get_encoding(self): - "Return information about the encoding settings" + def get_encoder(self): + "Return an encoder based on encoding settings" kwargs = self.connection_kwargs - return { - 'encoding': kwargs.get('encoding', 'utf-8'), - 'encoding_errors': kwargs.get('encoding_errors', 'strict'), - 'decode_responses': kwargs.get('decode_responses', False) - } + return Encoder( + encoding=kwargs.get('encoding', 'utf-8'), + encoding_errors=kwargs.get('encoding_errors', 'strict'), + decode_responses=kwargs.get('decode_responses', False) + ) def make_connection(self): "Create a new connection" |