| 1 | """
 | 
| 2 | visitor.py - AST pass that accepts everything.
 | 
| 3 | """
 | 
| 4 | from typing import overload, Union, Optional
 | 
| 5 | 
 | 
| 6 | import mypy
 | 
| 7 | from mypy.visitor import ExpressionVisitor, StatementVisitor
 | 
| 8 | from mypy.nodes import (Expression, Statement, ExpressionStmt, StrExpr,
 | 
| 9 |                         CallExpr)
 | 
| 10 | 
 | 
| 11 | from mycpp.crash import catch_errors
 | 
| 12 | from mycpp.util import split_py_name
 | 
| 13 | from mycpp import util
 | 
| 14 | 
 | 
| 15 | T = None  # TODO: Make it type check?
 | 
| 16 | 
 | 
| 17 | 
 | 
| 18 | class UnsupportedException(Exception):
 | 
| 19 |     pass
 | 
| 20 | 
 | 
| 21 | 
 | 
| 22 | class SimpleVisitor(ExpressionVisitor[T], StatementVisitor[None]):
 | 
| 23 |     """
 | 
| 24 |     A simple AST visitor that accepts every node in the AST. Derrived classes
 | 
| 25 |     can override the visit methods that are relevant to them.
 | 
| 26 |     """
 | 
| 27 | 
 | 
| 28 |     def __init__(self):
 | 
| 29 |         self.current_class_name = None
 | 
| 30 | 
 | 
| 31 |     #
 | 
| 32 |     # COPIED from IRBuilder
 | 
| 33 |     #
 | 
| 34 | 
 | 
| 35 |     @overload
 | 
| 36 |     def accept(self, node: Expression) -> T:
 | 
| 37 |         ...
 | 
| 38 | 
 | 
| 39 |     @overload
 | 
| 40 |     def accept(self, node: Statement) -> None:
 | 
| 41 |         ...
 | 
| 42 | 
 | 
| 43 |     def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
 | 
| 44 |         with catch_errors(self.module_path, node.line):
 | 
| 45 |             if isinstance(node, Expression):
 | 
| 46 |                 try:
 | 
| 47 |                     res = node.accept(self)
 | 
| 48 |                     #res = self.coerce(res, self.node_type(node), node.line)
 | 
| 49 | 
 | 
| 50 |                 # If we hit an error during compilation, we want to
 | 
| 51 |                 # keep trying, so we can produce more error
 | 
| 52 |                 # messages. Generate a temp of the right type to keep
 | 
| 53 |                 # from causing more downstream trouble.
 | 
| 54 |                 except UnsupportedException:
 | 
| 55 |                     res = self.alloc_temp(self.node_type(node))
 | 
| 56 |                 return res
 | 
| 57 |             else:
 | 
| 58 |                 try:
 | 
| 59 |                     node.accept(self)
 | 
| 60 |                 except UnsupportedException:
 | 
| 61 |                     pass
 | 
| 62 |                 return None
 | 
| 63 | 
 | 
| 64 |     # Not in superclasses:
 | 
| 65 | 
 | 
| 66 |     def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
 | 
| 67 |         if util.ShouldSkipPyFile(o):
 | 
| 68 |             return
 | 
| 69 | 
 | 
| 70 |         self.module_path = o.path
 | 
| 71 | 
 | 
| 72 |         for node in o.defs:
 | 
| 73 |             # skip module docstring
 | 
| 74 |             if isinstance(node, ExpressionStmt) and isinstance(
 | 
| 75 |                     node.expr, StrExpr):
 | 
| 76 |                 continue
 | 
| 77 |             self.accept(node)
 | 
| 78 | 
 | 
| 79 |     # LITERALS
 | 
| 80 | 
 | 
| 81 |     def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
 | 
| 82 |         self.accept(o.expr)
 | 
| 83 |         self.accept(o.body)
 | 
| 84 | 
 | 
| 85 |     def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
 | 
| 86 |         assert len(o.expr) == 1, o.expr
 | 
| 87 |         expr = o.expr[0]
 | 
| 88 |         assert isinstance(expr, CallExpr), expr
 | 
| 89 |         self.accept(expr)
 | 
| 90 |         self.accept(o.body)
 | 
| 91 | 
 | 
| 92 |     def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
 | 
| 93 |         if o.name == '__repr__':  # Don't translate
 | 
| 94 |             return
 | 
| 95 | 
 | 
| 96 |         for arg in o.arguments:
 | 
| 97 |             if arg.initializer:
 | 
| 98 |                 self.accept(arg.initializer)
 | 
| 99 | 
 | 
| 100 |         self.accept(o.body)
 | 
| 101 | 
 | 
| 102 |     def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
 | 
| 103 |         self.current_class_name = split_py_name(o.fullname)
 | 
| 104 |         for stmt in o.defs.body:
 | 
| 105 |             self.accept(stmt)
 | 
| 106 |         self.current_class_name = None
 | 
| 107 | 
 | 
| 108 |     # Statements
 | 
| 109 | 
 | 
| 110 |     def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
 | 
| 111 |         for lval in o.lvalues:
 | 
| 112 |             self.accept(lval)
 | 
| 113 | 
 | 
| 114 |         self.accept(o.rvalue)
 | 
| 115 | 
 | 
