diff options
author | Claudiu Popa <pcmanticore@gmail.com> | 2015-10-07 17:22:48 +0300 |
---|---|---|
committer | Claudiu Popa <pcmanticore@gmail.com> | 2015-10-07 17:22:48 +0300 |
commit | 053f20781a8f01c57a4ad36fd6b3f8d6dbd75608 (patch) | |
tree | 06b4a1e42386500c3d3d74cecc6484d109ee661e | |
parent | d35aad3f97ce329742d554a5960f8494d7727154 (diff) | |
download | astroid-053f20781a8f01c57a4ad36fd6b3f8d6dbd75608.tar.gz |
Change arguments.ArgumentsInference to arguments.CallSite
This new class can be used to obtain the already unpacked arguments
and keyword arguments that a call site uses, which is especially
useful when some of arguments are packed into Starred nodes.
-rw-r--r-- | astroid/arguments.py | 90 | ||||
-rw-r--r-- | astroid/protocols.py | 11 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 87 |
3 files changed, 155 insertions, 33 deletions
diff --git a/astroid/arguments.py b/astroid/arguments.py index 2642620..bdce172 100644 --- a/astroid/arguments.py +++ b/astroid/arguments.py @@ -25,29 +25,61 @@ from astroid import util import six -class ArgumentInference(object): +class CallSite(object): """Class for understanding arguments passed into a call site - It needs the arguments and the keyword arguments that were - passed into a given call site. + It needs a call context, which contains the arguments and the + keyword arguments that were passed into a given call site. In order to infer what an argument represents, call :meth:`infer_argument` with the corresponding function node and the argument name. """ - def __init__(self, args, keywords): - self._args = self._unpack_args(args) - self._keywords = self._unpack_keywords(keywords) - args = [arg for arg in self._args if arg is not util.YES] - keywords = {key: value for key, value in self._keywords.items() - if value is not util.YES} - self._args_failure = len(args) != len(self._args) - self._kwargs_failure = len(keywords) != len(self._keywords) - self._args = args - self._keywords = keywords - - @staticmethod - def _unpack_keywords(keywords): + def __init__(self, callcontext): + args = callcontext.args + keywords = callcontext.keywords + + self._duplicated_kwargs = {} + self._unpacked_args = self._unpack_args(args) + self._unpacked_kwargs = self._unpack_keywords(keywords) + + self.positional_arguments = [ + arg for arg in self._unpacked_args + if arg is not util.YES + ] + self.keyword_arguments = { + key: value for key, value in self._unpacked_kwargs.items() + if value is not util.YES + } + + @classmethod + def from_call(cls, call_node): + """Get a CallSite object from the given Call node.""" + callcontext = contextmod.CallContext(call_node.args, + call_node.keywords) + return cls(callcontext) + + def has_invalid_arguments(self): + """Check if in the current CallSite were passed *invalid* arguments + + This can mean multiple things. For instance, if an unpacking + of an invalid object was passed, then this method will return True. + Other cases can be when the arguments can't be inferred by astroid, + for example, by passing objects which aren't known statically. + """ + return len(self.positional_arguments) != len(self._unpacked_args) + + def has_invalid_keywords(self): + """Check if in the current CallSite were passed *invalid* keyword arguments + + For instance, unpacking a dictionary with integer keys is invalid + (**{1:2}), because the keys must be strings, which will make this + method to return True. Other cases where this might return True if + objects which can't be inferred were passed. + """ + return len(self.keyword_arguments) != len(self._unpacked_kwargs) + + def _unpack_keywords(self, keywords): values = {} context = contextmod.InferenceContext() for name, value in keywords: @@ -78,7 +110,8 @@ class ArgumentInference(object): continue if dict_key.value in values: # The name is already in the dictionary - values[name] = util.YES + values[dict_key.value] = util.YES + self._duplicated_kwargs[dict_key.value] = True continue values[dict_key.value] = dict_value else: @@ -110,23 +143,28 @@ class ArgumentInference(object): def infer_argument(self, funcnode, name, context): """infer a function argument value according to the call context""" + if name in self._duplicated_kwargs: + raise exceptions.InferenceError(name) + # Look into the keywords first, maybe it's already there. try: - return self._keywords[name].infer(context) + return self.keyword_arguments[name].infer(context) except KeyError: pass # Too many arguments given and no variable arguments. - if len(self._args) > len(funcnode.args.args): + if len(self.positional_arguments) > len(funcnode.args.args): if not funcnode.args.vararg: raise exceptions.InferenceError(name) - positional = self._args[:len(funcnode.args.args)] - vararg = self._args[len(funcnode.args.args):] + positional = self.positional_arguments[:len(funcnode.args.args)] + vararg = self.positional_arguments[len(funcnode.args.args):] argindex = funcnode.args.find_argname(name)[0] kwonlyargs = set(arg.name for arg in funcnode.args.kwonlyargs) - kwargs = {key: value for key, value in self._keywords.items() - if key not in kwonlyargs} + kwargs = { + key: value for key, value in self.keyword_arguments.items() + if key not in kwonlyargs + } # If there are too few positionals compared to # what the function expects to receive, check to see # if the missing positional arguments were passed @@ -159,14 +197,14 @@ class ArgumentInference(object): argindex -= 1 # 2. search arg index try: - return self._args[argindex].infer(context) + return self.positional_arguments[argindex].infer(context) except IndexError: pass if funcnode.args.kwarg == name: # It wants all the keywords that were passed into # the call site. - if self._kwargs_failure: + if self.has_invalid_keywords(): raise exceptions.InferenceError kwarg = nodes.Dict(lineno=funcnode.args.lineno, col_offset=funcnode.args.col_offset, @@ -177,7 +215,7 @@ class ArgumentInference(object): elif funcnode.args.vararg == name: # It wants all the args that were passed into # the call site. - if self._args_failure: + if self.has_invalid_arguments(): raise exceptions.InferenceError args = nodes.Tuple(lineno=funcnode.args.lineno, col_offset=funcnode.args.col_offset, diff --git a/astroid/protocols.py b/astroid/protocols.py index 24a85fa..2780573 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -282,9 +282,8 @@ def _arguments_infer_argname(self, name, context): return if context and context.callcontext: - inferator = arguments.ArgumentInference(context.callcontext.args, - context.callcontext.keywords) - for value in inferator.infer_argument(self.parent, name, context): + call_site = arguments.CallSite(context.callcontext) + for value in call_site.infer_argument(self.parent, name, context): yield value return @@ -316,10 +315,8 @@ def arguments_assigned_stmts(self, node, context, asspath=None): callcontext = context.callcontext context = contextmod.copy_context(context) context.callcontext = None - inferator = arguments.ArgumentInference( - callcontext.args, - callcontext.keywords) - return inferator.infer_argument(self.parent, node.name, context) + args = arguments.CallSite(callcontext) + return args.infer_argument(self.parent, node.name, context) return _arguments_infer_argname(self, node.name, context) nodes.Arguments.assigned_stmts = arguments_assigned_stmts diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 8b6d143..4423312 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -30,6 +30,8 @@ from astroid.builder import parse from astroid.inference import infer_end as inference_infer_end from astroid.bases import Instance, BoundMethod, UnboundMethod,\ path_wrapper, BUILTINS +from astroid import arguments +from astroid import context from astroid import helpers from astroid import objects from astroid import test_utils @@ -3588,6 +3590,19 @@ class ArgumentsTest(unittest.TestCase): value = self._get_dict_value(inferred) self.assertEqual(value, expected_value) + def test_kwargs_are_overriden(self): + ast_nodes = test_utils.extract_node(''' + def test(f): + return f + test(f=23, **{'f': 34}) #@ + def test(f=None): + return f + test(f=23, **{'f':23}) #@ + ''') + for ast_node in ast_nodes: + inferred = next(ast_node.infer()) + self.assertEqual(inferred, util.YES) + def test_fail_to_infer_args(self): ast_nodes = test_utils.extract_node(''' def test(a, **kwargs): return a @@ -3681,5 +3696,77 @@ class SliceTest(unittest.TestCase): self.assertEqual(inferred.name, 'slice') +class CallSiteTest(unittest.TestCase): + + @staticmethod + def _call_site_from_call(call): + return arguments.CallSite.from_call(call) + + def _test_call_site_pair(self, code, expected_args, expected_keywords): + ast_node = test_utils.extract_node(code) + call_site = self._call_site_from_call(ast_node) + self.assertEqual(len(call_site.positional_arguments), len(expected_args)) + self.assertEqual([arg.value for arg in call_site.positional_arguments], + expected_args) + self.assertEqual(len(call_site.keyword_arguments), len(expected_keywords)) + for keyword, value in expected_keywords.items(): + self.assertIn(keyword, call_site.keyword_arguments) + self.assertEqual(call_site.keyword_arguments[keyword].value, value) + + def _test_call_site(self, pairs): + for pair in pairs: + self._test_call_site_pair(*pair) + + @test_utils.require_version('3.5') + def test_call_site_starred_args(self): + pairs = [ + ( + "f(*(1, 2), *(2, 3), *(3, 4), **{'a':1}, **{'b': 2})", + [1, 2, 2, 3, 3, 4], + {'a': 1, 'b': 2} + ), + ( + "f(1, 2, *(3, 4), 5, *(6, 7), f=24, **{'c':3})", + [1, 2, 3, 4, 5, 6, 7], + {'f':24, 'c': 3}, + ), + # Too many fs passed into. + ( + "f(f=24, **{'f':24})", [], {}, + ), + ] + self._test_call_site(pairs) + + def test_call_site(self): + pairs = [ + ( + "f(1, 2)", [1, 2], {} + ), + ( + "f(1, 2, *(1, 2))", [1, 2, 1, 2], {} + ), + ( + "f(a=1, b=2, c=3)", [], {'a':1, 'b':2, 'c':3} + ) + ] + self._test_call_site(pairs) + + def _test_call_site_valid_arguments(self, values, invalid): + for value in values: + ast_node = test_utils.extract_node(value) + call_site = self._call_site_from_call(ast_node) + self.assertEqual(call_site.has_invalid_arguments(), invalid) + + def test_call_site_valid_arguments(self): + values = [ + "f(*lala)", "f(*1)", "f(*object)", + ] + self._test_call_site_valid_arguments(values, invalid=True) + values = [ + "f()", "f(*(1, ))", "f(1, 2, *(2, 3))", + ] + self._test_call_site_valid_arguments(values, invalid=False) + + if __name__ == '__main__': unittest.main() |