| 1 | """
 | 
| 2 | control_flow_pass.py - AST pass that builds a control flow graph.
 | 
| 3 | """
 | 
| 4 | import collections
 | 
| 5 | from typing import overload, Union, Optional, Dict
 | 
| 6 | 
 | 
| 7 | import mypy
 | 
| 8 | from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
 | 
| 9 |                         CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr)
 | 
| 10 | 
 | 
| 11 | from mypy.types import CallableType, Instance, Type, UnionType
 | 
| 12 | 
 | 
| 13 | from mycpp.crash import catch_errors
 | 
| 14 | from mycpp.util import join_name, split_py_name
 | 
| 15 | from mycpp.visitor import SimpleVisitor, T
 | 
| 16 | from mycpp import util
 | 
| 17 | from mycpp import pass_state
 | 
| 18 | 
 | 
| 19 | 
 | 
| 20 | class UnsupportedException(Exception):
 | 
| 21 |     pass
 | 
| 22 | 
 | 
| 23 | 
 | 
| 24 | class Build(SimpleVisitor):
 | 
| 25 | 
 | 
| 26 |     def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
 | 
| 27 |                  dot_exprs):
 | 
| 28 | 
 | 
| 29 |         self.types = types
 | 
| 30 |         self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
 | 
| 31 |         self.current_statement_id = None
 | 
| 32 |         self.current_class_name = None
 | 
| 33 |         self.current_func_node = None
 | 
| 34 |         self.loop_stack = []
 | 
| 35 |         self.virtual = virtual
 | 
| 36 |         self.local_vars = local_vars
 | 
| 37 |         self.dot_exprs = dot_exprs
 | 
| 38 |         self.callees = {} # statement object -> SymbolPath of the callee
 | 
| 39 | 
 | 
| 40 |     def current_cfg(self):
 | 
| 41 |         if not self.current_func_node:
 | 
| 42 |             return None
 | 
| 43 | 
 | 
| 44 |         return self.cfgs[split_py_name(self.current_func_node.fullname)]
 | 
| 45 | 
 | 
| 46 |     def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
 | 
| 47 |         """
 | 
| 48 |         Returns the fully qualified name of the callee in the given call
 | 
| 49 |         expression.
 | 
| 50 | 
 | 
| 51 |         Member functions are prefixed by the names of the classes that contain
 | 
| 52 |         them. For example, the name of the callee in the last statement of the
 | 
| 53 |         snippet below is `module.SomeObject.Foo`.
 | 
| 54 | 
 | 
| 55 |             x = module.SomeObject()
 | 
| 56 |             x.Foo()
 | 
| 57 | 
 | 
| 58 |         Free-functions defined in the local module are referred to by their
 | 
| 59 |         normal fully qualified names. The function `foo` in a module called
 | 
| 60 |         `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
 | 
| 61 |         in imported modules are named the same way.
 | 
| 62 |         """
 | 
| 63 | 
 | 
| 64 |         if isinstance(o.callee, NameExpr):
 | 
| 65 |             return split_py_name(o.callee.fullname)
 | 
| 66 | 
 | 
| 67 |         elif isinstance(o.callee, MemberExpr):
 | 
| 68 |             if isinstance(o.callee.expr, NameExpr):
 | 
| 69 |                 is_module = isinstance(self.dot_exprs.get(o.callee),
 | 
| 70 |                                        pass_state.ModuleMember)
 | 
| 71 |                 if is_module:
 | 
| 72 |                     return split_py_name(
 | 
| 73 |                         o.callee.expr.fullname) + (o.callee.name, )
 | 
| 74 | 
 | 
| 75 |                 elif o.callee.expr.name == 'self':
 | 
| 76 |                     assert self.current_class_name
 | 
| 77 |                     return self.current_class_name + (o.callee.name, )
 | 
| 78 | 
 | 
| 79 |                 else:
 | 
| 80 |                     local_type = None
 | 
| 81 |                     for name, t in self.local_vars.get(self.current_func_node,
 | 
| 82 |                                                        []):
 | 
| 83 |                         if name == o.callee.expr.name:
 | 
| 84 |                             local_type = t
 | 
| 85 |                             break
 | 
| 86 | 
 | 
| 87 |                     if local_type:
 | 
| 88 |                         if isinstance(local_type, str):
 | 
| 89 |                             return split_py_name(local_type) + (
 | 
| 90 |                                 o.callee.name, )
 | 
| 91 | 
 | 
| 92 |                         elif isinstance(local_type, Instance):
 | 
| 93 |                             return split_py_name(
 | 
| 94 |                                 local_type.type.fullname) + (o.callee.name, )
 | 
| 95 | 
 | 
| 96 |                         elif isinstance(local_type, UnionType):
 | 
| 97 |                             assert len(local_type.items) == 2
 | 
| 98 |                             return split_py_name(
 | 
| 99 |                                 local_type.items[0].type.fullname) + (
 | 
| 100 |                                     o.callee.expr.name, )
 | 
| 101 | 
 | 
| 102 |                         else:
 | 
| 103 |                             assert not isinstance(local_type, CallableType)
 | 
| 104 |                             # primitive type or string. don't care.
 | 
| 105 |                             return None
 | 
| 106 | 
 | 
| 107 |                     else:
 | 
| 108 |                         # context or exception handler. probably safe to ignore.
 | 
| 109 |                         return None
 | 
| 110 | 
 | 
| 111 |             else:
 | 
| 112 |                 t = self.types.get(o.callee.expr)
 | 
| 113 |                 if isinstance(t, Instance):
 | 
| 114 |                     return split_py_name(t.type.fullname) + (o.callee.name, )
 | 
| 115 | 
 | 
| 116 |                 elif isinstance(t, UnionType):
 | 
| 117 |                     assert len(t.items) == 2
 | 
| 118 |                     return split_py_name(
 | 
| 119 |                         t.items[0].type.fullname) + (o.callee.name, )
 | 
| 120 | 
 | 
| 121 |                 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
 | 
| 122 |                                                None):
 | 
| 123 |                     return split_py_name(
 | 
| 124 |                         o.callee.expr.fullname) + (o.callee.name, )
 | 
| 125 | 
 | 
| 126 |                 else:
 | 
| 127 |                     # constructors of things that we don't care about.
 | 
| 128 |                     return None
 | 
| 129 | 
 | 
| 130 |         # Don't currently get here
 | 
| 131 |         raise AssertionError()
 | 
| 132 | 
 | 
| 133 |     #
 | 
| 134 |     # COPIED from IRBuilder
 | 
| 135 |     #
 | 
| 136 | 
 | 
| 137 |     @overload
 | 
| 138 |     def accept(self, node: Expression) -> T:
 | 
| 139 |         ...
 | 
| 140 | 
 | 
| 141 |     @overload
 | 
| 142 |     def accept(self, node: Statement) -> None:
 | 
| 143 |         ...
 | 
| 144 | 
 | 
| 145 |     def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
 | 
| 146 |         with catch_errors(self.module_path, node.line):
 | 
| 147 |             if isinstance(node, Expression):
 | 
| 148 |                 try:
 | 
| 149 |                     res = node.accept(self)
 | 
| 150 |                     #res = self.coerce(res, self.node_type(node), node.line)
 | 
| 151 | 
 | 
| 152 |                 # If we hit an error during compilation, we want to
 | 
| 153 |                 # keep trying, so we can produce more error
 | 
| 154 |                 # messages. Generate a temp of the right type to keep
 | 
| 155 |                 # from causing more downstream trouble.
 | 
| 156 |                 except UnsupportedException:
 | 
| 157 |                     res = self.alloc_temp(self.node_type(node))
 | 
| 158 |                 return res
 | 
| 159 |             else:
 | 
| 160 |                 try:
 | 
| 161 |                     cfg = self.current_cfg()
 | 
| 162 |                     # Most statements have empty visitors because they don't
 | 
| 163 |                     # require any special logic. Create statements for them
 | 
| 164 |                     # here. Don't create statements from blocks to avoid
 | 
| 165 |                     # stuttering.
 | 
| 166 |                     if cfg and not isinstance(node, Block):
 | 
| 167 |                         self.current_statement_id = cfg.AddStatement()
 | 
| 168 | 
 | 
| 169 |                     node.accept(self)
 | 
| 170 |                 except UnsupportedException:
 | 
| 171 |                     pass
 | 
| 172 |                 return None
 | 
| 173 | 
 | 
| 174 |     # Not in superclasses:
 | 
| 175 | 
 | 
| 176 |     def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
 | 
| 177 |         if util.ShouldSkipPyFile(o):
 | 
| 178 |             return
 | 
| 179 | 
 | 
| 180 |         self.module_path = o.path
 | 
| 181 | 
 | 
