diff options
author | hippo91 <guillaume.peillex@gmail.com> | 2018-12-02 16:20:06 +0100 |
---|---|---|
committer | Claudiu Popa <pcmanticore@gmail.com> | 2019-01-23 09:09:22 +0100 |
commit | 87b55a6ea6a421b4702dbc178a45165f20d1e775 (patch) | |
tree | 18c75077ec735554381ecd88be576c3f18cec91a | |
parent | c54ca8e30d8cce8660ddac842c58b66de2733322 (diff) | |
download | astroid-git-87b55a6ea6a421b4702dbc178a45165f20d1e775.tar.gz |
Avoid statement deletion in the _filter_stmts method of the LookupMixin class for PartialFunction
In the case where the node is a PartialFunction and its name is the same as the current statement's name,
avoid the statement deletion.
The problem was that a call to a function that has been previously called vit a functools.partial was wrongly inferred.
The bug comes from the _filter_stmts method of the LookupMixin class. The deletion of the current statement should not be
made in the case where the node is an instance of the PartialFunction class and if the node's name is the same as the statement's name.
This change also extracts PartialFunction from brain_functools into astroid.objects so that we remove a circular import problem.
Close PyCQA/pylint#2588
-rw-r--r-- | ChangeLog | 4 | ||||
-rw-r--r-- | astroid/brain/brain_functools.py | 28 | ||||
-rw-r--r-- | astroid/node_classes.py | 12 | ||||
-rw-r--r-- | astroid/objects.py | 29 | ||||
-rw-r--r-- | astroid/tests/unittest_brain.py | 5 |
5 files changed, 50 insertions, 28 deletions
@@ -6,6 +6,10 @@ What's New in astroid 2.2.0? ============================ Release Date: TBA +* Fix a bug where a call to a function that has been previously called via + functools.partial was wrongly inferred + + Close PyCQA/pylint#2588 * Fix a recursion error caused by inferring the ``slice`` builtin. diff --git a/astroid/brain/brain_functools.py b/astroid/brain/brain_functools.py index 473b5bfe..726c82fc 100644 --- a/astroid/brain/brain_functools.py +++ b/astroid/brain/brain_functools.py @@ -12,6 +12,7 @@ from astroid import extract_node from astroid import helpers from astroid.interpreter import objectmodel from astroid import MANAGER +from astroid import objects LRU_CACHE = "functools.lru_cache" @@ -98,31 +99,8 @@ def _functools_partial_inference(node, context=None): "wrapped function received unknown parameters" ) - # Return a wrapped() object that can be used further for inference - class PartialFunction(astroid.FunctionDef): - - filled_positionals = len(call.positional_arguments[1:]) - filled_keywords = list(call.keyword_arguments) - - def infer_call_result(self, caller=None, context=None): - nonlocal call - filled_args = call.positional_arguments[1:] - filled_keywords = call.keyword_arguments - - if context: - current_passed_keywords = { - keyword for (keyword, _) in context.callcontext.keywords - } - for keyword, value in filled_keywords.items(): - if keyword not in current_passed_keywords: - context.callcontext.keywords.append((keyword, value)) - - call_context_args = context.callcontext.args or [] - context.callcontext.args = filled_args + call_context_args - - return super().infer_call_result(caller=caller, context=context) - - partial_function = PartialFunction( + partial_function = objects.PartialFunction( + call, name=inferred_wrapped_function.name, doc=inferred_wrapped_function.doc, lineno=inferred_wrapped_function.lineno, diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 4e5f5612..0ab9be08 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1217,8 +1217,16 @@ class LookupMixIn: # want to clear previous assignments if any (hence the test on # optional_assign) if not (optional_assign or are_exclusive(_stmts[pindex], node)): - del _stmt_parents[pindex] - del _stmts[pindex] + if ( + # In case of partial function node, if the statement is different + # from the origin function then it can be deleted otherwise it should + # remain to be able to correctly infer the call to origin function. + not node.is_function + or node.qname() != "PartialFunction" + or node.name != _stmts[pindex].name + ): + del _stmt_parents[pindex] + del _stmts[pindex] if isinstance(node, AssignName): if not optional_assign and stmt.parent is mystmt.parent: _stmts = [] diff --git a/astroid/objects.py b/astroid/objects.py index b68d3596..b511cda0 100644 --- a/astroid/objects.py +++ b/astroid/objects.py @@ -248,6 +248,35 @@ class DictValues(bases.Proxy): __repr__ = node_classes.NodeNG.__repr__ +class PartialFunction(scoped_nodes.FunctionDef): + """A class representing partial function obtained via functools.partial""" + + def __init__( + self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None + ): + super().__init__(name, doc, lineno, col_offset, parent) + self.filled_positionals = len(call.positional_arguments[1:]) + self.filled_args = call.positional_arguments[1:] + self.filled_keywords = call.keyword_arguments + + def infer_call_result(self, caller=None, context=None): + if context: + current_passed_keywords = { + keyword for (keyword, _) in context.callcontext.keywords + } + for keyword, value in self.filled_keywords.items(): + if keyword not in current_passed_keywords: + context.callcontext.keywords.append((keyword, value)) + + call_context_args = context.callcontext.args or [] + context.callcontext.args = self.filled_args + call_context_args + + return super().infer_call_result(caller=caller, context=context) + + def qname(self): + return self.__class__.__name__ + + # TODO: Hack to solve the circular import problem between node_classes and objects # This is not needed in 2.0, which has a cleaner design overall node_classes.Dict.__bases__ = (node_classes.NodeNG, DictInstance) diff --git a/astroid/tests/unittest_brain.py b/astroid/tests/unittest_brain.py index 031bce7a..5eec16ad 100644 --- a/astroid/tests/unittest_brain.py +++ b/astroid/tests/unittest_brain.py @@ -1776,9 +1776,12 @@ class TestFunctoolsPartial: partial(other_test, c=4)(1, 3) #@ partial(other_test, 4, c=4)(4) #@ partial(other_test, 4, c=4)(b=5) #@ + test(1, 2) #@ + partial(other_test, 1, 2)(c=3) #@ + partial(test, b=4)(a=3) #@ """ ) - expected_values = [4, 7, 7, 3, 12, 16, 32, 36] + expected_values = [4, 7, 7, 3, 12, 16, 32, 36, 3, 9, 7] for node, expected_value in zip(ast_nodes, expected_values): inferred = next(node.infer()) assert isinstance(inferred, astroid.Const) |