diff options
author | ubershmekel <devnull@localhost> | 2009-11-02 22:27:47 +0000 |
---|---|---|
committer | ubershmekel <devnull@localhost> | 2009-11-02 22:27:47 +0000 |
commit | 921d934a0d61fd2863eb939aa872baca4156a85b (patch) | |
tree | 24c6553c3bc2d64727931b4c59ea74e779114358 | |
parent | 06ac90e1a75d81fc4a0003e6a4bee92d9c203b55 (diff) | |
download | argparse-921d934a0d61fd2863eb939aa872baca4156a85b.tar.gz |
argparse.run now works with bound methods (instance methods and class methods).
-rw-r--r-- | argparse.py | 9 | ||||
-rw-r--r-- | test/test_argparse.py | 12 |
2 files changed, 21 insertions, 0 deletions
diff --git a/argparse.py b/argparse.py index 9ebf845..c1459fb 100644 --- a/argparse.py +++ b/argparse.py @@ -95,6 +95,7 @@ import re as _re import sys as _sys import textwrap as _textwrap import inspect as _inspect +import types as _types from gettext import gettext as _ @@ -2414,6 +2415,14 @@ def _getfunctionspec(function): else: arg_names, varargs, varkw, defaults = _inspect.getargspec(function) kwonlyargs, kwonlydefaults, annotations = [], {}, {} + + # A fix for class-methods and instance-methods is to remove the first + # argument name (which is self or cls). + # In case of (*args, **kwargs) we don't intervene. + if isinstance(function, _types.MethodType): + if len(arg_names) > 0: + arg_names.pop(0) + return (arg_names, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations) diff --git a/test/test_argparse.py b/test/test_argparse.py index 462b065..8dc3c25 100644 --- a/test/test_argparse.py +++ b/test/test_argparse.py @@ -4077,10 +4077,22 @@ class TestAddFunction(TestCase): result = argparse.run(func, proc, args="proc monster".split()) self.failUnlessEqual(('monster', '...'), result) + def test_runner_on_bound_functions(self): + + class Bacon: + def proc(self, flying, spaghetti='...'): + return flying, spaghetti + + meat = Bacon() + result = argparse.run(meat.proc, args="monster".split()) + self.failUnlessEqual(('monster', '...'), result) + # only compile and test annotations if this is Python >= 3 if sys.version_info[0] >= 3: def test_annotations(self): + # the function with annotations is a string to avoid syntax errors + # in python 2.x func_source = textwrap.dedent(''' def func(foo:bool, bar:int, spam:int=101): return foo, bar |