| 182 |         for node in o.defs:
 | 
| 183 |             # skip module docstring
 | 
| 184 |             if isinstance(node, ExpressionStmt) and isinstance(
 | 
| 185 |                     node.expr, StrExpr):
 | 
| 186 |                 continue
 | 
| 187 |             self.accept(node)
 | 
| 188 | 
 | 
| 189 |     # Statements
 | 
| 190 | 
 | 
| 191 |     def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
 | 
| 192 |         cfg = self.current_cfg()
 | 
| 193 |         with pass_state.CfgLoopContext(
 | 
| 194 |                 cfg, entry=self.current_statement_id) as loop:
 | 
| 195 |             self.accept(o.expr)
 | 
| 196 |             self.loop_stack.append(loop)
 | 
| 197 |             self.accept(o.body)
 | 
| 198 |             self.loop_stack.pop()
 | 
| 199 | 
 | 
| 200 |     def _handle_switch(self, expr, o, cfg):
 | 
| 201 |         assert len(o.body.body) == 1, o.body.body
 | 
| 202 |         if_node = o.body.body[0]
 | 
| 203 |         assert isinstance(if_node, IfStmt), if_node
 | 
| 204 |         cases = []
 | 
| 205 |         default_block = util._collect_cases(self.module_path, if_node, cases)
 | 
| 206 |         with pass_state.CfgBranchContext(
 | 
| 207 |                 cfg, self.current_statement_id) as branch_ctx:
 | 
| 208 |             for expr, body in cases:
 | 
| 209 |                 self.accept(expr)
 | 
| 210 |                 assert expr is not None, expr
 | 
| 211 |                 with branch_ctx.AddBranch():
 | 
| 212 |                     self.accept(body)
 | 
| 213 | 
 | 
| 214 |             if default_block:
 | 
| 215 |                 with branch_ctx.AddBranch():
 | 
| 216 |                     self.accept(default_block)
 | 
| 217 | 
 | 
| 218 |     def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
 | 
| 219 |         cfg = self.current_cfg()
 | 
| 220 |         assert len(o.expr) == 1, o.expr
 | 
| 221 |         expr = o.expr[0]
 | 
| 222 |         assert isinstance(expr, CallExpr), expr
 | 
| 223 |         self.accept(expr)
 | 
| 224 | 
 | 
| 225 |         callee_name = expr.callee.name
 | 
| 226 |         if callee_name == 'switch':
 | 
| 227 |             self._handle_switch(expr, o, cfg)
 | 
| 228 |         elif callee_name == 'str_switch':
 | 
| 229 |             self._handle_switch(expr, o, cfg)
 | 
| 230 |         elif callee_name == 'tagswitch':
 | 
| 231 |             self._handle_switch(expr, o, cfg)
 | 
| 232 |         else:
 | 
| 233 |             with pass_state.CfgBlockContext(cfg, self.current_statement_id):
 | 
| 234 |                 self.accept(o.body)
 | 
| 235 | 
 | 
| 236 |     def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
 | 
| 237 |         if o.name == '__repr__':  # Don't translate
 | 
| 238 |             return
 | 
| 239 | 
 | 
| 240 |         # For virtual methods, pretend that the method on the base class calls
 | 
| 241 |         # the same method on every subclass. This way call sites using the
 | 
| 242 |         # abstract base class will over-approximate the set of call paths they
 | 
| 243 |         # can take when checking if they can reach MaybeCollect().
 | 
| 244 |         if self.current_class_name and self.virtual.IsVirtual(
 | 
| 245 |                 self.current_class_name, o.name):
 | 
| 246 |             key = (self.current_class_name, o.name)
 | 
| 247 |             base = self.virtual.virtuals[key]
 | 
| 248 |             if base:
 | 
| 249 |                 sub = join_name(self.current_class_name + (o.name, ),
 | 
| 250 |                                 delim='.')
 | 
| 251 |                 base_key = base[0] + (base[1], )
 | 
| 252 |                 cfg = self.cfgs[base_key]
 | 
| 253 |                 cfg.AddFact(0, pass_state.FunctionCall(sub))
 | 
| 254 | 
 | 
| 255 |         self.current_func_node = o
 | 
| 256 |         self.accept(o.body)
 | 
| 257 |         self.current_func_node = None
 | 
| 258 |         self.current_statement_id = None
 | 
| 259 | 
 | 
| 260 |     def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
 | 
