OILS / asdl / gen_python.py View on Github | oilshell.org

611 lines, 395 significant
1#!/usr/bin/env python2
2"""gen_python.py: Generate Python code from an ASDL schema."""
3from __future__ import print_function
4
5from collections import defaultdict
6
7from asdl import ast
8from asdl import visitor
9from asdl.util import log
10
11_ = log # shut up lint
12
13_PRIMITIVES = {
14 'string': 'str',
15 'int': 'int',
16 'uint16': 'int',
17 'BigInt': 'mops.BigInt',
18 'float': 'float',
19 'bool': 'bool',
20 'any': 'Any',
21 # TODO: frontend/syntax.asdl should properly import id enum instead of
22 # hard-coding it here.
23 'id': 'Id_t',
24}
25
26
27def _MyPyType(typ):
28 """ASDL type to MyPy Type."""
29 if isinstance(typ, ast.ParameterizedType):
30
31 if typ.type_name == 'Dict':
32 k_type = _MyPyType(typ.children[0])
33 v_type = _MyPyType(typ.children[1])
34 return 'Dict[%s, %s]' % (k_type, v_type)
35
36 if typ.type_name == 'List':
37 return 'List[%s]' % _MyPyType(typ.children[0])
38
39 if typ.type_name == 'Optional':
40 return 'Optional[%s]' % _MyPyType(typ.children[0])
41
42 elif isinstance(typ, ast.NamedType):
43 if typ.resolved:
44 if isinstance(typ.resolved, ast.Sum): # includes SimpleSum
45 return '%s_t' % typ.name
46 if isinstance(typ.resolved, ast.Product):
47 return typ.name
48 if isinstance(typ.resolved, ast.Use):
49 return ast.TypeNameHeuristic(typ.name)
50
51 # 'id' falls through here
52 return _PRIMITIVES[typ.name]
53
54 else:
55 raise AssertionError()
56
57
58def _DefaultValue(typ, mypy_type):
59 """Values that the static CreateNull() constructor passes.
60
61 mypy_type is used to cast None, to maintain mypy --strict for ASDL.
62
63 We circumvent the type system on CreateNull(). Then the user is
64 responsible for filling in all the fields. If they do so, we can
65 rely on it when reading fields at runtime.
66 """
67 if isinstance(typ, ast.ParameterizedType):
68 type_name = typ.type_name
69
70 if type_name == 'Optional':
71 return "cast('%s', None)" % mypy_type
72
73 if type_name == 'List':
74 return "[] if alloc_lists else cast('%s', None)" % mypy_type
75
76 if type_name == 'Dict': # TODO: can respect alloc_dicts=True
77 return "cast('%s', None)" % mypy_type
78
79 raise AssertionError(type_name)
80
81 if isinstance(typ, ast.NamedType):
82 type_name = typ.name
83
84 if type_name == 'id': # hard-coded HACK
85 return '-1'
86
87 if type_name == 'int':
88 return '-1'
89
90 if type_name == 'BigInt':
91 return 'mops.BigInt(-1)'
92
93 if type_name == 'bool':
94 return 'False'
95
96 if type_name == 'float':
97 return '0.0' # or should it be NaN?
98
99 if type_name == 'string':
100 return "''"
101
102 if isinstance(typ.resolved, ast.SimpleSum):
103 sum_type = typ.resolved
104 # Just make it the first variant. We could define "Undef" for
105 # each enum, but it doesn't seem worth it.
106 return '%s_e.%s' % (type_name, sum_type.types[0].name)
107
108 # CompoundSum or Product type
109 return 'cast(%s, None)' % mypy_type
110
111 else:
112 raise AssertionError()
113
114
115def _HNodeExpr(abbrev, typ, var_name):
116 # type: (str, ast.TypeExpr, str) -> str
117 none_guard = False
118
119 if typ.IsOptional():
120 typ = typ.children[0] # descend one level
121
122 if isinstance(typ, ast.ParameterizedType):
123 code_str = '%s.%s()' % (var_name, abbrev)
124 none_guard = True
125
126 elif isinstance(typ, ast.NamedType):
127 type_name = typ.name
128
129 if type_name == 'bool':
130 code_str = "hnode.Leaf('T' if %s else 'F', color_e.OtherConst)" % var_name
131
132 elif type_name in ('int', 'uint16'):
133 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
134
135 elif type_name == 'BigInt':
136 code_str = 'hnode.Leaf(mops.ToStr(%s), color_e.OtherConst)' % var_name
137
138 elif type_name == 'float':
139 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
140
141 elif type_name == 'string':
142 code_str = 'NewLeaf(%s, color_e.StringConst)' % var_name
143
144 elif type_name == 'any': # TODO: Remove this. Used for value.Obj().
145 code_str = 'hnode.External(%s)' % var_name
146
147 elif type_name == 'id': # was meta.UserType
148 # This assumes it's Id, which is a simple SumType. TODO: Remove this.
149 code_str = 'hnode.Leaf(Id_str(%s), color_e.UserType)' % var_name
150
151 elif typ.resolved and isinstance(typ.resolved, ast.SimpleSum):
152 code_str = 'hnode.Leaf(%s_str(%s), color_e.TypeName)' % (type_name,
153 var_name)
154
155 else:
156 code_str = '%s.%s(trav=trav)' % (var_name, abbrev)
157 none_guard = True
158
159 else:
160 raise AssertionError()
161
162 return code_str, none_guard
163
164
165class GenMyPyVisitor(visitor.AsdlVisitor):
166 """Generate Python code with MyPy type annotations."""
167
168 def __init__(self,
169 f,
170 abbrev_mod_entries=None,
171 pretty_print_methods=True,
172 py_init_n=False,
173 simple_int_sums=None):
174
175 visitor.AsdlVisitor.__init__(self, f)
176 self.abbrev_mod_entries = abbrev_mod_entries or []
177 self.pretty_print_methods = pretty_print_methods
178 self.py_init_n = py_init_n
179
180 # For Id to use different code gen. It's used like an integer, not just
181 # like an enum.
182 self.simple_int_sums = simple_int_sums or []
183
184 self._shared_type_tags = {}
185 self._product_counter = 64 # matches asdl/gen_cpp.py
186
187 self._products = []
188 self._product_bases = defaultdict(list)
189
190 def _EmitDict(self, name, d, depth):
191 self.Emit('_%s_str = {' % name, depth)
192 for k in sorted(d):
193 self.Emit('%d: %r,' % (k, d[k]), depth + 1)
194 self.Emit('}', depth)
195 self.Emit('', depth)
196
197 def VisitSimpleSum(self, sum, sum_name, depth):
198 int_to_str = {}
199 variants = []
200 for i, variant in enumerate(sum.types):
201 tag_num = i + 1
202 tag_str = '%s.%s' % (sum_name, variant.name)
203 int_to_str[tag_num] = tag_str
204 variants.append((variant, tag_num))
205
206 add_suffix = not ('no_namespace_suffix' in sum.generate)
207 gen_integers = 'integers' in sum.generate or 'uint16' in sum.generate
208
209 if gen_integers:
210 self.Emit('%s_t = int # type alias for integer' % sum_name)
211 self.Emit('')
212
213 i_name = ('%s_i' % sum_name) if add_suffix else sum_name
214
215 self.Emit('class %s(object):' % i_name, depth)
216
217 for variant, tag_num in variants:
218 line = ' %s = %d' % (variant.name, tag_num)
219 self.Emit(line, depth)
220
221 # Help in sizing array. Note that we're 1-based.
222 line = ' %s = %d' % ('ARRAY_SIZE', len(variants) + 1)
223 self.Emit(line, depth)
224
225 else:
226 # First emit a type
227 self.Emit('class %s_t(pybase.SimpleObj):' % sum_name, depth)
228 self.Emit(' pass', depth)
229 self.Emit('', depth)
230
231 # Now emit a namespace
232 e_name = ('%s_e' % sum_name) if add_suffix else sum_name
233 self.Emit('class %s(object):' % e_name, depth)
234
235 for variant, tag_num in variants:
236 line = ' %s = %s_t(%d)' % (variant.name, sum_name, tag_num)
237 self.Emit(line, depth)
238
239 self.Emit('', depth)
240
241 self._EmitDict(sum_name, int_to_str, depth)
242
243 self.Emit('def %s_str(val):' % sum_name, depth)
244 self.Emit(' # type: (%s_t) -> str' % sum_name, depth)
245 self.Emit(' return _%s_str[val]' % sum_name, depth)
246 self.Emit('', depth)
247
248 def _EmitCodeForField(self, abbrev, field, counter):
249 """Generate code that returns an hnode for a field."""
250 out_val_name = 'x%d' % counter
251
252 if field.typ.IsList():
253 iter_name = 'i%d' % counter
254
255 typ = field.typ
256 if typ.type_name == 'Optional': # descend one level
257 typ = typ.children[0]
258 item_type = typ.children[0]
259
260 self.Emit(' if self.%s is not None: # List' % field.name)
261 self.Emit(' %s = hnode.Array([])' % out_val_name)
262 self.Emit(' for %s in self.%s:' % (iter_name, field.name))
263 child_code_str, none_guard = _HNodeExpr(abbrev, item_type,
264 iter_name)
265
266 if none_guard: # e.g. for List[Optional[value_t]]
267 # TODO: could consolidate with asdl/runtime.py NewLeaf(), which
268 # also uses _ to mean None/nullptr
269 self.Emit(
270 ' h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
271 % (iter_name, child_code_str))
272 self.Emit(' %s.children.append(h)' % out_val_name)
273 else:
274 self.Emit(' %s.children.append(%s)' %
275 (out_val_name, child_code_str))
276
277 self.Emit(' L.append(Field(%r, %s))' %
278 (field.name, out_val_name))
279
280 elif field.typ.IsDict():
281 k = 'k%d' % counter
282 v = 'v%d' % counter
283
284 typ = field.typ
285 if typ.type_name == 'Optional': # descend one level
286 typ = typ.children[0]
287
288 k_typ = typ.children[0]
289 v_typ = typ.children[1]
290
291 k_code_str, _ = _HNodeExpr(abbrev, k_typ, k)
292 v_code_str, _ = _HNodeExpr(abbrev, v_typ, v)
293
294 self.Emit(' if self.%s is not None: # Dict' % field.name)
295 self.Emit(' m = hnode.Leaf("Dict", color_e.OtherConst)')
296 self.Emit(' %s = hnode.Array([m])' % out_val_name)
297 self.Emit(' for %s, %s in self.%s.iteritems():' %
298 (k, v, field.name))
299 self.Emit(' %s.children.append(%s)' %
300 (out_val_name, k_code_str))
301 self.Emit(' %s.children.append(%s)' %
302 (out_val_name, v_code_str))
303 self.Emit(' L.append(Field(%r, %s))' %
304 (field.name, out_val_name))
305
306 elif field.typ.IsOptional():
307 typ = field.typ.children[0]
308
309 self.Emit(' if self.%s is not None: # Optional' % field.name)
310 child_code_str, _ = _HNodeExpr(abbrev, typ, 'self.%s' % field.name)
311 self.Emit(' %s = %s' % (out_val_name, child_code_str))
312 self.Emit(' L.append(Field(%r, %s))' %
313 (field.name, out_val_name))
314
315 else:
316 var_name = 'self.%s' % field.name
317 code_str, obj_none_guard = _HNodeExpr(abbrev, field.typ, var_name)
318 depth = self.current_depth
319 if obj_none_guard: # to satisfy MyPy type system
320 self.Emit(' assert self.%s is not None' % field.name)
321 self.Emit(' %s = %s' % (out_val_name, code_str), depth)
322
323 self.Emit(' L.append(Field(%r, %s))' % (field.name, out_val_name),
324 depth)
325
326 def _GenClass(self,
327 ast_node,
328 class_name,
329 base_classes,
330 tag_num,
331 class_ns=''):
332 """Used for both Sum variants ("constructors") and Product types.
333
334 Args:
335 class_ns: for variants like value.Str
336 """
337 self.Emit('class %s(%s):' % (class_name, ', '.join(base_classes)))
338 self.Emit(' _type_tag = %d' % tag_num)
339
340 all_fields = ast_node.fields
341
342 field_names = [f.name for f in all_fields]
343
344 quoted_fields = repr(tuple(field_names))
345 self.Emit(' __slots__ = %s' % quoted_fields)
346 self.Emit('')
347
348 #
349 # __init__
350 #
351
352 args = [f.name for f in ast_node.fields]
353
354 self.Emit(' def __init__(self, %s):' % ', '.join(args))
355
356 arg_types = []
357 default_vals = []
358 for f in ast_node.fields:
359 mypy_type = _MyPyType(f.typ)
360 arg_types.append(mypy_type)
361
362 d_str = _DefaultValue(f.typ, mypy_type)
363 default_vals.append(d_str)
364
365 self.Emit(' # type: (%s) -> None' % ', '.join(arg_types),
366 reflow=False)
367
368 if not all_fields:
369 self.Emit(' pass') # for types like NoOp
370
371 for f in ast_node.fields:
372 # don't wrap the type comment
373 self.Emit(' self.%s = %s' % (f.name, f.name), reflow=False)
374
375 self.Emit('')
376
377 pretty_cls_name = '%s%s' % (class_ns, class_name)
378
379 if len(all_fields) and not self.py_init_n:
380 self.Emit(' @staticmethod')
381 self.Emit(' def CreateNull(alloc_lists=False):')
382 self.Emit(' # type: () -> %s%s' % (class_ns, class_name))
383 self.Emit(' return %s%s(%s)' %
384 (class_ns, class_name, ', '.join(default_vals)),
385 reflow=False)
386 self.Emit('')
387
388 if not self.pretty_print_methods:
389 return
390
391 #
392 # PrettyTree
393 #
394
395 self.Emit(' def PrettyTree(self, trav=None):')
396 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
397 self.Emit(' trav = trav or TraversalState()')
398 self.Emit(' heap_id = id(self)')
399 self.Emit(' if heap_id in trav.seen:')
400 # cut off recursion
401 self.Emit(' return hnode.AlreadySeen(heap_id)')
402 self.Emit(' trav.seen[heap_id] = True')
403
404 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
405 self.Emit(' L = out_node.fields')
406 self.Emit('')
407
408 # Use the runtime type to be more like asdl/format.py
409 for local_id, field in enumerate(all_fields):
410 #log('%s :: %s', field_name, field_desc)
411 self.Indent()
412 self._EmitCodeForField('PrettyTree', field, local_id)
413 self.Dedent()
414 self.Emit('')
415 self.Emit(' return out_node')
416 self.Emit('')
417
418 #
419 # _AbbreviatedTree
420 #
421
422 self.Emit(' def _AbbreviatedTree(self, trav=None):')
423 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
424 self.Emit(' trav = trav or TraversalState()')
425 self.Emit(' heap_id = id(self)')
426 self.Emit(' if heap_id in trav.seen:')
427 # cut off recursion
428 self.Emit(' return hnode.AlreadySeen(heap_id)')
429 self.Emit(' trav.seen[heap_id] = True')
430 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
431 self.Emit(' L = out_node.fields')
432
433 for local_id, field in enumerate(ast_node.fields):
434 self.Indent()
435 self._EmitCodeForField('AbbreviatedTree', field, local_id)
436 self.Dedent()
437 self.Emit('')
438 self.Emit(' return out_node')
439 self.Emit('')
440
441 self.Emit(' def AbbreviatedTree(self, trav=None):')
442 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
443 abbrev_name = '_%s' % class_name
444 if abbrev_name in self.abbrev_mod_entries:
445 self.Emit(' p = %s(self)' % abbrev_name)
446 # If the user function didn't return anything, fall back.
447 self.Emit(
448 ' return p if p else self._AbbreviatedTree(trav=trav)')
449 else:
450 self.Emit(' return self._AbbreviatedTree(trav=trav)')
451 self.Emit('')
452
453 def VisitCompoundSum(self, sum, sum_name, depth):
454 """Note that the following is_simple:
455
456 cflow = Break | Continue
457
458 But this is compound:
459
460 cflow = Break | Continue | Return(int val)
461
462 The generated code changes depending on which one it is.
463 """
464 #log('%d variants in %s', len(sum.types), sum_name)
465
466 # We emit THREE Python types for each meta.CompoundType:
467 #
468 # 1. enum for tag (cflow_e)
469 # 2. base class for inheritance (cflow_t)
470 # 3. namespace for classes (cflow) -- TODO: Get rid of this one.
471 #
472 # Should code use cflow_e.tag or isinstance()?
473 # isinstance() is better for MyPy I think. But tag is better for C++.
474 # int tag = static_cast<cflow>(node).tag;
475
476 int_to_str = {}
477
478 # enum for the tag
479 self.Emit('class %s_e(object):' % sum_name, depth)
480
481 for i, variant in enumerate(sum.types):
482 if variant.shared_type:
483 tag_num = self._shared_type_tags[variant.shared_type]
484 # e.g. DoubleQuoted may have base types expr_t, word_part_t
485 base_class = sum_name + '_t'
486 bases = self._product_bases[variant.shared_type]
487 if base_class in bases:
488 raise RuntimeError(
489 "Two tags in sum %r refer to product type %r" %
490 (sum_name, variant.shared_type))
491
492 else:
493 bases.append(base_class)
494 else:
495 tag_num = i + 1
496 self.Emit(' %s = %d' % (variant.name, tag_num), depth)
497 int_to_str[tag_num] = variant.name
498 self.Emit('', depth)
499
500 self._EmitDict(sum_name, int_to_str, depth)
501
502 self.Emit('def %s_str(tag, dot=True):' % sum_name, depth)
503 self.Emit(' # type: (int, bool) -> str', depth)
504 self.Emit(' v = _%s_str[tag]' % sum_name, depth)
505 self.Emit(' if dot:', depth)
506 self.Emit(' return "%s.%%s" %% v' % sum_name, depth)
507 self.Emit(' else:', depth)
508 self.Emit(' return v', depth)
509 self.Emit('', depth)
510
511 # the base class, e.g. 'oil_cmd'
512 self.Emit('class %s_t(pybase.CompoundObj):' % sum_name, depth)
513 self.Indent()
514 depth = self.current_depth
515
516 # To imitate C++ API
517 self.Emit('def tag(self):')
518 self.Emit(' # type: () -> int')
519 self.Emit(' return self._type_tag')
520
521 # This is what we would do in C++, but we don't need it in Python because
522 # every function is virtual.
523 if 0:
524 #if self.pretty_print_methods:
525 for abbrev in 'PrettyTree', '_AbbreviatedTree', 'AbbreviatedTree':
526 self.Emit('')
527 self.Emit('def %s(self):' % abbrev, depth)
528 self.Emit(' # type: () -> hnode_t', depth)
529 self.Indent()
530 depth = self.current_depth
531 self.Emit('UP_self = self', depth)
532 self.Emit('', depth)
533
534 for variant in sum.types:
535 if variant.shared_type:
536 subtype_name = variant.shared_type
537 else:
538 subtype_name = '%s__%s' % (sum_name, variant.name)
539
540 self.Emit(
541 'if self.tag() == %s_e.%s:' % (sum_name, variant.name),
542 depth)
543 self.Emit(' self = cast(%s, UP_self)' % subtype_name,
544 depth)
545 self.Emit(' return self.%s()' % abbrev, depth)
546
547 self.Emit('raise AssertionError()', depth)
548
549 self.Dedent()
550 depth = self.current_depth
551 else:
552 # Otherwise it's empty
553 self.Emit('pass', depth)
554
555 self.Dedent()
556 depth = self.current_depth
557 self.Emit('')
558
559 # Declare any zero argument singleton classes outside of the main
560 # "namespace" class.
561 for i, variant in enumerate(sum.types):
562 if variant.shared_type:
563 continue # Don't generate a class for shared types.
564 if len(variant.fields) == 0:
565 # We must use the old-style naming here, ie. command__NoOp, in order
566 # to support zero field variants as constants.
567 class_name = '%s__%s' % (sum_name, variant.name)
568 self._GenClass(variant, class_name, (sum_name + '_t', ), i + 1)
569
570 # Class that's just a NAMESPACE, e.g. for value.Str
571 self.Emit('class %s(object):' % sum_name, depth)
572
573 self.Indent()
574
575 for i, variant in enumerate(sum.types):
576 if variant.shared_type:
577 continue
578
579 if len(variant.fields) == 0:
580 self.Emit('%s = %s__%s()' %
581 (variant.name, sum_name, variant.name))
582 self.Emit('')
583 else:
584 # Use fully-qualified name, so we can have osh_cmd.Simple and
585 # oil_cmd.Simple.
586 fq_name = variant.name
587 self._GenClass(variant,
588 fq_name, (sum_name + '_t', ),
589 i + 1,
590 class_ns=sum_name + '.')
591 self.Emit(' pass', depth) # in case every variant is first class
592
593 self.Dedent()
594 self.Emit('')
595
596 def VisitProduct(self, product, name, depth):
597 self._shared_type_tags[name] = self._product_counter
598 # Create a tuple of _GenClass args to create LAST. They may inherit from
599 # sum types that have yet to be defined.
600 self._products.append((product, name, depth, self._product_counter))
601 self._product_counter += 1
602
603 def EmitFooter(self):
604 # Now generate all the product types we deferred.
605 for args in self._products:
606 ast_node, name, depth, tag_num = args
607 # Figure out base classes AFTERWARD.
608 bases = self._product_bases[name]
609 if not bases:
610 bases = ('pybase.CompoundObj', )
611 self._GenClass(ast_node, name, bases, tag_num)