summaryrefslogtreecommitdiff
path: root/plac/plac.py
diff options
context:
space:
mode:
Diffstat (limited to 'plac/plac.py')
-rw-r--r--plac/plac.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/plac/plac.py b/plac/plac.py
index d54d8e7..309853a 100644
--- a/plac/plac.py
+++ b/plac/plac.py
@@ -29,7 +29,7 @@ See plac/doc.html for the documentation.
"""
# this module should be kept Python 2.3 compatible
-__version__ = '0.3.1'
+__version__ = '0.4.0'
import re, sys, inspect, argparse
@@ -42,6 +42,10 @@ else:
self.args, self.varargs, self.keywords, self.defaults = \
inspect.getargspec(f)
self.annotations = getattr(f, '__annotations__', {})
+try:
+ set
+except NameError: # Python 2.3
+ from sets import Set as set
def annotations(**ann):
"""
@@ -53,6 +57,8 @@ def annotations(**ann):
args = fas.args
if fas.varargs:
args.append(fas.varargs)
+ if fas.keywords:
+ args.append(fas.keywords)
for argname in ann:
if argname not in args:
raise NameError(
@@ -139,8 +145,24 @@ def parser_from(func):
a = Annotation.from_(f.annotations.get(f.varargs, ()))
p.add_argument(f.varargs, nargs='*', help=a.help, default=[],
type=a.type, metavar=a.metavar)
+ if f.keywords:
+ a = Annotation.from_(f.annotations.get(f.keywords, ()))
+ p.add_argument(f.keywords, nargs='*', help=a.help, default={},
+ type=a.type, metavar=a.metavar)
return p
+def extract_kwargs(args):
+ arglist = []
+ kwargs = {}
+ for arg in args:
+ match = re.match(r'([a-zA-Z_]\w*)=', arg)
+ if match:
+ name = match.group(1)
+ kwargs[name] = arg[len(name)+1:]
+ else:
+ arglist.append(arg)
+ return arglist, kwargs
+
def call(func, arglist=sys.argv[1:]):
"""
Parse the given arglist by using an argparser inferred from the
@@ -149,7 +171,14 @@ def call(func, arglist=sys.argv[1:]):
provide a custom parse_annotation hook or replace the default one.
"""
p = parser_from(func)
+ if p.argspec.keywords:
+ arglist, kwargs = extract_kwargs(arglist)
+ else:
+ kwargs = {}
argdict = vars(p.parse_args(arglist))
args = [argdict[a] for a in p.argspec.args]
varargs = argdict.get(p.argspec.varargs, [])
- func(*(args + varargs))
+ collision = set(p.argspec.args) & set(kwargs)
+ if collision:
+ p.error('colliding keyword arguments: %s' % ' '.join(collision))
+ return func(*(args + varargs), **kwargs)