| 1 | #!/usr/bin/env python2
 | 
| 2 | """gen_python.py: Generate Python code from an ASDL schema."""
 | 
| 3 | from __future__ import print_function
 | 
| 4 | 
 | 
| 5 | from collections import defaultdict
 | 
| 6 | 
 | 
| 7 | from asdl import ast
 | 
| 8 | from asdl import visitor
 | 
| 9 | from asdl.util import log
 | 
| 10 | 
 | 
| 11 | _ = log  # shut up lint
 | 
| 12 | 
 | 
| 13 | _PRIMITIVES = {
 | 
| 14 |     'string': 'str',
 | 
| 15 |     'int': 'int',
 | 
| 16 |     'uint16': 'int',
 | 
| 17 |     'BigInt': 'mops.BigInt',
 | 
| 18 |     'float': 'float',
 | 
| 19 |     'bool': 'bool',
 | 
| 20 |     'any': 'Any',
 | 
| 21 |     # TODO: frontend/syntax.asdl should properly import id enum instead of
 | 
| 22 |     # hard-coding it here.
 | 
| 23 |     'id': 'Id_t',
 | 
| 24 | }
 | 
| 25 | 
 | 
| 26 | 
 | 
| 27 | def _MyPyType(typ):
 | 
| 28 |     """ASDL type to MyPy Type."""
 | 
| 29 |     if isinstance(typ, ast.ParameterizedType):
 | 
| 30 | 
 | 
| 31 |         if typ.type_name == 'Dict':
 | 
| 32 |             k_type = _MyPyType(typ.children[0])
 | 
| 33 |             v_type = _MyPyType(typ.children[1])
 | 
| 34 |             return 'Dict[%s, %s]' % (k_type, v_type)
 | 
| 35 | 
 | 
| 36 |         if typ.type_name == 'List':
 | 
| 37 |             return 'List[%s]' % _MyPyType(typ.children[0])
 | 
| 38 | 
 | 
| 39 |         if typ.type_name == 'Optional':
 | 
| 40 |             return 'Optional[%s]' % _MyPyType(typ.children[0])
 | 
| 41 | 
 | 
| 42 |     elif isinstance(typ, ast.NamedType):
 | 
| 43 |         if typ.resolved:
 | 
| 44 |             if isinstance(typ.resolved, ast.Sum):  # includes SimpleSum
 | 
| 45 |                 return '%s_t' % typ.name
 | 
| 46 |             if isinstance(typ.resolved, ast.Product):
 | 
| 47 |                 return typ.name
 | 
| 48 |             if isinstance(typ.resolved, ast.Use):
 | 
| 49 |                 return ast.TypeNameHeuristic(typ.name)
 | 
| 50 | 
 | 
| 51 |         # 'id' falls through here
 | 
| 52 |         return _PRIMITIVES[typ.name]
 | 
| 53 | 
 | 
| 54 |     else:
 | 
| 55 |         raise AssertionError()
 | 
| 56 | 
 | 
| 57 | 
 | 
| 58 | def _DefaultValue(typ, mypy_type):
 | 
| 59 |     """Values that the static CreateNull() constructor passes.
 | 
| 60 | 
 | 
| 61 |     mypy_type is used to cast None, to maintain mypy --strict for ASDL.
 | 
| 62 | 
 | 
| 63 |     We circumvent the type system on CreateNull().  Then the user is
 | 
| 64 |     responsible for filling in all the fields.  If they do so, we can
 | 
| 65 |     rely on it when reading fields at runtime.
 | 
| 66 |     """
 | 
| 67 |     if isinstance(typ, ast.ParameterizedType):
 | 
| 68 |         type_name = typ.type_name
 | 
| 69 | 
 | 
| 70 |         if type_name == 'Optional':
 | 
| 71 |             return "cast('%s', None)" % mypy_type
 | 
| 72 | 
 | 
| 73 |         if type_name == 'List':
 | 
| 74 |             return "[] if alloc_lists else cast('%s', None)" % mypy_type
 | 
| 75 | 
 | 
| 76 |         if type_name == 'Dict':  # TODO: can respect alloc_dicts=True
 | 
| 77 |             return "cast('%s', None)" % mypy_type
 | 
| 78 | 
 | 
| 79 |         raise AssertionError(type_name)
 | 
| 80 | 
 | 
| 81 |     if isinstance(typ, ast.NamedType):
 | 
| 82 |         type_name = typ.name
 | 
