OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

505 lines, 316 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr,
10 IndexExpr, TupleExpr)
11
12from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
13
14from mycpp.crash import catch_errors
15from mycpp.util import join_name, split_py_name
16from mycpp.visitor import SimpleVisitor, T
17from mycpp import util
18from mycpp import pass_state
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25def GetObjectTypeName(t: Type) -> util.SymbolPath:
26 if isinstance(t, Instance):
27 return split_py_name(t.type.fullname)
28
29 elif isinstance(t, UnionType):
30 assert len(t.items) == 2
31 if isinstance(t.items[0], NoneTyp):
32 return GetObjectTypeName(t.items[1])
33
34 return GetObjectTypeName(t.items[0])
35
36 assert False, t
37
38
39class Build(SimpleVisitor):
40
41 def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
42 dot_exprs):
43
44 self.types = types
45 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
46 self.current_statement_id = None
47 self.current_class_name = None
48 self.current_func_node = None
49 self.loop_stack = []
50 self.virtual = virtual
51 self.local_vars = local_vars
52 self.dot_exprs = dot_exprs
53 self.callees = {} # statement object -> SymbolPath of the callee
54
55 def current_cfg(self):
56 if not self.current_func_node:
57 return None
58
59 return self.cfgs[split_py_name(self.current_func_node.fullname)]
60
61 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
62 """
63 Returns the fully qualified name of the callee in the given call
64 expression.
65
66 Member functions are prefixed by the names of the classes that contain
67 them. For example, the name of the callee in the last statement of the
68 snippet below is `module.SomeObject.Foo`.
69
70 x = module.SomeObject()
71 x.Foo()
72
73 Free-functions defined in the local module are referred to by their
74 normal fully qualified names. The function `foo` in a module called
75 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
76 in imported modules are named the same way.
77 """
78
79 if isinstance(o.callee, NameExpr):
80 return split_py_name(o.callee.fullname)
81
82 elif isinstance(o.callee, MemberExpr):
83 if isinstance(o.callee.expr, NameExpr):
84 is_module = isinstance(self.dot_exprs.get(o.callee),
85 pass_state.ModuleMember)
86 if is_module:
87 return split_py_name(
88 o.callee.expr.fullname) + (o.callee.name, )
89
90 elif o.callee.expr.name == 'self':
91 assert self.current_class_name
92 return self.current_class_name + (o.callee.name, )
93
94 else:
95 local_type = None
96 for name, t in self.local_vars.get(self.current_func_node,
97 []):
98 if name == o.callee.expr.name:
99 local_type = t
100 break
101
102 if local_type:
103 if isinstance(local_type, str):
104 return split_py_name(local_type) + (
105 o.callee.name, )
106
107 elif isinstance(local_type, Instance):
108 return split_py_name(
109 local_type.type.fullname) + (o.callee.name, )
110
111 elif isinstance(local_type, UnionType):
112 assert len(local_type.items) == 2
113 return split_py_name(
114 local_type.items[0].type.fullname) + (
115 o.callee.expr.name, )
116
117 else:
118 assert not isinstance(local_type, CallableType)
119 # primitive type or string. don't care.
120 return None
121
122 else:
123 # context or exception handler. probably safe to ignore.
124 return None
125
126 else:
127 t = self.types.get(o.callee.expr)
128 if isinstance(t, Instance):
129 return split_py_name(t.type.fullname) + (o.callee.name, )
130
131 elif isinstance(t, UnionType):
132 assert len(t.items) == 2
133 return split_py_name(
134 t.items[0].type.fullname) + (o.callee.name, )
135
136 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
137 None):
138 return split_py_name(
139 o.callee.expr.fullname) + (o.callee.name, )
140
141 else:
142 # constructors of things that we don't care about.
143 return None
144
145 # Don't currently get here
146 raise AssertionError()
147
148 def get_variable_name(self, expr: Expression) -> Optional[util.SymbolPath]:
149 """
150 To do dataflow analysis we need to track changes to objects, which
151 requires naming them. This function returns the name of the object
152 referred to by the given expression. If the expression doesn't refer to
153 an object it returns None.
154
155 Objects are named slightly differently than they appear in the source
156 code.
157
158 Objects referenced by local variables are referred to by the name of the
159 local. For example, the name of the object in both statements below is
160 `x`.
161
162 x = module.SomeObject()
163 x = None
164
165 Member expressions are named after the parent object's type. For
166 example, the names of the objects in the member assignment statements
167 below are both `module.SomeObject.member_a`. This makes it possible to
168 track data flow across object members without having to track individual
169 heap objects, which would increase the search space for analyses and
170 slow things down.
171
172 x = module.SomeObject()
173 y = module.SomeObject()
174 x.member_a = 'foo'
175 y.member_a = 'bar'
176
177 Index expressions are named after their bases, for the same reasons as
178 member expressions. The coarse-grained precision should lead to an
179 over-approximation of where objects are in use, but should not miss any
180 references. This should be fine for our purposes. In the snippet below
181 the last two assignments are named `x` and `module.SomeObject.a_list`.
182
183 x = [None] # list[Thing]
184 y = module.SomeObject()
185 x[0] = Thing()
186 y.a_list[1] = Blah()
187
188 The examples above all deal with assignments, but these rules apply to
189 any expression that uses an object.
190
191 Returns None if expr does not refer to a variable or object.
192 """
193 if isinstance(expr,
194 NameExpr) and expr.name not in {'True', 'False', 'None'}:
195 return (expr.name, )
196
197 elif isinstance(expr, MemberExpr):
198 dot_expr = self.dot_exprs[expr]
199 if isinstance(dot_expr, pass_state.ModuleMember):
200 return dot_expr.module_path + (dot_expr.member, )
201
202 elif isinstance(dot_expr, pass_state.HeapObjectMember):
203 return GetObjectTypeName(
204 dot_expr.object_type) + (dot_expr.member, )
205
206 elif isinstance(dot_expr, pass_state.StackObjectMember):
207 return GetObjectTypeName(
208 dot_expr.object_type) + (dot_expr.member, )
209
210 elif isinstance(expr, IndexExpr):
211 return self.get_variable_name(expr.base)
212
213 return None
214
215 #
216 # COPIED from IRBuilder
217 #
218
219 @overload
220 def accept(self, node: Expression) -> T:
221 ...
222
223 @overload
224 def accept(self, node: Statement) -> None:
225 ...
226
227 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
228 with catch_errors(self.module_path, node.line):
229 if isinstance(node, Expression):
230 try:
231 res = node.accept(self)
232 #res = self.coerce(res, self.node_type(node), node.line)
233
234 # If we hit an error during compilation, we want to
235 # keep trying, so we can produce more error
236 # messages. Generate a temp of the right type to keep
237 # from causing more downstream trouble.
238 except UnsupportedException:
239 res = self.alloc_temp(self.node_type(node))
240 return res
241 else:
242 try:
243 cfg = self.current_cfg()
244 # Most statements have empty visitors because they don't
245 # require any special logic. Create statements for them
246 # here. Don't create statements from blocks to avoid
247 # stuttering.
248 if cfg and not isinstance(node, Block):
249 self.current_statement_id = cfg.AddStatement()
250
251 node.accept(self)
252 except UnsupportedException:
253 pass
254 return None
255
256 # Not in superclasses:
257
258 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
259 if util.ShouldSkipPyFile(o):
260 return
261
262 self.module_path = o.path
263
264 for node in o.defs:
265 # skip module docstring
266 if isinstance(node, ExpressionStmt) and isinstance(
267 node.expr, StrExpr):
268 continue
269 self.accept(node)
270
271 # Statements
272
273 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
274 cfg = self.current_cfg()
275 with pass_state.CfgLoopContext(
276 cfg, entry=self.current_statement_id) as loop:
277 self.accept(o.expr)
278 self.loop_stack.append(loop)
279 self.accept(o.body)
280 self.loop_stack.pop()
281
282 def _handle_switch(self, expr, o, cfg):
283 assert len(o.body.body) == 1, o.body.body
284 if_node = o.body.body[0]
285 assert isinstance(if_node, IfStmt), if_node
286 cases = []
287 default_block = util._collect_cases(self.module_path, if_node, cases)
288 with pass_state.CfgBranchContext(
289 cfg, self.current_statement_id) as branch_ctx:
290 for expr, body in cases:
291 self.accept(expr)
292 assert expr is not None, expr
293 with branch_ctx.AddBranch():
294 self.accept(body)
295
296 if default_block:
297 with branch_ctx.AddBranch():
298 self.accept(default_block)
299
300 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
301 cfg = self.current_cfg()
302 assert len(o.expr) == 1, o.expr
303 expr = o.expr[0]
304 assert isinstance(expr, CallExpr), expr
305 self.accept(expr)
306
307 callee_name = expr.callee.name
308 if callee_name == 'switch':
309 self._handle_switch(expr, o, cfg)
310 elif callee_name == 'str_switch':
311 self._handle_switch(expr, o, cfg)
312 elif callee_name == 'tagswitch':
313 self._handle_switch(expr, o, cfg)
314 else:
315 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
316 self.accept(o.body)
317
318 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
319 if o.name == '__repr__': # Don't translate
320 return
321
322 # For virtual methods, pretend that the method on the base class calls
323 # the same method on every subclass. This way call sites using the
324 # abstract base class will over-approximate the set of call paths they
325 # can take when checking if they can reach MaybeCollect().
326 if self.current_class_name and self.virtual.IsVirtual(
327 self.current_class_name, o.name):
328 key = (self.current_class_name, o.name)
329 base = self.virtual.virtuals[key]
330 if base:
331 sub = join_name(self.current_class_name + (o.name, ),
332 delim='.')
333 base_key = base[0] + (base[1], )
334 cfg = self.cfgs[base_key]
335 cfg.AddFact(0, pass_state.FunctionCall(sub))
336
337 self.current_func_node = o
338 self.accept(o.body)
339 self.current_func_node = None
340 self.current_statement_id = None
341
342 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
343 self.current_class_name = split_py_name(o.fullname)
344 for stmt in o.defs.body:
345 # Ignore things that look like docstrings
346 if (isinstance(stmt, ExpressionStmt) and
347 isinstance(stmt.expr, StrExpr)):
348 continue
349
350 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
351 continue
352
353 self.accept(stmt)
354
355 self.current_class_name = None
356
357 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
358 cfg = self.current_cfg()
359 with pass_state.CfgLoopContext(
360 cfg, entry=self.current_statement_id) as loop:
361 self.accept(o.expr)
362 self.loop_stack.append(loop)
363 self.accept(o.body)
364 self.loop_stack.pop()
365
366 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
367 cfg = self.current_cfg()
368 if cfg:
369 cfg.AddDeadend(self.current_statement_id)
370
371 if o.expr:
372 self.accept(o.expr)
373
374 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
375 if util.MaybeSkipIfStmt(self, o):
376 return
377
378 cfg = self.current_cfg()
379 for expr in o.expr:
380 self.accept(expr)
381
382 with pass_state.CfgBranchContext(
383 cfg, self.current_statement_id) as branch_ctx:
384 with branch_ctx.AddBranch():
385 for node in o.body:
386 self.accept(node)
387
388 if o.else_body:
389 with branch_ctx.AddBranch():
390 self.accept(o.else_body)
391
392 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
393 if len(self.loop_stack):
394 self.loop_stack[-1].AddBreak(self.current_statement_id)
395
396 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
397 if len(self.loop_stack):
398 self.loop_stack[-1].AddContinue(self.current_statement_id)
399
400 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
401 cfg = self.current_cfg()
402 if cfg:
403 cfg.AddDeadend(self.current_statement_id)
404
405 if o.expr:
406 self.accept(o.expr)
407
408 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
409 cfg = self.current_cfg()
410 with pass_state.CfgBranchContext(cfg,
411 self.current_statement_id) as try_ctx:
412 with try_ctx.AddBranch() as try_block:
413 self.accept(o.body)
414
415 for t, v, handler in zip(o.types, o.vars, o.handlers):
416 with try_ctx.AddBranch(try_block.exit):
417 self.accept(handler)
418
419 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
420 cfg = self.current_cfg()
421 if cfg:
422 assert len(o.lvalues) == 1
423 lval = o.lvalues[0]
424 lval_names = []
425 if isinstance(lval, TupleExpr):
426 lval_names.extend(
427 [self.get_variable_name(item) for item in lval.items])
428
429 else:
430 lval_names.append(self.get_variable_name(lval))
431
432 assert lval_names, o
433
434 rval_type = self.types[o.rvalue]
435 rval_names = []
436 if isinstance(o.rvalue, CallExpr):
437 # The RHS is either an object constructor or something that
438 # returns a primitive type (e.g. Tuple[int, int] or str).
439 # XXX: When we add inter-procedural analysis we should treat
440 # these not as definitions but as some new kind of assignment.
441 rval_names = [None for _ in lval_names]
442
443 else:
444 if isinstance(o.rvalue, TupleExpr) and len(lval_names) == 1:
445 # We're constructing a tuple. Since tuples have have a fixed
446 # (and usually small) size, we can name each of the
447 # elements.
448 base = lval_names[0]
449 lval_names = [
450 base + (str(i), ) for i in range(len(o.rvalue.items))
451 ]
452 rval_names = [
453 self.get_variable_name(item) for item in o.rvalue.items
454 ]
455
456 elif isinstance(rval_type, TupleType):
457 # We're unpacking a tuple. Like the tuple construction case,
458 # give each element a name.
459 rval_name = self.get_variable_name(o.rvalue)
460 assert rval_name, o.rvalue
461 rval_names = [
462 rval_name + (str(i), ) for i in range(len(lval_names))
463 ]
464
465 else:
466 rval_names = [self.get_variable_name(o.rvalue)]
467
468 assert len(rval_names) == len(lval_names)
469
470 for lhs, rhs in zip(lval_names, rval_names):
471 assert lhs, lval
472 if rhs:
473 # In this case rhe RHS is another variable. Record the
474 # assignment so we can keep track of aliases.
475 cfg.AddFact(self.current_statement_id,
476 pass_state.Assignment(lhs, rhs))
477 else:
478 # In this case the RHS is either some kind of literal
479 # (e.g. [] or 'foo'). Mark this statement as an
480 # (re-)definition of a variable.
481 cfg.AddFact(
482 self.current_statement_id,
483 pass_state.Definition(lhs),
484 )
485
486 for lval in o.lvalues:
487 self.accept(lval)
488
489 self.accept(o.rvalue)
490
491 # Expressions
492
493 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
494 cfg = self.current_cfg()
495 if self.current_func_node:
496 full_callee = self.resolve_callee(o)
497 if full_callee:
498 self.callees[o] = full_callee
499 cfg.AddFact(
500 self.current_statement_id,
501 pass_state.FunctionCall(join_name(full_callee, delim='.')))
502
503 self.accept(o.callee)
504 for arg in o.args:
505 self.accept(arg)