OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

238 lines, 167 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4from typing import overload, Union, Optional, Dict
5
6import mypy
7from mypy.visitor import ExpressionVisitor, StatementVisitor
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 ForStmt, WhileStmt, CallExpr, FuncDef, IfStmt)
10
11from mypy.types import Type
12
13from mycpp.crash import catch_errors
14from mycpp.util import split_py_name
15from mycpp import util
16from mycpp import pass_state
17
18T = None # TODO: Make it type check?
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25class Build(ExpressionVisitor[T], StatementVisitor[None]):
26
27 def __init__(self, types: Dict[Expression, Type]):
28
29 self.types = types
30 self.cfgs = {}
31 self.current_statement_id = None
32 self.current_func_node = None
33 self.loop_stack = []
34
35 def current_cfg(self):
36 if not self.current_func_node:
37 return None
38
39 return self.cfgs.get(split_py_name(self.current_func_node.fullname))
40
41 #
42 # COPIED from IRBuilder
43 #
44
45 @overload
46 def accept(self, node: Expression) -> T:
47 ...
48
49 @overload
50 def accept(self, node: Statement) -> None:
51 ...
52
53 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
54 with catch_errors(self.module_path, node.line):
55 if isinstance(node, Expression):
56 try:
57 res = node.accept(self)
58 #res = self.coerce(res, self.node_type(node), node.line)
59
60 # If we hit an error during compilation, we want to
61 # keep trying, so we can produce more error
62 # messages. Generate a temp of the right type to keep
63 # from causing more downstream trouble.
64 except UnsupportedException:
65 res = self.alloc_temp(self.node_type(node))
66 return res
67 else:
68 try:
69 cfg = self.current_cfg()
70 # Most statements have empty visitors because they don't
71 # require any special logic. Create statements for them
72 # here. Blocks and loops are handled by their visitors.
73 if (cfg and not isinstance(node, Block) and
74 not isinstance(node, ForStmt) and
75 not isinstance(node, WhileStmt)):
76 self.current_statement_id = cfg.AddStatement()
77
78 node.accept(self)
79 except UnsupportedException:
80 pass
81 return None
82
83 # Not in superclasses:
84
85 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
86 if util.ShouldSkipPyFile(o):
87 return
88
89 self.module_path = o.path
90
91 for node in o.defs:
92 # skip module docstring
93 if isinstance(node, ExpressionStmt) and isinstance(
94 node.expr, StrExpr):
95 continue
96 self.accept(node)
97
98 # LITERALS
99
100 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
101 cfg = self.current_cfg()
102 if not cfg:
103 return
104
105 with pass_state.CfgLoopContext(cfg) as loop:
106 self.loop_stack.append(loop)
107 self.accept(o.body)
108 self.loop_stack.pop()
109
110 def _handle_switch(self, expr, o, cfg):
111 assert len(o.body.body) == 1, o.body.body
112 if_node = o.body.body[0]
113 assert isinstance(if_node, IfStmt), if_node
114 cases = []
115 default_block = util._collect_cases(self.module_path, if_node, cases)
116 with pass_state.CfgBranchContext(
117 cfg, self.current_statement_id) as branch_ctx:
118 for expr, body in cases:
119 assert expr is not None, expr
120 with branch_ctx.AddBranch():
121 self.accept(body)
122
123 if default_block:
124 with branch_ctx.AddBranch():
125 self.accept(default_block)
126
127 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
128 cfg = self.current_cfg()
129 if not cfg:
130 return
131
132 assert len(o.expr) == 1, o.expr
133 expr = o.expr[0]
134 assert isinstance(expr, CallExpr), expr
135
136 callee_name = expr.callee.name
137 if callee_name == 'switch':
138 self._handle_switch(expr, o, cfg)
139 elif callee_name == 'str_switch':
140 self._handle_switch(expr, o, cfg)
141 elif callee_name == 'tagswitch':
142 self._handle_switch(expr, o, cfg)
143 else:
144 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
145 for stmt in o.body.body:
146 self.accept(stmt)
147
148 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
149 if o.name == '__repr__': # Don't translate
150 return
151
152 self.cfgs[split_py_name(o.fullname)] = pass_state.ControlFlowGraph()
153 self.current_func_node = o
154 self.accept(o.body)
155 self.current_func_node = None
156
157 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
158 for stmt in o.defs.body:
159 # Ignore things that look like docstrings
160 if (isinstance(stmt, ExpressionStmt) and
161 isinstance(stmt.expr, StrExpr)):
162 continue
163
164 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
165 continue
166
167 self.accept(stmt)
168
169 # Statements
170
171 def visit_block(self, block: 'mypy.nodes.Block') -> T:
172 for stmt in block.body:
173 # Ignore things that look like docstrings
174 if (isinstance(stmt, ExpressionStmt) and
175 isinstance(stmt.expr, StrExpr)):
176 continue
177
178 self.accept(stmt)
179
180 def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
181 self.accept(o.expr)
182
183 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
184 cfg = self.current_cfg()
185 if not cfg:
186 return
187
188 with pass_state.CfgLoopContext(cfg) as loop:
189 self.loop_stack.append(loop)
190 self.accept(o.body)
191 self.loop_stack.pop()
192
193 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
194 cfg = self.current_cfg()
195 if cfg:
196 cfg.AddDeadend(self.current_statement_id)
197
198 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
199 cfg = self.current_cfg()
200 if not cfg:
201 return
202
203 with pass_state.CfgBranchContext(
204 cfg, self.current_statement_id) as branch_ctx:
205 with branch_ctx.AddBranch():
206 for node in o.body:
207 self.accept(node)
208
209 if o.else_body:
210 with branch_ctx.AddBranch():
211 self.accept(o.else_body)
212
213 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
214 if len(self.loop_stack):
215 self.loop_stack[-1].AddBreak(self.current_statement_id)
216
217 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
218 if len(self.loop_stack):
219 self.loop_stack[-1].AddContinue(self.current_statement_id)
220
221 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
222 cfg = self.current_cfg()
223 if cfg:
224 cfg.AddDeadend(self.current_statement_id)
225
226 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
227 cfg = self.current_cfg()
228 if not cfg:
229 return
230
231 with pass_state.CfgBranchContext(cfg,
232 self.current_statement_id) as try_ctx:
233 with try_ctx.AddBranch() as try_block:
234 self.accept(o.body)
235
236 for t, v, handler in zip(o.types, o.vars, o.handlers):
237 with try_ctx.AddBranch(try_block.exit):
238 self.accept(handler)