summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <pcmanticore@gmail.com>2015-10-07 17:22:48 +0300
committerClaudiu Popa <pcmanticore@gmail.com>2015-10-07 17:22:48 +0300
commit053f20781a8f01c57a4ad36fd6b3f8d6dbd75608 (patch)
tree06b4a1e42386500c3d3d74cecc6484d109ee661e
parentd35aad3f97ce329742d554a5960f8494d7727154 (diff)
downloadastroid-053f20781a8f01c57a4ad36fd6b3f8d6dbd75608.tar.gz
Change arguments.ArgumentsInference to arguments.CallSite
This new class can be used to obtain the already unpacked arguments and keyword arguments that a call site uses, which is especially useful when some of arguments are packed into Starred nodes.
-rw-r--r--astroid/arguments.py90
-rw-r--r--astroid/protocols.py11
-rw-r--r--astroid/tests/unittest_inference.py87
3 files changed, 155 insertions, 33 deletions
diff --git a/astroid/arguments.py b/astroid/arguments.py
index 2642620..bdce172 100644
--- a/astroid/arguments.py
+++ b/astroid/arguments.py
@@ -25,29 +25,61 @@ from astroid import util
import six
-class ArgumentInference(object):
+class CallSite(object):
"""Class for understanding arguments passed into a call site
- It needs the arguments and the keyword arguments that were
- passed into a given call site.
+ It needs a call context, which contains the arguments and the
+ keyword arguments that were passed into a given call site.
In order to infer what an argument represents, call
:meth:`infer_argument` with the corresponding function node
and the argument name.
"""
- def __init__(self, args, keywords):
- self._args = self._unpack_args(args)
- self._keywords = self._unpack_keywords(keywords)
- args = [arg for arg in self._args if arg is not util.YES]
- keywords = {key: value for key, value in self._keywords.items()
- if value is not util.YES}
- self._args_failure = len(args) != len(self._args)
- self._kwargs_failure = len(keywords) != len(self._keywords)
- self._args = args
- self._keywords = keywords
-
- @staticmethod
- def _unpack_keywords(keywords):
+ def __init__(self, callcontext):
+ args = callcontext.args
+ keywords = callcontext.keywords
+
+ self._duplicated_kwargs = {}
+ self._unpacked_args = self._unpack_args(args)
+ self._unpacked_kwargs = self._unpack_keywords(keywords)
+
+ self.positional_arguments = [
+ arg for arg in self._unpacked_args
+ if arg is not util.YES
+ ]
+ self.keyword_arguments = {
+ key: value for key, value in self._unpacked_kwargs.items()
+ if value is not util.YES
+ }
+
+ @classmethod
+ def from_call(cls, call_node):
+ """Get a CallSite object from the given Call node."""
+ callcontext = contextmod.CallContext(call_node.args,
+ call_node.keywords)
+ return cls(callcontext)
+
+ def has_invalid_arguments(self):
+ """Check if in the current CallSite were passed *invalid* arguments
+
+ This can mean multiple things. For instance, if an unpacking
+ of an invalid object was passed, then this method will return True.
+ Other cases can be when the arguments can't be inferred by astroid,
+ for example, by passing objects which aren't known statically.
+ """
+ return len(self.positional_arguments) != len(self._unpacked_args)
+
+ def has_invalid_keywords(self):
+ """Check if in the current CallSite were passed *invalid* keyword arguments
+
+ For instance, unpacking a dictionary with integer keys is invalid
+ (**{1:2}), because the keys must be strings, which will make this
+ method to return True. Other cases where this might return True if
+ objects which can't be inferred were passed.
+ """
+ return len(self.keyword_arguments) != len(self._unpacked_kwargs)
+
+ def _unpack_keywords(self, keywords):
values = {}
context = contextmod.InferenceContext()
for name, value in keywords:
@@ -78,7 +110,8 @@ class ArgumentInference(object):
continue
if dict_key.value in values:
# The name is already in the dictionary
- values[name] = util.YES
+ values[dict_key.value] = util.YES
+ self._duplicated_kwargs[dict_key.value] = True
continue
values[dict_key.value] = dict_value
else:
@@ -110,23 +143,28 @@ class ArgumentInference(object):
def infer_argument(self, funcnode, name, context):
"""infer a function argument value according to the call context"""
+ if name in self._duplicated_kwargs:
+ raise exceptions.InferenceError(name)
+
# Look into the keywords first, maybe it's already there.
try:
- return self._keywords[name].infer(context)
+ return self.keyword_arguments[name].infer(context)
except KeyError:
pass
# Too many arguments given and no variable arguments.
- if len(self._args) > len(funcnode.args.args):
+ if len(self.positional_arguments) > len(funcnode.args.args):
if not funcnode.args.vararg:
raise exceptions.InferenceError(name)
- positional = self._args[:len(funcnode.args.args)]
- vararg = self._args[len(funcnode.args.args):]
+ positional = self.positional_arguments[:len(funcnode.args.args)]
+ vararg = self.positional_arguments[len(funcnode.args.args):]
argindex = funcnode.args.find_argname(name)[0]
kwonlyargs = set(arg.name for arg in funcnode.args.kwonlyargs)
- kwargs = {key: value for key, value in self._keywords.items()
- if key not in kwonlyargs}
+ kwargs = {
+ key: value for key, value in self.keyword_arguments.items()
+ if key not in kwonlyargs
+ }
# If there are too few positionals compared to
# what the function expects to receive, check to see
# if the missing positional arguments were passed
@@ -159,14 +197,14 @@ class ArgumentInference(object):
argindex -= 1
# 2. search arg index
try:
- return self._args[argindex].infer(context)
+ return self.positional_arguments[argindex].infer(context)
except IndexError:
pass
if funcnode.args.kwarg == name:
# It wants all the keywords that were passed into
# the call site.
- if self._kwargs_failure:
+ if self.has_invalid_keywords():
raise exceptions.InferenceError
kwarg = nodes.Dict(lineno=funcnode.args.lineno,
col_offset=funcnode.args.col_offset,
@@ -177,7 +215,7 @@ class ArgumentInference(object):
elif funcnode.args.vararg == name:
# It wants all the args that were passed into
# the call site.
- if self._args_failure:
+ if self.has_invalid_arguments():
raise exceptions.InferenceError
args = nodes.Tuple(lineno=funcnode.args.lineno,
col_offset=funcnode.args.col_offset,
diff --git a/astroid/protocols.py b/astroid/protocols.py
index 24a85fa..2780573 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -282,9 +282,8 @@ def _arguments_infer_argname(self, name, context):
return
if context and context.callcontext:
- inferator = arguments.ArgumentInference(context.callcontext.args,
- context.callcontext.keywords)
- for value in inferator.infer_argument(self.parent, name, context):
+ call_site = arguments.CallSite(context.callcontext)
+ for value in call_site.infer_argument(self.parent, name, context):
yield value
return
@@ -316,10 +315,8 @@ def arguments_assigned_stmts(self, node, context, asspath=None):
callcontext = context.callcontext
context = contextmod.copy_context(context)
context.callcontext = None
- inferator = arguments.ArgumentInference(
- callcontext.args,
- callcontext.keywords)
- return inferator.infer_argument(self.parent, node.name, context)
+ args = arguments.CallSite(callcontext)
+ return args.infer_argument(self.parent, node.name, context)
return _arguments_infer_argname(self, node.name, context)
nodes.Arguments.assigned_stmts = arguments_assigned_stmts
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 8b6d143..4423312 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -30,6 +30,8 @@ from astroid.builder import parse
from astroid.inference import infer_end as inference_infer_end
from astroid.bases import Instance, BoundMethod, UnboundMethod,\
path_wrapper, BUILTINS
+from astroid import arguments
+from astroid import context
from astroid import helpers
from astroid import objects
from astroid import test_utils
@@ -3588,6 +3590,19 @@ class ArgumentsTest(unittest.TestCase):
value = self._get_dict_value(inferred)
self.assertEqual(value, expected_value)
+ def test_kwargs_are_overriden(self):
+ ast_nodes = test_utils.extract_node('''
+ def test(f):
+ return f
+ test(f=23, **{'f': 34}) #@
+ def test(f=None):
+ return f
+ test(f=23, **{'f':23}) #@
+ ''')
+ for ast_node in ast_nodes:
+ inferred = next(ast_node.infer())
+ self.assertEqual(inferred, util.YES)
+
def test_fail_to_infer_args(self):
ast_nodes = test_utils.extract_node('''
def test(a, **kwargs): return a
@@ -3681,5 +3696,77 @@ class SliceTest(unittest.TestCase):
self.assertEqual(inferred.name, 'slice')
+class CallSiteTest(unittest.TestCase):
+
+ @staticmethod
+ def _call_site_from_call(call):
+ return arguments.CallSite.from_call(call)
+
+ def _test_call_site_pair(self, code, expected_args, expected_keywords):
+ ast_node = test_utils.extract_node(code)
+ call_site = self._call_site_from_call(ast_node)
+ self.assertEqual(len(call_site.positional_arguments), len(expected_args))
+ self.assertEqual([arg.value for arg in call_site.positional_arguments],
+ expected_args)
+ self.assertEqual(len(call_site.keyword_arguments), len(expected_keywords))
+ for keyword, value in expected_keywords.items():
+ self.assertIn(keyword, call_site.keyword_arguments)
+ self.assertEqual(call_site.keyword_arguments[keyword].value, value)
+
+ def _test_call_site(self, pairs):
+ for pair in pairs:
+ self._test_call_site_pair(*pair)
+
+ @test_utils.require_version('3.5')
+ def test_call_site_starred_args(self):
+ pairs = [
+ (
+ "f(*(1, 2), *(2, 3), *(3, 4), **{'a':1}, **{'b': 2})",
+ [1, 2, 2, 3, 3, 4],
+ {'a': 1, 'b': 2}
+ ),
+ (
+ "f(1, 2, *(3, 4), 5, *(6, 7), f=24, **{'c':3})",
+ [1, 2, 3, 4, 5, 6, 7],
+ {'f':24, 'c': 3},
+ ),
+ # Too many fs passed into.
+ (
+ "f(f=24, **{'f':24})", [], {},
+ ),
+ ]
+ self._test_call_site(pairs)
+
+ def test_call_site(self):
+ pairs = [
+ (
+ "f(1, 2)", [1, 2], {}
+ ),
+ (
+ "f(1, 2, *(1, 2))", [1, 2, 1, 2], {}
+ ),
+ (
+ "f(a=1, b=2, c=3)", [], {'a':1, 'b':2, 'c':3}
+ )
+ ]
+ self._test_call_site(pairs)
+
+ def _test_call_site_valid_arguments(self, values, invalid):
+ for value in values:
+ ast_node = test_utils.extract_node(value)
+ call_site = self._call_site_from_call(ast_node)
+ self.assertEqual(call_site.has_invalid_arguments(), invalid)
+
+ def test_call_site_valid_arguments(self):
+ values = [
+ "f(*lala)", "f(*1)", "f(*object)",
+ ]
+ self._test_call_site_valid_arguments(values, invalid=True)
+ values = [
+ "f()", "f(*(1, ))", "f(1, 2, *(2, 3))",
+ ]
+ self._test_call_site_valid_arguments(values, invalid=False)
+
+
if __name__ == '__main__':
unittest.main()