diff options
Diffstat (limited to 'astroid/test_utils.py')
-rw-r--r-- | astroid/test_utils.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/astroid/test_utils.py b/astroid/test_utils.py index b15ab766..ff81cd59 100644 --- a/astroid/test_utils.py +++ b/astroid/test_utils.py @@ -19,6 +19,7 @@ _TRANSIENT_FUNCTION = '__' # when calling extract_node. _STATEMENT_SELECTOR = '#@' + def _extract_expressions(node): """Find expressions in a call to _TRANSIENT_FUNCTION and extract them. @@ -46,8 +47,18 @@ def _extract_expressions(node): child = getattr(node.parent, name) if isinstance(child, (list, tuple)): for idx, compound_child in enumerate(child): - if compound_child is node: + + # Can't find a cleaner way to do this. + if isinstance(compound_child, nodes.Parameter): + if compound_child.default is node: + child[idx].default = real_expr + elif compound_child.annotation is node: + child[idx].annotation = real_expr + else: + child[idx] = real_expr + elif compound_child is node: child[idx] = real_expr + elif child is node: setattr(node.parent, name, real_expr) yield real_expr |