OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

518 lines, 321 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr,
10 IndexExpr, TupleExpr, IntExpr)
11
12from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
13
14from mycpp.crash import catch_errors
15from mycpp.util import join_name, split_py_name
16from mycpp.visitor import SimpleVisitor, T
17from mycpp import util
18from mycpp import pass_state
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25def GetObjectTypeName(t: Type) -> util.SymbolPath:
26 if isinstance(t, Instance):
27 return split_py_name(t.type.fullname)
28
29 elif isinstance(t, UnionType):
30 assert len(t.items) == 2
31 if isinstance(t.items[0], NoneTyp):
32 return GetObjectTypeName(t.items[1])
33
34 return GetObjectTypeName(t.items[0])
35
36 assert False, t
37
38
39class Build(SimpleVisitor):
40
41 def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
42 dot_exprs):
43
44 self.types = types
45 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
46 self.current_statement_id = None
47 self.current_class_name = None
48 self.current_func_node = None
49 self.loop_stack = []
50 self.virtual = virtual
51 self.local_vars = local_vars
52 self.dot_exprs = dot_exprs
53 self.callees = {} # statement object -> SymbolPath of the callee
54
55 def current_cfg(self):
56 if not self.current_func_node:
57 return None
58
59 return self.cfgs[split_py_name(self.current_func_node.fullname)]
60
61 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
62 """
63 Returns the fully qualified name of the callee in the given call
64 expression.
65
66 Member functions are prefixed by the names of the classes that contain
67 them. For example, the name of the callee in the last statement of the
68 snippet below is `module.SomeObject.Foo`.
69
70 x = module.SomeObject()
71 x.Foo()
72
73 Free-functions defined in the local module are referred to by their
74 normal fully qualified names. The function `foo` in a module called
75 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
76 in imported modules are named the same way.
77 """
78
79 if isinstance(o.callee, NameExpr):
80 return split_py_name(o.callee.fullname)
81
82 elif isinstance(o.callee, MemberExpr):
83 if isinstance(o.callee.expr, NameExpr):
84 is_module = isinstance(self.dot_exprs.get(o.callee),
85 pass_state.ModuleMember)
86 if is_module:
87 return split_py_name(
88 o.callee.expr.fullname) + (o.callee.name, )
89
90 elif o.callee.expr.name == 'self':
91 assert self.current_class_name
92 return self.current_class_name + (o.callee.name, )
93
94 else:
95 local_type = None
96 for name, t in self.local_vars.get(self.current_func_node,
97 []):
98 if name == o.callee.expr.name:
99 local_type = t
100 break
101
102 if local_type:
103 if isinstance(local_type, str):
104 return split_py_name(local_type) + (
105 o.callee.name, )
106
107 elif isinstance(local_type, Instance):
108 return split_py_name(
109 local_type.type.fullname) + (o.callee.name, )
110
111 elif isinstance(local_type, UnionType):
112 assert len(local_type.items) == 2
113 return split_py_name(
114 local_type.items[0].type.fullname) + (
115 o.callee.expr.name, )
116
117 else:
118 assert not isinstance(local_type, CallableType)
119 # primitive type or string. don't care.
120 return None
121
122 else:
123 # context or exception handler. probably safe to ignore.
124 return None
125
126 else:
127 t = self.types.get(o.callee.expr)
128 if isinstance(t, Instance):
129 return split_py_name(t.type.fullname) + (o.callee.name, )
130
131 elif isinstance(t, UnionType):
132 assert len(t.items) == 2
133 return split_py_name(
134 t.items[0].type.fullname) + (o.callee.name, )
135
136 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
137 None):
138 return split_py_name(
139 o.callee.expr.fullname) + (o.callee.name, )
140
141 else:
142 # constructors of things that we don't care about.
143 return None
144
145 # Don't currently get here
146 raise AssertionError()
147
148 def get_variable_name(self, expr: Expression) -> Optional[util.SymbolPath]:
149 """
150 To do dataflow analysis we need to track changes to objects, which
151 requires naming them. This function returns the name of the object
152 referred to by the given expression. If the expression doesn't refer to
153 an object or variable it returns None.
154
155 Objects are named slightly differently than they appear in the source
156 code.
157
158 Objects referenced by local variables are referred to by the name of the
159 local. For example, the name of the object in both statements below is
160 `x`.
161
162 x = module.SomeObject()
163 x = None
164
165 Member expressions are named after the parent object's type. For
166 example, the names of the objects in the member assignment statements
167 below are both `module.SomeObject.member_a`. This makes it possible to
168 track data flow across object members without having to track individual
169 heap objects, which would increase the search space for analyses and
170 slow things down.
171
172 x = module.SomeObject()
173 y = module.SomeObject()
174 x.member_a = 'foo'
175 y.member_a = 'bar'
176
177 Index expressions are named after their bases, for the same reasons as
178 member expressions. The coarse-grained precision should lead to an
179 over-approximation of where objects are in use, but should not miss any
180 references. This should be fine for our purposes. In the snippet below
181 the last two assignments are named `x` and `module.SomeObject.a_list`.
182
183 x = [None] # list[Thing]
184 y = module.SomeObject()
185 x[0] = Thing()
186 y.a_list[1] = Blah()
187
188 Index expressions over tuples are treated differently, though. Since
189 they have a fixed size and tend to be small, their elements are
190 individually named. In the snippet below, the name of the RHS in the
191 second assignment is `t.0`.
192
193 t = (1, 2, 3, 4)
194 x = t[0]
195
196 The examples above all deal with assignments, but these rules apply to
197 any expression that uses an object.
198 """
199 if isinstance(expr,
200 NameExpr) and expr.name not in {'True', 'False', 'None'}:
201 return (expr.name, )
202
203 elif isinstance(expr, MemberExpr):
204 dot_expr = self.dot_exprs[expr]
205 if isinstance(dot_expr, pass_state.ModuleMember):
206 return dot_expr.module_path + (dot_expr.member, )
207
208 elif isinstance(dot_expr, pass_state.HeapObjectMember):
209 return GetObjectTypeName(
210 dot_expr.object_type) + (dot_expr.member, )
211
212 elif isinstance(dot_expr, pass_state.StackObjectMember):
213 return GetObjectTypeName(
214 dot_expr.object_type) + (dot_expr.member, )
215
216 elif isinstance(expr, IndexExpr):
217 if isinstance(self.types[expr.base], TupleType):
218 assert isinstance(expr.index, IntExpr)
219 return self.get_variable_name(expr.base) + (str(expr.index.value),)
220
221 return self.get_variable_name(expr.base)
222
223 return None
224
225 #
226 # COPIED from IRBuilder
227 #
228
229 @overload
230 def accept(self, node: Expression) -> T:
231 ...
232
233 @overload
234 def accept(self, node: Statement) -> None:
235 ...
236
237 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
238 with catch_errors(self.module_path, node.line):
239 if isinstance(node, Expression):
240 try:
241 res = node.accept(self)
242 #res = self.coerce(res, self.node_type(node), node.line)
243
244 # If we hit an error during compilation, we want to
245 # keep trying, so we can produce more error
246 # messages. Generate a temp of the right type to keep
247 # from causing more downstream trouble.
248 except UnsupportedException:
249 res = self.alloc_temp(self.node_type(node))
250 return res
251 else:
252 try:
253 cfg = self.current_cfg()
254 # Most statements have empty visitors because they don't
255 # require any special logic. Create statements for them
256 # here. Don't create statements from blocks to avoid
257 # stuttering.
258 if cfg and not isinstance(node, Block):
259 self.current_statement_id = cfg.AddStatement()
260
261 node.accept(self)
262 except UnsupportedException:
263 pass
264 return None
265
266 # Not in superclasses:
267
268 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
269 if util.ShouldSkipPyFile(o):
270 return
271
272 self.module_path = o.path
273
274 for node in o.defs:
275 # skip module docstring
276 if isinstance(node, ExpressionStmt) and isinstance(
277 node.expr, StrExpr):
278 continue
279 self.accept(node)
280
281 # Statements
282
283 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
284 cfg = self.current_cfg()
285 with pass_state.CfgLoopContext(
286 cfg, entry=self.current_statement_id) as loop:
287 self.accept(o.expr)
288 self.loop_stack.append(loop)
289 self.accept(o.body)
290 self.loop_stack.pop()
291
292 def _handle_switch(self, expr, o, cfg):
293 assert len(o.body.body) == 1, o.body.body
294 if_node = o.body.body[0]
295 assert isinstance(if_node, IfStmt), if_node
296 cases = []
297 default_block = util._collect_cases(self.module_path, if_node, cases)
298 with pass_state.CfgBranchContext(
299 cfg, self.current_statement_id) as branch_ctx:
300 for expr, body in cases:
301 self.accept(expr)
302 assert expr is not None, expr
303 with branch_ctx.AddBranch():
304 self.accept(body)
305
306 if default_block:
307 with branch_ctx.AddBranch():
308 self.accept(default_block)
309
310 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
311 cfg = self.current_cfg()
312 assert len(o.expr) == 1, o.expr
313 expr = o.expr[0]
314 assert isinstance(expr, CallExpr), expr
315 self.accept(expr)
316
317 callee_name = expr.callee.name
318 if callee_name == 'switch':
319 self._handle_switch(expr, o, cfg)
320 elif callee_name == 'str_switch':
321 self._handle_switch(expr, o, cfg)
322 elif callee_name == 'tagswitch':
323 self._handle_switch(expr, o, cfg)
324 else:
325 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
326 self.accept(o.body)
327
328 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
329 if o.name == '__repr__': # Don't translate
330 return
331
332 # For virtual methods, pretend that the method on the base class calls
333 # the same method on every subclass. This way call sites using the
334 # abstract base class will over-approximate the set of call paths they
335 # can take when checking if they can reach MaybeCollect().
336 if self.current_class_name and self.virtual.IsVirtual(
337 self.current_class_name, o.name):
338 key = (self.current_class_name, o.name)
339 base = self.virtual.virtuals[key]
340 if base:
341 sub = join_name(self.current_class_name + (o.name, ),
342 delim='.')
343 base_key = base[0] + (base[1], )
344 cfg = self.cfgs[base_key]
345 cfg.AddFact(0, pass_state.FunctionCall(sub))
346
347 self.current_func_node = o
348 cfg = self.current_cfg()
349 for arg in o.arguments:
350 cfg.AddFact(0, pass_state.Definition((arg.variable.name,)))
351
352 self.accept(o.body)
353 self.current_func_node = None
354 self.current_statement_id = None
355
356 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
357 self.current_class_name = split_py_name(o.fullname)
358 for stmt in o.defs.body:
359 # Ignore things that look like docstrings
360 if (isinstance(stmt, ExpressionStmt) and
361 isinstance(stmt.expr, StrExpr)):
362 continue
363
364 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
365 continue
366
367 self.accept(stmt)
368
369 self.current_class_name = None
370
371 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
372 cfg = self.current_cfg()
373 with pass_state.CfgLoopContext(
374 cfg, entry=self.current_statement_id) as loop:
375 self.accept(o.expr)
376 self.loop_stack.append(loop)
377 self.accept(o.body)
378 self.loop_stack.pop()
379
380 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
381 cfg = self.current_cfg()
382 if cfg:
383 cfg.AddDeadend(self.current_statement_id)
384
385 if o.expr:
386 self.accept(o.expr)
387
388 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
389 cfg = self.current_cfg()
390
391 if util.ShouldVisitIfExpr(o):
392 for expr in o.expr:
393 self.accept(expr)
394
395 with pass_state.CfgBranchContext(
396 cfg, self.current_statement_id) as branch_ctx:
397 if util.ShouldVisitIfBody(o):
398 with branch_ctx.AddBranch():
399 for node in o.body:
400 self.accept(node)
401
402 if util.ShouldVisitElseBody(o):
403 with branch_ctx.AddBranch():
404 self.accept(o.else_body)
405
406 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
407 if len(self.loop_stack):
408 self.loop_stack[-1].AddBreak(self.current_statement_id)
409
410 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
411 if len(self.loop_stack):
412 self.loop_stack[-1].AddContinue(self.current_statement_id)
413
414 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
415 cfg = self.current_cfg()
416 if cfg:
417 cfg.AddDeadend(self.current_statement_id)
418
419 if o.expr:
420 self.accept(o.expr)
421
422 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
423 cfg = self.current_cfg()
424 with pass_state.CfgBranchContext(cfg,
425 self.current_statement_id) as try_ctx:
426 with try_ctx.AddBranch() as try_block:
427 self.accept(o.body)
428
429 for t, v, handler in zip(o.types, o.vars, o.handlers):
430 with try_ctx.AddBranch(try_block.exit):
431 self.accept(handler)
432
433 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
434 cfg = self.current_cfg()
435 if cfg:
436 assert len(o.lvalues) == 1
437 lval = o.lvalues[0]
438 lval_names = []
439 if isinstance(lval, TupleExpr):
440 lval_names.extend(
441 [self.get_variable_name(item) for item in lval.items])
442
443 else:
444 lval_names.append(self.get_variable_name(lval))
445
446 assert lval_names, o
447
448 rval_type = self.types[o.rvalue]
449 rval_names = []
450 if isinstance(o.rvalue, CallExpr):
451 # The RHS is either an object constructor or something that
452 # returns a primitive type (e.g. Tuple[int, int] or str).
453 # XXX: When we add inter-procedural analysis we should treat
454 # these not as definitions but as some new kind of assignment.
455 rval_names = [None for _ in lval_names]
456
457 elif isinstance(o.rvalue, TupleExpr) and len(lval_names) == 1:
458 # We're constructing a tuple. Since tuples have have a fixed
459 # (and usually small) size, we can name each of the
460 # elements.
461 base = lval_names[0]
462 lval_names = [
463 base + (str(i), ) for i in range(len(o.rvalue.items))
464 ]
465 rval_names = [
466 self.get_variable_name(item) for item in o.rvalue.items
467 ]
468
469 elif isinstance(rval_type, TupleType):
470 # We're unpacking a tuple. Like the tuple construction case,
471 # give each element a name.
472 rval_name = self.get_variable_name(o.rvalue)
473 assert rval_name, o.rvalue
474 rval_names = [
475 rval_name + (str(i), ) for i in range(len(lval_names))
476 ]
477
478 else:
479 rval_names = [self.get_variable_name(o.rvalue)]
480
481 assert len(rval_names) == len(lval_names)
482
483 for lhs, rhs in zip(lval_names, rval_names):
484 assert lhs, lval
485 if rhs:
486 # In this case rhe RHS is another variable. Record the
487 # assignment so we can keep track of aliases.
488 cfg.AddFact(self.current_statement_id,
489 pass_state.Assignment(lhs, rhs))
490 else:
491 # In this case the RHS is either some kind of literal (e.g.
492 # [] or 'foo') or a call to an object constructor. Mark this
493 # statement as an (re-)definition of a variable.
494 cfg.AddFact(
495 self.current_statement_id,
496 pass_state.Definition(lhs),
497 )
498
499 for lval in o.lvalues:
500 self.accept(lval)
501
502 self.accept(o.rvalue)
503
504 # Expressions
505
506 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
507 cfg = self.current_cfg()
508 if self.current_func_node:
509 full_callee = self.resolve_callee(o)
510 if full_callee:
511 self.callees[o] = full_callee
512 cfg.AddFact(
513 self.current_statement_id,
514 pass_state.FunctionCall(join_name(full_callee, delim='.')))
515
516 self.accept(o.callee)
517 for arg in o.args:
518 self.accept(arg)