summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decorators.py19
-rw-r--r--test/unittest_decorators.py14
2 files changed, 24 insertions, 9 deletions
diff --git a/decorators.py b/decorators.py
index 6b73440..39144d1 100644
--- a/decorators.py
+++ b/decorators.py
@@ -62,7 +62,7 @@ class _SingleValueCache(object):
def closure(self):
def wrapped(*args, **kwargs):
return self.__call__(*args, **kwargs)
- wrapped.clear = self.clear
+ wrapped.cache_obj = self
try:
wrapped.__doc__ = self.callable.__doc__
wrapped.__name__ = self.callable.__name__
@@ -116,6 +116,13 @@ def cached(callableobj=None, keyarg=None, **kwargs):
else:
return decorator(callableobj)
+def get_cache_impl(obj, funcname):
+ cls = obj.__class__
+ member = getattr(cls, funcname)
+ if isinstance(member, property):
+ member = member.fget
+ return member.cache_obj
+
def clear_cache(obj, funcname):
"""Clear a cache handled by the :func:`cached` decorator. If 'x' class has
@cached on its method `foo`, type
@@ -124,17 +131,13 @@ def clear_cache(obj, funcname):
to purge this method's cache on the instance.
"""
- cls = obj.__class__
- member = getattr(cls, funcname)
- if isinstance(member, property):
- member = member.fget
- member.clear(obj)
+ get_cache_impl(obj, funcname).clear(obj)
def copy_cache(obj, funcname, cacheobj):
"""Copy cache for <funcname> from cacheobj to obj."""
- cache = getattr(obj, funcname).cacheattr
+ cacheattr = get_cache_impl(obj, funcname).cacheattr
try:
- setattr(obj, cache, cacheobj.__dict__[cache])
+ setattr(obj, cacheattr, cacheobj.__dict__[cacheattr])
except KeyError:
pass
diff --git a/test/unittest_decorators.py b/test/unittest_decorators.py
index 1f92d55..75261c9 100644
--- a/test/unittest_decorators.py
+++ b/test/unittest_decorators.py
@@ -19,7 +19,7 @@
"""
from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.decorators import monkeypatch, cached, clear_cache
+from logilab.common.decorators import monkeypatch, cached, clear_cache, copy_cache
class DecoratorsTC(TestCase):
@@ -114,6 +114,18 @@ class DecoratorsTC(TestCase):
clear_cache(foo, 'foo')
self.assertFalse(hasattr(foo, '_foo'))
+ def test_copy_cache(self):
+ class Foo(object):
+ @cached(cacheattr=u'_foo')
+ def foo(self, args):
+ """ what's up doc ? """
+ foo = Foo()
+ foo.foo(1)
+ self.assertEqual(foo._foo, {(1,): None})
+ foo2 = Foo()
+ self.assertFalse(hasattr(foo2, '_foo'))
+ copy_cache(foo2, 'foo', foo)
+ self.assertEqual(foo2._foo, {(1,): None})
if __name__ == '__main__':
unittest_main()