summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2011-07-11 00:14:39 -0700
committerAndy McCurdy <andy@andymccurdy.com>2011-07-11 00:14:39 -0700
commit53c928d44acd3d1fbcb3896cadad0bde4671987a (patch)
tree52ce6641f8136fc0ae972d82f37db44ace96f2ae
parent24b0f17a80b51ed2d7f7e1ca139428f27bf642c9 (diff)
downloadredis-py-53c928d44acd3d1fbcb3896cadad0bde4671987a.tar.gz
WATCH and UNWATCH have been broken since 2.4 because of connection pooling. This fix moves WATCH and UNWATCH to the Pipeline class, where they belong and tests to prove they work.
-rw-r--r--CHANGES5
-rw-r--r--redis/client.py92
-rw-r--r--tests/pipeline.py54
-rw-r--r--tests/server_commands.py21
4 files changed, 136 insertions, 36 deletions
diff --git a/CHANGES b/CHANGES
index ffd5cae..24fb67e 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,7 +1,10 @@
* 2.4.6 (in development)
* Variadic arguments for SADD, SREM, ZREN, HDEL, LPUSH, and RPUSH. Thanks
Raphaƫl Vinot.
- * Fix for #153, only check for \n rather than \r\n.
+ * Fixed a bug in the Hiredis Parser causing pooled connections to get
+ corrupted occasionally.
+ * Pipelines now contain WATCH and UNWATCH. Calling WATCH or UNWATCH from
+ the base client class will result in a deprecation warning.
* 2.4.5
* The PythonParser now works better when reading zero length strings.
* 2.4.4
diff --git a/redis/client.py b/redis/client.py
index 93ff9a9..8b8e7a8 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -503,13 +503,13 @@ class Redis(object):
"""
Watches the values at keys ``names``, or None if the key doesn't exist
"""
- return self.execute_command('WATCH', *names)
+ warnings.warn(DeprecationWarning('Call WATCH from a Pipeline object'))
def unwatch(self):
"""
Unwatches the value at key ``name``, or None of the key doesn't exist
"""
- return self.execute_command('UNWATCH')
+ warnings.warn(DeprecationWarning('Call UNWATCH from a Pipeline object'))
#### LIST COMMANDS ####
def blpop(self, keys, timeout=0):
@@ -1177,17 +1177,61 @@ class Pipeline(Redis):
def __init__(self, connection_pool, response_callbacks, transaction,
shard_hint):
self.connection_pool = connection_pool
+ self.connection = None
self.response_callbacks = response_callbacks
self.transaction = transaction
self.shard_hint = shard_hint
+
+ self._real_exec = self.default_execute_command
+ self._pipe_exec = self.pipeline_execute_command
self.reset()
+ def _get_watch(self):
+ return self._watching
+
+ def _set_watch(self, value):
+ self._watching = value
+ self.execute_command = value and self._real_exec or self._pipe_exec
+
+ watching = property(_get_watch, _set_watch)
+
def reset(self):
self.command_stack = []
+ self.watching = False
if self.transaction:
self.execute_command('MULTI')
+ if self.connection:
+ self.connection_pool.release(self.connection)
+ self.connection = None
- def execute_command(self, *args, **options):
+ def multi(self):
+ """
+ Start a transactional block of the pipeline after WATCH commands
+ are issued. End the transactional block with `execute`.
+ """
+ self.execute_command = self._pipe_exec
+
+ def default_execute_command(self, *args, **options):
+ """
+ Execute a command, but don't auto-retry on a ConnectionError. Used
+ when issuing WATCH or subsequent commands retrieving their values
+ but before MULTI is called.
+ """
+ command_name = args[0]
+ conn = self.connection
+ # if this is the first call, we need a connection
+ if not conn:
+ conn = self.connection_pool.get_connection(command_name,
+ self.shard_hint)
+ self.connection = conn
+ try:
+ conn.send_command(*args)
+ return self.parse_response(conn, command_name, **options)
+ except ConnectionError:
+ self.reset()
+ raise
+
+ def pipeline_execute_command(self, *args, **options):
"""
Stage a command to be executed when execute() is next called
@@ -1221,7 +1265,7 @@ class Pipeline(Redis):
if len(response) != len(commands):
raise ResponseError("Wrong number of response items from "
- "pipeline execution")
+ "pipeline execution")
# We have to run response callbacks manually
data = []
for r, cmd in izip(response, commands):
@@ -1249,16 +1293,50 @@ class Pipeline(Redis):
else:
execute = self._execute_pipeline
stack = self.command_stack
- self.reset()
- conn = self.connection_pool.get_connection('MULTI', self.shard_hint)
+ conn = self.connection or \
+ self.connection_pool.get_connection('MULTI', self.shard_hint)
try:
return execute(conn, stack)
except ConnectionError:
conn.disconnect()
+ # if we watching a variable, the watch is no longer valid since
+ # this conncetion has died.
+ if self.watching:
+ raise WatchError("Watched variable changed.")
return execute(conn, stack)
finally:
- self.connection_pool.release(conn)
+ self.reset()
+
+ def watch(self, *names):
+ """
+ Watches the values at keys ``names``
+ """
+ if not self.transaction:
+ raise RedisError("Can only WATCH when using transactions")
+ # if more than 'MULTI' is in the command_stack, we can't WATCH anymore
+ if self.watching and len(self.command_stack) > 1:
+ raise RedisError("Can only WATCH before issuing pipeline commands")
+ self.watching = True
+ return self.execute_command('WATCH', *names)
+ def unwatch(self):
+ """
+ Unwatches all previously specified keys
+ """
+ if not self.transaction:
+ raise RedisError("Can only UNWATCH when using transactions")
+ # if more than 'MULTI' is in the command_stack, we can't UNWATCH anymore
+ if self.watching:
+ if len(self.command_stack) > 1:
+ raise RedisError("Can only UNWATCH before issuing "
+ "pipeline commands")
+ response = self.execute_command('UNWATCH')
+ else:
+ response = True
+ # it's safe to reset() here because we are no longer bound to a
+ # single connection and we're sure the command stack is empty.
+ self.reset()
+ return response
class LockError(RedisError):
"Errors thrown from the Lock"
diff --git a/tests/pipeline.py b/tests/pipeline.py
index 6eda4f2..ee3b3c5 100644
--- a/tests/pipeline.py
+++ b/tests/pipeline.py
@@ -24,6 +24,14 @@ class PipelineTestCase(unittest.TestCase):
]
)
+ def test_pipeline_no_transaction(self):
+ pipe = self.client.pipeline(transaction=False)
+ pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
+ self.assertEquals(pipe.execute(), [True, True, True])
+ self.assertEquals(self.client['a'], 'a1')
+ self.assertEquals(self.client['b'], 'b1')
+ self.assertEquals(self.client['c'], 'c1')
+
def test_invalid_command_in_pipeline(self):
# all commands but the invalid one should be excuted correctly
self.client['c'] = 'a'
@@ -46,11 +54,43 @@ class PipelineTestCase(unittest.TestCase):
self.assertEquals(pipe.set('z', 'zzz').execute(), [True])
self.assertEquals(self.client['z'], 'zzz')
- def test_pipeline_no_transaction(self):
- pipe = self.client.pipeline(transaction=False)
- pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
- self.assertEquals(pipe.execute(), [True, True, True])
- self.assertEquals(self.client['a'], 'a1')
- self.assertEquals(self.client['b'], 'b1')
- self.assertEquals(self.client['c'], 'c1')
+ def test_watch_succeed(self):
+ self.client.set('a', 1)
+ self.client.set('b', 2)
+
+ pipe = self.client.pipeline()
+ pipe.watch('a', 'b')
+ self.assertEquals(pipe.watching, True)
+ a = pipe.get('a')
+ b = pipe.get('b')
+ self.assertEquals(a, '1')
+ self.assertEquals(b, '2')
+ pipe.multi()
+
+ pipe.set('c', 3)
+ self.assertEquals(pipe.execute(), [True])
+ self.assertEquals(pipe.watching, False)
+ def test_watch_failure(self):
+ self.client.set('a', 1)
+ self.client.set('b', 2)
+
+ pipe = self.client.pipeline()
+ pipe.watch('a', 'b')
+ self.client.set('b', 3)
+ pipe.multi()
+ pipe.get('a')
+ self.assertRaises(redis.WatchError, pipe.execute)
+ self.assertEquals(pipe.watching, False)
+
+ def test_unwatch(self):
+ self.client.set('a', 1)
+ self.client.set('b', 2)
+
+ pipe = self.client.pipeline()
+ pipe.watch('a', 'b')
+ self.client.set('b', 3)
+ pipe.unwatch()
+ self.assertEquals(pipe.watching, False)
+ pipe.get('a')
+ self.assertEquals(pipe.execute(), ['1'])
diff --git a/tests/server_commands.py b/tests/server_commands.py
index f3fd524..b106126 100644
--- a/tests/server_commands.py
+++ b/tests/server_commands.py
@@ -266,27 +266,6 @@ class ServerCommandsTestCase(unittest.TestCase):
self.client.zadd('a', **{'1': 1})
self.assertEquals(self.client.type('a'), 'zset')
- def test_watch(self):
- self.client.set("a", 1)
- self.client.set("b", 2)
-
- self.client.watch("a", "b")
- pipeline = self.client.pipeline()
- pipeline.set("a", 2)
- pipeline.set("b", 3)
- self.assertEquals(pipeline.execute(), [True, True])
-
- self.client.set("b", 1)
- self.client.watch("b")
- self.get_client().set("b", 2)
- pipeline = self.client.pipeline()
- pipeline.set("b", 3)
-
- self.assertRaises(redis.exceptions.WatchError, pipeline.execute)
-
- def test_unwatch(self):
- self.assertEquals(self.client.unwatch(), True)
-
# LISTS
def make_list(self, name, l):
for i in l: