| 1 | """
 | 
| 2 | control_flow_pass.py - AST pass that builds a control flow graph.
 | 
| 3 | """
 | 
| 4 | from typing import overload, Union, Optional, Dict
 | 
| 5 | 
 | 
| 6 | import mypy
 | 
| 7 | from mypy.visitor import ExpressionVisitor, StatementVisitor
 | 
| 8 | from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
 | 
| 9 |                         ForStmt, WhileStmt, CallExpr, FuncDef, IfStmt)
 | 
| 10 | 
 | 
| 11 | from mypy.types import Type
 | 
| 12 | 
 | 
| 13 | from mycpp.crash import catch_errors
 | 
| 14 | from mycpp.util import split_py_name
 | 
| 15 | from mycpp import util
 | 
| 16 | from mycpp import pass_state
 | 
| 17 | 
 | 
| 18 | T = None  # TODO: Make it type check?
 | 
| 19 | 
 | 
| 20 | 
 | 
| 21 | class UnsupportedException(Exception):
 | 
| 22 |     pass
 | 
| 23 | 
 | 
| 24 | 
 | 
| 25 | class Build(ExpressionVisitor[T], StatementVisitor[None]):
 | 
| 26 | 
 | 
| 27 |     def __init__(self, types: Dict[Expression, Type]):
 | 
| 28 | 
 | 
| 29 |         self.types = types
 | 
| 30 |         self.cfgs = {}
 | 
| 31 |         self.current_statement_id = None
 | 
| 32 |         self.current_func_node = None
 | 
| 33 |         self.loop_stack = []
 | 
| 34 | 
 | 
| 35 |     def current_cfg(self):
 | 
| 36 |         if not self.current_func_node:
 | 
| 37 |             return None
 | 
| 38 | 
 | 
| 39 |         return self.cfgs.get(split_py_name(self.current_func_node.fullname))
 | 
| 40 | 
 | 
| 41 |     #
 | 
| 42 |     # COPIED from IRBuilder
 | 
| 43 |     #
 | 
| 44 | 
 | 
| 45 |     @overload
 | 
| 46 |     def accept(self, node: Expression) -> T:
 | 
| 47 |         ...
 | 
| 48 | 
 | 
| 49 |     @overload
 | 
| 50 |     def accept(self, node: Statement) -> None:
 | 
| 51 |         ...
 | 
| 52 | 
 | 
| 53 |     def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
 | 
| 54 |         with catch_errors(self.module_path, node.line):
 | 
| 55 |             if isinstance(node, Expression):
 | 
| 56 |                 try:
 | 
| 57 |                     res = node.accept(self)
 | 
| 58 |                     #res = self.coerce(res, self.node_type(node), node.line)
 | 
| 59 | 
 | 
| 60 |                 # If we hit an error during compilation, we want to
 | 
| 61 |                 # keep trying, so we can produce more error
 | 
| 62 |                 # messages. Generate a temp of the right type to keep
 | 
| 63 |                 # from causing more downstream trouble.
 | 
| 64 |                 except UnsupportedException:
 | 
| 65 |                     res = self.alloc_temp(self.node_type(node))
 | 
| 66 |                 return res
 | 
| 67 |             else:
 | 
| 68 |                 try:
 | 
| 69 |                     cfg = self.current_cfg()
 | 
| 70 |                     # Most statements have empty visitors because they don't
 | 
| 71 |                     # require any special logic. Create statements for them
 | 
| 72 |                     # here. Blocks and loops are handled by their visitors.
 | 
| 73 |                     if (cfg and not isinstance(node, Block) and
 | 
| 74 |                             not isinstance(node, ForStmt) and
 | 
| 75 |                             not isinstance(node, WhileStmt)):
 | 
| 76 |                         self.current_statement_id = cfg.AddStatement()
 | 
| 77 | 
 | 
| 78 |                     node.accept(self)
 | 
| 79 |                 except UnsupportedException:
 | 
| 80 |                     pass
 | 
