summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAurelien Campeas <aurelien.campeas@logilab.fr>2010-11-30 12:19:25 +0100
committerAurelien Campeas <aurelien.campeas@logilab.fr>2010-11-30 12:19:25 +0100
commit419ccc94c73f7143ab7cca1fd03d70787428d097 (patch)
tree30dfff17477c2d3e1969cd27229991db9782b952
parentd76f9681147d4b4f676ed044bd5165aec180642b (diff)
downloadlogilab-common-419ccc94c73f7143ab7cca1fd03d70787428d097.tar.gz
[decorators] prevent caching of decorator functions
-rw-r--r--decorators.py10
-rw-r--r--test/unittest_decorators.py5
2 files changed, 12 insertions, 3 deletions
diff --git a/decorators.py b/decorators.py
index 2218087..9942e5a 100644
--- a/decorators.py
+++ b/decorators.py
@@ -18,14 +18,18 @@
""" A few useful function/method decorators. """
__docformat__ = "restructuredtext en"
-from types import MethodType
+import types
from time import clock, time
import sys, re
# XXX rewrite so we can use the decorator syntax when keyarg has to be specified
+def _is_generator_function(callableobj):
+ return callableobj.func_code.co_flags & 0x20
+
def cached(callableobj, keyarg=None):
"""Simple decorator to cache result of method call."""
+ assert not _is_generator_function(callableobj), 'cannot cache generator function: %s' % callableobj
if callableobj.func_code.co_argcount == 1 or keyarg == 0:
def cache_wrapper1(self, *args):
@@ -140,8 +144,8 @@ class iclassmethod(object):
self.func = func
def __get__(self, instance, objtype):
if instance is None:
- return MethodType(self.func, objtype, objtype.__class__)
- return MethodType(self.func, instance, objtype)
+ return types.MethodType(self.func, objtype, objtype.__class__)
+ return types.MethodType(self.func, instance, objtype)
def __set__(self, instance, value):
raise AttributeError("can't set attribute")
diff --git a/test/unittest_decorators.py b/test/unittest_decorators.py
index 0217067..a016027 100644
--- a/test/unittest_decorators.py
+++ b/test/unittest_decorators.py
@@ -43,6 +43,11 @@ class DecoratorsTC(TestCase):
inst = MyClass()
self.assertEqual(inst.foo(4), 16)
+ def test_cannot_cache_generator(self):
+ def foo():
+ yield 42
+ self.assertRaises(AssertionError, cached, foo)
+
def test_cached_preserves_docstrings_and_name(self):
class Foo(object):
@cached