OILS / mycpp / visitor.py View on Github | oilshell.org

238 lines, 158 significant
1"""
2visitor.py - AST pass that accepts everything.
3"""
4from typing import overload, Union, Optional
5
6import mypy
7from mypy.visitor import ExpressionVisitor, StatementVisitor
8from mypy.nodes import (Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr)
10
11from mycpp.crash import catch_errors
12from mycpp.util import split_py_name
13from mycpp import util
14
15T = None # TODO: Make it type check?
16
17
18class UnsupportedException(Exception):
19 pass
20
21
22class SimpleVisitor(ExpressionVisitor[T], StatementVisitor[None]):
23 """
24 A simple AST visitor that accepts every node in the AST. Derrived classes
25 can override the visit methods that are relevant to them.
26 """
27
28 def __init__(self):
29 self.current_class_name = None
30
31 #
32 # COPIED from IRBuilder
33 #
34
35 @overload
36 def accept(self, node: Expression) -> T:
37 ...
38
39 @overload
40 def accept(self, node: Statement) -> None:
41 ...
42
43 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
44 with catch_errors(self.module_path, node.line):
45 if isinstance(node, Expression):
46 try:
47 res = node.accept(self)
48 #res = self.coerce(res, self.node_type(node), node.line)
49
50 # If we hit an error during compilation, we want to
51 # keep trying, so we can produce more error
52 # messages. Generate a temp of the right type to keep
53 # from causing more downstream trouble.
54 except UnsupportedException:
55 res = self.alloc_temp(self.node_type(node))
56 return res
57 else:
58 try:
59 node.accept(self)
60 except UnsupportedException:
61 pass
62 return None
63
64 # Not in superclasses:
65
66 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
67 if util.ShouldSkipPyFile(o):
68 return
69
70 self.module_path = o.path
71
72 for node in o.defs:
73 # skip module docstring
74 if isinstance(node, ExpressionStmt) and isinstance(
75 node.expr, StrExpr):
76 continue
77 self.accept(node)
78
79 # LITERALS
80
81 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
82 self.accept(o.expr)
83 self.accept(o.body)
84
85 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
86 assert len(o.expr) == 1, o.expr
87 expr = o.expr[0]
88 assert isinstance(expr, CallExpr), expr
89 self.accept(expr)
90 self.accept(o.body)
91
92 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
93 if o.name == '__repr__': # Don't translate
94 return
95
96 for arg in o.arguments:
97 if arg.initializer:
98 self.accept(arg.initializer)
99
100 self.accept(o.body)
101
102 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
103 self.current_class_name = split_py_name(o.fullname)
104 for stmt in o.defs.body:
105 self.accept(stmt)
106 self.current_class_name = None
107
108 # Statements
109
110 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
111 for lval in o.lvalues:
112 self.accept(lval)
113
114 self.accept(o.rvalue)
115
116 def visit_operator_assignment_stmt(
117 self, o: 'mypy.nodes.OperatorAssignmentStmt') -> T:
118 self.accept(o.lvalue)
119 self.accept(o.rvalue)
120
121 def visit_block(self, block: 'mypy.nodes.Block') -> T:
122 for stmt in block.body:
123 # Ignore things that look like docstrings
124 if (isinstance(stmt, ExpressionStmt) and
125 isinstance(stmt.expr, StrExpr)):
126 continue
127
128 self.accept(stmt)
129
130 def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
131 self.accept(o.expr)
132
133 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
134 self.accept(o.expr)
135 self.accept(o.body)
136
137 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
138 if o.expr:
139 self.accept(o.expr)
140
141 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
142 if util.ShouldVisitIfExpr(o):
143 for expr in o.expr:
144 self.accept(expr)
145
146 if util.ShouldVisitIfBody(o):
147 for body in o.body:
148 self.accept(body)
149
150 if util.ShouldVisitElseBody(o):
151 self.accept(o.else_body)
152
153 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
154 if o.expr:
155 self.accept(o.expr)
156
157 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
158 self.accept(o.body)
159 for handler in o.handlers:
160 self.accept(handler)
161
162 def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T:
163 self.accept(o.expr)
164
165 # Expressions
166
167 def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T:
168 self.accept(o.left_expr)
169
170 for expr in o.indices:
171 self.accept(expr)
172
173 for expr in o.sequences:
174 self.accept(expr)
175
176 for l in o.condlists:
177 for expr in l:
178 self.accept(expr)
179
180 def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T:
181 self.accept(o.generator)
182
183 def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
184 self.accept(o.expr)
185
186 def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T:
187 self.accept(o.expr)
188
189 def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
190 self.accept(o.left)
191 self.accept(o.right)
192
193 def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
194 for operand in o.operands:
195 self.accept(operand)
196
197 def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T:
198 self.accept(o.expr)
199
200 def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T:
201 if o.items:
202 for item in o.items:
203 self.accept(item)
204
205 def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T:
206 if o.items:
207 for k, v in o.items:
208 self.accept(k)
209 self.accept(v)
210
211 def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T:
212 if o.items:
213 for item in o.items:
214 self.accept(item)
215
216 def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T:
217 self.accept(o.base)
218 self.accept(o.index)
219
220 def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T:
221 if o.begin_index:
222 self.accept(o.begin_index)
223
224 if o.end_index:
225 self.accept(o.end_index)
226
227 if o.stride:
228 self.accept(o.stride)
229
230 def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T:
231 self.accept(o.cond)
232 self.accept(o.if_expr)
233 self.accept(o.else_expr)
234
235 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
236 self.accept(o.callee)
237 for arg in o.args:
238 self.accept(arg)