summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog4
-rw-r--r--astroid/brain/brain_functools.py28
-rw-r--r--astroid/node_classes.py12
-rw-r--r--astroid/objects.py29
-rw-r--r--astroid/tests/unittest_brain.py5
5 files changed, 50 insertions, 28 deletions
diff --git a/ChangeLog b/ChangeLog
index 118170fd..2ffee9c5 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -6,6 +6,10 @@ What's New in astroid 2.2.0?
============================
Release Date: TBA
+* Fix a bug where a call to a function that has been previously called via
+ functools.partial was wrongly inferred
+
+ Close PyCQA/pylint#2588
* Fix a recursion error caused by inferring the ``slice`` builtin.
diff --git a/astroid/brain/brain_functools.py b/astroid/brain/brain_functools.py
index 473b5bfe..726c82fc 100644
--- a/astroid/brain/brain_functools.py
+++ b/astroid/brain/brain_functools.py
@@ -12,6 +12,7 @@ from astroid import extract_node
from astroid import helpers
from astroid.interpreter import objectmodel
from astroid import MANAGER
+from astroid import objects
LRU_CACHE = "functools.lru_cache"
@@ -98,31 +99,8 @@ def _functools_partial_inference(node, context=None):
"wrapped function received unknown parameters"
)
- # Return a wrapped() object that can be used further for inference
- class PartialFunction(astroid.FunctionDef):
-
- filled_positionals = len(call.positional_arguments[1:])
- filled_keywords = list(call.keyword_arguments)
-
- def infer_call_result(self, caller=None, context=None):
- nonlocal call
- filled_args = call.positional_arguments[1:]
- filled_keywords = call.keyword_arguments
-
- if context:
- current_passed_keywords = {
- keyword for (keyword, _) in context.callcontext.keywords
- }
- for keyword, value in filled_keywords.items():
- if keyword not in current_passed_keywords:
- context.callcontext.keywords.append((keyword, value))
-
- call_context_args = context.callcontext.args or []
- context.callcontext.args = filled_args + call_context_args
-
- return super().infer_call_result(caller=caller, context=context)
-
- partial_function = PartialFunction(
+ partial_function = objects.PartialFunction(
+ call,
name=inferred_wrapped_function.name,
doc=inferred_wrapped_function.doc,
lineno=inferred_wrapped_function.lineno,
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index 4e5f5612..0ab9be08 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -1217,8 +1217,16 @@ class LookupMixIn:
# want to clear previous assignments if any (hence the test on
# optional_assign)
if not (optional_assign or are_exclusive(_stmts[pindex], node)):
- del _stmt_parents[pindex]
- del _stmts[pindex]
+ if (
+ # In case of partial function node, if the statement is different
+ # from the origin function then it can be deleted otherwise it should
+ # remain to be able to correctly infer the call to origin function.
+ not node.is_function
+ or node.qname() != "PartialFunction"
+ or node.name != _stmts[pindex].name
+ ):
+ del _stmt_parents[pindex]
+ del _stmts[pindex]
if isinstance(node, AssignName):
if not optional_assign and stmt.parent is mystmt.parent:
_stmts = []
diff --git a/astroid/objects.py b/astroid/objects.py
index b68d3596..b511cda0 100644
--- a/astroid/objects.py
+++ b/astroid/objects.py
@@ -248,6 +248,35 @@ class DictValues(bases.Proxy):
__repr__ = node_classes.NodeNG.__repr__
+class PartialFunction(scoped_nodes.FunctionDef):
+ """A class representing partial function obtained via functools.partial"""
+
+ def __init__(
+ self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None
+ ):
+ super().__init__(name, doc, lineno, col_offset, parent)
+ self.filled_positionals = len(call.positional_arguments[1:])
+ self.filled_args = call.positional_arguments[1:]
+ self.filled_keywords = call.keyword_arguments
+
+ def infer_call_result(self, caller=None, context=None):
+ if context:
+ current_passed_keywords = {
+ keyword for (keyword, _) in context.callcontext.keywords
+ }
+ for keyword, value in self.filled_keywords.items():
+ if keyword not in current_passed_keywords:
+ context.callcontext.keywords.append((keyword, value))
+
+ call_context_args = context.callcontext.args or []
+ context.callcontext.args = self.filled_args + call_context_args
+
+ return super().infer_call_result(caller=caller, context=context)
+
+ def qname(self):
+ return self.__class__.__name__
+
+
# TODO: Hack to solve the circular import problem between node_classes and objects
# This is not needed in 2.0, which has a cleaner design overall
node_classes.Dict.__bases__ = (node_classes.NodeNG, DictInstance)
diff --git a/astroid/tests/unittest_brain.py b/astroid/tests/unittest_brain.py
index 031bce7a..5eec16ad 100644
--- a/astroid/tests/unittest_brain.py
+++ b/astroid/tests/unittest_brain.py
@@ -1776,9 +1776,12 @@ class TestFunctoolsPartial:
partial(other_test, c=4)(1, 3) #@
partial(other_test, 4, c=4)(4) #@
partial(other_test, 4, c=4)(b=5) #@
+ test(1, 2) #@
+ partial(other_test, 1, 2)(c=3) #@
+ partial(test, b=4)(a=3) #@
"""
)
- expected_values = [4, 7, 7, 3, 12, 16, 32, 36]
+ expected_values = [4, 7, 7, 3, 12, 16, 32, 36, 3, 9, 7]
for node, expected_value in zip(ast_nodes, expected_values):
inferred = next(node.infer())
assert isinstance(inferred, astroid.Const)