| 116 |     def visit_operator_assignment_stmt(
 | 
| 117 |             self, o: 'mypy.nodes.OperatorAssignmentStmt') -> T:
 | 
| 118 |         self.accept(o.lvalue)
 | 
| 119 |         self.accept(o.rvalue)
 | 
| 120 | 
 | 
| 121 |     def visit_block(self, block: 'mypy.nodes.Block') -> T:
 | 
| 122 |         for stmt in block.body:
 | 
| 123 |             # Ignore things that look like docstrings
 | 
| 124 |             if (isinstance(stmt, ExpressionStmt) and
 | 
| 125 |                     isinstance(stmt.expr, StrExpr)):
 | 
| 126 |                 continue
 | 
| 127 | 
 | 
| 128 |             self.accept(stmt)
 | 
| 129 | 
 | 
| 130 |     def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
 | 
| 131 |         self.accept(o.expr)
 | 
| 132 | 
 | 
| 133 |     def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
 | 
| 134 |         self.accept(o.expr)
 | 
| 135 |         self.accept(o.body)
 | 
| 136 | 
 | 
| 137 |     def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
 | 
| 138 |         if o.expr:
 | 
| 139 |             self.accept(o.expr)
 | 
| 140 | 
 | 
| 141 |     def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
 | 
| 142 |         if util.ShouldVisitIfExpr(o):
 | 
| 143 |             for expr in o.expr:
 | 
| 144 |                 self.accept(expr)
 | 
| 145 | 
 | 
| 146 |         if util.ShouldVisitIfBody(o):
 | 
| 147 |             for body in o.body:
 | 
| 148 |                 self.accept(body)
 | 
| 149 | 
 | 
| 150 |         if util.ShouldVisitElseBody(o):
 | 
| 151 |             self.accept(o.else_body)
 | 
| 152 | 
 | 
| 153 |     def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
 | 
| 154 |         if o.expr:
 | 
| 155 |             self.accept(o.expr)
 | 
| 156 | 
 | 
| 157 |     def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
 | 
| 158 |         self.accept(o.body)
 | 
| 159 |         for handler in o.handlers:
 | 
| 160 |             self.accept(handler)
 | 
| 161 | 
 | 
| 162 |     def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T:
 | 
| 163 |         self.accept(o.expr)
 | 
| 164 | 
 | 
| 165 |     # Expressions
 | 
| 166 | 
 | 
| 167 |     def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T:
 | 
| 168 |         self.accept(o.left_expr)
 | 
| 169 | 
 | 
| 170 |         for expr in o.indices:
 | 
| 171 |             self.accept(expr)
 | 
| 172 | 
 | 
| 173 |         for expr in o.sequences:
 | 
| 174 |             self.accept(expr)
 | 
| 175 | 
 | 
| 176 |         for l in o.condlists:
 | 
| 177 |             for expr in l:
 | 
| 178 |                 self.accept(expr)
 | 
| 179 | 
 | 
| 180 |     def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T:
 | 
| 181 |         self.accept(o.generator)
 | 
| 182 | 
 | 
| 183 |     def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
 | 
| 184 |         self.accept(o.expr)
 | 
| 185 | 
 | 
| 186 |     def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T:
 | 
| 187 |         self.accept(o.expr)
 | 
| 188 | 
 | 
| 189 |     def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
 | 
| 190 |         self.accept(o.left)
 | 
| 191 |         self.accept(o.right)
 | 
| 192 | 
 | 
| 193 |     def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
 | 
| 194 |         for operand in o.operands:
 | 
| 195 |             self.accept(operand)
 | 
| 196 | 
 | 
| 197 |     def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T:
 | 
| 198 |         self.accept(o.expr)
 | 
| 199 | 
 | 
| 200 |     def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T:
 | 
| 201 |         if o.items:
 | 
| 202 |             for item in o.items:
 | 
| 203 |                 self.accept(item)
 | 
| 204 | 
 | 
| 205 |     def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T:
 | 
| 206 |         if o.items:
 | 
| 207 |             for k, v in o.items:
 | 
| 208 |                 self.accept(k)
 | 
| 209 |                 self.accept(v)
 | 
| 210 | 
 | 
| 211 |     def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T:
 | 
| 212 |         if o.items:
 | 
| 213 |             for item in o.items:
 | 
| 214 |                 self.accept(item)
 | 
| 215 | 
 | 
| 216 |     def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T:
 | 
| 217 |         self.accept(o.base)
 | 
| 218 |         self.accept(o.index)
 | 
| 219 | 
 | 
| 220 |     def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T:
 | 
| 221 |         if o.begin_index:
 | 
| 222 |             self.accept(o.begin_index)
 | 
| 223 | 
 | 
| 224 |         if o.end_index:
 | 
| 225 |             self.accept(o.end_index)
 | 
| 226 | 
 | 
| 227 |         if o.stride:
 | 
| 228 |             self.accept(o.stride)
 | 
| 229 | 
 | 
| 230 |     def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T:
 | 
| 231 |         self.accept(o.cond)
 | 
| 232 |         self.accept(o.if_expr)
 | 
| 233 |         self.accept(o.else_expr)
 | 
| 234 | 
 | 
| 235 |     def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
 | 
| 236 |         self.accept(o.callee)
 | 
| 237 |         for arg in o.args:
 | 
| 238 |             self.accept(arg)
 |