summaryrefslogtreecommitdiff
path: root/test_utils.py
diff options
context:
space:
mode:
authorTorsten Marek <tmarek@google.com>2013-06-20 13:51:04 +0200
committerTorsten Marek <tmarek@google.com>2013-06-20 13:51:04 +0200
commit03ce489bdcb712e23d82d1917153b9aae51dd1a8 (patch)
treefb9b404dec2d420ef07d1625e82c32aa8a5052a3 /test_utils.py
parentf816bceffc7c43c049773223f94b776961fd3e78 (diff)
downloadastroid-git-03ce489bdcb712e23d82d1917153b9aae51dd1a8.tar.gz
Allow selecting several statements/expressions in test_utils.extract_node.
Diffstat (limited to 'test_utils.py')
-rw-r--r--test_utils.py67
1 files changed, 38 insertions, 29 deletions
diff --git a/test_utils.py b/test_utils.py
index 7b027043..8e54eb92 100644
--- a/test_utils.py
+++ b/test_utils.py
@@ -14,19 +14,18 @@ _TRANSIENT_FUNCTION = '__'
_STATEMENT_SELECTOR = '#@'
-def _extract_expression(node):
- """Find an expression in a call to _TRANSIENT_FUNCTION and extract it.
+def _extract_expressions(node):
+ """Find expressions in a call to _TRANSIENT_FUNCTION and extract them.
- The function walks the AST recursively to search for an expression that
- is wrapped into a call to _TRANSIENT_FUNCTION. If it finds such an
+ The function walks the AST recursively to search for expressions that
+ are wrapped into a call to _TRANSIENT_FUNCTION. If it finds such an
expression, it completely removes the function call node from the tree,
replacing it by the wrapped expression inside the parent.
:param node: An astroid node.
:type node: astroid.bases.NodeNG
- :returns: The wrapped expression on the modified tree, or None if no such
+ :yields: The sequence of wrapped expressions on the modified tree
expression can be found.
- :rtype: astroid.bases.NodeNG or None
"""
if (isinstance(node, nodes.CallFunc)
and isinstance(node.func, nodes.Name)
@@ -46,12 +45,11 @@ def _extract_expression(node):
child[idx] = real_expr
elif child is node:
setattr(node.parent, name, real_expr)
- return real_expr
+ yield real_expr
else:
for child in node.get_children():
- result = _extract_expression(child)
- if result:
- return result
+ for result in _extract_expressions(child):
+ yield result
def _find_statement_by_line(node, line):
@@ -92,7 +90,7 @@ def extract_node(code, module_name=''):
"""Parses some Python code as a module and extracts a designated AST node.
Statements:
- To extract a statement node, append #@ to the end of the line
+ To extract one or more statement nodes, append #@ to the end of the line
Examples:
>>> def x():
@@ -107,8 +105,8 @@ def extract_node(code, module_name=''):
The funcion object 'meth' will be extracted.
- Expression:
- To extract an arbitrary expression, surround it with the fake
+ Expressions:
+ To extract arbitrary expressions, surround them with the fake
function call __(...). After parsing, the surrounded expression
will be returned and the whole AST (accessible via the returned
node's parent attribute) will look like the function call was
@@ -128,37 +126,48 @@ def extract_node(code, module_name=''):
The node containing the function call 'len' will be extracted.
- If no statement or expression is selected, the last toplevel
+ If no statements or expressions are selected, the last toplevel
statement will be returned.
- If the selected is a discard statement, (i.e. an expression turned
- into a statement), the wrapped expression is returned instead.
+ If the selected statement is a discard statement, (i.e. an expression
+ turned into a statement), the wrapped expression is returned instead.
+
+ For convenience, singleton lists are unpacked.
:param str code: A piece of Python code that is parsed as
a module. Will be passed through textwrap.dedent first.
:param str module_name: The name of the module.
- :returns: The designated node from the parse tree.
- :rtype: astroid.bases.NodeNG
+ :returns: The designated node from the parse tree, or a list of nodes.
+ :rtype: astroid.bases.NodeNG, or a list of nodes.
"""
- requested_line = None
+ def _extract(node):
+ if isinstance(node, nodes.Discard):
+ return node.value
+ else:
+ return node
+
+ requested_lines = []
for idx, line in enumerate(code.splitlines()):
if line.strip().endswith(_STATEMENT_SELECTOR):
- requested_line = idx + 1
- break
+ requested_lines.append(idx + 1)
tree = build_module(code, module_name=module_name)
- if requested_line:
- node = _find_statement_by_line(tree, requested_line)
+ extracted = []
+ if requested_lines:
+ for line in requested_lines:
+ extracted.append(_find_statement_by_line(tree, line))
else:
- node = _extract_expression(tree)
+ # Modifies the tree.
+ extracted = list(_extract_expressions(tree))
- if not node:
- node = tree.body[-1]
+ if not extracted:
+ extracted.append(tree.body[-1])
- if isinstance(node, nodes.Discard):
- return node.value
+ extracted = [_extract(node) for node in extracted]
+ if len(extracted) == 1:
+ return extracted[0]
else:
- return node
+ return extracted
def build_module(code, module_name=''):