| 261 |         self.current_class_name = split_py_name(o.fullname)
 | 
| 262 |         for stmt in o.defs.body:
 | 
| 263 |             # Ignore things that look like docstrings
 | 
| 264 |             if (isinstance(stmt, ExpressionStmt) and
 | 
| 265 |                     isinstance(stmt.expr, StrExpr)):
 | 
| 266 |                 continue
 | 
| 267 | 
 | 
| 268 |             if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
 | 
| 269 |                 continue
 | 
| 270 | 
 | 
| 271 |             self.accept(stmt)
 | 
| 272 | 
 | 
| 273 |         self.current_class_name = None
 | 
| 274 | 
 | 
| 275 |     def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
 | 
| 276 |         cfg = self.current_cfg()
 | 
| 277 |         with pass_state.CfgLoopContext(
 | 
| 278 |                 cfg, entry=self.current_statement_id) as loop:
 | 
| 279 |             self.accept(o.expr)
 | 
| 280 |             self.loop_stack.append(loop)
 | 
| 281 |             self.accept(o.body)
 | 
| 282 |             self.loop_stack.pop()
 | 
| 283 | 
 | 
| 284 |     def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
 | 
| 285 |         cfg = self.current_cfg()
 | 
| 286 |         if cfg:
 | 
| 287 |             cfg.AddDeadend(self.current_statement_id)
 | 
| 288 | 
 | 
| 289 |         if o.expr:
 | 
| 290 |             self.accept(o.expr)
 | 
| 291 | 
 | 
| 292 |     def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
 | 
| 293 |         cfg = self.current_cfg()
 | 
| 294 |         for expr in o.expr:
 | 
| 295 |             self.accept(expr)
 | 
| 296 | 
 | 
| 297 |         with pass_state.CfgBranchContext(
 | 
| 298 |                 cfg, self.current_statement_id) as branch_ctx:
 | 
| 299 |             with branch_ctx.AddBranch():
 | 
| 300 |                 for node in o.body:
 | 
| 301 |                     self.accept(node)
 | 
| 302 | 
 | 
| 303 |             if o.else_body:
 | 
| 304 |                 with branch_ctx.AddBranch():
 | 
| 305 |                     self.accept(o.else_body)
 | 
| 306 | 
 | 
| 307 |     def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
 | 
| 308 |         if len(self.loop_stack):
 | 
| 309 |             self.loop_stack[-1].AddBreak(self.current_statement_id)
 | 
| 310 | 
 | 
| 311 |     def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
 | 
| 312 |         if len(self.loop_stack):
 | 
| 313 |             self.loop_stack[-1].AddContinue(self.current_statement_id)
 | 
| 314 | 
 | 
| 315 |     def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
 | 
| 316 |         cfg = self.current_cfg()
 | 
| 317 |         if cfg:
 | 
| 318 |             cfg.AddDeadend(self.current_statement_id)
 | 
| 319 | 
 | 
| 320 |         if o.expr:
 | 
| 321 |             self.accept(o.expr)
 | 
| 322 | 
 | 
| 323 |     def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
 | 
| 324 |         cfg = self.current_cfg()
 | 
| 325 |         with pass_state.CfgBranchContext(cfg,
 | 
| 326 |                                          self.current_statement_id) as try_ctx:
 | 
| 327 |             with try_ctx.AddBranch() as try_block:
 | 
| 328 |                 self.accept(o.body)
 | 
| 329 | 
 | 
| 330 |             for t, v, handler in zip(o.types, o.vars, o.handlers):
 | 
| 331 |                 with try_ctx.AddBranch(try_block.exit):
 | 
| 332 |                     self.accept(handler)
 | 
| 333 | 
 | 
| 334 |     # Expressions
 | 
| 335 | 
 | 
| 336 |     def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
 | 
| 337 |         cfg = self.current_cfg()
 | 
| 338 |         if self.current_func_node:
 | 
| 339 |             full_callee = self.resolve_callee(o)
 | 
| 340 |             if full_callee:
 | 
| 341 |                 self.callees[o] = full_callee
 | 
| 342 |                 cfg.AddFact(
 | 
| 343 |                     self.current_statement_id,
 | 
| 344 |                     pass_state.FunctionCall(join_name(full_callee, delim='.')))
 | 
| 345 | 
 | 
| 346 |         self.accept(o.callee)
 | 
| 347 |         for arg in o.args:
 | 
| 348 |             self.accept(arg)
 |