| 1 | #!/usr/bin/env python
 | 
| 2 | from __future__ import print_function
 | 
| 3 | """
 | 
| 4 | asdl_cpp.py
 | 
| 5 | 
 | 
| 6 | Turn an ASDL schema into C++ code.
 | 
| 7 | 
 | 
| 8 | TODO:
 | 
| 9 | - Optional fields
 | 
| 10 |   - in osh, it's only used in two places:
 | 
| 11 |   - arith_expr? for slice length
 | 
| 12 |   - word? for var replace
 | 
| 13 |   - So you're already using pointers, can encode the NULL pointer.
 | 
| 14 | 
 | 
| 15 | - Change everything to use references instead of pointers?  Non-nullable.
 | 
| 16 | - Unify ClassDefVisitor and MethodBodyVisitor.
 | 
| 17 |   - Whether you need a separate method body should be a flag.
 | 
| 18 |   - offset calculations are duplicated
 | 
| 19 | - generate a C++ pretty-printer
 | 
| 20 | 
 | 
| 21 | Technically we don't even need alignment?  I guess the reason is to increase
 | 
| 22 | address space.  If 1, then we have 16MiB of code.  If 4, then we have 64 MiB.
 | 
| 23 | 
 | 
| 24 | Everything is decoded on the fly, or is a char*, which I don't think has to be
 | 
| 25 | aligned (because the natural alignment would be 1 byte anyway.)
 | 
| 26 | """
 | 
| 27 | 
 | 
| 28 | import sys
 | 
| 29 | 
 | 
| 30 | from asdl import asdl_ as asdl
 | 
| 31 | from asdl import encode
 | 
| 32 | from asdl import visitor
 | 
| 33 | 
 | 
| 34 | from osh.meta import Id
 | 
| 35 | 
 | 
| 36 | class ChainOfVisitors:
 | 
| 37 |   def __init__(self, *visitors):
 | 
| 38 |     self.visitors = visitors
 | 
| 39 | 
 | 
| 40 |   def VisitModule(self, module):
 | 
| 41 |     for v in self.visitors:
 | 
| 42 |       v.VisitModule(module)
 | 
| 43 | 
 | 
| 44 | 
 | 
| 45 | _BUILTINS = {
 | 
| 46 |     'string': 'char*',  # A read-only string is a char*
 | 
| 47 |     'int': 'int',
 | 
| 48 |     'bool': 'bool',
 | 
| 49 |     'id': 'Id',  # Application specific hack for now
 | 
| 50 | }
 | 
| 51 | 
 | 
| 52 | class ForwardDeclareVisitor(visitor.AsdlVisitor):
 | 
| 53 |   """Print forward declarations.
 | 
| 54 | 
 | 
| 55 |   ASDL allows forward references of types, but C++ doesn't.
 | 
| 56 |   """
 | 
| 57 |   def VisitCompoundSum(self, sum, name, depth):
 | 
| 58 |     self.Emit("class %(name)s_t;" % locals(), depth)
 | 
| 59 | 
 | 
| 60 |   def VisitProduct(self, product, name, depth):
 | 
| 61 |     self.Emit("class %(name)s_t;" % locals(), depth)
 | 
| 62 | 
 | 
| 63 |   def EmitFooter(self):
 | 
| 64 |     self.Emit("", 0)  # blank line
 | 
| 65 | 
 | 
| 66 | 
 | 
| 67 | class ClassDefVisitor(visitor.AsdlVisitor):
 | 
| 68 |   """Generate C++ classes and type-safe enums."""
 | 
| 69 | 
 | 
| 70 |   def __init__(self, f, enc_params, type_lookup, enum_types=None):
 | 
| 71 |     visitor.AsdlVisitor.__init__(self, f)
 | 
| 72 |     self.ref_width = enc_params.ref_width
 | 
| 73 |     self.type_lookup = type_lookup
 | 
| 74 |     self.enum_types = enum_types or {}
 | 
| 75 |     self.pointer_type = enc_params.pointer_type
 | 
| 76 |     self.footer = []  # lines
 | 
| 77 | 
 | 
| 78 |   def _GetCppType(self, field):
 | 
| 79 |     """Return a string for the C++ name of the type."""
 | 
| 80 |     type_name = field.type
 | 
| 81 | 
 | 
| 82 |     cpp_type = _BUILTINS.get(type_name)
 | 
| 83 |     if cpp_type is not None:
 | 
| 84 |       return cpp_type
 | 
| 85 | 
 | 
| 86 |     typ = self.type_lookup.ByTypeName(type_name)
 | 