| 81 |                 return None
 | 
| 82 | 
 | 
| 83 |     # Not in superclasses:
 | 
| 84 | 
 | 
| 85 |     def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
 | 
| 86 |         if util.ShouldSkipPyFile(o):
 | 
| 87 |             return
 | 
| 88 | 
 | 
| 89 |         self.module_path = o.path
 | 
| 90 | 
 | 
| 91 |         for node in o.defs:
 | 
| 92 |             # skip module docstring
 | 
| 93 |             if isinstance(node, ExpressionStmt) and isinstance(
 | 
| 94 |                     node.expr, StrExpr):
 | 
| 95 |                 continue
 | 
| 96 |             self.accept(node)
 | 
| 97 | 
 | 
| 98 |     # LITERALS
 | 
| 99 | 
 | 
| 100 |     def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
 | 
| 101 |         cfg = self.current_cfg()
 | 
| 102 |         if not cfg:
 | 
| 103 |             return
 | 
| 104 | 
 | 
| 105 |         with pass_state.CfgLoopContext(cfg) as loop:
 | 
| 106 |             self.loop_stack.append(loop)
 | 
| 107 |             self.accept(o.body)
 | 
| 108 |             self.loop_stack.pop()
 | 
| 109 | 
 | 
| 110 |     def _handle_switch(self, expr, o, cfg):
 | 
| 111 |         assert len(o.body.body) == 1, o.body.body
 | 
| 112 |         if_node = o.body.body[0]
 | 
| 113 |         assert isinstance(if_node, IfStmt), if_node
 | 
| 114 |         cases = []
 | 
| 115 |         default_block = util._collect_cases(self.module_path, if_node, cases)
 | 
| 116 |         with pass_state.CfgBranchContext(
 | 
| 117 |                 cfg, self.current_statement_id) as branch_ctx:
 | 
| 118 |             for expr, body in cases:
 | 
| 119 |                 assert expr is not None, expr
 | 
| 120 |                 with branch_ctx.AddBranch():
 | 
| 121 |                     self.accept(body)
 | 
| 122 | 
 | 
| 123 |             if default_block:
 | 
| 124 |                 with branch_ctx.AddBranch():
 | 
| 125 |                     self.accept(default_block)
 | 
| 126 | 
 | 
| 127 |     def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
 | 
| 128 |         cfg = self.current_cfg()
 | 
| 129 |         if not cfg:
 | 
| 130 |             return
 | 
| 131 | 
 | 
| 132 |         assert len(o.expr) == 1, o.expr
 | 
| 133 |         expr = o.expr[0]
 | 
| 134 |         assert isinstance(expr, CallExpr), expr
 | 
| 135 | 
 | 
| 136 |         callee_name = expr.callee.name
 | 
| 137 |         if callee_name == 'switch':
 | 
| 138 |             self._handle_switch(expr, o, cfg)
 | 
| 139 |         elif callee_name == 'str_switch':
 | 
| 140 |             self._handle_switch(expr, o, cfg)
 | 
| 141 |         elif callee_name == 'tagswitch':
 | 
| 142 |             self._handle_switch(expr, o, cfg)
 | 
| 143 |         else:
 | 
| 144 |             with pass_state.CfgBlockContext(cfg, self.current_statement_id):
 | 
| 145 |                 for stmt in o.body.body:
 | 
| 146 |                     self.accept(stmt)
 | 
| 147 | 
 | 
| 148 |     def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
 | 
| 149 |         if o.name == '__repr__':  # Don't translate
 | 
| 150 |             return
 | 
| 151 | 
 | 
| 152 |         self.cfgs[split_py_name(o.fullname)] = pass_state.ControlFlowGraph()
 | 
| 153 |         self.current_func_node = o
 | 
| 154 |         self.accept(o.body)
 | 
| 155 |         self.current_func_node = None
 | 
| 156 | 
 | 
| 157 |     def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
 | 
| 158 |         for stmt in o.defs.body:
 | 