| 83 | 
 | 
| 84 |         if type_name == 'id':  # hard-coded HACK
 | 
| 85 |             return '-1'
 | 
| 86 | 
 | 
| 87 |         if type_name == 'int':
 | 
| 88 |             return '-1'
 | 
| 89 | 
 | 
| 90 |         if type_name == 'BigInt':
 | 
| 91 |             return 'mops.BigInt(-1)'
 | 
| 92 | 
 | 
| 93 |         if type_name == 'bool':
 | 
| 94 |             return 'False'
 | 
| 95 | 
 | 
| 96 |         if type_name == 'float':
 | 
| 97 |             return '0.0'  # or should it be NaN?
 | 
| 98 | 
 | 
| 99 |         if type_name == 'string':
 | 
| 100 |             return "''"
 | 
| 101 | 
 | 
| 102 |         if isinstance(typ.resolved, ast.SimpleSum):
 | 
| 103 |             sum_type = typ.resolved
 | 
| 104 |             # Just make it the first variant.  We could define "Undef" for
 | 
| 105 |             # each enum, but it doesn't seem worth it.
 | 
| 106 |             return '%s_e.%s' % (type_name, sum_type.types[0].name)
 | 
| 107 | 
 | 
| 108 |         # CompoundSum or Product type
 | 
| 109 |         return 'cast(%s, None)' % mypy_type
 | 
| 110 | 
 | 
| 111 |     else:
 | 
| 112 |         raise AssertionError()
 | 
| 113 | 
 | 
| 114 | 
 | 
| 115 | def _HNodeExpr(abbrev, typ, var_name):
 | 
| 116 |     # type: (str, ast.TypeExpr, str) -> str
 | 
| 117 |     none_guard = False
 | 
| 118 | 
 | 
| 119 |     if typ.IsOptional():
 | 
| 120 |         typ = typ.children[0]  # descend one level
 | 
| 121 | 
 | 
| 122 |     if isinstance(typ, ast.ParameterizedType):
 | 
| 123 |         code_str = '%s.%s()' % (var_name, abbrev)
 | 
| 124 |         none_guard = True
 | 
| 125 | 
 | 
| 126 |     elif isinstance(typ, ast.NamedType):
 | 
| 127 |         type_name = typ.name
 | 
| 128 | 
 | 
| 129 |         if type_name == 'bool':
 | 
| 130 |             code_str = "hnode.Leaf('T' if %s else 'F', color_e.OtherConst)" % var_name
 | 
| 131 | 
 | 
| 132 |         elif type_name in ('int', 'uint16'):
 | 
| 133 |             code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
 | 
| 134 | 
 | 
| 135 |         elif type_name == 'BigInt':
 | 
| 136 |             code_str = 'hnode.Leaf(mops.ToStr(%s), color_e.OtherConst)' % var_name
 | 
| 137 | 
 | 
| 138 |         elif type_name == 'float':
 | 
| 139 |             code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
 | 
| 140 | 
 | 
| 141 |         elif type_name == 'string':
 | 
| 142 |             code_str = 'NewLeaf(%s, color_e.StringConst)' % var_name
 | 
| 143 | 
 | 
| 144 |         elif type_name == 'any':  # TODO: Remove this.  Used for value.Obj().
 | 
| 145 |             code_str = 'hnode.External(%s)' % var_name
 | 
| 146 | 
 | 
| 147 |         elif type_name == 'id':  # was meta.UserType
 | 
| 148 |             # This assumes it's Id, which is a simple SumType.  TODO: Remove this.
 | 
| 149 |             code_str = 'hnode.Leaf(Id_str(%s), color_e.UserType)' % var_name
 | 
| 150 | 
 | 
| 151 |         elif typ.resolved and isinstance(typ.resolved, ast.SimpleSum):
 | 
| 152 |             code_str = 'hnode.Leaf(%s_str(%s), color_e.TypeName)' % (type_name,
 | 
| 153 |                                                                      var_name)
 | 
| 154 | 
 | 
| 155 |         else:
 | 
| 156 |             code_str = '%s.%s(trav=trav)' % (var_name, abbrev)
 | 
| 157 |             none_guard = True
 | 
| 158 | 
 | 
| 159 |     else:
 | 
| 160 |         raise AssertionError()
 | 
| 161 | 
 | 
| 162 |     return code_str, none_guard
 | 
| 163 | 
 | 
| 164 | 
 | 