| 87 |     if isinstance(typ, asdl.Sum) and asdl.is_simple(typ):
 | 
| 88 |       # Use the enum instead of the class.
 | 
| 89 |       return "%s_e" % type_name
 | 
| 90 | 
 | 
| 91 |     # - Pointer for optional type.
 | 
| 92 |     # - ints and strings should generally not be optional?  We don't have them
 | 
| 93 |     # in osh yet, so leave it out for now.
 | 
| 94 |     if field.opt:
 | 
| 95 |       return "%s_t*" % type_name
 | 
| 96 | 
 | 
| 97 |     return "%s_t&" % type_name
 | 
| 98 | 
 | 
| 99 |   def EmitFooter(self):
 | 
| 100 |     for line in self.footer:
 | 
| 101 |       self.f.write(line)
 | 
| 102 | 
 | 
| 103 |   def _EmitEnum(self, sum, name, depth):
 | 
| 104 |     enum = []
 | 
| 105 |     for i in range(len(sum.types)):
 | 
| 106 |       type = sum.types[i]
 | 
| 107 |       enum.append("%s = %d" % (type.name, i + 1))  # zero is reserved
 | 
| 108 | 
 | 
| 109 |     self.Emit("enum class %s_e : uint8_t {" % name, depth)
 | 
| 110 |     self.Emit(", ".join(enum), depth + 1)
 | 
| 111 |     self.Emit("};", depth)
 | 
| 112 |     self.Emit("", depth)
 | 
| 113 | 
 | 
| 114 |   def VisitSimpleSum(self, sum, name, depth):
 | 
| 115 |     self._EmitEnum(sum, name, depth)
 | 
| 116 | 
 | 
| 117 |   def VisitCompoundSum(self, sum, name, depth):
 | 
| 118 |     # This is a sign that Python needs string interpolation!!!
 | 
| 119 |     def Emit(s, depth=depth):
 | 
| 120 |       self.Emit(s % sys._getframe(1).f_locals, depth)
 | 
| 121 | 
 | 
| 122 |     self._EmitEnum(sum, name, depth)
 | 
| 123 | 
 | 
| 124 |     Emit("class %(name)s_t : public Obj {")
 | 
| 125 |     Emit(" public:")
 | 
| 126 |     # All sum types have a tag
 | 
| 127 |     Emit("%(name)s_e tag() const {", depth + 1)
 | 
| 128 |     Emit("return static_cast<%(name)s_e>(bytes_[0]);", depth + 2)
 | 
| 129 |     Emit("}", depth + 1)
 | 
| 130 |     Emit("};")
 | 
| 131 |     Emit("")
 | 
| 132 | 
 | 
| 133 |     # TODO: This should be replaced with a call to the generic
 | 
| 134 |     # self.VisitChildren()
 | 
| 135 |     super_name = "%s_t" % name
 | 
| 136 |     for t in sum.types:
 | 
| 137 |       self.VisitConstructor(t, super_name, depth)
 | 
| 138 | 
 | 
| 139 |     # rudimentary attribute handling
 | 
| 140 |     for field in sum.attributes:
 | 
| 141 |       type = str(field.type)
 | 
| 142 |       assert type in asdl.builtin_types, type
 | 
| 143 |       Emit("%s %s;" % (type, field.name), depth + 1)
 | 
| 144 | 
 | 
| 145 |   def VisitConstructor(self, cons, def_name, depth):
 | 
| 146 |     #print(dir(cons))
 | 
| 147 |     if cons.fields:
 | 
| 148 |       self.Emit("class %s : public %s {" % (cons.name, def_name), depth)
 | 
| 149 |       self.Emit(" public:", depth)
 | 
| 150 |       offset = 1  #  for the ID
 | 
| 151 |       for f in cons.fields:
 | 
| 152 |         self.VisitField(f, cons.name, offset, depth + 1)
 | 
| 153 |         offset += self.ref_width
 | 
| 154 |       self.Emit("};", depth)
 | 
| 155 |       self.Emit("", depth)
 | 
| 156 | 
 | 
| 157 |   def VisitProduct(self, product, name, depth):
 | 
| 158 |     self.Emit("class %(name)s_t : public Obj {" % locals(), depth)
 | 
| 159 |     self.Emit(" public:", depth)
 | 
| 160 |     offset = 0
 | 
| 161 |     for f in product.fields:
 | 
| 162 |       type_name = '%s_t' % name
 | 
| 163 |       self.VisitField(f, type_name, offset, depth + 1)
 | 
| 164 |       offset += self.ref_width
 | 
| 165 | 
 | 
