OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

348 lines, 240 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr)
10
11from mypy.types import CallableType, Instance, Type, UnionType
12
13from mycpp.crash import catch_errors
14from mycpp.util import join_name, split_py_name
15from mycpp.visitor import SimpleVisitor, T
16from mycpp import util
17from mycpp import pass_state
18
19
20class UnsupportedException(Exception):
21 pass
22
23
24class Build(SimpleVisitor):
25
26 def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
27 dot_exprs):
28
29 self.types = types
30 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
31 self.current_statement_id = None
32 self.current_class_name = None
33 self.current_func_node = None
34 self.loop_stack = []
35 self.virtual = virtual
36 self.local_vars = local_vars
37 self.dot_exprs = dot_exprs
38 self.callees = {} # statement object -> SymbolPath of the callee
39
40 def current_cfg(self):
41 if not self.current_func_node:
42 return None
43
44 return self.cfgs[split_py_name(self.current_func_node.fullname)]
45
46 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
47 """
48 Returns the fully qualified name of the callee in the given call
49 expression.
50
51 Member functions are prefixed by the names of the classes that contain
52 them. For example, the name of the callee in the last statement of the
53 snippet below is `module.SomeObject.Foo`.
54
55 x = module.SomeObject()
56 x.Foo()
57
58 Free-functions defined in the local module are referred to by their
59 normal fully qualified names. The function `foo` in a module called
60 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
61 in imported modules are named the same way.
62 """
63
64 if isinstance(o.callee, NameExpr):
65 return split_py_name(o.callee.fullname)
66
67 elif isinstance(o.callee, MemberExpr):
68 if isinstance(o.callee.expr, NameExpr):
69 is_module = isinstance(self.dot_exprs.get(o.callee),
70 pass_state.ModuleMember)
71 if is_module:
72 return split_py_name(
73 o.callee.expr.fullname) + (o.callee.name, )
74
75 elif o.callee.expr.name == 'self':
76 assert self.current_class_name
77 return self.current_class_name + (o.callee.name, )
78
79 else:
80 local_type = None
81 for name, t in self.local_vars.get(self.current_func_node,
82 []):
83 if name == o.callee.expr.name:
84 local_type = t
85 break
86
87 if local_type:
88 if isinstance(local_type, str):
89 return split_py_name(local_type) + (
90 o.callee.name, )
91
92 elif isinstance(local_type, Instance):
93 return split_py_name(
94 local_type.type.fullname) + (o.callee.name, )
95
96 elif isinstance(local_type, UnionType):
97 assert len(local_type.items) == 2
98 return split_py_name(
99 local_type.items[0].type.fullname) + (
100 o.callee.expr.name, )
101
102 else:
103 assert not isinstance(local_type, CallableType)
104 # primitive type or string. don't care.
105 return None
106
107 else:
108 # context or exception handler. probably safe to ignore.
109 return None
110
111 else:
112 t = self.types.get(o.callee.expr)
113 if isinstance(t, Instance):
114 return split_py_name(t.type.fullname) + (o.callee.name, )
115
116 elif isinstance(t, UnionType):
117 assert len(t.items) == 2
118 return split_py_name(
119 t.items[0].type.fullname) + (o.callee.name, )
120
121 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
122 None):
123 return split_py_name(
124 o.callee.expr.fullname) + (o.callee.name, )
125
126 else:
127 # constructors of things that we don't care about.
128 return None
129
130 # Don't currently get here
131 raise AssertionError()
132
133 #
134 # COPIED from IRBuilder
135 #
136
137 @overload
138 def accept(self, node: Expression) -> T:
139 ...
140
141 @overload
142 def accept(self, node: Statement) -> None:
143 ...
144
145 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
146 with catch_errors(self.module_path, node.line):
147 if isinstance(node, Expression):
148 try:
149 res = node.accept(self)
150 #res = self.coerce(res, self.node_type(node), node.line)
151
152 # If we hit an error during compilation, we want to
153 # keep trying, so we can produce more error
154 # messages. Generate a temp of the right type to keep
155 # from causing more downstream trouble.
156 except UnsupportedException:
157 res = self.alloc_temp(self.node_type(node))
158 return res
159 else:
160 try:
161 cfg = self.current_cfg()
162 # Most statements have empty visitors because they don't
163 # require any special logic. Create statements for them
164 # here. Don't create statements from blocks to avoid
165 # stuttering.
166 if cfg and not isinstance(node, Block):
167 self.current_statement_id = cfg.AddStatement()
168
169 node.accept(self)
170 except UnsupportedException:
171 pass
172 return None
173
174 # Not in superclasses:
175
176 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
177 if util.ShouldSkipPyFile(o):
178 return
179
180 self.module_path = o.path
181
182 for node in o.defs:
183 # skip module docstring
184 if isinstance(node, ExpressionStmt) and isinstance(
185 node.expr, StrExpr):
186 continue
187 self.accept(node)
188
189 # Statements
190
191 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
192 cfg = self.current_cfg()
193 with pass_state.CfgLoopContext(
194 cfg, entry=self.current_statement_id) as loop:
195 self.accept(o.expr)
196 self.loop_stack.append(loop)
197 self.accept(o.body)
198 self.loop_stack.pop()
199
200 def _handle_switch(self, expr, o, cfg):
201 assert len(o.body.body) == 1, o.body.body
202 if_node = o.body.body[0]
203 assert isinstance(if_node, IfStmt), if_node
204 cases = []
205 default_block = util._collect_cases(self.module_path, if_node, cases)
206 with pass_state.CfgBranchContext(
207 cfg, self.current_statement_id) as branch_ctx:
208 for expr, body in cases:
209 self.accept(expr)
210 assert expr is not None, expr
211 with branch_ctx.AddBranch():
212 self.accept(body)
213
214 if default_block:
215 with branch_ctx.AddBranch():
216 self.accept(default_block)
217
218 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
219 cfg = self.current_cfg()
220 assert len(o.expr) == 1, o.expr
221 expr = o.expr[0]
222 assert isinstance(expr, CallExpr), expr
223 self.accept(expr)
224
225 callee_name = expr.callee.name
226 if callee_name == 'switch':
227 self._handle_switch(expr, o, cfg)
228 elif callee_name == 'str_switch':
229 self._handle_switch(expr, o, cfg)
230 elif callee_name == 'tagswitch':
231 self._handle_switch(expr, o, cfg)
232 else:
233 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
234 self.accept(o.body)
235
236 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
237 if o.name == '__repr__': # Don't translate
238 return
239
240 # For virtual methods, pretend that the method on the base class calls
241 # the same method on every subclass. This way call sites using the
242 # abstract base class will over-approximate the set of call paths they
243 # can take when checking if they can reach MaybeCollect().
244 if self.current_class_name and self.virtual.IsVirtual(
245 self.current_class_name, o.name):
246 key = (self.current_class_name, o.name)
247 base = self.virtual.virtuals[key]
248 if base:
249 sub = join_name(self.current_class_name + (o.name, ),
250 delim='.')
251 base_key = base[0] + (base[1], )
252 cfg = self.cfgs[base_key]
253 cfg.AddFact(0, pass_state.FunctionCall(sub))
254
255 self.current_func_node = o
256 self.accept(o.body)
257 self.current_func_node = None
258 self.current_statement_id = None
259
260 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
261 self.current_class_name = split_py_name(o.fullname)
262 for stmt in o.defs.body:
263 # Ignore things that look like docstrings
264 if (isinstance(stmt, ExpressionStmt) and
265 isinstance(stmt.expr, StrExpr)):
266 continue
267
268 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
269 continue
270
271 self.accept(stmt)
272
273 self.current_class_name = None
274
275 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
276 cfg = self.current_cfg()
277 with pass_state.CfgLoopContext(
278 cfg, entry=self.current_statement_id) as loop:
279 self.accept(o.expr)
280 self.loop_stack.append(loop)
281 self.accept(o.body)
282 self.loop_stack.pop()
283
284 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
285 cfg = self.current_cfg()
286 if cfg:
287 cfg.AddDeadend(self.current_statement_id)
288
289 if o.expr:
290 self.accept(o.expr)
291
292 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
293 cfg = self.current_cfg()
294 for expr in o.expr:
295 self.accept(expr)
296
297 with pass_state.CfgBranchContext(
298 cfg, self.current_statement_id) as branch_ctx:
299 with branch_ctx.AddBranch():
300 for node in o.body:
301 self.accept(node)
302
303 if o.else_body:
304 with branch_ctx.AddBranch():
305 self.accept(o.else_body)
306
307 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
308 if len(self.loop_stack):
309 self.loop_stack[-1].AddBreak(self.current_statement_id)
310
311 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
312 if len(self.loop_stack):
313 self.loop_stack[-1].AddContinue(self.current_statement_id)
314
315 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
316 cfg = self.current_cfg()
317 if cfg:
318 cfg.AddDeadend(self.current_statement_id)
319
320 if o.expr:
321 self.accept(o.expr)
322
323 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
324 cfg = self.current_cfg()
325 with pass_state.CfgBranchContext(cfg,
326 self.current_statement_id) as try_ctx:
327 with try_ctx.AddBranch() as try_block:
328 self.accept(o.body)
329
330 for t, v, handler in zip(o.types, o.vars, o.handlers):
331 with try_ctx.AddBranch(try_block.exit):
332 self.accept(handler)
333
334 # Expressions
335
336 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
337 cfg = self.current_cfg()
338 if self.current_func_node:
339 full_callee = self.resolve_callee(o)
340 if full_callee:
341 self.callees[o] = full_callee
342 cfg.AddFact(
343 self.current_statement_id,
344 pass_state.FunctionCall(join_name(full_callee, delim='.')))
345
346 self.accept(o.callee)
347 for arg in o.args:
348 self.accept(arg)