| 165 | class GenMyPyVisitor(visitor.AsdlVisitor):
 | 
| 166 |     """Generate Python code with MyPy type annotations."""
 | 
| 167 | 
 | 
| 168 |     def __init__(self,
 | 
| 169 |                  f,
 | 
| 170 |                  abbrev_mod_entries=None,
 | 
| 171 |                  pretty_print_methods=True,
 | 
| 172 |                  py_init_n=False,
 | 
| 173 |                  simple_int_sums=None):
 | 
| 174 | 
 | 
| 175 |         visitor.AsdlVisitor.__init__(self, f)
 | 
| 176 |         self.abbrev_mod_entries = abbrev_mod_entries or []
 | 
| 177 |         self.pretty_print_methods = pretty_print_methods
 | 
| 178 |         self.py_init_n = py_init_n
 | 
| 179 | 
 | 
| 180 |         # For Id to use different code gen.  It's used like an integer, not just
 | 
| 181 |         # like an enum.
 | 
| 182 |         self.simple_int_sums = simple_int_sums or []
 | 
| 183 | 
 | 
| 184 |         self._shared_type_tags = {}
 | 
| 185 |         self._product_counter = 64  # matches asdl/gen_cpp.py
 | 
| 186 | 
 | 
| 187 |         self._products = []
 | 
| 188 |         self._product_bases = defaultdict(list)
 | 
| 189 | 
 | 
| 190 |     def _EmitDict(self, name, d, depth):
 | 
| 191 |         self.Emit('_%s_str = {' % name, depth)
 | 
| 192 |         for k in sorted(d):
 | 
| 193 |             self.Emit('%d: %r,' % (k, d[k]), depth + 1)
 | 
| 194 |         self.Emit('}', depth)
 | 
| 195 |         self.Emit('', depth)
 | 
| 196 | 
 | 
| 197 |     def VisitSimpleSum(self, sum, sum_name, depth):
 | 
| 198 |         int_to_str = {}
 | 
| 199 |         variants = []
 | 
| 200 |         for i, variant in enumerate(sum.types):
 | 
| 201 |             tag_num = i + 1
 | 
| 202 |             tag_str = '%s.%s' % (sum_name, variant.name)
 | 
| 203 |             int_to_str[tag_num] = tag_str
 | 
| 204 |             variants.append((variant, tag_num))
 | 
| 205 | 
 | 
| 206 |         add_suffix = not ('no_namespace_suffix' in sum.generate)
 | 
| 207 |         gen_integers = 'integers' in sum.generate or 'uint16' in sum.generate
 | 
| 208 | 
 | 
| 209 |         if gen_integers:
 | 
| 210 |             self.Emit('%s_t = int  # type alias for integer' % sum_name)
 | 
| 211 |             self.Emit('')
 | 
| 212 | 
 | 
| 213 |             i_name = ('%s_i' % sum_name) if add_suffix else sum_name
 | 
| 214 | 
 | 
| 215 |             self.Emit('class %s(object):' % i_name, depth)
 | 
| 216 | 
 | 
| 217 |             for variant, tag_num in variants:
 | 
| 218 |                 line = '  %s = %d' % (variant.name, tag_num)
 | 
| 219 |                 self.Emit(line, depth)
 | 
| 220 | 
 | 
| 221 |             # Help in sizing array.  Note that we're 1-based.
 | 
| 222 |             line = '  %s = %d' % ('ARRAY_SIZE', len(variants) + 1)
 | 
| 223 |             self.Emit(line, depth)
 | 
| 224 | 
 | 
| 225 |         else:
 | 
| 226 |             # First emit a type
 | 
| 227 |             self.Emit('class %s_t(pybase.SimpleObj):' % sum_name, depth)
 | 
| 228 |             self.Emit('  pass', depth)
 | 
| 229 |             self.Emit('', depth)
 | 
| 230 | 
 | 
| 231 |             # Now emit a namespace
 | 
| 232 |             e_name = ('%s_e' % sum_name) if add_suffix else sum_name
 | 
| 233 |             self.Emit('class %s(object):' % e_name, depth)
 | 
| 234 | 
 | 
| 235 |             for variant, tag_num in variants:
 | 
| 236 |                 line = '  %s = %s_t(%d)' % (variant.name, sum_name, tag_num)
 | 
| 237 |                 self.Emit(line, depth)
 | 
| 238 | 
 | 
| 239 |         self.Emit('', depth)
 | 