| 166 |     for field in product.attributes:
 | 
| 167 |       # rudimentary attribute handling
 | 
| 168 |       type = str(field.type)
 | 
| 169 |       assert type in asdl.builtin_types, type
 | 
| 170 |       self.Emit("%s %s;" % (type, field.name), depth + 1)
 | 
| 171 |     self.Emit("};", depth)
 | 
| 172 |     self.Emit("", depth)
 | 
| 173 | 
 | 
| 174 |   def VisitField(self, field, type_name, offset, depth):
 | 
| 175 |     """
 | 
| 176 |     Even though they are inline, some of them can't be in the class {}, because
 | 
| 177 |     static_cast<> requires inheritance relationships to be already declared.  We
 | 
| 178 |     have to print all the classes first, then all the bodies that might use
 | 
| 179 |     static_cast<>.
 | 
| 180 | 
 | 
| 181 |     http://stackoverflow.com/questions/5808758/why-is-a-static-cast-from-a-pointer-to-base-to-a-pointer-to-derived-invalid
 | 
| 182 |     """
 | 
| 183 |     ctype = self._GetCppType(field)
 | 
| 184 |     name = field.name
 | 
| 185 |     pointer_type = self.pointer_type
 | 
| 186 |     # Either 'left' or 'BoolBinary::left', depending on whether it's inline.
 | 
| 187 |     # Mutated later.
 | 
| 188 |     maybe_qual_name = name
 | 
| 189 | 
 | 
| 190 |     func_proto = None
 | 
| 191 |     func_header = None
 | 
| 192 |     body_line1 = None
 | 
| 193 |     inline_body = None
 | 
| 194 | 
 | 
| 195 |     if field.seq:  # Array/repeated
 | 
| 196 |       # For size accessor, follow the ref, and then it's the first integer.
 | 
| 197 |       size_header = (
 | 
| 198 |           'inline int %(name)s_size(const %(pointer_type)s* base) const {')
 | 
| 199 |       size_body = "return Ref(base, %(offset)d).Int(0);"
 | 
| 200 | 
 | 
| 201 |       self.Emit(size_header % locals(), depth)
 | 
| 202 |       self.Emit(size_body % locals(), depth + 1)
 | 
| 203 |       self.Emit("}", depth)
 | 
| 204 | 
 | 
| 205 |       ARRAY_OFFSET = 'int a = (index+1) * 3;'
 | 
| 206 |       A_POINTER = (
 | 
| 207 |           'inline const %(ctype)s %(maybe_qual_name)s('
 | 
| 208 |           'const %(pointer_type)s* base, int index) const')
 | 
| 209 | 
 | 
| 210 |       if ctype in ('bool', 'int'):
 | 
| 211 |         func_header = A_POINTER + ' {'
 | 
| 212 |         body_line1 = ARRAY_OFFSET
 | 
| 213 |         inline_body = 'return Ref(base, %(offset)d).Int(a);'
 | 
| 214 | 
 | 
| 215 |       elif ctype.endswith('_e') or ctype in self.enum_types:
 | 
| 216 |         func_header = A_POINTER + ' {'
 | 
| 217 |         body_line1 = ARRAY_OFFSET
 | 
| 218 |         inline_body = (
 | 
| 219 |             'return static_cast<const %(ctype)s>(Ref(base, %(offset)d).Int(a));')
 | 
| 220 | 
 | 
| 221 |       elif ctype == 'char*':
 | 
| 222 |         func_header = A_POINTER + ' {'
 | 
| 223 |         body_line1 = ARRAY_OFFSET
 | 
| 224 |         inline_body = 'return Ref(base, %(offset)d).Str(base, a);'
 | 
| 225 | 
 | 
| 226 |       else:
 | 
| 227 |         # Write function prototype now; write body later.
 | 
| 228 |         func_proto = A_POINTER + ';'
 | 
| 229 | 
 | 
| 230 |         maybe_qual_name = '%s::%s' % (type_name, name)
 | 
| 231 |         func_def = A_POINTER + ' {'
 | 
| 232 |         # This static_cast<> (downcast) causes problems if put within "class
 | 
| 233 |         # {}".
 | 
| 234 |         func_body = (
 | 
| 235 |             'return static_cast<const %(ctype)s>('
 | 
| 236 |             'Ref(base, %(offset)d).Ref(base, a));')
 | 
| 237 | 
 | 
| 238 |         self.footer.extend(visitor.FormatLines(func_def % locals(), 0))
 | 
| 239 |         self.footer.extend(visitor.FormatLines(ARRAY_OFFSET, 1))
 | 
