diff options
-rw-r--r-- | redis/client.py | 69 | ||||
-rw-r--r-- | tests/__init__.py | 2 | ||||
-rw-r--r-- | tests/lock.py | 47 |
3 files changed, 118 insertions, 0 deletions
diff --git a/redis/client.py b/redis/client.py index 3dabcf4..17ffb7c 100644 --- a/redis/client.py +++ b/redis/client.py @@ -283,6 +283,12 @@ class Redis(threading.local): self.errors ) + def lock(self, name): + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock + """ + return Lock(self, name) #### COMMAND EXECUTION AND PROTOCOL PARSING #### def _execute_command(self, command_name, command, **options): @@ -1293,3 +1299,66 @@ class Pipeline(Redis): raise RedisError("Cannot select a different database from a pipeline") +class Lock(object): + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + LOCK_FOREVER = 2**31+1 # 1 past max unix time + + def __init__(self, redis, name): + self.redis = redis + self.name = name + self.acquired_until = None + + def acquire(self, blocking=True, timeout=None): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``timeout`` indicates the maxium lifetime of the lock. If None, + lock forever. + + Note: If using ``timeout``, you should make sure all the hosts + that are running clients are within the same timezone and are using + a network time service like ntp. + """ + while 1: + unixtime = int(time.time()) + if timeout: + timeout_at = unixtime + timeout + else: + timeout_at = Lock.LOCK_FOREVER + if self.redis.setnx(self.name, timeout_at): + self.acquired_until = timeout_at + return True + # We want blocking, but didn't acquire the lock + # check to see if the current lock is expired + existing = long(self.redis.get(self.name) or 1) + if existing < unixtime: + # the previous lock is expired, attempt to overwrite it + existing = long(self.redis.getset(self.name, timeout_at) or 1) + if existing < unixtime: + # we successfully acquired the lock + self.acquired_until = timeout_at + return True + if not blocking: + return False + time.sleep(0.1) + + def release(self): + "Releases the already acquired lock" + if self.acquired_until is None: + raise ValueError("Cannot release an unlocked lock") + existing = long(self.redis.get(self.name) or 1) + # if the lock time is in the future, delete the lock + if existing >= self.acquired_until: + self.redis.delete(self.name) + self.acquired_until = None diff --git a/tests/__init__.py b/tests/__init__.py index 45e55b0..8931b07 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,10 +2,12 @@ import unittest from server_commands import ServerCommandsTestCase from connection_pool import ConnectionPoolTestCase from pipeline import PipelineTestCase +from lock import LockTestCase def all_tests(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(ServerCommandsTestCase)) suite.addTest(unittest.makeSuite(ConnectionPoolTestCase)) suite.addTest(unittest.makeSuite(PipelineTestCase)) + suite.addTest(unittest.makeSuite(LockTestCase)) return suite diff --git a/tests/lock.py b/tests/lock.py new file mode 100644 index 0000000..a6ec9e7 --- /dev/null +++ b/tests/lock.py @@ -0,0 +1,47 @@ +import redis +import time +import unittest +from redis.client import Lock + +class LockTestCase(unittest.TestCase): + def setUp(self): + self.client = redis.Redis(host='localhost', port=6379, db=9) + self.client.flushdb() + + def tearDown(self): + self.client.flushdb() + + def test_lock(self): + lock = self.client.lock('foo') + self.assert_(lock.acquire()) + self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER)) + lock.release() + self.assertEquals(self.client['foo'], None) + + def test_competing_locks(self): + lock1 = self.client.lock('foo') + lock2 = self.client.lock('foo') + self.assert_(lock1.acquire()) + self.assertFalse(lock2.acquire(blocking=False)) + lock1.release() + self.assert_(lock2.acquire()) + self.assertFalse(lock1.acquire(blocking=False)) + lock2.release() + + def test_timeouts(self): + lock1 = self.client.lock('foo') + lock2 = self.client.lock('foo') + self.assert_(lock1.acquire(timeout=1)) + self.assertEquals(lock1.acquired_until, long(time.time()) + 1) + self.assertEquals(lock1.acquired_until, long(self.client['foo'])) + self.assertFalse(lock2.acquire(blocking=False)) + time.sleep(2) # need to wait up to 2 seconds for lock to timeout + self.assert_(lock2.acquire(blocking=False)) + lock2.release() + + def test_non_blocking(self): + lock1 = self.client.lock('foo') + self.assert_(lock1.acquire(blocking=False)) + self.assert_(lock1.acquired_until) + lock1.release() + self.assert_(lock1.acquired_until is None) |