OILS / mycpp / const_pass.py View on Github | oilshell.org

532 lines, 337 significant
1"""
2const_pass.py - AST pass that collects constants.
3
4Immutable string constants like 'StrFromC("foo")' are moved to the top level of
5the generated C++ program for efficiency.
6"""
7import json
8
9from typing import overload, Union, Optional, Dict, List
10
11import mypy
12from mypy.visitor import ExpressionVisitor, StatementVisitor
13from mypy.nodes import (Expression, Statement, ExpressionStmt, StrExpr,
14 ComparisonExpr, NameExpr, MemberExpr)
15
16from mypy.types import Type
17
18from mycpp.crash import catch_errors
19from mycpp import format_strings
20from mycpp.util import log
21from mycpp import util
22
23T = None # TODO: Make it type check?
24
25
26class UnsupportedException(Exception):
27 pass
28
29
30class Collect(ExpressionVisitor[T], StatementVisitor[None]):
31
32 def __init__(self, types: Dict[Expression, Type],
33 const_lookup: Dict[Expression, str], const_code: List[str]):
34
35 self.types = types
36 self.const_lookup = const_lookup
37 self.const_code = const_code
38 self.unique_id = 0
39
40 self.indent = 0
41
42 def out(self, msg, *args):
43 if args:
44 msg = msg % args
45 self.const_code.append(msg)
46
47 #
48 # COPIED from IRBuilder
49 #
50
51 @overload
52 def accept(self, node: Expression) -> T:
53 ...
54
55 @overload
56 def accept(self, node: Statement) -> None:
57 ...
58
59 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
60 with catch_errors(self.module_path, node.line):
61 if isinstance(node, Expression):
62 try:
63 res = node.accept(self)
64 #res = self.coerce(res, self.node_type(node), node.line)
65
66 # If we hit an error during compilation, we want to
67 # keep trying, so we can produce more error
68 # messages. Generate a temp of the right type to keep
69 # from causing more downstream trouble.
70 except UnsupportedException:
71 res = self.alloc_temp(self.node_type(node))
72 return res
73 else:
74 try:
75 node.accept(self)
76 except UnsupportedException:
77 pass
78 return None
79
80 def log(self, msg, *args):
81 if 0: # quiet
82 ind_str = self.indent * ' '
83 log(ind_str + msg, *args)
84
85 # Not in superclasses:
86
87 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
88 # Skip some stdlib stuff. A lot of it is brought in by 'import
89 # typing'.
90 if o.fullname in ('__future__', 'sys', 'types', 'typing', 'abc',
91 '_ast', 'ast', '_weakrefset', 'collections',
92 'cStringIO', 're', 'builtins'):
93
94 # These module are special; their contents are currently all
95 # built-in primitives.
96 return
97
98 self.module_path = o.path
99
100 self.indent += 1
101 for node in o.defs:
102 # skip module docstring
103 if isinstance(node, ExpressionStmt) and isinstance(
104 node.expr, StrExpr):
105 continue
106 self.accept(node)
107 self.indent -= 1
108
109 # LITERALS
110
111 def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T:
112 self.log('IntExpr %d', o.value)
113
114 def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T:
115 id_ = 'str%d' % self.unique_id
116 self.unique_id += 1
117
118 raw_string = format_strings.DecodeMyPyString(o.value)
119
120 if util.SMALL_STR:
121 self.out('GLOBAL_STR2(%s, %s);', id_, json.dumps(raw_string))
122 else:
123 self.out('GLOBAL_STR(%s, %s);', id_, json.dumps(raw_string))
124
125 self.const_lookup[o] = id_
126
127 def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> T:
128 pass
129
130 def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> T:
131 pass
132
133 def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T:
134 pass
135
136 def visit_complex_expr(self, o: 'mypy.nodes.ComplexExpr') -> T:
137 pass
138
139 # Expression
140
141 def visit_ellipsis(self, o: 'mypy.nodes.EllipsisExpr') -> T:
142 pass
143
144 def visit_star_expr(self, o: 'mypy.nodes.StarExpr') -> T:
145 pass
146
147 def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T:
148 #self.log('NameExpr %s', o.name)
149 pass
150
151 def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
152 if o.expr:
153 self.accept(o.expr)
154
155 def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T:
156 pass
157
158 def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T:
159 pass
160
161 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
162 self.log('CallExpr')
163 self.accept(o.callee) # could be f() or obj.method()
164 if o.callee.name == 'probe':
165 # don't generate constants for probe names
166 return
167
168 self.indent += 1
169 for arg in o.args:
170 self.accept(arg)
171 # The type of each argument
172 #self.log(':: %s', self.types[arg])
173 self.indent -= 1
174 #self.log( 'args %s', o.args)
175
176 #self.log(' arg_kinds %s', o.arg_kinds)
177 #self.log(' arg_names %s', o.arg_names)
178
179 def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
180 self.log('OpExpr')
181 self.indent += 1
182 self.accept(o.left)
183 self.accept(o.right)
184 self.indent -= 1
185
186 def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
187 self.log('ComparisonExpr')
188 self.log(' operators %s', o.operators)
189 self.indent += 1
190
191 for operand in o.operands:
192 self.indent += 1
193 self.accept(operand)
194 self.indent -= 1
195
196 self.indent -= 1
197
198 def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
199 pass
200
201 def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
202 pass
203
204 def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> T:
205 pass
206
207 def visit_assignment_expr(self, o: 'mypy.nodes.AssignmentExpr') -> T:
208 pass
209
210 def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T:
211 # e.g. a[-1] or 'not x'
212 self.accept(o.expr)
213
214 def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T:
215 # lists are MUTABLE, so we can't generate constants at the top level
216
217 # but we want to visit the string literals!
218 for item in o.items:
219 self.accept(item)
220
221 def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T:
222 for k, v in o.items:
223 self.accept(k)
224 self.accept(v)
225
226 def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T:
227 for item in o.items:
228 self.accept(item)
229
230 def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> T:
231 pass
232
233 def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T:
234 self.accept(o.base)
235 self.accept(o.index)
236
237 def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> T:
238 pass
239
240 def visit_lambda_expr(self, o: 'mypy.nodes.LambdaExpr') -> T:
241 pass
242
243 def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T:
244 gen = o.generator # GeneratorExpr
245 left_expr = gen.left_expr
246 index_expr = gen.indices[0]
247 seq = gen.sequences[0]
248 cond = gen.condlists[0]
249
250 # We might use all of these, so collect constants.
251 self.accept(left_expr)
252 self.accept(index_expr)
253 self.accept(seq)
254 for c in cond:
255 self.accept(c)
256
257 def visit_set_comprehension(self, o: 'mypy.nodes.SetComprehension') -> T:
258 pass
259
260 def visit_dictionary_comprehension(
261 self, o: 'mypy.nodes.DictionaryComprehension') -> T:
262 pass
263
264 def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T:
265 pass
266
267 def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T:
268 if o.begin_index:
269 self.accept(o.begin_index)
270
271 if o.end_index:
272 self.accept(o.end_index)
273
274 if o.stride:
275 self.accept(o.stride)
276
277 def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T:
278 self.accept(o.cond)
279 self.accept(o.if_expr)
280 self.accept(o.else_expr)
281
282 def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> T:
283 pass
284
285 def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T:
286 pass
287
288 def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
289 pass
290
291 def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
292 pass
293
294 def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
295 pass
296
297 def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
298 pass
299
300 def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> T:
301 pass
302
303 def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T:
304 pass
305
306 def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T:
307 pass
308
309 def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T:
310 pass
311
312 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
313 # How does this get reached??
314
315 # Ah wtf, why is there no type on here!
316 # I thought we did parse_and_typecheck already?
317
318 if 1:
319 self.log('AssignmentStmt')
320 #self.log(' type %s', o.type)
321 #self.log(' unanalyzed_type %s', o.unanalyzed_type)
322
323 # NICE! Got the lvalue
324 for lval in o.lvalues:
325 try:
326 self.log(' lval %s :: %s', lval, self.types[lval])
327 except KeyError: # TODO: handle this
328 pass
329 self.accept(lval)
330
331 try:
332 r = self.types[o.rvalue]
333 except KeyError:
334 # This seems to only happen for Ellipsis, I guess in the abc module
335 #log(' NO TYPE FOR RVALUE: %s', o.rvalue)
336 pass
337 else:
338 #self.log(' %s :: %s', o.rvalue, r)
339 self.indent += 1
340 #self.log(' rvalue :: %s', r)
341 self.accept(o.rvalue)
342 self.indent -= 1
343 #self.log(' o.rvalue %s', o.rvalue)
344
345 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
346 self.log('ForStmt')
347 #self.log(' index_type %s', o.index_type)
348 #self.log(' inferred_item_type %s', o.inferred_item_type)
349 #self.log(' inferred_iterator_type %s', o.inferred_iterator_type)
350 self.accept(o.index) # index var expression
351 self.accept(o.expr) # the thing being iterated over
352 self.accept(o.body)
353 if o.else_body:
354 raise AssertionError("can't translate for-else")
355
356 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
357 assert len(o.expr) == 1, o.expr
358 self.accept(o.expr[0])
359 self.accept(o.body)
360
361 def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T:
362 self.accept(o.expr)
363
364 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
365 # got the type here, nice!
366 typ = o.type
367 self.log('FuncDef %s :: %s', o.name, typ)
368 #self.log('%s', type(typ))
369
370 for t, name in zip(typ.arg_types, typ.arg_names):
371 self.log(' arg %s %s', t, name)
372 self.log(' ret %s', o.type.ret_type)
373
374 self.indent += 1
375 for arg in o.arguments:
376 # e.g. foo=''
377 if arg.initializer:
378 self.accept(arg.initializer)
379
380 # We can't use __str__ on these Argument objects? That seems like an
381 # oversight
382 #self.log('%r', arg)
383
384 self.log('Argument %s', arg.variable)
385 self.log(' type_annotation %s', arg.type_annotation)
386 # I think these are for default values
387 self.log(' initializer %s', arg.initializer)
388 self.log(' kind %s', arg.kind)
389
390 self.accept(o.body)
391 self.indent -= 1
392
393 def visit_overloaded_func_def(self,
394 o: 'mypy.nodes.OverloadedFuncDef') -> T:
395 pass
396
397 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
398 self.log('const_pass ClassDef %s', o.name)
399 for b in o.base_type_exprs:
400 self.log(' base_type_expr %s', b)
401 self.indent += 1
402 self.accept(o.defs)
403 self.indent -= 1
404
405 def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T:
406 pass
407
408 def visit_nonlocal_decl(self, o: 'mypy.nodes.NonlocalDecl') -> T:
409 pass
410
411 def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T:
412 pass
413
414 def visit_var(self, o: 'mypy.nodes.Var') -> T:
415 pass
416
417 # Module structure
418
419 def visit_import(self, o: 'mypy.nodes.Import') -> T:
420 pass
421
422 def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T:
423 pass
424
425 def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T:
426 pass
427
428 # Statements
429
430 def visit_block(self, block: 'mypy.nodes.Block') -> T:
431 self.log('Block')
432 self.indent += 1
433 for stmt in block.body:
434 # Ignore things that look like docstrings
435 if isinstance(stmt, ExpressionStmt) and isinstance(
436 stmt.expr, StrExpr):
437 continue
438 #log('-- %d', self.indent)
439 self.accept(stmt)
440 self.indent -= 1
441
442 def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
443 self.log('ExpressionStmt')
444 self.indent += 1
445 self.accept(o.expr)
446 self.indent -= 1
447
448 def visit_operator_assignment_stmt(
449 self, o: 'mypy.nodes.OperatorAssignmentStmt') -> T:
450 self.log('OperatorAssignmentStmt')
451
452 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
453 self.log('WhileStmt')
454 self.accept(o.expr)
455 self.accept(o.body)
456
457 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
458 self.log('ReturnStmt')
459 if o.expr:
460 self.accept(o.expr)
461
462 def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T:
463 pass
464
465 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
466 # Copied from cppgen_pass.py
467 # Not sure why this wouldn't be true
468 assert len(o.expr) == 1, o.expr
469
470 # Omit anything that looks like if __name__ == ...
471 cond = o.expr[0]
472 if (isinstance(cond, ComparisonExpr) and
473 isinstance(cond.operands[0], NameExpr) and
474 cond.operands[0].name == '__name__'):
475 return
476
477 # Omit if TYPE_CHECKING blocks. They contain type expressions that
478 # don't type check!
479 if isinstance(cond, NameExpr) and cond.name == 'TYPE_CHECKING':
480 return
481 # mylib.CPP
482 if isinstance(cond, MemberExpr) and cond.name == 'CPP':
483 # just take the if block
484 for node in o.body:
485 self.accept(node)
486 return
487 # mylib.PYTHON
488 if isinstance(cond, MemberExpr) and cond.name == 'PYTHON':
489 if o.else_body:
490 self.accept(o.else_body)
491 return
492
493 self.log('IfStmt')
494 self.indent += 1
495 for e in o.expr:
496 self.accept(e)
497
498 for node in o.body:
499 self.accept(node)
500
501 if o.else_body:
502 self.accept(o.else_body)
503 self.indent -= 1
504
505 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
506 pass
507
508 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
509 pass
510
511 def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T:
512 pass
513
514 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
515 if o.expr:
516 self.accept(o.expr)
517
518 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
519 self.accept(o.body)
520 for t, v, handler in zip(o.types, o.vars, o.handlers):
521 self.accept(handler)
522
523 #if o.else_body:
524 # raise AssertionError('try/else not supported')
525 #if o.finally_body:
526 # raise AssertionError('try/finally not supported')
527
528 def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T:
529 pass
530
531 def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T:
532 pass