| 240 |         self.footer.extend(visitor.FormatLines(func_body % locals(), 1))
 | 
| 241 |         self.footer.append('}\n\n')
 | 
| 242 |         maybe_qual_name = name  # RESET for later
 | 
| 243 | 
 | 
| 244 |     else:  # not repeated
 | 
| 245 |       SIMPLE = "inline %(ctype)s %(maybe_qual_name)s() const {"
 | 
| 246 |       POINTER = (
 | 
| 247 |           'inline const %(ctype)s %(maybe_qual_name)s('
 | 
| 248 |           'const %(pointer_type)s* base) const')
 | 
| 249 | 
 | 
| 250 |       if ctype in ('bool', 'int'):
 | 
| 251 |         func_header = SIMPLE
 | 
| 252 |         inline_body = 'return Int(%(offset)d);'
 | 
| 253 | 
 | 
| 254 |       elif ctype.endswith('_e') or ctype in self.enum_types:
 | 
| 255 |         func_header = SIMPLE
 | 
| 256 |         inline_body = 'return static_cast<const %(ctype)s>(Int(%(offset)d));'
 | 
| 257 | 
 | 
| 258 |       elif ctype == 'char*':
 | 
| 259 |         func_header = POINTER + " {"
 | 
| 260 |         inline_body = 'return Str(base, %(offset)d);'
 | 
| 261 | 
 | 
| 262 |       else:
 | 
| 263 |         # Write function prototype now; write body later.
 | 
| 264 |         func_proto = POINTER + ";"
 | 
| 265 | 
 | 
| 266 |         maybe_qual_name = '%s::%s' % (type_name, name)
 | 
| 267 |         func_def = POINTER + ' {'
 | 
| 268 |         if field.opt:
 | 
| 269 |           func_body = (
 | 
| 270 |               'return static_cast<const %(ctype)s>(Optional(base, %(offset)d));')
 | 
| 271 |         else:
 | 
| 272 |           func_body = (
 | 
| 273 |               'return static_cast<const %(ctype)s>(Ref(base, %(offset)d));')
 | 
| 274 | 
 | 
| 275 |         # depth 0 for bodies
 | 
| 276 |         self.footer.extend(visitor.FormatLines(func_def % locals(), 0))
 | 
| 277 |         self.footer.extend(visitor.FormatLines(func_body % locals(), 1))
 | 
| 278 |         self.footer.append('}\n\n')
 | 
| 279 |         maybe_qual_name = name  # RESET for later
 | 
| 280 | 
 | 
| 281 |     if func_proto:
 | 
| 282 |       self.Emit(func_proto % locals(), depth)
 | 
| 283 |     else:
 | 
| 284 |       self.Emit(func_header % locals(), depth)
 | 
| 285 |       if body_line1:
 | 
| 286 |         self.Emit(body_line1, depth + 1)
 | 
| 287 |       self.Emit(inline_body % locals(), depth + 1)
 | 
| 288 |       self.Emit("}", depth)
 | 
| 289 | 
 | 
| 290 | 
 | 
| 291 | # Used by osh/ast_gen.py
 | 
| 292 | class CEnumVisitor(visitor.AsdlVisitor):
 | 
| 293 | 
 | 
| 294 |   def VisitSimpleSum(self, sum, name, depth):
 | 
| 295 |     # Just use #define, since enums aren't namespaced.
 | 
| 296 |     for i, variant in enumerate(sum.types):
 | 
| 297 |       self.Emit('#define %s__%s %d' % (name, variant.name, i + 1), depth)
 | 
| 298 |     self.Emit("", depth)
 | 
| 299 | 
 | 
| 300 | 
 | 
| 301 | def main(argv):
 | 
| 302 |   try:
 | 
| 303 |     action = argv[1]
 | 
| 304 |   except IndexError:
 | 
| 305 |     raise RuntimeError('Action required')
 | 
| 306 | 
 | 
| 307 |   # TODO: Also generate a switch/static_cast<> pretty printer in C++!  For
 | 
| 308 |   # debugging.  Might need to detect cycles though.
 | 
| 309 |   if action == 'cpp':
 | 
| 310 |     schema_path = argv[2]
 | 
| 311 | 
 | 
| 312 |     app_types = {'id': asdl.UserType(Id)}
 | 
| 313 |     with open(schema_path) as input_f:
 | 
| 314 |       module, type_lookup = asdl.LoadSchema(input_f, app_types)
 | 
| 315 | 
 | 
| 316 |     # TODO: gen_cpp.py should be a library and the application should add Id?
 | 
| 317 |     # Or we should enable ASDL metaprogramming, and let Id be a metaprogrammed
 | 
| 318 |     # simple sum type.
 | 
| 319 | 
 | 
| 320 |     f = sys.stdout
 | 
| 321 | 
 | 
| 322 |     # How do mutation of strings, arrays, etc.  work?  Are they like C++
 | 
| 323 |     # containers, or their own?  I think they mirror the oil language
 | 
| 324 |     # semantics.
 | 
| 325 |     # Every node should have a mirror.  MutableObj.  MutableRef (pointer).
 | 
| 326 |     # MutableArithVar -- has std::string.  The mirrors are heap allocated.
 | 
| 327 |     # All the mutable ones should support Dump()/Encode()?
 | 
| 328 |     # You can just write more at the end... don't need to disturb existing
 | 
| 329 |     # nodes?  Rewrite pointers.
 | 
| 330 | 
 | 
| 331 |     alignment = 4
 | 
| 332 |     enc = encode.Params(alignment)
 | 
| 333 |     d = {'pointer_type': enc.pointer_type}
 | 
| 334 | 
 | 
| 335 |     f.write("""\
 | 
| 336 | #include <cstdint>
 | 
| 337 | 
 | 
| 338 | class Obj {
 | 
| 339 |  public:
 | 
| 340 |   // Decode a 3 byte integer from little endian
 | 
| 341 |   inline int Int(int n) const;
 | 
| 342 | 
 | 
| 343 |   inline const Obj& Ref(const %(pointer_type)s* base, int n) const;
 | 
| 344 | 
 | 
| 345 |   inline const Obj* Optional(const %(pointer_type)s* base, int n) const;
 | 
| 346 | 
 | 
| 347 |   // NUL-terminated
 | 
| 348 |   inline const char* Str(const %(pointer_type)s* base, int n) const;
 | 
| 349 | 
 | 
| 350 |  protected:
 | 
| 351 |   uint8_t bytes_[1];  // first is ID; rest are a payload
 | 
| 352 | };
 | 
| 353 | 
 | 
| 354 | """ % d)
 | 
| 355 | 
 | 
| 356 |     # Id should be treated as an enum.
 | 
| 357 |     c = ChainOfVisitors(
 | 
| 358 |         ForwardDeclareVisitor(f),
 | 
| 359 |         ClassDefVisitor(f, enc, type_lookup, enum_types=['Id']))
 | 
| 360 |     c.VisitModule(module)
 | 
| 361 | 
 | 
| 362 |     f.write("""\
 | 
| 363 | inline int Obj::Int(int n) const {
 | 
| 364 |   return bytes_[n] + (bytes_[n+1] << 8) + (bytes_[n+2] << 16);
 | 
| 365 | }
 | 
| 366 | 
 | 
| 367 | inline const Obj& Obj::Ref(const %(pointer_type)s* base, int n) const {
 | 
| 368 |   int offset = Int(n);
 | 
| 369 |   return reinterpret_cast<const Obj&>(base[offset]);
 | 
| 370 | }
 | 
| 371 | 
 | 
| 372 | inline const Obj* Obj::Optional(const %(pointer_type)s* base, int n) const {
 | 
| 373 |   int offset = Int(n);
 | 
| 374 |   if (offset) {
 | 
| 375 |     return reinterpret_cast<const Obj*>(base + offset);
 | 
| 376 |   } else {
 | 
| 377 |     return nullptr;
 | 
| 378 |   }
 | 
| 379 | }
 | 
| 380 | 
 | 
| 381 | inline const char* Obj::Str(const %(pointer_type)s* base, int n) const {
 | 
| 382 |   int offset = Int(n);
 | 
| 383 |   return reinterpret_cast<const char*>(base + offset);
 | 
| 384 | }
 | 
| 385 | """ % d)
 | 
| 386 |   # uint32_t* and char*/Obj* aren't related, so we need to use
 | 
| 387 |   # reinterpret_cast<>.
 | 
| 388 |   # http://stackoverflow.com/questions/10151834/why-cant-i-static-cast-between-char-and-unsigned-char
 | 
| 389 | 
 | 
| 390 |   else:
 | 
| 391 |     raise RuntimeError('Invalid action %r' % action)
 | 
| 392 | 
 | 
| 393 | 
 | 
| 394 | if __name__ == '__main__':
 | 
| 395 |   try:
 | 
| 396 |     main(sys.argv)
 | 
| 397 |   except RuntimeError as e:
 | 
| 398 |     print('FATAL: %s' % e, file=sys.stderr)
 | 
| 399 |     sys.exit(1)
 |