1 | """
|
2 | visitor.py - AST pass that accepts everything.
|
3 | """
|
4 | from typing import overload, Union, Optional
|
5 |
|
6 | import mypy
|
7 | from mypy.visitor import ExpressionVisitor, StatementVisitor
|
8 | from mypy.nodes import (Expression, Statement, ExpressionStmt, StrExpr,
|
9 | CallExpr, NameExpr)
|
10 |
|
11 | from mycpp.crash import catch_errors
|
12 | from mycpp.util import split_py_name
|
13 | from mycpp import util
|
14 |
|
15 | T = None # TODO: Make it type check?
|
16 |
|
17 |
|
18 | class UnsupportedException(Exception):
|
19 | pass
|
20 |
|
21 |
|
22 | class SimpleVisitor(ExpressionVisitor[T], StatementVisitor[None]):
|
23 | """
|
24 | A simple AST visitor that accepts every node in the AST. Derrived classes
|
25 | can override the visit methods that are relevant to them.
|
26 | """
|
27 |
|
28 | def __init__(self):
|
29 | self.current_class_name = None
|
30 |
|
31 | #
|
32 | # COPIED from IRBuilder
|
33 | #
|
34 |
|
35 | @overload
|
36 | def accept(self, node: Expression) -> T:
|
37 | ...
|
38 |
|
39 | @overload
|
40 | def accept(self, node: Statement) -> None:
|
41 | ...
|
42 |
|
43 | def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
|
44 | with catch_errors(self.module_path, node.line):
|
45 | if isinstance(node, Expression):
|
46 | try:
|
47 | res = node.accept(self)
|
48 | #res = self.coerce(res, self.node_type(node), node.line)
|
49 |
|
50 | # If we hit an error during compilation, we want to
|
51 | # keep trying, so we can produce more error
|
52 | # messages. Generate a temp of the right type to keep
|
53 | # from causing more downstream trouble.
|
54 | except UnsupportedException:
|
55 | res = self.alloc_temp(self.node_type(node))
|
56 | return res
|
57 | else:
|
58 | try:
|
59 | node.accept(self)
|
60 | except UnsupportedException:
|
61 | pass
|
62 | return None
|
63 |
|
64 | # Not in superclasses:
|
65 |
|
66 | def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
|
67 | if util.ShouldSkipPyFile(o):
|
68 | return
|
69 |
|
70 | self.module_path = o.path
|
71 |
|
72 | for node in o.defs:
|
73 | # skip module docstring
|
74 | if isinstance(node, ExpressionStmt) and isinstance(
|
75 | node.expr, StrExpr):
|
76 | continue
|
77 | self.accept(node)
|
78 |
|
79 | # LITERALS
|
80 |
|
81 | def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
|
82 | self.accept(o.expr)
|
83 | self.accept(o.body)
|
84 |
|
85 | def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
|
86 | assert len(o.expr) == 1, o.expr
|
87 | expr = o.expr[0]
|
88 | assert isinstance(expr, CallExpr), expr
|
89 | self.accept(expr)
|
90 | self.accept(o.body)
|
91 |
|
92 | def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
|
93 | if o.name == '__repr__': # Don't translate
|
94 | return
|
95 |
|
96 | for arg in o.arguments:
|
97 | if arg.initializer:
|
98 | self.accept(arg.initializer)
|
99 |
|
100 | self.accept(o.body)
|
101 |
|
102 | def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
|
103 | self.current_class_name = split_py_name(o.fullname)
|
104 | for stmt in o.defs.body:
|
105 | self.accept(stmt)
|
106 | self.current_class_name = None
|
107 |
|
108 | # Statements
|
109 |
|
110 | def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
|
111 | for lval in o.lvalues:
|
112 | self.accept(lval)
|
113 |
|
114 | self.accept(o.rvalue)
|
115 |
|
116 | def visit_operator_assignment_stmt(
|
117 | self, o: 'mypy.nodes.OperatorAssignmentStmt') -> T:
|
118 | self.accept(o.lvalue)
|
119 | self.accept(o.rvalue)
|
120 |
|
121 | def visit_block(self, block: 'mypy.nodes.Block') -> T:
|
122 | for stmt in block.body:
|
123 | # Ignore things that look like docstrings
|
124 | if (isinstance(stmt, ExpressionStmt) and
|
125 | isinstance(stmt.expr, StrExpr)):
|
126 | continue
|
127 |
|
128 | self.accept(stmt)
|
129 |
|
130 | def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
|
131 | self.accept(o.expr)
|
132 |
|
133 | def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
|
134 | self.accept(o.expr)
|
135 | self.accept(o.body)
|
136 |
|
137 | def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
|
138 | if o.expr:
|
139 | self.accept(o.expr)
|
140 |
|
141 | def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
|
142 | # Omit if TYPE_CHECKING blocks. They contain type expressions that
|
143 | # don't type check!
|
144 | cond = o.expr[0]
|
145 | if isinstance(cond, NameExpr) and cond.name == 'TYPE_CHECKING':
|
146 | return
|
147 |
|
148 | for expr in o.expr:
|
149 | self.accept(expr)
|
150 |
|
151 | for body in o.body:
|
152 | self.accept(body)
|
153 |
|
154 | if o.else_body:
|
155 | self.accept(o.else_body)
|
156 |
|
157 | def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
|
158 | if o.expr:
|
159 | self.accept(o.expr)
|
160 |
|
161 | def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
|
162 | self.accept(o.body)
|
163 | for handler in o.handlers:
|
164 | self.accept(handler)
|
165 |
|
166 | def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T:
|
167 | self.accept(o.expr)
|
168 |
|
169 | # Expressions
|
170 |
|
171 | def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T:
|
172 | self.accept(o.left_expr)
|
173 |
|
174 | for expr in o.indices:
|
175 | self.accept(expr)
|
176 |
|
177 | for expr in o.sequences:
|
178 | self.accept(expr)
|
179 |
|
180 | for l in o.condlists:
|
181 | for expr in l:
|
182 | self.accept(expr)
|
183 |
|
184 | def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T:
|
185 | self.accept(o.generator)
|
186 |
|
187 | def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
|
188 | self.accept(o.expr)
|
189 |
|
190 | def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T:
|
191 | self.accept(o.expr)
|
192 |
|
193 | def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
|
194 | self.accept(o.left)
|
195 | self.accept(o.right)
|
196 |
|
197 | def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
|
198 | for operand in o.operands:
|
199 | self.accept(operand)
|
200 |
|
201 | def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T:
|
202 | self.accept(o.expr)
|
203 |
|
204 | def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T:
|
205 | if o.items:
|
206 | for item in o.items:
|
207 | self.accept(item)
|
208 |
|
209 | def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T:
|
210 | if o.items:
|
211 | for k, v in o.items:
|
212 | self.accept(k)
|
213 | self.accept(v)
|
214 |
|
215 | def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T:
|
216 | if o.items:
|
217 | for item in o.items:
|
218 | self.accept(item)
|
219 |
|
220 | def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T:
|
221 | self.accept(o.base)
|
222 | self.accept(o.index)
|
223 |
|
224 | def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T:
|
225 | if o.begin_index:
|
226 | self.accept(o.begin_index)
|
227 |
|
228 | if o.end_index:
|
229 | self.accept(o.end_index)
|
230 |
|
231 | if o.stride:
|
232 | self.accept(o.stride)
|
233 |
|
234 | def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T:
|
235 | self.accept(o.cond)
|
236 | self.accept(o.if_expr)
|
237 | self.accept(o.else_expr)
|
238 |
|
239 | def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
|
240 | self.accept(o.callee)
|
241 | for arg in o.args:
|
242 | self.accept(arg)
|