diff options
author | andy <andy@whiskeymedia.com> | 2012-08-23 20:38:36 -0700 |
---|---|---|
committer | andy <andy@whiskeymedia.com> | 2012-08-23 20:39:39 -0700 |
commit | 6a3875b9d964189f84c4509234e4d12db4222d89 (patch) | |
tree | 61633c003ad378c1ce161885e62aaf3e1caf0d89 | |
parent | fcb12f9ac8e081a35768b9ba501ac4dae3862420 (diff) | |
download | redis-py-6a3875b9d964189f84c4509234e4d12db4222d89.tar.gz |
implementing LUA scripting, still need tests.
-rw-r--r-- | redis/client.py | 118 | ||||
-rw-r--r-- | redis/connection.py | 23 | ||||
-rw-r--r-- | redis/exceptions.py | 3 |
3 files changed, 131 insertions, 13 deletions
diff --git a/redis/client.py b/redis/client.py index 4e0de94..152c75c 100644 --- a/redis/client.py +++ b/redis/client.py @@ -12,6 +12,7 @@ from redis.exceptions import ( RedisError, ResponseError, WatchError, + NoScriptError ) SYM_EMPTY = b('') @@ -140,13 +141,21 @@ def float_or_none(response): def parse_config(response, **options): - # this is stupid, but don't have a better option right now if options['parse'] == 'GET': response = [nativestr(i) if i is not None else None for i in response] return response and pairs_to_dict(response) or {} return nativestr(response) == 'OK' +def parse_script(response, **options): + parse = options['parse'] + if parse in ('FLUSH', 'KILL'): + return response == 'OK' + if parse == 'EXISTS': + return list(imap(bool, response)) + return response + + class StrictRedis(object): """ Implementation of the Redis protocol. @@ -204,6 +213,7 @@ class StrictRedis(object): 'OBJECT': parse_object, 'PING': lambda r: nativestr(r) == 'PONG', 'RANDOMKEY': lambda r: r and r or None, + 'SCRIPT': parse_script, 'TIME': lambda x: (int(x[0]), int(x[1])) } ) @@ -1208,6 +1218,61 @@ class StrictRedis(object): """ return self.execute_command('PUBLISH', channel, message) + def eval(self, script, numkeys, *keys_and_args): + """ + Execute the LUA ``script``, specifying the ``numkeys`` the script + will touch and the key names and argument values in ``keys_and_args``. + Returns the result of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command('EVAL', script, numkeys, *keys_and_args) + + def evalsha(self, sha, numkeys, *keys_and_args): + """ + Use the ``sha`` to execute a LUA script already registered via EVAL + or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the + key names and argument values in ``keys_and_args``. Returns the result + of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args) + + def script_exists(self, *args): + """ + Check if a script exists in the script cache by specifying the SHAs of + each script as ``args``. Returns a list of boolean values indicating if + if each already script exists in the cache. + """ + options = {'parse': 'EXISTS'} + return self.execute_command('SCRIPT', 'EXISTS', *args, **options) + + def script_flush(self): + "Flush all scripts from the script cache" + options = {'parse': 'FLUSH'} + return self.execute_command('SCRIPT', 'FLUSH', **options) + + def script_kill(self): + "Kill the currently executing LUA script" + options = {'parse': 'KILL'} + return self.execute_command('SCRIPT', 'KILL', **options) + + def script_load(self, script): + "Load a LUA ``script`` into the script cache. Returns the SHA." + options = {'parse': 'LOAD'} + return self.execute_command('SCRIPT', 'LOAD', script, **options) + + def register_script(self, script, *keys): + """ + Register a LUA ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with LUA scripts. + """ + return Script(self, script, *keys) class Redis(StrictRedis): """ @@ -1482,6 +1547,7 @@ class BasePipeline(object): def reset(self): self.command_stack = [] + self.scripts = set() # make sure to reset the connection state in the event that we were # watching something if self.watching and self.connection: @@ -1612,8 +1678,21 @@ class BasePipeline(object): self.watching = True return result + def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + if not all(exists): + for s, exist in izip(scripts, exists): + if not exist: + immediate('SCRIPT', 'LOAD', s.script, **{'parse': 'LOAD'}) + def execute(self): "Execute all the commands in the current pipeline" + if self.scripts: + self.load_scripts() stack = self.command_stack if self.transaction or self.explicit_transaction: stack = [(('MULTI', ), {})] + stack + [(('EXEC', ), {})] @@ -1648,19 +1727,19 @@ class BasePipeline(object): self.reset() def watch(self, *names): - """ - Watches the values at keys ``names`` - """ + "Watches the values at keys ``names``" if self.explicit_transaction: raise RedisError('Cannot issue a WATCH after a MULTI') return self.execute_command('WATCH', *names) def unwatch(self): - """ - Unwatches all previously specified keys - """ + "Unwatches all previously specified keys" return self.watching and self.execute_command('UNWATCH') or True + def script_load_for_pipeline(self, script): + "Make sure scripts are loaded prior to pipeline execution" + self.scripts.add(script) + class StrictPipeline(BasePipeline, StrictRedis): "Pipeline for the StrictRedis class" @@ -1672,6 +1751,31 @@ class Pipeline(BasePipeline, Redis): pass +class Script(object): + "An executable LUA script object returned by ``register_script``" + + def __init__(self, registered_client, script): + self.registered_client = registered_client + self.script = script + self.sha = registered_client.script_load(script) + + def __call__(self, keys=[], args=[], client=None): + "Execute the script, passing any required ``args``" + client = client or self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + if isinstance(client, BasePipeline): + # make sure this script is good to go on pipeline + client.script_load_for_pipeline(self.script) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a differnet server than the client + # that created this instance? + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + class LockError(RedisError): "Errors thrown from the Lock" pass diff --git a/redis/connection.py b/redis/connection.py index 5a2d8df..f2b164e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -10,7 +10,8 @@ from redis.exceptions import ( ConnectionError, ResponseError, InvalidResponse, - AuthenticationError + AuthenticationError, + NoScriptError, ) try: @@ -31,6 +32,11 @@ class PythonParser(object): MAX_READ_LENGTH = 1000000 encoding = None + EXCEPTION_CLASSES = { + 'ERR': ResponseError, + 'NOSCRIPT': NoScriptError, + } + def __init__(self): self._fp = None @@ -84,6 +90,14 @@ class PythonParser(object): raise ConnectionError("Error while reading from socket: %s" % (e.args,)) + def parse_error(self, response): + "Parse an error response" + error_code = response.split(' ')[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1:] + return self.EXCEPTION_CLASSES[error_code](response) + return ResponseError(response) + def read_response(self): response = self.read() if not response: @@ -100,12 +114,9 @@ class PythonParser(object): # if we're loading the dataset into memory, kill the socket # so we re-initialize (and re-SELECT) next time. raise ConnectionError("Redis is loading data into memory") - # if the error starts with ERR, trim that off - if nativestr(response).startswith('ERR '): - response = response[4:] # *return*, not raise the exception class. if it is meant to be # raised, it will be at a higher level. - return ResponseError(response) + return self.parse_error(response) # single value elif byte == '+': pass @@ -293,7 +304,7 @@ class Connection(object): except: self.disconnect() raise - if response.__class__ == ResponseError: + if isinstance(response, ResponseError): raise response return response diff --git a/redis/exceptions.py b/redis/exceptions.py index 746bf07..3462c36 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -31,3 +31,6 @@ class PubSubError(RedisError): class WatchError(RedisError): pass + +class NoScriptError(ResponseError): + pass |