diff options
-rw-r--r-- | redis/client.py | 26 | ||||
-rw-r--r-- | tests/server_commands.py | 30 |
2 files changed, 53 insertions, 3 deletions
diff --git a/redis/client.py b/redis/client.py index b9b843c..b68c1df 100644 --- a/redis/client.py +++ b/redis/client.py @@ -225,6 +225,26 @@ class Redis(object): """ return PubSub(self.connection_pool, shard_hint) + def connection(self): + """ + Returns an instance of ``RedisSingleConnection`` which is bound to one + connection, allowing transactional commands to run in a thread-safe + manner. + + Note that, unlike ``Redis``, ``RedisSingleConnection`` may raise a + ``ConnectionError`` which should be handled by the caller. + + >>> with redis.connection() as cxn: + ... cxn.watch('foo') + ... old_foo = cxn.get('foo') + ... cxn.multi() + ... cxn.set('foo', old_foo + 1) + ... cxn.execute() + ... + >>> + """ + return RedisConnection(connection_pool=self.connection_pool) + #### COMMAND EXECUTION AND PROTOCOL PARSING #### def execute_command(self, *args, **options): "Execute a command and return a parsed response" @@ -1320,10 +1340,10 @@ class Pipeline(Redis): return execute(conn, stack) except ConnectionError: conn.disconnect() - # if we watching a variable, the watch is no longer valid since - # this conncetion has died. + # if we were watching a variable, the watch is no longer valid since + # this connection has died. if self.watching: - raise WatchError("Watched variable changed.") + raise return execute(conn, stack) finally: self.reset() diff --git a/tests/server_commands.py b/tests/server_commands.py index 9e3375d..6e9fe99 100644 --- a/tests/server_commands.py +++ b/tests/server_commands.py @@ -266,6 +266,36 @@ class ServerCommandsTestCase(unittest.TestCase): self.client.zadd('a', **{'1': 1}) self.assertEquals(self.client.type('a'), 'zset') + def test_watch(self): + self.client.set("a", 1) + self.client.set("b", 2) + + self.client.watch("a", "b") + pipeline = self.client.pipeline() + pipeline.set("a", 2) + pipeline.set("b", 3) + self.assertEquals(pipeline.execute(), [True, True]) + + self.client.set("b", 1) + self.client.watch("b") + self.get_client().set("b", 2) + pipeline = self.client.pipeline() + pipeline.set("b", 3) + + self.assertRaises(redis.exceptions.WatchError, pipeline.execute) + + self.client.set("b", 1) + self.client.watch("b") + self.client.set("b", 2) + pipeline = self.client.pipeline() + pipeline.set("b", 3) + + self.assertEquals(self.client.get("b"), "2") + self.assertRaises(redis.exceptions.WatchError, pipeline.execute) + + def test_unwatch(self): + self.assertEquals(self.client.unwatch(), True) + # LISTS def make_list(self, name, l): for i in l: |