diff options
-rw-r--r-- | astroid/__init__.py | 3 | ||||
-rw-r--r-- | astroid/test_utils.py | 1 | ||||
-rw-r--r-- | astroid/tests/unittest_zipper.py | 16 | ||||
-rw-r--r-- | astroid/tree/zipper.py | 28 |
4 files changed, 37 insertions, 11 deletions
diff --git a/astroid/__init__.py b/astroid/__init__.py index e32e6ba0..7d84335e 100644 --- a/astroid/__init__.py +++ b/astroid/__init__.py @@ -82,8 +82,7 @@ from astroid.interpreter.lookup import builtin_lookup from astroid.builder import parse from astroid.util import Uninferable, YES -# ??? -from astroid.tree import zipper +# TODO from astroid.tree.base import NodeNG # transform utilities (filters and decorator) diff --git a/astroid/test_utils.py b/astroid/test_utils.py index ff81cd59..c1ff5ebd 100644 --- a/astroid/test_utils.py +++ b/astroid/test_utils.py @@ -37,6 +37,7 @@ def _extract_expressions(node): and isinstance(node.func, nodes.Name) and node.func.name == _TRANSIENT_FUNCTION): real_expr = node.args[0] + # real_expr = node.down().right().down() real_expr.parent = node.parent # Search for node in all _astng_fields (the fields checked when # get_children is called) of its parent. Some of those fields may diff --git a/astroid/tests/unittest_zipper.py b/astroid/tests/unittest_zipper.py index 1c2cca17..a10b7179 100644 --- a/astroid/tests/unittest_zipper.py +++ b/astroid/tests/unittest_zipper.py @@ -83,7 +83,7 @@ def ast_from_file_name(name): return AST_CACHE[name] with open(name, 'r') as source_file: # print(name) - root = astroid.parse(source_file.read()) + root = astroid.parse(source_file.read()).__wrapped__ ast = ASTMap() AssignLabels()(ast, root) to_visit = [1] @@ -137,7 +137,7 @@ def postorder_descendants(label, ast, dont_recurse_on=None): # This test function uses a set-based implementation for finding the # common parent rather than the reverse-based implementation in the -# functional code. +# actual code. def common_ancestor(label1, label2, ast): ancestors = set() while label1: @@ -162,6 +162,14 @@ def traverse_to_node(label, ast, location): location = move(location) return location +def get_children(label, ast): + for child_label in ast[label].children: + if isinstance(ast[child_label].node, collections.Sequence): + for grandchild_label in ast[child_label].children: + yield grandchild_label + else: + yield child_label + # This function and strategy creates a strategy for generating a # random node class, for testing that the iterators properly exclude # nodes of that type and their descendants. @@ -213,7 +221,7 @@ class TestZipper(unittest.TestCase): nodes = tuple(ast) random_label = choice(nodes) random_node = zipper.Zipper(ast[random_label].node) - for node, label in zip(random_node.get_children(), ast[random_label].children): + for node, label in zip(random_node.children(), ast[random_label].children): self.assertIs(node.__wrapped__, ast[label].node) for node, label in zip(random_node.preorder_descendants(), preorder_descendants(random_label, ast)): self.assertIs(node.__wrapped__, ast[label].node) @@ -223,6 +231,8 @@ class TestZipper(unittest.TestCase): self.assertIs(node.__wrapped__, ast[label].node) for node, label in zip(random_node.postorder_descendants(dont_recurse_on=node_type), postorder_descendants(random_label, ast, dont_recurse_on=node_type)): self.assertIs(node.__wrapped__, ast[label].node) + for node, label in zip(random_node.get_children(), get_children(random_label, ast)): + self.assertIs(node.__wrapped__, ast[label].node) @hypothesis.settings(perform_health_check=False) @hypothesis.given(ast_strategy, strategies.choices()) diff --git a/astroid/tree/zipper.py b/astroid/tree/zipper.py index cd30b6e9..51360851 100644 --- a/astroid/tree/zipper.py +++ b/astroid/tree/zipper.py @@ -260,7 +260,7 @@ class Zipper(wrapt.ObjectProxy): location = location.up() return location - def get_children(self): + def children(self): '''Iterates over the children of the focus.''' child = self.down() while child is not None: @@ -285,10 +285,10 @@ class Zipper(wrapt.ObjectProxy): yield location if dont_recurse_on is None: to_visit.extend(c for c in - reversed(tuple(location.get_children()))) + reversed(tuple(location.children()))) else: to_visit.extend(c for c in - reversed(tuple(location.get_children())) + reversed(tuple(location.children())) if not isinstance(c, dont_recurse_on)) def postorder_descendants(self, dont_recurse_on=None): @@ -306,10 +306,10 @@ class Zipper(wrapt.ObjectProxy): visited_ancestors.append(location) if dont_recurse_on is None: to_visit.extend(c for c in - reversed(tuple(location.get_children()))) + reversed(tuple(location.children()))) else: to_visit.extend(c for c in - reversed(tuple(location.get_children())) + reversed(tuple(location.children())) if not isinstance(c, dont_recurse_on)) continue visited_ancestors.pop() @@ -325,7 +325,7 @@ class Zipper(wrapt.ObjectProxy): not include nodes of this type or types or any of the descendants of those nodes. ''' - return (d for d in self.preorder_descendants(skip_class) if isinstance(node, cls)) + return (d for d in self.preorder_descendants(skip_class) if isinstance(d, cls)) # if isinstance(self, cls): # yield self # child = self.down() @@ -350,6 +350,22 @@ class Zipper(wrapt.ObjectProxy): else: return location + def get_children(self): + '''Iterates over nodes that are children or grandchildren, no + sequences. + + ''' + child = self.down() + while child is not None: + if isinstance(child, collections.Sequence): + grandchild = child.down() + for _ in range(len(child)): + yield grandchild + grandchild = grandchild.right() + else: + yield child + child = child.right() + def last_child(self): return self.rightmost() |