OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

509 lines, 319 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr,
10 IndexExpr, TupleExpr)
11
12from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
13
14from mycpp.crash import catch_errors
15from mycpp.util import join_name, split_py_name
16from mycpp.visitor import SimpleVisitor, T
17from mycpp import util
18from mycpp import pass_state
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25def GetObjectTypeName(t: Type) -> util.SymbolPath:
26 if isinstance(t, Instance):
27 return split_py_name(t.type.fullname)
28
29 elif isinstance(t, UnionType):
30 assert len(t.items) == 2
31 if isinstance(t.items[0], NoneTyp):
32 return GetObjectTypeName(t.items[1])
33
34 return GetObjectTypeName(t.items[0])
35
36 assert False, t
37
38
39class Build(SimpleVisitor):
40
41 def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
42 dot_exprs):
43
44 self.types = types
45 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
46 self.current_statement_id = None
47 self.current_class_name = None
48 self.current_func_node = None
49 self.loop_stack = []
50 self.virtual = virtual
51 self.local_vars = local_vars
52 self.dot_exprs = dot_exprs
53 self.callees = {} # statement object -> SymbolPath of the callee
54
55 def current_cfg(self):
56 if not self.current_func_node:
57 return None
58
59 return self.cfgs[split_py_name(self.current_func_node.fullname)]
60
61 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
62 """
63 Returns the fully qualified name of the callee in the given call
64 expression.
65
66 Member functions are prefixed by the names of the classes that contain
67 them. For example, the name of the callee in the last statement of the
68 snippet below is `module.SomeObject.Foo`.
69
70 x = module.SomeObject()
71 x.Foo()
72
73 Free-functions defined in the local module are referred to by their
74 normal fully qualified names. The function `foo` in a module called
75 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
76 in imported modules are named the same way.
77 """
78
79 if isinstance(o.callee, NameExpr):
80 return split_py_name(o.callee.fullname)
81
82 elif isinstance(o.callee, MemberExpr):
83 if isinstance(o.callee.expr, NameExpr):
84 is_module = isinstance(self.dot_exprs.get(o.callee),
85 pass_state.ModuleMember)
86 if is_module:
87 return split_py_name(
88 o.callee.expr.fullname) + (o.callee.name, )
89
90 elif o.callee.expr.name == 'self':
91 assert self.current_class_name
92 return self.current_class_name + (o.callee.name, )
93
94 else:
95 local_type = None
96 for name, t in self.local_vars.get(self.current_func_node,
97 []):
98 if name == o.callee.expr.name:
99 local_type = t
100 break
101
102 if local_type:
103 if isinstance(local_type, str):
104 return split_py_name(local_type) + (
105 o.callee.name, )
106
107 elif isinstance(local_type, Instance):
108 return split_py_name(
109 local_type.type.fullname) + (o.callee.name, )
110
111 elif isinstance(local_type, UnionType):
112 assert len(local_type.items) == 2
113 return split_py_name(
114 local_type.items[0].type.fullname) + (
115 o.callee.expr.name, )
116
117 else:
118 assert not isinstance(local_type, CallableType)
119 # primitive type or string. don't care.
120 return None
121
122 else:
123 # context or exception handler. probably safe to ignore.
124 return None
125
126 else:
127 t = self.types.get(o.callee.expr)
128 if isinstance(t, Instance):
129 return split_py_name(t.type.fullname) + (o.callee.name, )
130
131 elif isinstance(t, UnionType):
132 assert len(t.items) == 2
133 return split_py_name(
134 t.items[0].type.fullname) + (o.callee.name, )
135
136 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
137 None):
138 return split_py_name(
139 o.callee.expr.fullname) + (o.callee.name, )
140
141 else:
142 # constructors of things that we don't care about.
143 return None
144
145 # Don't currently get here
146 raise AssertionError()
147
148 def get_variable_name(self, expr: Expression) -> Optional[util.SymbolPath]:
149 """
150 To do dataflow analysis we need to track changes to objects, which
151 requires naming them. This function returns the name of the object
152 referred to by the given expression. If the expression doesn't refer to
153 an object it returns None.
154
155 Objects are named slightly differently than they appear in the source
156 code.
157
158 Objects referenced by local variables are referred to by the name of the
159 local. For example, the name of the object in both statements below is
160 `x`.
161
162 x = module.SomeObject()
163 x = None
164
165 Member expressions are named after the parent object's type. For
166 example, the names of the objects in the member assignment statements
167 below are both `module.SomeObject.member_a`. This makes it possible to
168 track data flow across object members without having to track individual
169 heap objects, which would increase the search space for analyses and
170 slow things down.
171
172 x = module.SomeObject()
173 y = module.SomeObject()
174 x.member_a = 'foo'
175 y.member_a = 'bar'
176
177 Index expressions are named after their bases, for the same reasons as
178 member expressions. The coarse-grained precision should lead to an
179 over-approximation of where objects are in use, but should not miss any
180 references. This should be fine for our purposes. In the snippet below
181 the last two assignments are named `x` and `module.SomeObject.a_list`.
182
183 x = [None] # list[Thing]
184 y = module.SomeObject()
185 x[0] = Thing()
186 y.a_list[1] = Blah()
187
188 The examples above all deal with assignments, but these rules apply to
189 any expression that uses an object.
190
191 Returns None if expr does not refer to a variable or object.
192 """
193 if isinstance(expr,
194 NameExpr) and expr.name not in {'True', 'False', 'None'}:
195 return (expr.name, )
196
197 elif isinstance(expr, MemberExpr):
198 dot_expr = self.dot_exprs[expr]
199 if isinstance(dot_expr, pass_state.ModuleMember):
200 return dot_expr.module_path + (dot_expr.member, )
201
202 elif isinstance(dot_expr, pass_state.HeapObjectMember):
203 return GetObjectTypeName(
204 dot_expr.object_type) + (dot_expr.member, )
205
206 elif isinstance(dot_expr, pass_state.StackObjectMember):
207 return GetObjectTypeName(
208 dot_expr.object_type) + (dot_expr.member, )
209
210 elif isinstance(expr, IndexExpr):
211 return self.get_variable_name(expr.base)
212
213 return None
214
215 #
216 # COPIED from IRBuilder
217 #
218
219 @overload
220 def accept(self, node: Expression) -> T:
221 ...
222
223 @overload
224 def accept(self, node: Statement) -> None:
225 ...
226
227 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
228 with catch_errors(self.module_path, node.line):
229 if isinstance(node, Expression):
230 try:
231 res = node.accept(self)
232 #res = self.coerce(res, self.node_type(node), node.line)
233
234 # If we hit an error during compilation, we want to
235 # keep trying, so we can produce more error
236 # messages. Generate a temp of the right type to keep
237 # from causing more downstream trouble.
238 except UnsupportedException:
239 res = self.alloc_temp(self.node_type(node))
240 return res
241 else:
242 try:
243 cfg = self.current_cfg()
244 # Most statements have empty visitors because they don't
245 # require any special logic. Create statements for them
246 # here. Don't create statements from blocks to avoid
247 # stuttering.
248 if cfg and not isinstance(node, Block):
249 self.current_statement_id = cfg.AddStatement()
250
251 node.accept(self)
252 except UnsupportedException:
253 pass
254 return None
255
256 # Not in superclasses:
257
258 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
259 if util.ShouldSkipPyFile(o):
260 return
261
262 self.module_path = o.path
263
264 for node in o.defs:
265 # skip module docstring
266 if isinstance(node, ExpressionStmt) and isinstance(
267 node.expr, StrExpr):
268 continue
269 self.accept(node)
270
271 # Statements
272
273 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
274 cfg = self.current_cfg()
275 with pass_state.CfgLoopContext(
276 cfg, entry=self.current_statement_id) as loop:
277 self.accept(o.expr)
278 self.loop_stack.append(loop)
279 self.accept(o.body)
280 self.loop_stack.pop()
281
282 def _handle_switch(self, expr, o, cfg):
283 assert len(o.body.body) == 1, o.body.body
284 if_node = o.body.body[0]
285 assert isinstance(if_node, IfStmt), if_node
286 cases = []
287 default_block = util._collect_cases(self.module_path, if_node, cases)
288 with pass_state.CfgBranchContext(
289 cfg, self.current_statement_id) as branch_ctx:
290 for expr, body in cases:
291 self.accept(expr)
292 assert expr is not None, expr
293 with branch_ctx.AddBranch():
294 self.accept(body)
295
296 if default_block:
297 with branch_ctx.AddBranch():
298 self.accept(default_block)
299
300 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
301 cfg = self.current_cfg()
302 assert len(o.expr) == 1, o.expr
303 expr = o.expr[0]
304 assert isinstance(expr, CallExpr), expr
305 self.accept(expr)
306
307 callee_name = expr.callee.name
308 if callee_name == 'switch':
309 self._handle_switch(expr, o, cfg)
310 elif callee_name == 'str_switch':
311 self._handle_switch(expr, o, cfg)
312 elif callee_name == 'tagswitch':
313 self._handle_switch(expr, o, cfg)
314 else:
315 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
316 self.accept(o.body)
317
318 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
319 if o.name == '__repr__': # Don't translate
320 return
321
322 # For virtual methods, pretend that the method on the base class calls
323 # the same method on every subclass. This way call sites using the
324 # abstract base class will over-approximate the set of call paths they
325 # can take when checking if they can reach MaybeCollect().
326 if self.current_class_name and self.virtual.IsVirtual(
327 self.current_class_name, o.name):
328 key = (self.current_class_name, o.name)
329 base = self.virtual.virtuals[key]
330 if base:
331 sub = join_name(self.current_class_name + (o.name, ),
332 delim='.')
333 base_key = base[0] + (base[1], )
334 cfg = self.cfgs[base_key]
335 cfg.AddFact(0, pass_state.FunctionCall(sub))
336
337 self.current_func_node = o
338 cfg = self.current_cfg()
339 for arg in o.arguments:
340 cfg.AddFact(0, pass_state.Definition((arg.variable.name,)))
341
342 self.accept(o.body)
343 self.current_func_node = None
344 self.current_statement_id = None
345
346 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
347 self.current_class_name = split_py_name(o.fullname)
348 for stmt in o.defs.body:
349 # Ignore things that look like docstrings
350 if (isinstance(stmt, ExpressionStmt) and
351 isinstance(stmt.expr, StrExpr)):
352 continue
353
354 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
355 continue
356
357 self.accept(stmt)
358
359 self.current_class_name = None
360
361 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
362 cfg = self.current_cfg()
363 with pass_state.CfgLoopContext(
364 cfg, entry=self.current_statement_id) as loop:
365 self.accept(o.expr)
366 self.loop_stack.append(loop)
367 self.accept(o.body)
368 self.loop_stack.pop()
369
370 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
371 cfg = self.current_cfg()
372 if cfg:
373 cfg.AddDeadend(self.current_statement_id)
374
375 if o.expr:
376 self.accept(o.expr)
377
378 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
379 if util.MaybeSkipIfStmt(self, o):
380 return
381
382 cfg = self.current_cfg()
383 for expr in o.expr:
384 self.accept(expr)
385
386 with pass_state.CfgBranchContext(
387 cfg, self.current_statement_id) as branch_ctx:
388 with branch_ctx.AddBranch():
389 for node in o.body:
390 self.accept(node)
391
392 if o.else_body:
393 with branch_ctx.AddBranch():
394 self.accept(o.else_body)
395
396 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
397 if len(self.loop_stack):
398 self.loop_stack[-1].AddBreak(self.current_statement_id)
399
400 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
401 if len(self.loop_stack):
402 self.loop_stack[-1].AddContinue(self.current_statement_id)
403
404 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
405 cfg = self.current_cfg()
406 if cfg:
407 cfg.AddDeadend(self.current_statement_id)
408
409 if o.expr:
410 self.accept(o.expr)
411
412 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
413 cfg = self.current_cfg()
414 with pass_state.CfgBranchContext(cfg,
415 self.current_statement_id) as try_ctx:
416 with try_ctx.AddBranch() as try_block:
417 self.accept(o.body)
418
419 for t, v, handler in zip(o.types, o.vars, o.handlers):
420 with try_ctx.AddBranch(try_block.exit):
421 self.accept(handler)
422
423 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
424 cfg = self.current_cfg()
425 if cfg:
426 assert len(o.lvalues) == 1
427 lval = o.lvalues[0]
428 lval_names = []
429 if isinstance(lval, TupleExpr):
430 lval_names.extend(
431 [self.get_variable_name(item) for item in lval.items])
432
433 else:
434 lval_names.append(self.get_variable_name(lval))
435
436 assert lval_names, o
437
438 rval_type = self.types[o.rvalue]
439 rval_names = []
440 if isinstance(o.rvalue, CallExpr):
441 # The RHS is either an object constructor or something that
442 # returns a primitive type (e.g. Tuple[int, int] or str).
443 # XXX: When we add inter-procedural analysis we should treat
444 # these not as definitions but as some new kind of assignment.
445 rval_names = [None for _ in lval_names]
446
447 else:
448 if isinstance(o.rvalue, TupleExpr) and len(lval_names) == 1:
449 # We're constructing a tuple. Since tuples have have a fixed
450 # (and usually small) size, we can name each of the
451 # elements.
452 base = lval_names[0]
453 lval_names = [
454 base + (str(i), ) for i in range(len(o.rvalue.items))
455 ]
456 rval_names = [
457 self.get_variable_name(item) for item in o.rvalue.items
458 ]
459
460 elif isinstance(rval_type, TupleType):
461 # We're unpacking a tuple. Like the tuple construction case,
462 # give each element a name.
463 rval_name = self.get_variable_name(o.rvalue)
464 assert rval_name, o.rvalue
465 rval_names = [
466 rval_name + (str(i), ) for i in range(len(lval_names))
467 ]
468
469 else:
470 rval_names = [self.get_variable_name(o.rvalue)]
471
472 assert len(rval_names) == len(lval_names)
473
474 for lhs, rhs in zip(lval_names, rval_names):
475 assert lhs, lval
476 if rhs:
477 # In this case rhe RHS is another variable. Record the
478 # assignment so we can keep track of aliases.
479 cfg.AddFact(self.current_statement_id,
480 pass_state.Assignment(lhs, rhs))
481 else:
482 # In this case the RHS is either some kind of literal (e.g.
483 # [] or 'foo') or a call to an object constructor. Mark this
484 # statement as an (re-)definition of a variable.
485 cfg.AddFact(
486 self.current_statement_id,
487 pass_state.Definition(lhs),
488 )
489
490 for lval in o.lvalues:
491 self.accept(lval)
492
493 self.accept(o.rvalue)
494
495 # Expressions
496
497 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
498 cfg = self.current_cfg()
499 if self.current_func_node:
500 full_callee = self.resolve_callee(o)
501 if full_callee:
502 self.callees[o] = full_callee
503 cfg.AddFact(
504 self.current_statement_id,
505 pass_state.FunctionCall(join_name(full_callee, delim='.')))
506
507 self.accept(o.callee)
508 for arg in o.args:
509 self.accept(arg)