diff options
author | Alphadelta14 <alpha@alphaservcomputing.solutions> | 2021-08-01 15:05:25 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-08-01 21:05:25 +0200 |
commit | ac95965bf32b4b4094bacedb422ba5de515bb85c (patch) | |
tree | 49136184e8d2d5abde6075c58ed95506dfa422b9 | |
parent | e17862654fc0585e63f5ab6785eab23787e91e9b (diff) | |
download | astroid-git-ac95965bf32b4b4094bacedb422ba5de515bb85c.tar.gz |
Fix incorrect scope for functools partials (#1097)
* Use scope parent
* Do not set name onto parent frame for partials
* Add test case that captures broken scopes
-rw-r--r-- | astroid/brain/brain_functools.py | 2 | ||||
-rw-r--r-- | astroid/node_classes.py | 12 | ||||
-rw-r--r-- | astroid/objects.py | 5 | ||||
-rw-r--r-- | tests/unittest_brain.py | 40 |
4 files changed, 47 insertions, 12 deletions
diff --git a/astroid/brain/brain_functools.py b/astroid/brain/brain_functools.py index 248e0fb9..9804d535 100644 --- a/astroid/brain/brain_functools.py +++ b/astroid/brain/brain_functools.py @@ -103,7 +103,7 @@ def _functools_partial_inference(node, context=None): doc=inferred_wrapped_function.doc, lineno=inferred_wrapped_function.lineno, col_offset=inferred_wrapped_function.col_offset, - parent=inferred_wrapped_function.parent, + parent=node.parent, ) partial_function.postinit( args=inferred_wrapped_function.args, diff --git a/astroid/node_classes.py b/astroid/node_classes.py index b3cb73c9..e9b2fcfa 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1209,16 +1209,8 @@ class LookupMixIn: # want to clear previous assignments if any (hence the test on # optional_assign) if not (optional_assign or are_exclusive(_stmts[pindex], node)): - if ( - # In case of partial function node, if the statement is different - # from the origin function then it can be deleted otherwise it should - # remain to be able to correctly infer the call to origin function. - not node.is_function - or node.qname() != "PartialFunction" - or node.name != _stmts[pindex].name - ): - del _stmt_parents[pindex] - del _stmts[pindex] + del _stmt_parents[pindex] + del _stmts[pindex] if isinstance(node, AssignName): if not optional_assign and stmt.parent is mystmt.parent: _stmts = [] diff --git a/astroid/objects.py b/astroid/objects.py index 1a32b52d..a598c5bb 100644 --- a/astroid/objects.py +++ b/astroid/objects.py @@ -260,7 +260,10 @@ class PartialFunction(scoped_nodes.FunctionDef): def __init__( self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None ): - super().__init__(name, doc, lineno, col_offset, parent) + super().__init__(name, doc, lineno, col_offset, parent=None) + # A typical FunctionDef automatically adds its name to the parent scope, + # but a partial should not, so defer setting parent until after init + self.parent = parent self.filled_positionals = len(call.positional_arguments[1:]) self.filled_args = call.positional_arguments[1:] self.filled_keywords = call.keyword_arguments diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index dd040bd4..b4bb85c3 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -2761,6 +2761,46 @@ class TestFunctoolsPartial: assert isinstance(inferred, astroid.Const) assert inferred.value == expected_value + def test_partial_assignment(self): + """Make sure partials are not assigned to original scope.""" + ast_nodes = astroid.extract_node( + """ + from functools import partial + def test(a, b): #@ + return a + b + test2 = partial(test, 1) + test2 #@ + def test3_scope(a): + test3 = partial(test, a) + test3 #@ + """ + ) + func1, func2, func3 = ast_nodes + assert func1.parent.scope() == func2.parent.scope() + assert func1.parent.scope() != func3.parent.scope() + partial_func3 = next(func3.infer()) + # use scope of parent, so that it doesn't just refer to self + scope = partial_func3.parent.scope() + assert scope.name == "test3_scope", "parented by closure" + + def test_partial_does_not_affect_scope(self): + """Make sure partials are not automatically assigned.""" + ast_nodes = astroid.extract_node( + """ + from functools import partial + def test(a, b): + return a + b + def scope(): + test2 = partial(test, 1) + test2 #@ + """ + ) + test2 = next(ast_nodes.infer()) + mod_scope = test2.root() + scope = test2.parent.scope() + assert set(mod_scope) == {"test", "scope", "partial"} + assert set(scope) == {"test2"} + def test_http_client_brain(): node = astroid.extract_node( |