| 159 |             # Ignore things that look like docstrings
 | 
| 160 |             if (isinstance(stmt, ExpressionStmt) and
 | 
| 161 |                     isinstance(stmt.expr, StrExpr)):
 | 
| 162 |                 continue
 | 
| 163 | 
 | 
| 164 |             if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
 | 
| 165 |                 continue
 | 
| 166 | 
 | 
| 167 |             self.accept(stmt)
 | 
| 168 | 
 | 
| 169 |     # Statements
 | 
| 170 | 
 | 
| 171 |     def visit_block(self, block: 'mypy.nodes.Block') -> T:
 | 
| 172 |         for stmt in block.body:
 | 
| 173 |             # Ignore things that look like docstrings
 | 
| 174 |             if (isinstance(stmt, ExpressionStmt) and
 | 
| 175 |                     isinstance(stmt.expr, StrExpr)):
 | 
| 176 |                 continue
 | 
| 177 | 
 | 
| 178 |             self.accept(stmt)
 | 
| 179 | 
 | 
| 180 |     def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
 | 
| 181 |         self.accept(o.expr)
 | 
| 182 | 
 | 
| 183 |     def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
 | 
| 184 |         cfg = self.current_cfg()
 | 
| 185 |         if not cfg:
 | 
| 186 |             return
 | 
| 187 | 
 | 
| 188 |         with pass_state.CfgLoopContext(cfg) as loop:
 | 
| 189 |             self.loop_stack.append(loop)
 | 
| 190 |             self.accept(o.body)
 | 
| 191 |             self.loop_stack.pop()
 | 
| 192 | 
 | 
| 193 |     def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
 | 
| 194 |         cfg = self.current_cfg()
 | 
| 195 |         if cfg:
 | 
| 196 |             cfg.AddDeadend(self.current_statement_id)
 | 
| 197 | 
 | 
| 198 |     def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
 | 
| 199 |         cfg = self.current_cfg()
 | 
| 200 |         if not cfg:
 | 
| 201 |             return
 | 
| 202 | 
 | 
| 203 |         with pass_state.CfgBranchContext(
 | 
| 204 |                 cfg, self.current_statement_id) as branch_ctx:
 | 
| 205 |             with branch_ctx.AddBranch():
 | 
| 206 |                 for node in o.body:
 | 
| 207 |                     self.accept(node)
 | 
| 208 | 
 | 
| 209 |             if o.else_body:
 | 
| 210 |                 with branch_ctx.AddBranch():
 | 
| 211 |                     self.accept(o.else_body)
 | 
| 212 | 
 | 
| 213 |     def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
 | 
| 214 |         if len(self.loop_stack):
 | 
| 215 |             self.loop_stack[-1].AddBreak(self.current_statement_id)
 | 
| 216 | 
 | 
| 217 |     def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
 | 
| 218 |         if len(self.loop_stack):
 | 
| 219 |             self.loop_stack[-1].AddContinue(self.current_statement_id)
 | 
| 220 | 
 | 
| 221 |     def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
 | 
| 222 |         cfg = self.current_cfg()
 | 
| 223 |         if cfg:
 | 
| 224 |             cfg.AddDeadend(self.current_statement_id)
 | 
| 225 | 
 | 
| 226 |     def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
 | 
| 227 |         cfg = self.current_cfg()
 | 
| 228 |         if not cfg:
 | 
| 229 |             return
 | 
| 230 | 
 | 
| 231 |         with pass_state.CfgBranchContext(cfg,
 | 
| 232 |                                          self.current_statement_id) as try_ctx:
 | 
| 233 |             with try_ctx.AddBranch() as try_block:
 | 
| 234 |                 self.accept(o.body)
 | 
| 235 | 
 | 
| 236 |             for t, v, handler in zip(o.types, o.vars, o.handlers):
 | 
| 237 |                 with try_ctx.AddBranch(try_block.exit):
 | 
| 238 |                     self.accept(handler)
 |