| 240 | 
 | 
| 241 |         self._EmitDict(sum_name, int_to_str, depth)
 | 
| 242 | 
 | 
| 243 |         self.Emit('def %s_str(val):' % sum_name, depth)
 | 
| 244 |         self.Emit('  # type: (%s_t) -> str' % sum_name, depth)
 | 
| 245 |         self.Emit('  return _%s_str[val]' % sum_name, depth)
 | 
| 246 |         self.Emit('', depth)
 | 
| 247 | 
 | 
| 248 |     def _EmitCodeForField(self, abbrev, field, counter):
 | 
| 249 |         """Generate code that returns an hnode for a field."""
 | 
| 250 |         out_val_name = 'x%d' % counter
 | 
| 251 | 
 | 
| 252 |         if field.typ.IsList():
 | 
| 253 |             iter_name = 'i%d' % counter
 | 
| 254 | 
 | 
| 255 |             typ = field.typ
 | 
| 256 |             if typ.type_name == 'Optional':  # descend one level
 | 
| 257 |                 typ = typ.children[0]
 | 
| 258 |             item_type = typ.children[0]
 | 
| 259 | 
 | 
| 260 |             self.Emit('  if self.%s is not None:  # List' % field.name)
 | 
| 261 |             self.Emit('    %s = hnode.Array([])' % out_val_name)
 | 
| 262 |             self.Emit('    for %s in self.%s:' % (iter_name, field.name))
 | 
| 263 |             child_code_str, none_guard = _HNodeExpr(abbrev, item_type,
 | 
| 264 |                                                     iter_name)
 | 
| 265 | 
 | 
| 266 |             if none_guard:  # e.g. for List[Optional[value_t]]
 | 
| 267 |                 # TODO: could consolidate with asdl/runtime.py NewLeaf(), which
 | 
| 268 |                 # also uses _ to mean None/nullptr
 | 
| 269 |                 self.Emit(
 | 
| 270 |                     '      h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
 | 
| 271 |                     % (iter_name, child_code_str))
 | 
| 272 |                 self.Emit('      %s.children.append(h)' % out_val_name)
 | 
| 273 |             else:
 | 
| 274 |                 self.Emit('      %s.children.append(%s)' %
 | 
| 275 |                           (out_val_name, child_code_str))
 | 
| 276 | 
 | 
| 277 |             self.Emit('    L.append(Field(%r, %s))' %
 | 
| 278 |                       (field.name, out_val_name))
 | 
| 279 | 
 | 
| 280 |         elif field.typ.IsDict():
 | 
| 281 |             k = 'k%d' % counter
 | 
| 282 |             v = 'v%d' % counter
 | 
| 283 | 
 | 
| 284 |             typ = field.typ
 | 
| 285 |             if typ.type_name == 'Optional':  # descend one level
 | 
| 286 |                 typ = typ.children[0]
 | 
| 287 | 
 | 
| 288 |             k_typ = typ.children[0]
 | 
| 289 |             v_typ = typ.children[1]
 | 
| 290 | 
 | 
| 291 |             k_code_str, _ = _HNodeExpr(abbrev, k_typ, k)
 | 
| 292 |             v_code_str, _ = _HNodeExpr(abbrev, v_typ, v)
 | 
| 293 | 
 | 
| 294 |             self.Emit('  if self.%s is not None:  # Dict' % field.name)
 | 
| 295 |             self.Emit('    m = hnode.Leaf("Dict", color_e.OtherConst)')
 | 
| 296 |             self.Emit('    %s = hnode.Array([m])' % out_val_name)
 | 
| 297 |             self.Emit('    for %s, %s in self.%s.iteritems():' %
 | 
| 298 |                       (k, v, field.name))
 | 
| 299 |             self.Emit('      %s.children.append(%s)' %
 | 
| 300 |                       (out_val_name, k_code_str))
 | 
| 301 |             self.Emit('      %s.children.append(%s)' %
 | 
| 302 |                       (out_val_name, v_code_str))
 | 
| 303 |             self.Emit('    L.append(Field(%r, %s))' %
 | 
| 304 |                       (field.name, out_val_name))
 | 
| 305 | 
 | 
| 306 |         elif field.typ.IsOptional():
 | 
| 307 |             typ = field.typ.children[0]
 | 
| 308 | 
 | 
| 309 |             self.Emit('  if self.%s is not None:  # Optional' % field.name)
 | 
