summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrien Di Mascio <Adrien.DiMascio@logilab.fr>2009-01-30 14:38:30 +0100
committerAdrien Di Mascio <Adrien.DiMascio@logilab.fr>2009-01-30 14:38:30 +0100
commite7b0eb487b77f36c776b477f7ce1497792b5b8fa (patch)
tree09ef8ce97cfbd7770656d97e2622ed09dbdce106
parented183004a65ee342d125889fd361873e935d59bc (diff)
downloadlogilab-common-e7b0eb487b77f36c776b477f7ce1497792b5b8fa.tar.gz
[decorators] new monkeypatch decorator (simple class extension)
-rw-r--r--decorators.py24
-rw-r--r--test/unittest_decorators.py31
2 files changed, 55 insertions, 0 deletions
diff --git a/decorators.py b/decorators.py
index 7f1ff29..001563e 100644
--- a/decorators.py
+++ b/decorators.py
@@ -143,3 +143,27 @@ def locked(acquire, release):
release(self)
return wrapper
return decorator
+
+
+def monkeypatch(klass, methodname=None):
+ """Decorator extending class with the decorated function
+ >>> class A:
+ ... pass
+ >>> @monkeypatch(A)
+ ... def meth(self):
+ ... return 12
+ ...
+ >>> a = A()
+ >>> a.meth()
+ 12
+ >>> @monkeypatch(A, 'foo')
+ ... def meth(self):
+ ... return 12
+ ...
+ >>> a.foo()
+ 12
+ """
+ def decorator(func):
+ setattr(klass, methodname or func.__name__, func)
+ return func
+ return decorator
diff --git a/test/unittest_decorators.py b/test/unittest_decorators.py
new file mode 100644
index 0000000..b2fd8b5
--- /dev/null
+++ b/test/unittest_decorators.py
@@ -0,0 +1,31 @@
+"""unit tests for the decorators module
+"""
+
+from logilab.common.testlib import TestCase, unittest_main
+from logilab.common.decorators import monkeypatch
+
+class DecoratorsTC(TestCase):
+
+ def test_monkeypatch_with_same_name(self):
+ class MyClass: pass
+ @monkeypatch(MyClass)
+ def meth1(self):
+ return 12
+ self.assertEquals([attr for attr in dir(MyClass) if attr[:2] != '__'],
+ ['meth1'])
+ inst = MyClass()
+ self.assertEquals(inst.meth1(), 12)
+
+ def test_monkeypatch_with_custom_name(self):
+ class MyClass: pass
+ @monkeypatch(MyClass, 'foo')
+ def meth2(self, param):
+ return param + 12
+ self.assertEquals([attr for attr in dir(MyClass) if attr[:2] != '__'],
+ ['foo'])
+ inst = MyClass()
+ self.assertEquals(inst.foo(4), 16)
+
+
+if __name__ == '__main__':
+ unittest_main()