diff options
-rw-r--r-- | ChangeLog | 4 | ||||
-rw-r--r-- | astroid/brain/brain_functools.py | 28 | ||||
-rw-r--r-- | astroid/node_classes.py | 12 | ||||
-rw-r--r-- | astroid/objects.py | 29 | ||||
-rw-r--r-- | astroid/tests/unittest_brain.py | 5 |
5 files changed, 50 insertions, 28 deletions
@@ -6,6 +6,10 @@ What's New in astroid 2.2.0? ============================ Release Date: TBA +* Fix a bug where a call to a function that has been previously called via + functools.partial was wrongly inferred + + Close PyCQA/pylint#2588 * Fix a recursion error caused by inferring the ``slice`` builtin. diff --git a/astroid/brain/brain_functools.py b/astroid/brain/brain_functools.py index 473b5bfe..726c82fc 100644 --- a/astroid/brain/brain_functools.py +++ b/astroid/brain/brain_functools.py @@ -12,6 +12,7 @@ from astroid import extract_node from astroid import helpers from astroid.interpreter import objectmodel from astroid import MANAGER +from astroid import objects LRU_CACHE = "functools.lru_cache" @@ -98,31 +99,8 @@ def _functools_partial_inference(node, context=None): "wrapped function received unknown parameters" ) - # Return a wrapped() object that can be used further for inference - class PartialFunction(astroid.FunctionDef): - - filled_positionals = len(call.positional_arguments[1:]) - filled_keywords = list(call.keyword_arguments) - - def infer_call_result(self, caller=None, context=None): - nonlocal call - filled_args = call.positional_arguments[1:] - filled_keywords = call.keyword_arguments - - if context: - current_passed_keywords = { - keyword for (keyword, _) in context.callcontext.keywords - } - for keyword, value in filled_keywords.items(): - if keyword not in current_passed_keywords: - context.callcontext.keywords.append((keyword, value)) - - call_context_args = context.callcontext.args or [] - context.callcontext.args = filled_args + call_context_args - - return super().infer_call_result(caller=caller, context=context) - - partial_function = PartialFunction( + partial_function = objects.PartialFunction( + call, name=inferred_wrapped_function.name, doc=inferred_wrapped_function.doc, lineno=inferred_wrapped_function.lineno, diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 4e5f5612..0ab9be08 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1217,8 +1217,16 @@ class LookupMixIn: # want to clear previous assignments if any (hence the test on # optional_assign) if not (optional_assign or are_exclusive(_stmts[pindex], node)): - del _stmt_parents[pindex] - del _stmts[pindex] + if ( + # In case of partial function node, if the statement is different + # from the origin function then it can be deleted otherwise it should + # remain to be able to correctly infer the call to origin function. + not node.is_function + or node.qname() != "PartialFunction" + or node.name != _stmts[pindex].name + ): + del _stmt_parents[pindex] + del _stmts[pindex] if isinstance(node, AssignName): if not optional_assign and stmt.parent is mystmt.parent: _stmts = [] diff --git a/astroid/objects.py b/astroid/objects.py index b68d3596..b511cda0 100644 --- a/astroid/objects.py +++ b/astroid/objects.py @@ -248,6 +248,35 @@ class DictValues(bases.Proxy): __repr__ = node_classes.NodeNG.__repr__ +class PartialFunction(scoped_nodes.FunctionDef): + """A class representing partial function obtained via functools.partial""" + + def __init__( + self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None + ): + super().__init__(name, doc, lineno, col_offset, parent) + self.filled_positionals = len(call.positional_arguments[1:]) + self.filled_args = call.positional_arguments[1:] + self.filled_keywords = call.keyword_arguments + + def infer_call_result(self, caller=None, context=None): + if context: + current_passed_keywords = { + keyword for (keyword, _) in context.callcontext.keywords + } + for keyword, value in self.filled_keywords.items(): + if keyword not in current_passed_keywords: + context.callcontext.keywords.append((keyword, value)) + + call_context_args = context.callcontext.args or [] + context.callcontext.args = self.filled_args + call_context_args + + return super().infer_call_result(caller=caller, context=context) + + def qname(self): + return self.__class__.__name__ + + # TODO: Hack to solve the circular import problem between node_classes and objects # This is not needed in 2.0, which has a cleaner design overall node_classes.Dict.__bases__ = (node_classes.NodeNG, DictInstance) diff --git a/astroid/tests/unittest_brain.py b/astroid/tests/unittest_brain.py index 031bce7a..5eec16ad 100644 --- a/astroid/tests/unittest_brain.py +++ b/astroid/tests/unittest_brain.py @@ -1776,9 +1776,12 @@ class TestFunctoolsPartial: partial(other_test, c=4)(1, 3) #@ partial(other_test, 4, c=4)(4) #@ partial(other_test, 4, c=4)(b=5) #@ + test(1, 2) #@ + partial(other_test, 1, 2)(c=3) #@ + partial(test, b=4)(a=3) #@ """ ) - expected_values = [4, 7, 7, 3, 12, 16, 32, 36] + expected_values = [4, 7, 7, 3, 12, 16, 32, 36, 3, 9, 7] for node, expected_value in zip(ast_nodes, expected_values): inferred = next(node.infer()) assert isinstance(inferred, astroid.Const) |