| 310 |             child_code_str, _ = _HNodeExpr(abbrev, typ, 'self.%s' % field.name)
 | 
| 311 |             self.Emit('    %s = %s' % (out_val_name, child_code_str))
 | 
| 312 |             self.Emit('    L.append(Field(%r, %s))' %
 | 
| 313 |                       (field.name, out_val_name))
 | 
| 314 | 
 | 
| 315 |         else:
 | 
| 316 |             var_name = 'self.%s' % field.name
 | 
| 317 |             code_str, obj_none_guard = _HNodeExpr(abbrev, field.typ, var_name)
 | 
| 318 |             depth = self.current_depth
 | 
| 319 |             if obj_none_guard:  # to satisfy MyPy type system
 | 
| 320 |                 self.Emit('  assert self.%s is not None' % field.name)
 | 
| 321 |             self.Emit('  %s = %s' % (out_val_name, code_str), depth)
 | 
| 322 | 
 | 
| 323 |             self.Emit('  L.append(Field(%r, %s))' % (field.name, out_val_name),
 | 
| 324 |                       depth)
 | 
| 325 | 
 | 
| 326 |     def _GenClass(self,
 | 
| 327 |                   ast_node,
 | 
| 328 |                   class_name,
 | 
| 329 |                   base_classes,
 | 
| 330 |                   tag_num,
 | 
| 331 |                   class_ns=''):
 | 
| 332 |         """Used for both Sum variants ("constructors") and Product types.
 | 
| 333 | 
 | 
| 334 |         Args:
 | 
| 335 |           class_ns: for variants like value.Str
 | 
| 336 |         """
 | 
| 337 |         self.Emit('class %s(%s):' % (class_name, ', '.join(base_classes)))
 | 
| 338 |         self.Emit('  _type_tag = %d' % tag_num)
 | 
| 339 | 
 | 
| 340 |         all_fields = ast_node.fields
 | 
| 341 | 
 | 
| 342 |         field_names = [f.name for f in all_fields]
 | 
| 343 | 
 | 
| 344 |         quoted_fields = repr(tuple(field_names))
 | 
| 345 |         self.Emit('  __slots__ = %s' % quoted_fields)
 | 
| 346 |         self.Emit('')
 | 
| 347 | 
 | 
| 348 |         #
 | 
| 349 |         # __init__
 | 
| 350 |         #
 | 
| 351 | 
 | 
| 352 |         args = [f.name for f in ast_node.fields]
 | 
| 353 | 
 | 
| 354 |         self.Emit('  def __init__(self, %s):' % ', '.join(args))
 | 
| 355 | 
 | 
| 356 |         arg_types = []
 | 
| 357 |         default_vals = []
 | 
| 358 |         for f in ast_node.fields:
 | 
| 359 |             mypy_type = _MyPyType(f.typ)
 | 
| 360 |             arg_types.append(mypy_type)
 | 
| 361 | 
 | 
| 362 |             d_str = _DefaultValue(f.typ, mypy_type)
 | 
| 363 |             default_vals.append(d_str)
 | 
| 364 | 
 | 
| 365 |         self.Emit('    # type: (%s) -> None' % ', '.join(arg_types),
 | 
| 366 |                   reflow=False)
 | 
| 367 | 
 | 
| 368 |         if not all_fields:
 | 
| 369 |             self.Emit('    pass')  # for types like NoOp
 | 
| 370 | 
 | 
| 371 |         for f in ast_node.fields:
 | 
| 372 |             # don't wrap the type comment
 | 
| 373 |             self.Emit('    self.%s = %s' % (f.name, f.name), reflow=False)
 | 
| 374 | 
 | 
| 375 |         self.Emit('')
 | 
| 376 | 
 | 
| 377 |         pretty_cls_name = '%s%s' % (class_ns, class_name)
 | 
| 378 | 
 | 
| 379 |         if len(all_fields) and not self.py_init_n:
 | 
| 380 |             self.Emit('  @staticmethod')
 | 
| 381 |             self.Emit('  def CreateNull(alloc_lists=False):')
 | 
| 382 |             self.Emit('    # type: () -> %s%s' % (class_ns, class_name))
 | 
| 383 |             self.Emit('    return %s%s(%s)' %
 | 
| 384 |                       (class_ns, class_name, ', '.join(default_vals)),
 | 
| 385 |                       reflow=False)
 | 
| 386 |             self.Emit('')
 | 
| 387 | 
 | 
