| 1 | #!/usr/bin/env python2
 | 
| 2 | """
 | 
| 3 | classes.py - Test out inheritance.
 | 
| 4 | """
 | 
| 5 | from __future__ import print_function
 | 
| 6 | 
 | 
| 7 | import cStringIO
 | 
| 8 | import os
 | 
| 9 | import sys
 | 
| 10 | 
 | 
| 11 | from mycpp import mylib
 | 
| 12 | from mycpp.mylib import log
 | 
| 13 | 
 | 
| 14 | from typing import IO, cast
 | 
| 15 | 
 | 
| 16 | # Based on asdl/format.py
 | 
| 17 | 
 | 
| 18 | 
 | 
| 19 | class ColorOutput(object):
 | 
| 20 |     """Abstract base class for plain text, ANSI color, and HTML color."""
 | 
| 21 | 
 | 
| 22 |     def __init__(self, f):
 | 
| 23 |         # type: (mylib.Writer) -> None
 | 
| 24 |         self.f = f
 | 
| 25 |         self.num_chars = 0
 | 
| 26 | 
 | 
| 27 |     def write(self, s):
 | 
| 28 |         # type: (str) -> None
 | 
| 29 |         self.f.write(s)
 | 
| 30 |         self.num_chars += len(s)  # Only count visible characters!
 | 
| 31 | 
 | 
| 32 | 
 | 
| 33 | class TextOutput(ColorOutput):
 | 
| 34 |     """TextOutput put obeys the color interface, but outputs nothing."""
 | 
| 35 | 
 | 
| 36 |     def __init__(self, f):
 | 
| 37 |         # type: (mylib.Writer) -> None
 | 
| 38 |         """
 | 
| 39 |     This docstring used to interfere with __init__ detection
 | 
| 40 |     """
 | 
| 41 |         # Note: translated into an initializer list.
 | 
| 42 |         ColorOutput.__init__(self, f)
 | 
| 43 |         print('TextOutput constructor')
 | 
| 44 |         self.i = 0  # field only in derived class
 | 
| 45 | 
 | 
| 46 |     def MutateFields(self):
 | 
| 47 |         # type: () -> None
 | 
| 48 |         self.num_chars = 42
 | 
| 49 |         self.i = 43
 | 
| 50 | 
 | 
| 51 |     def PrintFields(self):
 | 
| 52 |         # type: () -> None
 | 
| 53 |         print("num_chars = %d" % self.num_chars)  # field from base
 | 
| 54 |         print("i = %d" % self.i)  # field from derived
 | 
| 55 | 
 | 
| 56 | 
 | 
| 57 | #
 | 
| 58 | # Heterogeneous linked list to test field masks, inheritance, virtual dispatch,
 | 
| 59 | # constructors, etc.
 | 
| 60 | #
 | 
| 61 | 
 | 
| 62 | 
 | 
| 63 | class Abstract(object):
 | 
| 64 | 
 | 
| 65 |     # empty constructor required by mycpp
 | 
| 66 |     def __init__(self):
 | 
| 67 |         # type: () -> None
 | 
| 68 |         pass
 | 
| 69 | 
 | 
| 70 |     def TypeString(self):
 | 
| 71 |         # type: () -> str
 | 
| 72 | 
 | 
| 73 |         # TODO: could be translated to TypeString() = 0; in C++
 | 
| 74 |         raise NotImplementedError()
 | 
| 75 | 
 | 
| 76 | 
 | 
| 77 | class Base(Abstract):
 | 
| 78 | 
 | 
| 79 |     def __init__(self, n):
 | 
| 80 |         # type: (Base) -> None
 | 
| 81 |         Abstract.__init__(self)
 | 
| 82 |         self.next = n
 | 
| 83 | 
 | 
| 84 |     def TypeString(self):
 | 
| 85 |         # type: () -> str
 | 
| 86 |         return "Base(%s)" % ('next' if self.next else 'null')
 | 
| 87 | 
 | 
| 88 | 
 | 
| 89 | class DerivedI(Base):
 | 
| 90 | 
 | 
| 91 |     def __init__(self, n, i):
 | 
| 92 |         # type: (Base, int) -> None
 | 
| 93 |         Base.__init__(self, n)
 | 
| 94 |         self.i = i
 | 
| 95 | 
 | 
| 96 |     def Integer(self):
 | 
| 97 |         # type: () -> int
 | 
| 98 |         return self.i
 | 
| 99 | 
 | 
