summaryrefslogtreecommitdiff
path: root/plac/plac.py
diff options
context:
space:
mode:
Diffstat (limited to 'plac/plac.py')
-rw-r--r--plac/plac.py47
1 files changed, 33 insertions, 14 deletions
diff --git a/plac/plac.py b/plac/plac.py
index 309853a..21c4341 100644
--- a/plac/plac.py
+++ b/plac/plac.py
@@ -29,7 +29,7 @@ See plac/doc.html for the documentation.
"""
# this module should be kept Python 2.3 compatible
-__version__ = '0.4.0'
+__version__ = '0.4.1'
import re, sys, inspect, argparse
@@ -39,7 +39,7 @@ else:
class getfullargspec(object):
"A quick and dirty replacement for getfullargspec for Python 2.X"
def __init__(self, f):
- self.args, self.varargs, self.keywords, self.defaults = \
+ self.args, self.varargs, self.varkw, self.defaults = \
inspect.getargspec(f)
self.annotations = getattr(f, '__annotations__', {})
try:
@@ -57,8 +57,8 @@ def annotations(**ann):
args = fas.args
if fas.varargs:
args.append(fas.varargs)
- if fas.keywords:
- args.append(fas.keywords)
+ if fas.varkw:
+ args.append(fas.varkw)
for argname in ann:
if argname not in args:
raise NameError(
@@ -102,6 +102,20 @@ NONE = object() # sentinel use to signal the absence of a default
valid_attrs = getfullargspec(argparse.ArgumentParser.__init__).args[1:]
+class PlacHelpFormatter(argparse.HelpFormatter):
+ "Custom HelpFormatter which does not displau the default value twice"
+
+ def _format_action_invocation(self, action):
+ if not action.option_strings:
+ return self._metavar_formatter(action, action.dest)(1)[0]
+ long_short = tuple(action.option_strings)
+ if action.nargs == 0: # format is -s, --long
+ return '%s, %s' % long_short
+ else: # format is -s, --long ARGS
+ default = action.dest.upper()
+ args_string = self._format_args(action, default)
+ return '%s, %s %s' % (long_short + (args_string,))
+
def parser_from(func):
"""
Extract the arguments from the attributes of the passed function and
@@ -109,11 +123,12 @@ def parser_from(func):
"""
short_prefix = getattr(func, 'short_prefix', '-')
long_prefix = getattr(func, 'long_prefix', '--')
- attrs = {'description': func.__doc__}
+ attrs = dict(description=func.__doc__,
+ formatter_class=PlacHelpFormatter)
for n, v in vars(func).items():
if n in valid_attrs:
attrs[n] = v
- p = argparse.ArgumentParser(**attrs)
+ p = func.parser = argparse.ArgumentParser(**attrs)
f = p.argspec = getfullargspec(func)
defaults = f.defaults or ()
n_args = len(f.args)
@@ -121,10 +136,11 @@ def parser_from(func):
alldefaults = (NONE,) * (n_args - n_defaults) + defaults
for name, default in zip(f.args, alldefaults):
a = Annotation.from_(f.annotations.get(name, ()))
+ metavar = a.metavar
if default is NONE:
- dflt, metavar = None, a.metavar
+ dflt = None
else:
- dflt, metavar = default, a.metavar or str(default)
+ dflt = default
if a.kind in ('option', 'flag'):
short = short_prefix + (a.abbrev or name[0])
long = long_prefix + name
@@ -135,19 +151,22 @@ def parser_from(func):
p.add_argument(name, nargs='?', help=a.help, default=dflt,
type=a.type, choices=a.choices, metavar=metavar)
if a.kind == 'option':
+ if default is not NONE:
+ metavar = metavar or str(default)
p.add_argument(short, long, help=a.help, default=dflt,
type=a.type, choices=a.choices, metavar=metavar)
elif a.kind == 'flag':
- if default is not NONE:
- raise TypeError('Flag %r does not want a default' % name)
+ if default is not NONE and default is not False:
+ raise TypeError('Flag %r wants default False, got %r' %
+ (name, default))
p.add_argument(short, long, action='store_true', help=a.help)
if f.varargs:
a = Annotation.from_(f.annotations.get(f.varargs, ()))
p.add_argument(f.varargs, nargs='*', help=a.help, default=[],
type=a.type, metavar=a.metavar)
- if f.keywords:
- a = Annotation.from_(f.annotations.get(f.keywords, ()))
- p.add_argument(f.keywords, nargs='*', help=a.help, default={},
+ if f.varkw:
+ a = Annotation.from_(f.annotations.get(f.varkw, ()))
+ p.add_argument(f.varkw, nargs='*', help=a.help, default={},
type=a.type, metavar=a.metavar)
return p
@@ -171,7 +190,7 @@ def call(func, arglist=sys.argv[1:]):
provide a custom parse_annotation hook or replace the default one.
"""
p = parser_from(func)
- if p.argspec.keywords:
+ if p.argspec.varkw:
arglist, kwargs = extract_kwargs(arglist)
else:
kwargs = {}