| 388 |         if not self.pretty_print_methods:
 | 
| 389 |             return
 | 
| 390 | 
 | 
| 391 |         #
 | 
| 392 |         # PrettyTree
 | 
| 393 |         #
 | 
| 394 | 
 | 
| 395 |         self.Emit('  def PrettyTree(self, trav=None):')
 | 
| 396 |         self.Emit('    # type: (Optional[TraversalState]) -> hnode_t')
 | 
| 397 |         self.Emit('    trav = trav or TraversalState()')
 | 
| 398 |         self.Emit('    heap_id = id(self)')
 | 
| 399 |         self.Emit('    if heap_id in trav.seen:')
 | 
| 400 |         # cut off recursion
 | 
| 401 |         self.Emit('      return hnode.AlreadySeen(heap_id)')
 | 
| 402 |         self.Emit('    trav.seen[heap_id] = True')
 | 
| 403 | 
 | 
| 404 |         self.Emit('    out_node = NewRecord(%r)' % pretty_cls_name)
 | 
| 405 |         self.Emit('    L = out_node.fields')
 | 
| 406 |         self.Emit('')
 | 
| 407 | 
 | 
| 408 |         # Use the runtime type to be more like asdl/format.py
 | 
| 409 |         for local_id, field in enumerate(all_fields):
 | 
| 410 |             #log('%s :: %s', field_name, field_desc)
 | 
| 411 |             self.Indent()
 | 
| 412 |             self._EmitCodeForField('PrettyTree', field, local_id)
 | 
| 413 |             self.Dedent()
 | 
| 414 |             self.Emit('')
 | 
| 415 |         self.Emit('    return out_node')
 | 
| 416 |         self.Emit('')
 | 
| 417 | 
 | 
| 418 |         #
 | 
| 419 |         # _AbbreviatedTree
 | 
| 420 |         #
 | 
| 421 | 
 | 
| 422 |         self.Emit('  def _AbbreviatedTree(self, trav=None):')
 | 
| 423 |         self.Emit('    # type: (Optional[TraversalState]) -> hnode_t')
 | 
| 424 |         self.Emit('    trav = trav or TraversalState()')
 | 
| 425 |         self.Emit('    heap_id = id(self)')
 | 
| 426 |         self.Emit('    if heap_id in trav.seen:')
 | 
| 427 |         # cut off recursion
 | 
| 428 |         self.Emit('      return hnode.AlreadySeen(heap_id)')
 | 
| 429 |         self.Emit('    trav.seen[heap_id] = True')
 | 
| 430 |         self.Emit('    out_node = NewRecord(%r)' % pretty_cls_name)
 | 
| 431 |         self.Emit('    L = out_node.fields')
 | 
| 432 | 
 | 
| 433 |         for local_id, field in enumerate(ast_node.fields):
 | 
| 434 |             self.Indent()
 | 
| 435 |             self._EmitCodeForField('AbbreviatedTree', field, local_id)
 | 
| 436 |             self.Dedent()
 | 
| 437 |             self.Emit('')
 | 
| 438 |         self.Emit('    return out_node')
 | 
| 439 |         self.Emit('')
 | 
| 440 | 
 | 
| 441 |         self.Emit('  def AbbreviatedTree(self, trav=None):')
 | 
| 442 |         self.Emit('    # type: (Optional[TraversalState]) -> hnode_t')
 | 
| 443 |         abbrev_name = '_%s' % class_name
 | 
| 444 |         if abbrev_name in self.abbrev_mod_entries:
 | 
| 445 |             self.Emit('    p = %s(self)' % abbrev_name)
 | 
| 446 |             # If the user function didn't return anything, fall back.
 | 
| 447 |             self.Emit(
 | 
| 448 |                 '    return p if p else self._AbbreviatedTree(trav=trav)')
 | 
| 449 |         else:
 | 
| 450 |             self.Emit('    return self._AbbreviatedTree(trav=trav)')
 | 
| 451 |         self.Emit('')
 | 
| 452 | 
 | 
| 453 |     def VisitCompoundSum(self, sum, sum_name, depth):
 | 
| 454 |         """Note that the following is_simple:
 | 
| 455 | 
 | 
| 456 |           cflow = Break | Continue
 | 
| 457 | 
 | 
| 458 |         But this is compound:
 | 
| 459 | 
 | 
| 460 |           cflow = Break | Continue | Return(int val)
 | 
| 461 | 
 | 
| 462 |         The generated code changes depending on which one it is.
 | 
| 463 |         """
 | 
