OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

519 lines, 321 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr,
10 IndexExpr, TupleExpr, IntExpr)
11
12from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
13
14from mycpp.crash import catch_errors
15from mycpp.util import join_name, split_py_name
16from mycpp.visitor import SimpleVisitor, T
17from mycpp import util
18from mycpp import pass_state
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25def GetObjectTypeName(t: Type) -> util.SymbolPath:
26 if isinstance(t, Instance):
27 return split_py_name(t.type.fullname)
28
29 elif isinstance(t, UnionType):
30 assert len(t.items) == 2
31 if isinstance(t.items[0], NoneTyp):
32 return GetObjectTypeName(t.items[1])
33
34 return GetObjectTypeName(t.items[0])
35
36 assert False, t
37
38
39class Build(SimpleVisitor):
40
41 def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
42 dot_exprs):
43
44 self.types = types
45 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
46 self.current_statement_id = None
47 self.current_class_name = None
48 self.current_func_node = None
49 self.loop_stack = []
50 self.virtual = virtual
51 self.local_vars = local_vars
52 self.dot_exprs = dot_exprs
53 self.callees = {} # statement object -> SymbolPath of the callee
54
55 def current_cfg(self):
56 if not self.current_func_node:
57 return None
58
59 return self.cfgs[split_py_name(self.current_func_node.fullname)]
60
61 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
62 """
63 Returns the fully qualified name of the callee in the given call
64 expression.
65
66 Member functions are prefixed by the names of the classes that contain
67 them. For example, the name of the callee in the last statement of the
68 snippet below is `module.SomeObject.Foo`.
69
70 x = module.SomeObject()
71 x.Foo()
72
73 Free-functions defined in the local module are referred to by their
74 normal fully qualified names. The function `foo` in a module called
75 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
76 in imported modules are named the same way.
77 """
78
79 if isinstance(o.callee, NameExpr):
80 return split_py_name(o.callee.fullname)
81
82 elif isinstance(o.callee, MemberExpr):
83 if isinstance(o.callee.expr, NameExpr):
84 is_module = isinstance(self.dot_exprs.get(o.callee),
85 pass_state.ModuleMember)
86 if is_module:
87 return split_py_name(
88 o.callee.expr.fullname) + (o.callee.name, )
89
90 elif o.callee.expr.name == 'self':
91 assert self.current_class_name
92 return self.current_class_name + (o.callee.name, )
93
94 else:
95 local_type = None
96 for name, t in self.local_vars.get(self.current_func_node,
97 []):
98 if name == o.callee.expr.name:
99 local_type = t
100 break
101
102 if local_type:
103 if isinstance(local_type, str):
104 return split_py_name(local_type) + (
105 o.callee.name, )
106
107 elif isinstance(local_type, Instance):
108 return split_py_name(
109 local_type.type.fullname) + (o.callee.name, )
110
111 elif isinstance(local_type, UnionType):
112 assert len(local_type.items) == 2
113 return split_py_name(
114 local_type.items[0].type.fullname) + (
115 o.callee.expr.name, )
116
117 else:
118 assert not isinstance(local_type, CallableType)
119 # primitive type or string. don't care.
120 return None
121
122 else:
123 # context or exception handler. probably safe to ignore.
124 return None
125
126 else:
127 t = self.types.get(o.callee.expr)
128 if isinstance(t, Instance):
129 return split_py_name(t.type.fullname) + (o.callee.name, )
130
131 elif isinstance(t, UnionType):
132 assert len(t.items) == 2
133 return split_py_name(
134 t.items[0].type.fullname) + (o.callee.name, )
135
136 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
137 None):
138 return split_py_name(
139 o.callee.expr.fullname) + (o.callee.name, )
140
141 else:
142 # constructors of things that we don't care about.
143 return None
144
145 # Don't currently get here
146 raise AssertionError()
147
148 def get_variable_name(self, expr: Expression) -> Optional[util.SymbolPath]:
149 """
150 To do dataflow analysis we need to track changes to objects, which
151 requires naming them. This function returns the name of the object
152 referred to by the given expression. If the expression doesn't refer to
153 an object or variable it returns None.
154
155 Objects are named slightly differently than they appear in the source
156 code.
157
158 Objects referenced by local variables are referred to by the name of the
159 local. For example, the name of the object in both statements below is
160 `x`.
161
162 x = module.SomeObject()
163 x = None
164
165 Member expressions are named after the parent object's type. For
166 example, the names of the objects in the member assignment statements
167 below are both `module.SomeObject.member_a`. This makes it possible to
168 track data flow across object members without having to track individual
169 heap objects, which would increase the search space for analyses and
170 slow things down.
171
172 x = module.SomeObject()
173 y = module.SomeObject()
174 x.member_a = 'foo'
175 y.member_a = 'bar'
176
177 Index expressions are named after their bases, for the same reasons as
178 member expressions. The coarse-grained precision should lead to an
179 over-approximation of where objects are in use, but should not miss any
180 references. This should be fine for our purposes. In the snippet below
181 the last two assignments are named `x` and `module.SomeObject.a_list`.
182
183 x = [None] # list[Thing]
184 y = module.SomeObject()
185 x[0] = Thing()
186 y.a_list[1] = Blah()
187
188 Index expressions over tuples are treated differently, though. Tuples
189 have a fixed size, tend to be small, and their elements have distinct
190 types. So, each element can be (and probably needs to be) individually
191 named. In the snippet below, the name of the RHS in the second
192 assignment is `t.0`.
193
194 t = (1, 2, 3, 4)
195 x = t[0]
196
197 The examples above all deal with assignments, but these rules apply to
198 any expression that uses an object.
199 """
200 if isinstance(expr,
201 NameExpr) and expr.name not in {'True', 'False', 'None'}:
202 return (expr.name, )
203
204 elif isinstance(expr, MemberExpr):
205 dot_expr = self.dot_exprs[expr]
206 if isinstance(dot_expr, pass_state.ModuleMember):
207 return dot_expr.module_path + (dot_expr.member, )
208
209 elif isinstance(dot_expr, pass_state.HeapObjectMember):
210 return GetObjectTypeName(
211 dot_expr.object_type) + (dot_expr.member, )
212
213 elif isinstance(dot_expr, pass_state.StackObjectMember):
214 return GetObjectTypeName(
215 dot_expr.object_type) + (dot_expr.member, )
216
217 elif isinstance(expr, IndexExpr):
218 if isinstance(self.types[expr.base], TupleType):
219 assert isinstance(expr.index, IntExpr)
220 return self.get_variable_name(expr.base) + (str(expr.index.value),)
221
222 return self.get_variable_name(expr.base)
223
224 return None
225
226 #
227 # COPIED from IRBuilder
228 #
229
230 @overload
231 def accept(self, node: Expression) -> T:
232 ...
233
234 @overload
235 def accept(self, node: Statement) -> None:
236 ...
237
238 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
239 with catch_errors(self.module_path, node.line):
240 if isinstance(node, Expression):
241 try:
242 res = node.accept(self)
243 #res = self.coerce(res, self.node_type(node), node.line)
244
245 # If we hit an error during compilation, we want to
246 # keep trying, so we can produce more error
247 # messages. Generate a temp of the right type to keep
248 # from causing more downstream trouble.
249 except UnsupportedException:
250 res = self.alloc_temp(self.node_type(node))
251 return res
252 else:
253 try:
254 cfg = self.current_cfg()
255 # Most statements have empty visitors because they don't
256 # require any special logic. Create statements for them
257 # here. Don't create statements from blocks to avoid
258 # stuttering.
259 if cfg and not isinstance(node, Block):
260 self.current_statement_id = cfg.AddStatement()
261
262 node.accept(self)
263 except UnsupportedException:
264 pass
265 return None
266
267 # Not in superclasses:
268
269 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
270 if util.ShouldSkipPyFile(o):
271 return
272
273 self.module_path = o.path
274
275 for node in o.defs:
276 # skip module docstring
277 if isinstance(node, ExpressionStmt) and isinstance(
278 node.expr, StrExpr):
279 continue
280 self.accept(node)
281
282 # Statements
283
284 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
285 cfg = self.current_cfg()
286 with pass_state.CfgLoopContext(
287 cfg, entry=self.current_statement_id) as loop:
288 self.accept(o.expr)
289 self.loop_stack.append(loop)
290 self.accept(o.body)
291 self.loop_stack.pop()
292
293 def _handle_switch(self, expr, o, cfg):
294 assert len(o.body.body) == 1, o.body.body
295 if_node = o.body.body[0]
296 assert isinstance(if_node, IfStmt), if_node
297 cases = []
298 default_block = util._collect_cases(self.module_path, if_node, cases)
299 with pass_state.CfgBranchContext(
300 cfg, self.current_statement_id) as branch_ctx:
301 for expr, body in cases:
302 self.accept(expr)
303 assert expr is not None, expr
304 with branch_ctx.AddBranch():
305 self.accept(body)
306
307 if default_block:
308 with branch_ctx.AddBranch():
309 self.accept(default_block)
310
311 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
312 cfg = self.current_cfg()
313 assert len(o.expr) == 1, o.expr
314 expr = o.expr[0]
315 assert isinstance(expr, CallExpr), expr
316 self.accept(expr)
317
318 callee_name = expr.callee.name
319 if callee_name == 'switch':
320 self._handle_switch(expr, o, cfg)
321 elif callee_name == 'str_switch':
322 self._handle_switch(expr, o, cfg)
323 elif callee_name == 'tagswitch':
324 self._handle_switch(expr, o, cfg)
325 else:
326 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
327 self.accept(o.body)
328
329 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
330 if o.name == '__repr__': # Don't translate
331 return
332
333 # For virtual methods, pretend that the method on the base class calls
334 # the same method on every subclass. This way call sites using the
335 # abstract base class will over-approximate the set of call paths they
336 # can take when checking if they can reach MaybeCollect().
337 if self.current_class_name and self.virtual.IsVirtual(
338 self.current_class_name, o.name):
339 key = (self.current_class_name, o.name)
340 base = self.virtual.virtuals[key]
341 if base:
342 sub = join_name(self.current_class_name + (o.name, ),
343 delim='.')
344 base_key = base[0] + (base[1], )
345 cfg = self.cfgs[base_key]
346 cfg.AddFact(0, pass_state.FunctionCall(sub))
347
348 self.current_func_node = o
349 cfg = self.current_cfg()
350 for arg in o.arguments:
351 cfg.AddFact(0, pass_state.Definition((arg.variable.name,)))
352
353 self.accept(o.body)
354 self.current_func_node = None
355 self.current_statement_id = None
356
357 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
358 self.current_class_name = split_py_name(o.fullname)
359 for stmt in o.defs.body:
360 # Ignore things that look like docstrings
361 if (isinstance(stmt, ExpressionStmt) and
362 isinstance(stmt.expr, StrExpr)):
363 continue
364
365 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
366 continue
367
368 self.accept(stmt)
369
370 self.current_class_name = None
371
372 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
373 cfg = self.current_cfg()
374 with pass_state.CfgLoopContext(
375 cfg, entry=self.current_statement_id) as loop:
376 self.accept(o.expr)
377 self.loop_stack.append(loop)
378 self.accept(o.body)
379 self.loop_stack.pop()
380
381 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
382 cfg = self.current_cfg()
383 if cfg:
384 cfg.AddDeadend(self.current_statement_id)
385
386 if o.expr:
387 self.accept(o.expr)
388
389 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
390 cfg = self.current_cfg()
391
392 if util.ShouldVisitIfExpr(o):
393 for expr in o.expr:
394 self.accept(expr)
395
396 with pass_state.CfgBranchContext(
397 cfg, self.current_statement_id) as branch_ctx:
398 if util.ShouldVisitIfBody(o):
399 with branch_ctx.AddBranch():
400 for node in o.body:
401 self.accept(node)
402
403 if util.ShouldVisitElseBody(o):
404 with branch_ctx.AddBranch():
405 self.accept(o.else_body)
406
407 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
408 if len(self.loop_stack):
409 self.loop_stack[-1].AddBreak(self.current_statement_id)
410
411 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
412 if len(self.loop_stack):
413 self.loop_stack[-1].AddContinue(self.current_statement_id)
414
415 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
416 cfg = self.current_cfg()
417 if cfg:
418 cfg.AddDeadend(self.current_statement_id)
419
420 if o.expr:
421 self.accept(o.expr)
422
423 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
424 cfg = self.current_cfg()
425 with pass_state.CfgBranchContext(cfg,
426 self.current_statement_id) as try_ctx:
427 with try_ctx.AddBranch() as try_block:
428 self.accept(o.body)
429
430 for t, v, handler in zip(o.types, o.vars, o.handlers):
431 with try_ctx.AddBranch(try_block.exit):
432 self.accept(handler)
433
434 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
435 cfg = self.current_cfg()
436 if cfg:
437 assert len(o.lvalues) == 1
438 lval = o.lvalues[0]
439 lval_names = []
440 if isinstance(lval, TupleExpr):
441 lval_names.extend(
442 [self.get_variable_name(item) for item in lval.items])
443
444 else:
445 lval_names.append(self.get_variable_name(lval))
446
447 assert lval_names, o
448
449 rval_type = self.types[o.rvalue]
450 rval_names = []
451 if isinstance(o.rvalue, CallExpr):
452 # The RHS is either an object constructor or something that
453 # returns a primitive type (e.g. Tuple[int, int] or str).
454 # XXX: When we add inter-procedural analysis we should treat
455 # these not as definitions but as some new kind of assignment.
456 rval_names = [None for _ in lval_names]
457
458 elif isinstance(o.rvalue, TupleExpr) and len(lval_names) == 1:
459 # We're constructing a tuple. Since tuples have have a fixed
460 # (and usually small) size, we can name each of the
461 # elements.
462 base = lval_names[0]
463 lval_names = [
464 base + (str(i), ) for i in range(len(o.rvalue.items))
465 ]
466 rval_names = [
467 self.get_variable_name(item) for item in o.rvalue.items
468 ]
469
470 elif isinstance(rval_type, TupleType):
471 # We're unpacking a tuple. Like the tuple construction case,
472 # give each element a name.
473 rval_name = self.get_variable_name(o.rvalue)
474 assert rval_name, o.rvalue
475 rval_names = [
476 rval_name + (str(i), ) for i in range(len(lval_names))
477 ]
478
479 else:
480 rval_names = [self.get_variable_name(o.rvalue)]
481
482 assert len(rval_names) == len(lval_names)
483
484 for lhs, rhs in zip(lval_names, rval_names):
485 assert lhs, lval
486 if rhs:
487 # In this case rhe RHS is another variable. Record the
488 # assignment so we can keep track of aliases.
489 cfg.AddFact(self.current_statement_id,
490 pass_state.Assignment(lhs, rhs))
491 else:
492 # In this case the RHS is either some kind of literal (e.g.
493 # [] or 'foo') or a call to an object constructor. Mark this
494 # statement as an (re-)definition of a variable.
495 cfg.AddFact(
496 self.current_statement_id,
497 pass_state.Definition(lhs),
498 )
499
500 for lval in o.lvalues:
501 self.accept(lval)
502
503 self.accept(o.rvalue)
504
505 # Expressions
506
507 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
508 cfg = self.current_cfg()
509 if self.current_func_node:
510 full_callee = self.resolve_callee(o)
511 if full_callee:
512 self.callees[o] = full_callee
513 cfg.AddFact(
514 self.current_statement_id,
515 pass_state.FunctionCall(join_name(full_callee, delim='.')))
516
517 self.accept(o.callee)
518 for arg in o.args:
519 self.accept(arg)