diff options
-rw-r--r-- | decorators.py | 19 | ||||
-rw-r--r-- | test/unittest_decorators.py | 14 |
2 files changed, 24 insertions, 9 deletions
diff --git a/decorators.py b/decorators.py index 6b73440..39144d1 100644 --- a/decorators.py +++ b/decorators.py @@ -62,7 +62,7 @@ class _SingleValueCache(object): def closure(self): def wrapped(*args, **kwargs): return self.__call__(*args, **kwargs) - wrapped.clear = self.clear + wrapped.cache_obj = self try: wrapped.__doc__ = self.callable.__doc__ wrapped.__name__ = self.callable.__name__ @@ -116,6 +116,13 @@ def cached(callableobj=None, keyarg=None, **kwargs): else: return decorator(callableobj) +def get_cache_impl(obj, funcname): + cls = obj.__class__ + member = getattr(cls, funcname) + if isinstance(member, property): + member = member.fget + return member.cache_obj + def clear_cache(obj, funcname): """Clear a cache handled by the :func:`cached` decorator. If 'x' class has @cached on its method `foo`, type @@ -124,17 +131,13 @@ def clear_cache(obj, funcname): to purge this method's cache on the instance. """ - cls = obj.__class__ - member = getattr(cls, funcname) - if isinstance(member, property): - member = member.fget - member.clear(obj) + get_cache_impl(obj, funcname).clear(obj) def copy_cache(obj, funcname, cacheobj): """Copy cache for <funcname> from cacheobj to obj.""" - cache = getattr(obj, funcname).cacheattr + cacheattr = get_cache_impl(obj, funcname).cacheattr try: - setattr(obj, cache, cacheobj.__dict__[cache]) + setattr(obj, cacheattr, cacheobj.__dict__[cacheattr]) except KeyError: pass diff --git a/test/unittest_decorators.py b/test/unittest_decorators.py index 1f92d55..75261c9 100644 --- a/test/unittest_decorators.py +++ b/test/unittest_decorators.py @@ -19,7 +19,7 @@ """ from logilab.common.testlib import TestCase, unittest_main -from logilab.common.decorators import monkeypatch, cached, clear_cache +from logilab.common.decorators import monkeypatch, cached, clear_cache, copy_cache class DecoratorsTC(TestCase): @@ -114,6 +114,18 @@ class DecoratorsTC(TestCase): clear_cache(foo, 'foo') self.assertFalse(hasattr(foo, '_foo')) + def test_copy_cache(self): + class Foo(object): + @cached(cacheattr=u'_foo') + def foo(self, args): + """ what's up doc ? """ + foo = Foo() + foo.foo(1) + self.assertEqual(foo._foo, {(1,): None}) + foo2 = Foo() + self.assertFalse(hasattr(foo2, '_foo')) + copy_cache(foo2, 'foo', foo) + self.assertEqual(foo2._foo, {(1,): None}) if __name__ == '__main__': unittest_main() |