| 464 |         #log('%d variants in %s', len(sum.types), sum_name)
 | 
| 465 | 
 | 
| 466 |         # We emit THREE Python types for each meta.CompoundType:
 | 
| 467 |         #
 | 
| 468 |         # 1. enum for tag (cflow_e)
 | 
| 469 |         # 2. base class for inheritance (cflow_t)
 | 
| 470 |         # 3. namespace for classes (cflow)  -- TODO: Get rid of this one.
 | 
| 471 |         #
 | 
| 472 |         # Should code use cflow_e.tag or isinstance()?
 | 
| 473 |         # isinstance() is better for MyPy I think.  But tag is better for C++.
 | 
| 474 |         # int tag = static_cast<cflow>(node).tag;
 | 
| 475 | 
 | 
| 476 |         int_to_str = {}
 | 
| 477 | 
 | 
| 478 |         # enum for the tag
 | 
| 479 |         self.Emit('class %s_e(object):' % sum_name, depth)
 | 
| 480 | 
 | 
| 481 |         for i, variant in enumerate(sum.types):
 | 
| 482 |             if variant.shared_type:
 | 
| 483 |                 tag_num = self._shared_type_tags[variant.shared_type]
 | 
| 484 |                 # e.g. DoubleQuoted may have base types expr_t, word_part_t
 | 
| 485 |                 base_class = sum_name + '_t'
 | 
| 486 |                 bases = self._product_bases[variant.shared_type]
 | 
| 487 |                 if base_class in bases:
 | 
| 488 |                     raise RuntimeError(
 | 
| 489 |                         "Two tags in sum %r refer to product type %r" %
 | 
| 490 |                         (sum_name, variant.shared_type))
 | 
| 491 | 
 | 
| 492 |                 else:
 | 
| 493 |                     bases.append(base_class)
 | 
| 494 |             else:
 | 
| 495 |                 tag_num = i + 1
 | 
| 496 |             self.Emit('  %s = %d' % (variant.name, tag_num), depth)
 | 
| 497 |             int_to_str[tag_num] = variant.name
 | 
| 498 |         self.Emit('', depth)
 | 
| 499 | 
 | 
| 500 |         self._EmitDict(sum_name, int_to_str, depth)
 | 
| 501 | 
 | 
| 502 |         self.Emit('def %s_str(tag, dot=True):' % sum_name, depth)
 | 
| 503 |         self.Emit('  # type: (int, bool) -> str', depth)
 | 
| 504 |         self.Emit('  v = _%s_str[tag]' % sum_name, depth)
 | 
| 505 |         self.Emit('  if dot:', depth)
 | 
| 506 |         self.Emit('    return "%s.%%s" %% v' % sum_name, depth)
 | 
| 507 |         self.Emit('  else:', depth)
 | 
| 508 |         self.Emit('    return v', depth)
 | 
| 509 |         self.Emit('', depth)
 | 
| 510 | 
 | 
| 511 |         # the base class, e.g. 'oil_cmd'
 | 
| 512 |         self.Emit('class %s_t(pybase.CompoundObj):' % sum_name, depth)
 | 
| 513 |         self.Indent()
 | 
| 514 |         depth = self.current_depth
 | 
| 515 | 
 | 
| 516 |         # To imitate C++ API
 | 
| 517 |         self.Emit('def tag(self):')
 | 
| 518 |         self.Emit('  # type: () -> int')
 | 
| 519 |         self.Emit('  return self._type_tag')
 | 
| 520 | 
 | 
| 521 |         # This is what we would do in C++, but we don't need it in Python because
 | 
| 522 |         # every function is virtual.
 | 
| 523 |         if 0:
 | 
| 524 |             #if self.pretty_print_methods:
 | 
| 525 |             for abbrev in 'PrettyTree', '_AbbreviatedTree', 'AbbreviatedTree':
 | 
| 526 |                 self.Emit('')
 | 
| 527 |                 self.Emit('def %s(self):' % abbrev, depth)
 | 
| 528 |                 self.Emit('  # type: () -> hnode_t', depth)
 | 
| 529 |                 self.Indent()
 | 
| 530 |                 depth = self.current_depth
 | 
| 531 |                 self.Emit('UP_self = self', depth)
 | 
| 532 |                 self.Emit('', depth)
 | 
| 533 | 
 | 
| 534 |                 for variant in sum.types:
 | 