| 100 |     def TypeString(self):
 | 
| 101 |         # type: () -> str
 | 
| 102 |         return "DerivedI(%s, %d)" % ('next' if self.next else 'null', self.i)
 | 
| 103 | 
 | 
| 104 | 
 | 
| 105 | class DerivedSS(Base):
 | 
| 106 | 
 | 
| 107 |     def __init__(self, n, t, u):
 | 
| 108 |         # type: (Base, str, str) -> None
 | 
| 109 |         Base.__init__(self, n)
 | 
| 110 |         self.t = t
 | 
| 111 |         self.u = u
 | 
| 112 | 
 | 
| 113 |     def TypeString(self):
 | 
| 114 |         # type: () -> str
 | 
| 115 |         return "DerivedSS(%s, %s, %s)" % ('next' if self.next else 'null',
 | 
| 116 |                                           self.t, self.u)
 | 
| 117 | 
 | 
| 118 | 
 | 
| 119 | #
 | 
| 120 | # Homogeneous Node
 | 
| 121 | #
 | 
| 122 | 
 | 
| 123 | 
 | 
| 124 | class Node(object):
 | 
| 125 |     """No vtable pointer."""
 | 
| 126 | 
 | 
| 127 |     def __init__(self, n, i):
 | 
| 128 |         # type: (Node, int) -> None
 | 
| 129 |         self.next = n
 | 
| 130 |         self.i = i
 | 
| 131 | 
 | 
| 132 | 
 | 
| 133 | def TestMethods():
 | 
| 134 |     # type: () -> None
 | 
| 135 | 
 | 
| 136 |     stdout_ = mylib.Stdout()
 | 
| 137 |     out = TextOutput(stdout_)
 | 
| 138 |     out.write('foo\n')
 | 
| 139 |     out.write('bar\n')
 | 
| 140 |     log('Wrote %d bytes', out.num_chars)
 | 
| 141 | 
 | 
| 142 |     out.MutateFields()
 | 
| 143 |     out.PrintFields()
 | 
| 144 | 
 | 
| 145 | 
 | 
| 146 | def f(obj):
 | 
| 147 |     # type: (Base) -> str
 | 
| 148 |     return obj.TypeString()
 | 
| 149 | 
 | 
| 150 | 
 | 
| 151 | def TestInheritance():
 | 
| 152 |     # type: () -> None
 | 
| 153 | 
 | 
| 154 |     b = Base(None)
 | 
| 155 |     di = DerivedI(None, 1)
 | 
| 156 |     dss = DerivedSS(None, 'left', 'right')
 | 
| 157 | 
 | 
| 158 |     log('Integer() = %d', di.Integer())
 | 
| 159 | 
 | 
| 160 |     log("b.TypeString()   %s", b.TypeString())
 | 
| 161 |     log("di.TypeString()  %s", di.TypeString())
 | 
| 162 |     log("dss.TypeString() %s", dss.TypeString())
 | 
| 163 | 
 | 
| 164 |     log("f(b)           %s", f(b))
 | 
| 165 |     log("f(di)          %s", f(di))
 | 
| 166 |     log("f(dss)         %s", f(dss))
 | 
| 167 | 
 | 
| 168 | 
 | 
| 169 | def run_tests():
 | 
| 170 |     # type: () -> None
 | 
| 171 |     TestMethods()
 | 
| 172 |     TestInheritance()
 | 
| 173 | 
 | 
| 174 | 
 | 
| 175 | def BenchmarkWriter(n):
 | 
| 176 |     # type: (int) -> None
 | 
| 177 | 
 | 
| 178 |     log('BenchmarkWriter')
 | 
| 179 |     log('')
 | 
| 180 | 
 | 
| 181 |     f = mylib.BufWriter()
 | 
| 182 |     out = TextOutput(f)
 | 
| 183 | 
 | 
| 184 |     i = 0
 | 
| 185 |     while i < n:
 | 
| 186 |         out.write('foo\n')
 | 
| 187 |         i += 1
 | 
| 188 |     log('  Ran %d iterations', n)
 | 
| 189 |     log('  Wrote %d bytes', out.num_chars)
 | 
| 190 |     log('')
 | 
| 191 | 
 | 
| 192 | 
 | 
| 193 | def PrintLength(node):
 | 
| 194 |     # type: (Node) -> None
 | 
| 195 | 
 | 
| 196 |     current = node
 | 
| 197 |     linked_list_len = 0
 | 
| 198 |     while True:
 | 
