summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandy <andy@whiskeymedia.com>2012-08-23 20:38:36 -0700
committerandy <andy@whiskeymedia.com>2012-08-23 20:39:39 -0700
commit6a3875b9d964189f84c4509234e4d12db4222d89 (patch)
tree61633c003ad378c1ce161885e62aaf3e1caf0d89
parentfcb12f9ac8e081a35768b9ba501ac4dae3862420 (diff)
downloadredis-py-6a3875b9d964189f84c4509234e4d12db4222d89.tar.gz
implementing LUA scripting, still need tests.
-rw-r--r--redis/client.py118
-rw-r--r--redis/connection.py23
-rw-r--r--redis/exceptions.py3
3 files changed, 131 insertions, 13 deletions
diff --git a/redis/client.py b/redis/client.py
index 4e0de94..152c75c 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -12,6 +12,7 @@ from redis.exceptions import (
RedisError,
ResponseError,
WatchError,
+ NoScriptError
)
SYM_EMPTY = b('')
@@ -140,13 +141,21 @@ def float_or_none(response):
def parse_config(response, **options):
- # this is stupid, but don't have a better option right now
if options['parse'] == 'GET':
response = [nativestr(i) if i is not None else None for i in response]
return response and pairs_to_dict(response) or {}
return nativestr(response) == 'OK'
+def parse_script(response, **options):
+ parse = options['parse']
+ if parse in ('FLUSH', 'KILL'):
+ return response == 'OK'
+ if parse == 'EXISTS':
+ return list(imap(bool, response))
+ return response
+
+
class StrictRedis(object):
"""
Implementation of the Redis protocol.
@@ -204,6 +213,7 @@ class StrictRedis(object):
'OBJECT': parse_object,
'PING': lambda r: nativestr(r) == 'PONG',
'RANDOMKEY': lambda r: r and r or None,
+ 'SCRIPT': parse_script,
'TIME': lambda x: (int(x[0]), int(x[1]))
}
)
@@ -1208,6 +1218,61 @@ class StrictRedis(object):
"""
return self.execute_command('PUBLISH', channel, message)
+ def eval(self, script, numkeys, *keys_and_args):
+ """
+ Execute the LUA ``script``, specifying the ``numkeys`` the script
+ will touch and the key names and argument values in ``keys_and_args``.
+ Returns the result of the script.
+
+ In practice, use the object returned by ``register_script``. This
+ function exists purely for Redis API completion.
+ """
+ return self.execute_command('EVAL', script, numkeys, *keys_and_args)
+
+ def evalsha(self, sha, numkeys, *keys_and_args):
+ """
+ Use the ``sha`` to execute a LUA script already registered via EVAL
+ or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the
+ key names and argument values in ``keys_and_args``. Returns the result
+ of the script.
+
+ In practice, use the object returned by ``register_script``. This
+ function exists purely for Redis API completion.
+ """
+ return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args)
+
+ def script_exists(self, *args):
+ """
+ Check if a script exists in the script cache by specifying the SHAs of
+ each script as ``args``. Returns a list of boolean values indicating if
+ if each already script exists in the cache.
+ """
+ options = {'parse': 'EXISTS'}
+ return self.execute_command('SCRIPT', 'EXISTS', *args, **options)
+
+ def script_flush(self):
+ "Flush all scripts from the script cache"
+ options = {'parse': 'FLUSH'}
+ return self.execute_command('SCRIPT', 'FLUSH', **options)
+
+ def script_kill(self):
+ "Kill the currently executing LUA script"
+ options = {'parse': 'KILL'}
+ return self.execute_command('SCRIPT', 'KILL', **options)
+
+ def script_load(self, script):
+ "Load a LUA ``script`` into the script cache. Returns the SHA."
+ options = {'parse': 'LOAD'}
+ return self.execute_command('SCRIPT', 'LOAD', script, **options)
+
+ def register_script(self, script, *keys):
+ """
+ Register a LUA ``script`` specifying the ``keys`` it will touch.
+ Returns a Script object that is callable and hides the complexity of
+ deal with scripts, keys, and shas. This is the preferred way to work
+ with LUA scripts.
+ """
+ return Script(self, script, *keys)
class Redis(StrictRedis):
"""
@@ -1482,6 +1547,7 @@ class BasePipeline(object):
def reset(self):
self.command_stack = []
+ self.scripts = set()
# make sure to reset the connection state in the event that we were
# watching something
if self.watching and self.connection:
@@ -1612,8 +1678,21 @@ class BasePipeline(object):
self.watching = True
return result
+ def load_scripts(self):
+ # make sure all scripts that are about to be run on this pipeline exist
+ scripts = list(self.scripts)
+ immediate = self.immediate_execute_command
+ shas = [s.sha for s in scripts]
+ exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'})
+ if not all(exists):
+ for s, exist in izip(scripts, exists):
+ if not exist:
+ immediate('SCRIPT', 'LOAD', s.script, **{'parse': 'LOAD'})
+
def execute(self):
"Execute all the commands in the current pipeline"
+ if self.scripts:
+ self.load_scripts()
stack = self.command_stack
if self.transaction or self.explicit_transaction:
stack = [(('MULTI', ), {})] + stack + [(('EXEC', ), {})]
@@ -1648,19 +1727,19 @@ class BasePipeline(object):
self.reset()
def watch(self, *names):
- """
- Watches the values at keys ``names``
- """
+ "Watches the values at keys ``names``"
if self.explicit_transaction:
raise RedisError('Cannot issue a WATCH after a MULTI')
return self.execute_command('WATCH', *names)
def unwatch(self):
- """
- Unwatches all previously specified keys
- """
+ "Unwatches all previously specified keys"
return self.watching and self.execute_command('UNWATCH') or True
+ def script_load_for_pipeline(self, script):
+ "Make sure scripts are loaded prior to pipeline execution"
+ self.scripts.add(script)
+
class StrictPipeline(BasePipeline, StrictRedis):
"Pipeline for the StrictRedis class"
@@ -1672,6 +1751,31 @@ class Pipeline(BasePipeline, Redis):
pass
+class Script(object):
+ "An executable LUA script object returned by ``register_script``"
+
+ def __init__(self, registered_client, script):
+ self.registered_client = registered_client
+ self.script = script
+ self.sha = registered_client.script_load(script)
+
+ def __call__(self, keys=[], args=[], client=None):
+ "Execute the script, passing any required ``args``"
+ client = client or self.registered_client
+ args = tuple(keys) + tuple(args)
+ # make sure the Redis server knows about the script
+ if isinstance(client, BasePipeline):
+ # make sure this script is good to go on pipeline
+ client.script_load_for_pipeline(self.script)
+ try:
+ return client.evalsha(self.sha, len(keys), *args)
+ except NoScriptError:
+ # Maybe the client is pointed to a differnet server than the client
+ # that created this instance?
+ self.sha = client.script_load(self.script)
+ return client.evalsha(self.sha, len(keys), *args)
+
+
class LockError(RedisError):
"Errors thrown from the Lock"
pass
diff --git a/redis/connection.py b/redis/connection.py
index 5a2d8df..f2b164e 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -10,7 +10,8 @@ from redis.exceptions import (
ConnectionError,
ResponseError,
InvalidResponse,
- AuthenticationError
+ AuthenticationError,
+ NoScriptError,
)
try:
@@ -31,6 +32,11 @@ class PythonParser(object):
MAX_READ_LENGTH = 1000000
encoding = None
+ EXCEPTION_CLASSES = {
+ 'ERR': ResponseError,
+ 'NOSCRIPT': NoScriptError,
+ }
+
def __init__(self):
self._fp = None
@@ -84,6 +90,14 @@ class PythonParser(object):
raise ConnectionError("Error while reading from socket: %s" %
(e.args,))
+ def parse_error(self, response):
+ "Parse an error response"
+ error_code = response.split(' ')[0]
+ if error_code in self.EXCEPTION_CLASSES:
+ response = response[len(error_code) + 1:]
+ return self.EXCEPTION_CLASSES[error_code](response)
+ return ResponseError(response)
+
def read_response(self):
response = self.read()
if not response:
@@ -100,12 +114,9 @@ class PythonParser(object):
# if we're loading the dataset into memory, kill the socket
# so we re-initialize (and re-SELECT) next time.
raise ConnectionError("Redis is loading data into memory")
- # if the error starts with ERR, trim that off
- if nativestr(response).startswith('ERR '):
- response = response[4:]
# *return*, not raise the exception class. if it is meant to be
# raised, it will be at a higher level.
- return ResponseError(response)
+ return self.parse_error(response)
# single value
elif byte == '+':
pass
@@ -293,7 +304,7 @@ class Connection(object):
except:
self.disconnect()
raise
- if response.__class__ == ResponseError:
+ if isinstance(response, ResponseError):
raise response
return response
diff --git a/redis/exceptions.py b/redis/exceptions.py
index 746bf07..3462c36 100644
--- a/redis/exceptions.py
+++ b/redis/exceptions.py
@@ -31,3 +31,6 @@ class PubSubError(RedisError):
class WatchError(RedisError):
pass
+
+class NoScriptError(ResponseError):
+ pass