| 535 |                     if variant.shared_type:
 | 
| 536 |                         subtype_name = variant.shared_type
 | 
| 537 |                     else:
 | 
| 538 |                         subtype_name = '%s__%s' % (sum_name, variant.name)
 | 
| 539 | 
 | 
| 540 |                     self.Emit(
 | 
| 541 |                         'if self.tag() == %s_e.%s:' % (sum_name, variant.name),
 | 
| 542 |                         depth)
 | 
| 543 |                     self.Emit('  self = cast(%s, UP_self)' % subtype_name,
 | 
| 544 |                               depth)
 | 
| 545 |                     self.Emit('  return self.%s()' % abbrev, depth)
 | 
| 546 | 
 | 
| 547 |                 self.Emit('raise AssertionError()', depth)
 | 
| 548 | 
 | 
| 549 |                 self.Dedent()
 | 
| 550 |                 depth = self.current_depth
 | 
| 551 |         else:
 | 
| 552 |             # Otherwise it's empty
 | 
| 553 |             self.Emit('pass', depth)
 | 
| 554 | 
 | 
| 555 |         self.Dedent()
 | 
| 556 |         depth = self.current_depth
 | 
| 557 |         self.Emit('')
 | 
| 558 | 
 | 
| 559 |         # Declare any zero argument singleton classes outside of the main
 | 
| 560 |         # "namespace" class.
 | 
| 561 |         for i, variant in enumerate(sum.types):
 | 
| 562 |             if variant.shared_type:
 | 
| 563 |                 continue  # Don't generate a class for shared types.
 | 
| 564 |             if len(variant.fields) == 0:
 | 
| 565 |                 # We must use the old-style naming here, ie. command__NoOp, in order
 | 
| 566 |                 # to support zero field variants as constants.
 | 
| 567 |                 class_name = '%s__%s' % (sum_name, variant.name)
 | 
| 568 |                 self._GenClass(variant, class_name, (sum_name + '_t', ), i + 1)
 | 
| 569 | 
 | 
| 570 |         # Class that's just a NAMESPACE, e.g. for value.Str
 | 
| 571 |         self.Emit('class %s(object):' % sum_name, depth)
 | 
| 572 | 
 | 
| 573 |         self.Indent()
 | 
| 574 | 
 | 
| 575 |         for i, variant in enumerate(sum.types):
 | 
| 576 |             if variant.shared_type:
 | 
| 577 |                 continue
 | 
| 578 | 
 | 
| 579 |             if len(variant.fields) == 0:
 | 
| 580 |                 self.Emit('%s = %s__%s()' %
 | 
| 581 |                           (variant.name, sum_name, variant.name))
 | 
| 582 |                 self.Emit('')
 | 
| 583 |             else:
 | 
| 584 |                 # Use fully-qualified name, so we can have osh_cmd.Simple and
 | 
| 585 |                 # oil_cmd.Simple.
 | 
| 586 |                 fq_name = variant.name
 | 
| 587 |                 self._GenClass(variant,
 | 
| 588 |                                fq_name, (sum_name + '_t', ),
 | 
| 589 |                                i + 1,
 | 
| 590 |                                class_ns=sum_name + '.')
 | 
| 591 |         self.Emit('  pass', depth)  # in case every variant is first class
 | 
| 592 | 
 | 
| 593 |         self.Dedent()
 | 
| 594 |         self.Emit('')
 | 
| 595 | 
 | 
| 596 |     def VisitProduct(self, product, name, depth):
 | 
| 597 |         self._shared_type_tags[name] = self._product_counter
 | 
| 598 |         # Create a tuple of _GenClass args to create LAST.  They may inherit from
 | 
| 599 |         # sum types that have yet to be defined.
 | 
| 600 |         self._products.append((product, name, depth, self._product_counter))
 | 
| 601 |         self._product_counter += 1
 | 
| 602 | 
 | 
| 603 |     def EmitFooter(self):
 | 
| 604 |         # Now generate all the product types we deferred.
 | 
| 605 |         for args in self._products:
 | 
| 606 |             ast_node, name, depth, tag_num = args
 | 
| 607 |             # Figure out base classes AFTERWARD.
 | 
| 608 |             bases = self._product_bases[name]
 | 
| 609 |             if not bases:
 | 
| 610 |                 bases = ('pybase.CompoundObj', )
 | 
| 611 |             self._GenClass(ast_node, name, bases, tag_num)
 |