| 199 |         if linked_list_len < 10:
 | 
| 200 |             log('  -> %d', current.i)
 | 
| 201 | 
 | 
| 202 |         current = current.next
 | 
| 203 | 
 | 
| 204 |         if current is None:
 | 
| 205 |             break
 | 
| 206 | 
 | 
| 207 |         linked_list_len += 1
 | 
| 208 | 
 | 
| 209 |     log('')
 | 
| 210 |     log("  linked list len = %d", linked_list_len)
 | 
| 211 |     log('')
 | 
| 212 | 
 | 
| 213 | 
 | 
| 214 | def BenchmarkSimpleNode(n):
 | 
| 215 |     # type: (int) -> None
 | 
| 216 | 
 | 
| 217 |     log('BenchmarkSimpleNode')
 | 
| 218 |     log('')
 | 
| 219 | 
 | 
| 220 |     next_ = Node(None, -1)
 | 
| 221 |     for i in xrange(n):
 | 
| 222 |         node = Node(next_, i)
 | 
| 223 |         next_ = node
 | 
| 224 | 
 | 
| 225 |     PrintLength(node)
 | 
| 226 | 
 | 
| 227 | 
 | 
| 228 | def PrintLengthBase(current):
 | 
| 229 |     # type: (Base) -> None
 | 
| 230 | 
 | 
| 231 |     linked_list_len = 0
 | 
| 232 |     while True:
 | 
| 233 |         if linked_list_len < 10:
 | 
| 234 |             log('  -> %s', current.TypeString())
 | 
| 235 | 
 | 
| 236 |         current = current.next
 | 
| 237 | 
 | 
| 238 |         if current is None:
 | 
| 239 |             break
 | 
| 240 |         linked_list_len += 1
 | 
| 241 | 
 | 
| 242 |     log('')
 | 
| 243 |     log("  linked list len = %d", linked_list_len)
 | 
| 244 |     log('')
 | 
| 245 | 
 | 
| 246 | 
 | 
| 247 | def BenchmarkVirtualNodes(n):
 | 
| 248 |     # type: (int) -> None
 | 
| 249 |     """With virtual function pointers"""
 | 
| 250 | 
 | 
| 251 |     log('BenchmarkVirtualNodes')
 | 
| 252 |     log('')
 | 
| 253 | 
 | 
| 254 |     next_ = Base(None)
 | 
| 255 |     for i in xrange(n):
 | 
| 256 |         node1 = DerivedI(next_, i)
 | 
| 257 | 
 | 
| 258 |         # Allocate some children
 | 
| 259 |         s1 = str(i)
 | 
| 260 |         s2 = '+%d' % i
 | 
| 261 |         node2 = DerivedSS(node1, s1, s2)
 | 
| 262 | 
 | 
| 263 |         node3 = Base(node2)
 | 
| 264 |         next_ = node3
 | 
| 265 | 
 | 
| 266 |     # do this separately because of type
 | 
| 267 |     current = None  # type: Base
 | 
| 268 |     current = node3
 | 
| 269 | 
 | 
| 270 |     PrintLengthBase(current)
 | 
| 271 | 
 | 
| 272 | 
 | 
| 273 | def run_benchmarks():
 | 
| 274 |     # type: () -> None
 | 
| 275 | 
 | 
| 276 |     # NOTE: Raising this exposes quadratic behavior
 | 
| 277 |     #  30,000 iterations:  1.4 seconds in cxx-opt mode
 | 
| 278 |     #  60,000 iterations:  5.0 seconds in cxx-opt mode
 | 
| 279 |     if 1:
 | 
| 280 |         BenchmarkWriter(30000)
 | 
| 281 | 
 | 
| 282 |     if 1:
 | 
| 283 |         BenchmarkSimpleNode(10000)
 | 
| 284 | 
 | 
| 285 |     # Hits Collect() and ASAN finds bugs above 500 and before 1000
 | 
| 286 |     #BenchmarkNodes(750)
 | 
| 287 |     if 1:
 | 
| 288 |         BenchmarkVirtualNodes(1000)
 | 
| 289 | 
 | 
| 290 | 
 | 
| 291 | if __name__ == '__main__':
 | 
| 292 |     if os.getenv('BENCHMARK'):
 | 
| 293 |         log('Benchmarking...')
 | 
| 294 |         run_benchmarks()
 | 
| 295 |     else:
 | 
| 296 |         run_tests()
 | 
| 297 | 
 | 
| 298 | # vim: sw=2
 |