summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEli Bendersky <eliben@gmail.com>2012-02-03 11:22:25 +0200
committerEli Bendersky <eliben@gmail.com>2012-02-03 11:22:25 +0200
commit12f0c9d099e997c91a2622c25ce112632ef5e115 (patch)
treeadeaf93e542e7a3d536637f47ce99c6414a392cb
parent5433326654f0f266fa692c6f76477ea0accde6f8 (diff)
downloadpycparser-12f0c9d099e997c91a2622c25ce112632ef5e115.tar.gz
Transform the AST to create a correct representation of the cases inside a switch statement
-rw-r--r--pycparser/_c_ast.cfg4
-rw-r--r--pycparser/ast_transforms.py105
-rw-r--r--pycparser/c_ast.py14
-rw-r--r--pycparser/c_parser.py8
-rw-r--r--tests/test_c_parser.py49
5 files changed, 168 insertions, 12 deletions
diff --git a/pycparser/_c_ast.cfg b/pycparser/_c_ast.cfg
index 9feaf1a..ca2379b 100644
--- a/pycparser/_c_ast.cfg
+++ b/pycparser/_c_ast.cfg
@@ -25,7 +25,7 @@ BinaryOp: [op, left*, right*]
Break: []
-Case: [expr*, stmt*]
+Case: [expr*, stmts**]
Cast: [to_type*, expr*]
@@ -59,7 +59,7 @@ Decl: [name, quals, storage, funcspec, type*, init*, bitsize*]
DeclList: [decls**]
-Default: [stmt*]
+Default: [stmts**]
DoWhile: [cond*, stmt*]
diff --git a/pycparser/ast_transforms.py b/pycparser/ast_transforms.py
new file mode 100644
index 0000000..b30ae3c
--- /dev/null
+++ b/pycparser/ast_transforms.py
@@ -0,0 +1,105 @@
+#------------------------------------------------------------------------------
+# pycparser: ast_transforms.py
+#
+# Some utilities used by the parser to create a friendlier AST.
+#
+# Copyright (C) 2008-2012, Eli Bendersky
+# License: BSD
+#------------------------------------------------------------------------------
+
+from . import c_ast
+
+
+def fix_switch_cases(switch_node):
+ """ The 'case' statements in a 'switch' come out of parsing with one
+ child node, so subsequent statements are just tucked to the parent
+ Compound. Additionally, consecutive (fall-through) case statements
+ come out messy. This is a peculiarity of the C grammar. The following:
+
+ switch (myvar) {
+ case 10:
+ k = 10;
+ p = k + 1;
+ return 10;
+ case 20:
+ case 30:
+ return 20;
+ default:
+ break;
+ }
+
+ Creates this tree (pseudo-dump):
+
+ Switch
+ ID: myvar
+ Compound:
+ Case 10:
+ k = 10
+ p = k + 1
+ return 10
+ Case 20:
+ Case 30:
+ return 20
+ Default:
+ break
+
+ The goal of this transform it to fix this mess, turning it into the
+ following:
+
+ Switch
+ ID: myvar
+ Compound:
+ Case 10:
+ k = 10
+ p = k + 1
+ return 10
+ Case 20:
+ Case 30:
+ return 20
+ Default:
+ break
+
+ A fixed AST node is returned. The argument may be modified.
+ """
+ assert isinstance(switch_node, c_ast.Switch)
+ if not isinstance(switch_node.stmt, c_ast.Compound):
+ return switch_node
+
+ # The new Compound child for the Switch, which will collect children in the
+ # correct order
+ new_compound = c_ast.Compound([], switch_node.stmt.coord)
+
+ # The last Case/Default node
+ last_case = None
+
+ # Goes over the children of the Compound below the Switch, adding them
+ # either directly below new_compound or below the last Case as appropriate
+ for child in switch_node.stmt.block_items:
+ if isinstance(child, (c_ast.Case, c_ast.Default)):
+ # If it's a Case/Default:
+ # 1. Add it to the Compound and mark as "last case"
+ # 2. If its immediate child is also a Case or Default, promote it
+ # to a sibling.
+ new_compound.block_items.append(child)
+ _extract_nested_case(child, new_compound.block_items)
+ last_case = new_compound.block_items[-1]
+ else:
+ # Other statements are added as childrent to the last case, if it
+ # exists.
+ if last_case is None:
+ new_compound.block_items.append(child)
+ else:
+ last_case.stmts.append(child)
+
+ switch_node.stmt = new_compound
+ return switch_node
+
+
+def _extract_nested_case(case_node, stmts_list):
+ """ Recursively extract consecutive Case statements that are made nested
+ by the parser and add them to the stmts_list.
+ """
+ if isinstance(case_node.stmts[0], (c_ast.Case, c_ast.Default)):
+ stmts_list.append(case_node.stmts.pop())
+ _extract_nested_case(stmts_list[-1], stmts_list)
+
diff --git a/pycparser/c_ast.py b/pycparser/c_ast.py
index 5868b9b..a1c92fb 100644
--- a/pycparser/c_ast.py
+++ b/pycparser/c_ast.py
@@ -194,15 +194,16 @@ class Break(Node):
attr_names = ()
class Case(Node):
- def __init__(self, expr, stmt, coord=None):
+ def __init__(self, expr, stmts, coord=None):
self.expr = expr
- self.stmt = stmt
+ self.stmts = stmts
self.coord = coord
def children(self):
nodelist = []
if self.expr is not None: nodelist.append(("expr", self.expr))
- if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ for i, child in enumerate(self.stmts or []):
+ nodelist.append(("stmts[%d]" % i, child))
return tuple(nodelist)
attr_names = ()
@@ -303,13 +304,14 @@ class DeclList(Node):
attr_names = ()
class Default(Node):
- def __init__(self, stmt, coord=None):
- self.stmt = stmt
+ def __init__(self, stmts, coord=None):
+ self.stmts = stmts
self.coord = coord
def children(self):
nodelist = []
- if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ for i, child in enumerate(self.stmts or []):
+ nodelist.append(("stmts[%d]" % i, child))
return tuple(nodelist)
attr_names = ()
diff --git a/pycparser/c_parser.py b/pycparser/c_parser.py
index a53ccd4..cd21c3d 100644
--- a/pycparser/c_parser.py
+++ b/pycparser/c_parser.py
@@ -13,6 +13,7 @@ import ply.yacc
from . import c_ast
from .c_lexer import CLexer
from .plyparser import PLYParser, Coord, ParseError
+from .ast_transforms import fix_switch_cases
class CParser(PLYParser):
@@ -1085,11 +1086,11 @@ class CParser(PLYParser):
def p_labeled_statement_2(self, p):
""" labeled_statement : CASE constant_expression COLON statement """
- p[0] = c_ast.Case(p[2], p[4], self._coord(p.lineno(1)))
+ p[0] = c_ast.Case(p[2], [p[4]], self._coord(p.lineno(1)))
def p_labeled_statement_3(self, p):
""" labeled_statement : DEFAULT COLON statement """
- p[0] = c_ast.Default(p[3], self._coord(p.lineno(1)))
+ p[0] = c_ast.Default([p[3]], self._coord(p.lineno(1)))
def p_selection_statement_1(self, p):
""" selection_statement : IF LPAREN expression RPAREN statement """
@@ -1101,7 +1102,8 @@ class CParser(PLYParser):
def p_selection_statement_3(self, p):
""" selection_statement : SWITCH LPAREN expression RPAREN statement """
- p[0] = c_ast.Switch(p[3], p[5], self._coord(p.lineno(1)))
+ p[0] = fix_switch_cases(
+ c_ast.Switch(p[3], p[5], self._coord(p.lineno(1))))
def p_iteration_statement_1(self, p):
""" iteration_statement : WHILE LPAREN expression RPAREN statement """
diff --git a/tests/test_c_parser.py b/tests/test_c_parser.py
index c4292d7..dbf7533 100644
--- a/tests/test_c_parser.py
+++ b/tests/test_c_parser.py
@@ -1284,6 +1284,14 @@ class TestCParser_whole_code(TestCParser_base):
self.assert_num_klass_nodes(ps1, Return, 1)
def test_switch_statement(self):
+ def assert_case_node(node, const_value):
+ self.failUnless(isinstance(node, Case))
+ self.failUnless(isinstance(node.expr, Constant))
+ self.assertEqual(node.expr.value, const_value)
+
+ def assert_default_node(node):
+ self.failUnless(isinstance(node, Default))
+
s1 = r'''
int foo(void) {
switch (myvar) {
@@ -1301,7 +1309,46 @@ class TestCParser_whole_code(TestCParser_base):
}
'''
ps1 = self.parse(s1)
- #~ ps1.show()
+ switch = ps1.ext[0].body.block_items[0]
+
+ block = switch.stmt.block_items
+ assert_case_node(block[0], '10')
+ self.assertEqual(len(block[0].stmts), 3)
+ assert_case_node(block[1], '20')
+ self.assertEqual(len(block[1].stmts), 0)
+ assert_case_node(block[2], '30')
+ self.assertEqual(len(block[2].stmts), 1)
+ assert_default_node(block[3])
+
+ s2 = r'''
+ int foo(void) {
+ switch (myvar) {
+ default:
+ joe = moe;
+ return 10;
+ case 10:
+ case 20:
+ case 30:
+ case 40:
+ break;
+ }
+ return 0;
+ }
+ '''
+ ps2 = self.parse(s2)
+ switch = ps2.ext[0].body.block_items[0]
+
+ block = switch.stmt.block_items
+ assert_default_node(block[0])
+ self.assertEqual(len(block[0].stmts), 2)
+ assert_case_node(block[1], '10')
+ self.assertEqual(len(block[1].stmts), 0)
+ assert_case_node(block[2], '20')
+ self.assertEqual(len(block[1].stmts), 0)
+ assert_case_node(block[3], '30')
+ self.assertEqual(len(block[1].stmts), 0)
+ assert_case_node(block[4], '40')
+ self.assertEqual(len(block[4].stmts), 1)
def test_for_statement(self):
s2 = r'''