summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--redis/client.py26
-rw-r--r--tests/server_commands.py30
2 files changed, 53 insertions, 3 deletions
diff --git a/redis/client.py b/redis/client.py
index b9b843c..b68c1df 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -225,6 +225,26 @@ class Redis(object):
"""
return PubSub(self.connection_pool, shard_hint)
+ def connection(self):
+ """
+ Returns an instance of ``RedisSingleConnection`` which is bound to one
+ connection, allowing transactional commands to run in a thread-safe
+ manner.
+
+ Note that, unlike ``Redis``, ``RedisSingleConnection`` may raise a
+ ``ConnectionError`` which should be handled by the caller.
+
+ >>> with redis.connection() as cxn:
+ ... cxn.watch('foo')
+ ... old_foo = cxn.get('foo')
+ ... cxn.multi()
+ ... cxn.set('foo', old_foo + 1)
+ ... cxn.execute()
+ ...
+ >>>
+ """
+ return RedisConnection(connection_pool=self.connection_pool)
+
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
def execute_command(self, *args, **options):
"Execute a command and return a parsed response"
@@ -1320,10 +1340,10 @@ class Pipeline(Redis):
return execute(conn, stack)
except ConnectionError:
conn.disconnect()
- # if we watching a variable, the watch is no longer valid since
- # this conncetion has died.
+ # if we were watching a variable, the watch is no longer valid since
+ # this connection has died.
if self.watching:
- raise WatchError("Watched variable changed.")
+ raise
return execute(conn, stack)
finally:
self.reset()
diff --git a/tests/server_commands.py b/tests/server_commands.py
index 9e3375d..6e9fe99 100644
--- a/tests/server_commands.py
+++ b/tests/server_commands.py
@@ -266,6 +266,36 @@ class ServerCommandsTestCase(unittest.TestCase):
self.client.zadd('a', **{'1': 1})
self.assertEquals(self.client.type('a'), 'zset')
+ def test_watch(self):
+ self.client.set("a", 1)
+ self.client.set("b", 2)
+
+ self.client.watch("a", "b")
+ pipeline = self.client.pipeline()
+ pipeline.set("a", 2)
+ pipeline.set("b", 3)
+ self.assertEquals(pipeline.execute(), [True, True])
+
+ self.client.set("b", 1)
+ self.client.watch("b")
+ self.get_client().set("b", 2)
+ pipeline = self.client.pipeline()
+ pipeline.set("b", 3)
+
+ self.assertRaises(redis.exceptions.WatchError, pipeline.execute)
+
+ self.client.set("b", 1)
+ self.client.watch("b")
+ self.client.set("b", 2)
+ pipeline = self.client.pipeline()
+ pipeline.set("b", 3)
+
+ self.assertEquals(self.client.get("b"), "2")
+ self.assertRaises(redis.exceptions.WatchError, pipeline.execute)
+
+ def test_unwatch(self):
+ self.assertEquals(self.client.unwatch(), True)
+
# LISTS
def make_list(self, name, l):
for i in l: