diff options
author | Claudiu Popa <pcmanticore@gmail.com> | 2018-06-15 10:48:13 +0200 |
---|---|---|
committer | Claudiu Popa <pcmanticore@gmail.com> | 2018-06-15 10:48:58 +0200 |
commit | 4ffa9eab67944c9caa9c6c9eac2327358fb3983f (patch) | |
tree | 5b2e61d1c291dc2e5e03568e399bb7120ffcc675 /astroid/protocols.py | |
parent | 76836190096be94ab30c7f7fe4d988e9e374d397 (diff) | |
download | astroid-git-4ffa9eab67944c9caa9c6c9eac2327358fb3983f.tar.gz |
Added inference support for starred nodes in for loops
Close #146
Diffstat (limited to 'astroid/protocols.py')
-rw-r--r-- | astroid/protocols.py | 105 |
1 files changed, 100 insertions, 5 deletions
diff --git a/astroid/protocols.py b/astroid/protocols.py index 63be7101..64904524 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -549,6 +549,18 @@ def starred_assigned_stmts(self, node=None, context=None, asspath=None): context: TODO asspath: TODO """ + + def _determine_starred_iteration_lookups(starred, target, lookups): + # Determine the lookups for the rhs of the iteration + itered = target.itered() + for index, element in enumerate(itered): + if isinstance(element, nodes.Starred) and element.value.name == starred.value.name: + lookups.append((index, len(itered))) + break + if isinstance(element, nodes.Tuple): + lookups.append((index, len(element.itered()))) + _determine_starred_iteration_lookups(starred, element, lookups) + stmt = self.statement() if not isinstance(stmt, (nodes.Assign, nodes.For)): raise exceptions.InferenceError('Statement {stmt!r} enclosing {node!r} ' @@ -556,22 +568,21 @@ def starred_assigned_stmts(self, node=None, context=None, asspath=None): node=self, stmt=stmt, unknown=node, context=context) + if context is None: + context = contextmod.InferenceContext() + if isinstance(stmt, nodes.Assign): value = stmt.value lhs = stmt.targets[0] - if sum(1 for node in lhs.nodes_of_class(nodes.Starred)) > 1: + if sum(1 for _ in lhs.nodes_of_class(nodes.Starred)) > 1: raise exceptions.InferenceError('Too many starred arguments in the ' ' assignment targets {lhs!r}.', node=self, targets=lhs, unknown=node, context=context) - if context is None: - context = contextmod.InferenceContext() try: rhs = next(value.infer(context)) - except StopIteration: - return except exceptions.InferenceError: yield util.Uninferable return @@ -615,5 +626,89 @@ def starred_assigned_stmts(self, node=None, context=None, asspath=None): yield packed break + if isinstance(stmt, nodes.For): + try: + inferred_iterable = next(stmt.iter.infer(context=context)) + except exceptions.InferenceError: + yield util.Uninferable + return + if inferred_iterable is util.Uninferable or not hasattr(inferred_iterable, 'itered'): + yield util.Uninferable + return + try: + itered = inferred_iterable.itered() + except TypeError: + yield util.Uninferable + return + + target = stmt.target + + if not isinstance(target, nodes.Tuple): + raise exceptions.InferenceError( + 'Could not make sense of this, the target must be a tuple', + context=context, + ) + + lookups = [] + _determine_starred_iteration_lookups(self, target, lookups) + if not lookups: + raise exceptions.InferenceError( + 'Could not make sense of this, needs at least a lookup', + context=context, + ) + + # Make the last lookup a slice, since that what we want for a Starred node + last_element_index, last_element_length = lookups[-1] + is_starred_last = last_element_index == (last_element_length - 1) + + lookup_slice = slice( + last_element_index, + None if is_starred_last else (last_element_length - last_element_index) + ) + lookups[-1] = lookup_slice + + for element in itered: + + # We probably want to infer the potential values *for each* element in an + # iterable, but we can't infer a list of all values, when only a list of + # step values are expected: + # + # for a, *b in [...]: + # b + # + # *b* should now point to just the elements at that particular iteration step, + # which astroid can't know about. + + found_element = None + for lookup in lookups: + if not hasattr(element, 'itered'): + break + if not isinstance(lookup, slice): + # Grab just the index, not the whole length + lookup = lookup[0] + try: + itered_inner_element = element.itered() + element = itered_inner_element[lookup] + except IndexError: + break + except TypeError: + # Most likely the itered() call failed, cannot make sense of this + yield util.Uninferable + return + else: + found_element = element + + unpacked = nodes.List( + ctx=Store, + parent=self, + lineno=self.lineno, + col_offset=self.col_offset, + ) + unpacked.postinit(elts=found_element or []) + yield unpacked + return + + yield util.Uninferable + nodes.Starred.assigned_stmts = starred_assigned_stmts |