diff options
author | Aurelien Campeas <aurelien.campeas@logilab.fr> | 2010-11-30 12:19:25 +0100 |
---|---|---|
committer | Aurelien Campeas <aurelien.campeas@logilab.fr> | 2010-11-30 12:19:25 +0100 |
commit | 419ccc94c73f7143ab7cca1fd03d70787428d097 (patch) | |
tree | 30dfff17477c2d3e1969cd27229991db9782b952 | |
parent | d76f9681147d4b4f676ed044bd5165aec180642b (diff) | |
download | logilab-common-419ccc94c73f7143ab7cca1fd03d70787428d097.tar.gz |
[decorators] prevent caching of decorator functions
-rw-r--r-- | decorators.py | 10 | ||||
-rw-r--r-- | test/unittest_decorators.py | 5 |
2 files changed, 12 insertions, 3 deletions
diff --git a/decorators.py b/decorators.py index 2218087..9942e5a 100644 --- a/decorators.py +++ b/decorators.py @@ -18,14 +18,18 @@ """ A few useful function/method decorators. """ __docformat__ = "restructuredtext en" -from types import MethodType +import types from time import clock, time import sys, re # XXX rewrite so we can use the decorator syntax when keyarg has to be specified +def _is_generator_function(callableobj): + return callableobj.func_code.co_flags & 0x20 + def cached(callableobj, keyarg=None): """Simple decorator to cache result of method call.""" + assert not _is_generator_function(callableobj), 'cannot cache generator function: %s' % callableobj if callableobj.func_code.co_argcount == 1 or keyarg == 0: def cache_wrapper1(self, *args): @@ -140,8 +144,8 @@ class iclassmethod(object): self.func = func def __get__(self, instance, objtype): if instance is None: - return MethodType(self.func, objtype, objtype.__class__) - return MethodType(self.func, instance, objtype) + return types.MethodType(self.func, objtype, objtype.__class__) + return types.MethodType(self.func, instance, objtype) def __set__(self, instance, value): raise AttributeError("can't set attribute") diff --git a/test/unittest_decorators.py b/test/unittest_decorators.py index 0217067..a016027 100644 --- a/test/unittest_decorators.py +++ b/test/unittest_decorators.py @@ -43,6 +43,11 @@ class DecoratorsTC(TestCase): inst = MyClass() self.assertEqual(inst.foo(4), 16) + def test_cannot_cache_generator(self): + def foo(): + yield 42 + self.assertRaises(AssertionError, cached, foo) + def test_cached_preserves_docstrings_and_name(self): class Foo(object): @cached |