summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--redis/client.py69
-rw-r--r--tests/__init__.py2
-rw-r--r--tests/lock.py47
3 files changed, 118 insertions, 0 deletions
diff --git a/redis/client.py b/redis/client.py
index 3dabcf4..17ffb7c 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -283,6 +283,12 @@ class Redis(threading.local):
self.errors
)
+ def lock(self, name):
+ """
+ Return a new Lock object using key ``name`` that mimics
+ the behavior of threading.Lock
+ """
+ return Lock(self, name)
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
def _execute_command(self, command_name, command, **options):
@@ -1293,3 +1299,66 @@ class Pipeline(Redis):
raise RedisError("Cannot select a different database from a pipeline")
+class Lock(object):
+ """
+ A shared, distributed Lock. Using Redis for locking allows the Lock
+ to be shared across processes and/or machines.
+
+ It's left to the user to resolve deadlock issues and make sure
+ multiple clients play nicely together.
+ """
+
+ LOCK_FOREVER = 2**31+1 # 1 past max unix time
+
+ def __init__(self, redis, name):
+ self.redis = redis
+ self.name = name
+ self.acquired_until = None
+
+ def acquire(self, blocking=True, timeout=None):
+ """
+ Use Redis to hold a shared, distributed lock named ``name``.
+ Returns True once the lock is acquired.
+
+ If ``blocking`` is False, always return immediately. If the lock
+ was acquired, return True, otherwise return False.
+
+ ``timeout`` indicates the maxium lifetime of the lock. If None,
+ lock forever.
+
+ Note: If using ``timeout``, you should make sure all the hosts
+ that are running clients are within the same timezone and are using
+ a network time service like ntp.
+ """
+ while 1:
+ unixtime = int(time.time())
+ if timeout:
+ timeout_at = unixtime + timeout
+ else:
+ timeout_at = Lock.LOCK_FOREVER
+ if self.redis.setnx(self.name, timeout_at):
+ self.acquired_until = timeout_at
+ return True
+ # We want blocking, but didn't acquire the lock
+ # check to see if the current lock is expired
+ existing = long(self.redis.get(self.name) or 1)
+ if existing < unixtime:
+ # the previous lock is expired, attempt to overwrite it
+ existing = long(self.redis.getset(self.name, timeout_at) or 1)
+ if existing < unixtime:
+ # we successfully acquired the lock
+ self.acquired_until = timeout_at
+ return True
+ if not blocking:
+ return False
+ time.sleep(0.1)
+
+ def release(self):
+ "Releases the already acquired lock"
+ if self.acquired_until is None:
+ raise ValueError("Cannot release an unlocked lock")
+ existing = long(self.redis.get(self.name) or 1)
+ # if the lock time is in the future, delete the lock
+ if existing >= self.acquired_until:
+ self.redis.delete(self.name)
+ self.acquired_until = None
diff --git a/tests/__init__.py b/tests/__init__.py
index 45e55b0..8931b07 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -2,10 +2,12 @@ import unittest
from server_commands import ServerCommandsTestCase
from connection_pool import ConnectionPoolTestCase
from pipeline import PipelineTestCase
+from lock import LockTestCase
def all_tests():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ServerCommandsTestCase))
suite.addTest(unittest.makeSuite(ConnectionPoolTestCase))
suite.addTest(unittest.makeSuite(PipelineTestCase))
+ suite.addTest(unittest.makeSuite(LockTestCase))
return suite
diff --git a/tests/lock.py b/tests/lock.py
new file mode 100644
index 0000000..a6ec9e7
--- /dev/null
+++ b/tests/lock.py
@@ -0,0 +1,47 @@
+import redis
+import time
+import unittest
+from redis.client import Lock
+
+class LockTestCase(unittest.TestCase):
+ def setUp(self):
+ self.client = redis.Redis(host='localhost', port=6379, db=9)
+ self.client.flushdb()
+
+ def tearDown(self):
+ self.client.flushdb()
+
+ def test_lock(self):
+ lock = self.client.lock('foo')
+ self.assert_(lock.acquire())
+ self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER))
+ lock.release()
+ self.assertEquals(self.client['foo'], None)
+
+ def test_competing_locks(self):
+ lock1 = self.client.lock('foo')
+ lock2 = self.client.lock('foo')
+ self.assert_(lock1.acquire())
+ self.assertFalse(lock2.acquire(blocking=False))
+ lock1.release()
+ self.assert_(lock2.acquire())
+ self.assertFalse(lock1.acquire(blocking=False))
+ lock2.release()
+
+ def test_timeouts(self):
+ lock1 = self.client.lock('foo')
+ lock2 = self.client.lock('foo')
+ self.assert_(lock1.acquire(timeout=1))
+ self.assertEquals(lock1.acquired_until, long(time.time()) + 1)
+ self.assertEquals(lock1.acquired_until, long(self.client['foo']))
+ self.assertFalse(lock2.acquire(blocking=False))
+ time.sleep(2) # need to wait up to 2 seconds for lock to timeout
+ self.assert_(lock2.acquire(blocking=False))
+ lock2.release()
+
+ def test_non_blocking(self):
+ lock1 = self.client.lock('foo')
+ self.assert_(lock1.acquire(blocking=False))
+ self.assert_(lock1.acquired_until)
+ lock1.release()
+ self.assert_(lock1.acquired_until is None)