summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhippo91 <guillaume.peillex@gmail.com>2018-12-02 16:20:06 +0100
committerClaudiu Popa <pcmanticore@gmail.com>2019-01-23 09:09:22 +0100
commit87b55a6ea6a421b4702dbc178a45165f20d1e775 (patch)
tree18c75077ec735554381ecd88be576c3f18cec91a
parentc54ca8e30d8cce8660ddac842c58b66de2733322 (diff)
downloadastroid-git-87b55a6ea6a421b4702dbc178a45165f20d1e775.tar.gz
Avoid statement deletion in the _filter_stmts method of the LookupMixin class for PartialFunction
In the case where the node is a PartialFunction and its name is the same as the current statement's name, avoid the statement deletion. The problem was that a call to a function that has been previously called vit a functools.partial was wrongly inferred. The bug comes from the _filter_stmts method of the LookupMixin class. The deletion of the current statement should not be made in the case where the node is an instance of the PartialFunction class and if the node's name is the same as the statement's name. This change also extracts PartialFunction from brain_functools into astroid.objects so that we remove a circular import problem. Close PyCQA/pylint#2588
-rw-r--r--ChangeLog4
-rw-r--r--astroid/brain/brain_functools.py28
-rw-r--r--astroid/node_classes.py12
-rw-r--r--astroid/objects.py29
-rw-r--r--astroid/tests/unittest_brain.py5
5 files changed, 50 insertions, 28 deletions
diff --git a/ChangeLog b/ChangeLog
index 118170fd..2ffee9c5 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -6,6 +6,10 @@ What's New in astroid 2.2.0?
============================
Release Date: TBA
+* Fix a bug where a call to a function that has been previously called via
+ functools.partial was wrongly inferred
+
+ Close PyCQA/pylint#2588
* Fix a recursion error caused by inferring the ``slice`` builtin.
diff --git a/astroid/brain/brain_functools.py b/astroid/brain/brain_functools.py
index 473b5bfe..726c82fc 100644
--- a/astroid/brain/brain_functools.py
+++ b/astroid/brain/brain_functools.py
@@ -12,6 +12,7 @@ from astroid import extract_node
from astroid import helpers
from astroid.interpreter import objectmodel
from astroid import MANAGER
+from astroid import objects
LRU_CACHE = "functools.lru_cache"
@@ -98,31 +99,8 @@ def _functools_partial_inference(node, context=None):
"wrapped function received unknown parameters"
)
- # Return a wrapped() object that can be used further for inference
- class PartialFunction(astroid.FunctionDef):
-
- filled_positionals = len(call.positional_arguments[1:])
- filled_keywords = list(call.keyword_arguments)
-
- def infer_call_result(self, caller=None, context=None):
- nonlocal call
- filled_args = call.positional_arguments[1:]
- filled_keywords = call.keyword_arguments
-
- if context:
- current_passed_keywords = {
- keyword for (keyword, _) in context.callcontext.keywords
- }
- for keyword, value in filled_keywords.items():
- if keyword not in current_passed_keywords:
- context.callcontext.keywords.append((keyword, value))
-
- call_context_args = context.callcontext.args or []
- context.callcontext.args = filled_args + call_context_args
-
- return super().infer_call_result(caller=caller, context=context)
-
- partial_function = PartialFunction(
+ partial_function = objects.PartialFunction(
+ call,
name=inferred_wrapped_function.name,
doc=inferred_wrapped_function.doc,
lineno=inferred_wrapped_function.lineno,
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index 4e5f5612..0ab9be08 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -1217,8 +1217,16 @@ class LookupMixIn:
# want to clear previous assignments if any (hence the test on
# optional_assign)
if not (optional_assign or are_exclusive(_stmts[pindex], node)):
- del _stmt_parents[pindex]
- del _stmts[pindex]
+ if (
+ # In case of partial function node, if the statement is different
+ # from the origin function then it can be deleted otherwise it should
+ # remain to be able to correctly infer the call to origin function.
+ not node.is_function
+ or node.qname() != "PartialFunction"
+ or node.name != _stmts[pindex].name
+ ):
+ del _stmt_parents[pindex]
+ del _stmts[pindex]
if isinstance(node, AssignName):
if not optional_assign and stmt.parent is mystmt.parent:
_stmts = []
diff --git a/astroid/objects.py b/astroid/objects.py
index b68d3596..b511cda0 100644
--- a/astroid/objects.py
+++ b/astroid/objects.py
@@ -248,6 +248,35 @@ class DictValues(bases.Proxy):
__repr__ = node_classes.NodeNG.__repr__
+class PartialFunction(scoped_nodes.FunctionDef):
+ """A class representing partial function obtained via functools.partial"""
+
+ def __init__(
+ self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None
+ ):
+ super().__init__(name, doc, lineno, col_offset, parent)
+ self.filled_positionals = len(call.positional_arguments[1:])
+ self.filled_args = call.positional_arguments[1:]
+ self.filled_keywords = call.keyword_arguments
+
+ def infer_call_result(self, caller=None, context=None):
+ if context:
+ current_passed_keywords = {
+ keyword for (keyword, _) in context.callcontext.keywords
+ }
+ for keyword, value in self.filled_keywords.items():
+ if keyword not in current_passed_keywords:
+ context.callcontext.keywords.append((keyword, value))
+
+ call_context_args = context.callcontext.args or []
+ context.callcontext.args = self.filled_args + call_context_args
+
+ return super().infer_call_result(caller=caller, context=context)
+
+ def qname(self):
+ return self.__class__.__name__
+
+
# TODO: Hack to solve the circular import problem between node_classes and objects
# This is not needed in 2.0, which has a cleaner design overall
node_classes.Dict.__bases__ = (node_classes.NodeNG, DictInstance)
diff --git a/astroid/tests/unittest_brain.py b/astroid/tests/unittest_brain.py
index 031bce7a..5eec16ad 100644
--- a/astroid/tests/unittest_brain.py
+++ b/astroid/tests/unittest_brain.py
@@ -1776,9 +1776,12 @@ class TestFunctoolsPartial:
partial(other_test, c=4)(1, 3) #@
partial(other_test, 4, c=4)(4) #@
partial(other_test, 4, c=4)(b=5) #@
+ test(1, 2) #@
+ partial(other_test, 1, 2)(c=3) #@
+ partial(test, b=4)(a=3) #@
"""
)
- expected_values = [4, 7, 7, 3, 12, 16, 32, 36]
+ expected_values = [4, 7, 7, 3, 12, 16, 32, 36, 3, 9, 7]
for node, expected_value in zip(ast_nodes, expected_values):
inferred = next(node.infer())
assert isinstance(inferred, astroid.Const)