diff options
author | Torsten Marek <tmarek@google.com> | 2013-06-20 13:51:04 +0200 |
---|---|---|
committer | Torsten Marek <tmarek@google.com> | 2013-06-20 13:51:04 +0200 |
commit | 03ce489bdcb712e23d82d1917153b9aae51dd1a8 (patch) | |
tree | fb9b404dec2d420ef07d1625e82c32aa8a5052a3 /test_utils.py | |
parent | f816bceffc7c43c049773223f94b776961fd3e78 (diff) | |
download | astroid-git-03ce489bdcb712e23d82d1917153b9aae51dd1a8.tar.gz |
Allow selecting several statements/expressions in test_utils.extract_node.
Diffstat (limited to 'test_utils.py')
-rw-r--r-- | test_utils.py | 67 |
1 files changed, 38 insertions, 29 deletions
diff --git a/test_utils.py b/test_utils.py index 7b027043..8e54eb92 100644 --- a/test_utils.py +++ b/test_utils.py @@ -14,19 +14,18 @@ _TRANSIENT_FUNCTION = '__' _STATEMENT_SELECTOR = '#@' -def _extract_expression(node): - """Find an expression in a call to _TRANSIENT_FUNCTION and extract it. +def _extract_expressions(node): + """Find expressions in a call to _TRANSIENT_FUNCTION and extract them. - The function walks the AST recursively to search for an expression that - is wrapped into a call to _TRANSIENT_FUNCTION. If it finds such an + The function walks the AST recursively to search for expressions that + are wrapped into a call to _TRANSIENT_FUNCTION. If it finds such an expression, it completely removes the function call node from the tree, replacing it by the wrapped expression inside the parent. :param node: An astroid node. :type node: astroid.bases.NodeNG - :returns: The wrapped expression on the modified tree, or None if no such + :yields: The sequence of wrapped expressions on the modified tree expression can be found. - :rtype: astroid.bases.NodeNG or None """ if (isinstance(node, nodes.CallFunc) and isinstance(node.func, nodes.Name) @@ -46,12 +45,11 @@ def _extract_expression(node): child[idx] = real_expr elif child is node: setattr(node.parent, name, real_expr) - return real_expr + yield real_expr else: for child in node.get_children(): - result = _extract_expression(child) - if result: - return result + for result in _extract_expressions(child): + yield result def _find_statement_by_line(node, line): @@ -92,7 +90,7 @@ def extract_node(code, module_name=''): """Parses some Python code as a module and extracts a designated AST node. Statements: - To extract a statement node, append #@ to the end of the line + To extract one or more statement nodes, append #@ to the end of the line Examples: >>> def x(): @@ -107,8 +105,8 @@ def extract_node(code, module_name=''): The funcion object 'meth' will be extracted. - Expression: - To extract an arbitrary expression, surround it with the fake + Expressions: + To extract arbitrary expressions, surround them with the fake function call __(...). After parsing, the surrounded expression will be returned and the whole AST (accessible via the returned node's parent attribute) will look like the function call was @@ -128,37 +126,48 @@ def extract_node(code, module_name=''): The node containing the function call 'len' will be extracted. - If no statement or expression is selected, the last toplevel + If no statements or expressions are selected, the last toplevel statement will be returned. - If the selected is a discard statement, (i.e. an expression turned - into a statement), the wrapped expression is returned instead. + If the selected statement is a discard statement, (i.e. an expression + turned into a statement), the wrapped expression is returned instead. + + For convenience, singleton lists are unpacked. :param str code: A piece of Python code that is parsed as a module. Will be passed through textwrap.dedent first. :param str module_name: The name of the module. - :returns: The designated node from the parse tree. - :rtype: astroid.bases.NodeNG + :returns: The designated node from the parse tree, or a list of nodes. + :rtype: astroid.bases.NodeNG, or a list of nodes. """ - requested_line = None + def _extract(node): + if isinstance(node, nodes.Discard): + return node.value + else: + return node + + requested_lines = [] for idx, line in enumerate(code.splitlines()): if line.strip().endswith(_STATEMENT_SELECTOR): - requested_line = idx + 1 - break + requested_lines.append(idx + 1) tree = build_module(code, module_name=module_name) - if requested_line: - node = _find_statement_by_line(tree, requested_line) + extracted = [] + if requested_lines: + for line in requested_lines: + extracted.append(_find_statement_by_line(tree, line)) else: - node = _extract_expression(tree) + # Modifies the tree. + extracted = list(_extract_expressions(tree)) - if not node: - node = tree.body[-1] + if not extracted: + extracted.append(tree.body[-1]) - if isinstance(node, nodes.Discard): - return node.value + extracted = [_extract(node) for node in extracted] + if len(extracted) == 1: + return extracted[0] else: - return node + return extracted def build_module(code, module_name=''): |