OILS / opy / _regtest / src / opy / compiler2 / pyassem.py View on Github | oilshell.org

769 lines, 563 significant
1"""A flow graph representation for Python bytecode"""
2from __future__ import print_function
3
4import dis
5import types
6import sys
7
8from . import misc
9from .consts import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS
10
11class FlowGraph:
12 def __init__(self):
13 self.current = self.entry = Block()
14 self.exit = Block("exit")
15 self.blocks = misc.Set()
16 self.blocks.add(self.entry)
17 self.blocks.add(self.exit)
18
19 def startBlock(self, block):
20 if self._debug:
21 if self.current:
22 print("end", repr(self.current))
23 print(" next", self.current.next)
24 print(" prev", self.current.prev)
25 print(" ", self.current.get_children())
26 print(repr(block))
27 self.current = block
28
29 def nextBlock(self, block=None):
30 # XXX think we need to specify when there is implicit transfer
31 # from one block to the next. might be better to represent this
32 # with explicit JUMP_ABSOLUTE instructions that are optimized
33 # out when they are unnecessary.
34 #
35 # I think this strategy works: each block has a child
36 # designated as "next" which is returned as the last of the
37 # children. because the nodes in a graph are emitted in
38 # reverse post order, the "next" block will always be emitted
39 # immediately after its parent.
40 # Worry: maintaining this invariant could be tricky
41 if block is None:
42 block = self.newBlock()
43
44 # Note: If the current block ends with an unconditional control
45 # transfer, then it is technically incorrect to add an implicit
46 # transfer to the block graph. Doing so results in code generation
47 # for unreachable blocks. That doesn't appear to be very common
48 # with Python code and since the built-in compiler doesn't optimize
49 # it out we don't either.
50 self.current.addNext(block)
51 self.startBlock(block)
52
53 def newBlock(self):
54 b = Block()
55 self.blocks.add(b)
56 return b
57
58 def startExitBlock(self):
59 self.startBlock(self.exit)
60
61 _debug = 0
62
63 def _enable_debug(self):
64 self._debug = 1
65
66 def _disable_debug(self):
67 self._debug = 0
68
69 def emit(self, *inst):
70 if self._debug:
71 print("\t", inst)
72 if len(inst) == 2 and isinstance(inst[1], Block):
73 self.current.addOutEdge(inst[1])
74 self.current.emit(inst)
75
76 def getBlocksInOrder(self):
77 """Return the blocks in reverse postorder
78
79 i.e. each node appears before all of its successors
80 """
81 order = order_blocks(self.entry, self.exit)
82 return order
83
84 def getBlocks(self):
85 return self.blocks.elements()
86
87 def getRoot(self):
88 """Return nodes appropriate for use with dominator"""
89 return self.entry
90
91 def getContainedGraphs(self):
92 l = []
93 for b in self.getBlocks():
94 l.extend(b.getContainedGraphs())
95 return l
96
97
98def order_blocks(start_block, exit_block):
99 """Order blocks so that they are emitted in the right order"""
100 # Rules:
101 # - when a block has a next block, the next block must be emitted just after
102 # - when a block has followers (relative jumps), it must be emitted before
103 # them
104 # - all reachable blocks must be emitted
105 order = []
106
107 # Find all the blocks to be emitted.
108 remaining = set()
109 todo = [start_block]
110 while todo:
111 b = todo.pop()
112 if b in remaining:
113 continue
114 remaining.add(b)
115 for c in b.get_children():
116 if c not in remaining:
117 todo.append(c)
118
119 # A block is dominated by another block if that block must be emitted
120 # before it.
121 dominators = {}
122 for b in remaining:
123 if __debug__ and b.next:
124 assert b is b.next[0].prev[0], (b, b.next)
125 # Make sure every block appears in dominators, even if no
126 # other block must precede it.
127 dominators.setdefault(b, set())
128 # preceding blocks dominate following blocks
129 for c in b.get_followers():
130 while 1:
131 dominators.setdefault(c, set()).add(b)
132 # Any block that has a next pointer leading to c is also
133 # dominated because the whole chain will be emitted at once.
134 # Walk backwards and add them all.
135 if c.prev and c.prev[0] is not b:
136 c = c.prev[0]
137 else:
138 break
139
140 def find_next():
141 # Find a block that can be emitted next.
142 for b in remaining:
143 for c in dominators[b]:
144 if c in remaining:
145 break # can't emit yet, dominated by a remaining block
146 else:
147 return b
148 assert 0, 'circular dependency, cannot find next block'
149
150 b = start_block
151 while 1:
152 order.append(b)
153 remaining.discard(b)
154 if b.next:
155 b = b.next[0]
156 continue
157 elif b is not exit_block and not b.has_unconditional_transfer():
158 order.append(exit_block)
159 if not remaining:
160 break
161 b = find_next()
162 return order
163
164
165class Block:
166 _count = 0
167
168 def __init__(self, label=''):
169 self.insts = []
170 self.outEdges = set()
171 self.label = label
172 self.bid = Block._count
173 self.next = []
174 self.prev = []
175 Block._count = Block._count + 1
176
177 # BUG FIX: This is needed for deterministic order in sets (and dicts?).
178 # See order_blocks() below. remaining is set() of blocks. If we rely on
179 # the default id(), then the output bytecode is NONDETERMINISTIC.
180 def __hash__(self):
181 return self.bid
182
183 def __repr__(self):
184 if self.label:
185 return "<block %s id=%d>" % (self.label, self.bid)
186 else:
187 return "<block id=%d>" % (self.bid)
188
189 def __str__(self):
190 insts = map(str, self.insts)
191 return "<block %s %d:\n%s>" % (self.label, self.bid,
192 '\n'.join(insts))
193
194 def emit(self, inst):
195 op = inst[0]
196 self.insts.append(inst)
197
198 def getInstructions(self):
199 return self.insts
200
201 def addOutEdge(self, block):
202 self.outEdges.add(block)
203
204 def addNext(self, block):
205 self.next.append(block)
206 assert len(self.next) == 1, map(str, self.next)
207 block.prev.append(self)
208 assert len(block.prev) == 1, map(str, block.prev)
209
210 _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS',
211 'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP',
212 )
213
214 def has_unconditional_transfer(self):
215 """Returns True if there is an unconditional transfer to an other block
216 at the end of this block. This means there is no risk for the bytecode
217 executer to go past this block's bytecode."""
218 try:
219 op, arg = self.insts[-1]
220 except (IndexError, ValueError):
221 return
222 return op in self._uncond_transfer
223
224 def get_children(self):
225 return list(self.outEdges) + self.next
226
227 def get_followers(self):
228 """Get the whole list of followers, including the next block."""
229 followers = set(self.next)
230 # Blocks that must be emitted *after* this one, because of
231 # bytecode offsets (e.g. relative jumps) pointing to them.
232 for inst in self.insts:
233 if inst[0] in PyFlowGraph.hasjrel:
234 followers.add(inst[1])
235 return followers
236
237 def getContainedGraphs(self):
238 """Return all graphs contained within this block.
239
240 For example, a MAKE_FUNCTION block will contain a reference to
241 the graph for the function body.
242 """
243 contained = []
244 for inst in self.insts:
245 if len(inst) == 1:
246 continue
247 op = inst[1]
248 if hasattr(op, 'graph'):
249 contained.append(op.graph)
250 return contained
251
252# flags for code objects
253
254# the FlowGraph is transformed in place; it exists in one of these states
255RAW = "RAW"
256FLAT = "FLAT"
257CONV = "CONV"
258DONE = "DONE"
259
260class PyFlowGraph(FlowGraph):
261 super_init = FlowGraph.__init__
262
263 def __init__(self, name, filename, args=(), optimized=0, klass=None):
264 self.super_init()
265 self.name = name # name that is put in the code object
266 self.filename = filename
267 self.docstring = None
268 self.args = args # XXX
269 self.argcount = getArgCount(args)
270 self.klass = klass
271 if optimized:
272 self.flags = CO_OPTIMIZED | CO_NEWLOCALS
273 else:
274 self.flags = 0
275 self.consts = []
276 self.names = []
277 # Free variables found by the symbol table scan, including
278 # variables used only in nested scopes, are included here.
279 self.freevars = []
280 self.cellvars = []
281 # The closure list is used to track the order of cell
282 # variables and free variables in the resulting code object.
283 # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
284 # kinds of variables.
285 self.closure = []
286 self.varnames = list(args) or []
287 for i in range(len(self.varnames)):
288 var = self.varnames[i]
289 if isinstance(var, TupleArg):
290 self.varnames[i] = var.getName()
291 self.stage = RAW
292
293 def setDocstring(self, doc):
294 self.docstring = doc
295
296 def setFlag(self, flag):
297 self.flags = self.flags | flag
298 if flag == CO_VARARGS:
299 self.argcount = self.argcount - 1
300
301 def checkFlag(self, flag):
302 if self.flags & flag:
303 return 1
304
305 def setFreeVars(self, names):
306 self.freevars = list(names)
307
308 def setCellVars(self, names):
309 self.cellvars = names
310
311 def getCode(self):
312 """Get a Python code object"""
313 assert self.stage == RAW
314 self.computeStackDepth()
315 self.flattenGraph()
316 assert self.stage == FLAT
317 self.convertArgs()
318 assert self.stage == CONV
319 self.makeByteCode()
320 assert self.stage == DONE
321 return self.newCodeObject()
322
323 def dump(self, io=None):
324 if io:
325 save = sys.stdout
326 sys.stdout = io
327 pc = 0
328 for t in self.insts:
329 opname = t[0]
330 if opname == "SET_LINENO":
331 print()
332 if len(t) == 1:
333 print("\t", "%3d" % pc, opname)
334 pc = pc + 1
335 else:
336 print("\t", "%3d" % pc, opname, t[1])
337 pc = pc + 3
338 if io:
339 sys.stdout = save
340
341 def computeStackDepth(self):
342 """Compute the max stack depth.
343
344 Approach is to compute the stack effect of each basic block.
345 Then find the path through the code with the largest total
346 effect.
347 """
348 depth = {}
349 exit = None
350 for b in self.getBlocks():
351 depth[b] = findDepth(b.getInstructions())
352
353 seen = {}
354
355 def max_depth(b, d):
356 if b in seen:
357 return d
358 seen[b] = 1
359 d = d + depth[b]
360 children = b.get_children()
361 if children:
362 return max([max_depth(c, d) for c in children])
363 else:
364 if not b.label == "exit":
365 return max_depth(self.exit, d)
366 else:
367 return d
368
369 self.stacksize = max_depth(self.entry, 0)
370
371 def flattenGraph(self):
372 """Arrange the blocks in order and resolve jumps"""
373 assert self.stage == RAW
374 self.insts = insts = []
375 pc = 0
376 begin = {}
377 end = {}
378 for b in self.getBlocksInOrder():
379 begin[b] = pc
380 for inst in b.getInstructions():
381 insts.append(inst)
382 if len(inst) == 1:
383 pc = pc + 1
384 elif inst[0] != "SET_LINENO":
385 # arg takes 2 bytes
386 pc = pc + 3
387 end[b] = pc
388 pc = 0
389 for i in range(len(insts)):
390 inst = insts[i]
391 if len(inst) == 1:
392 pc = pc + 1
393 elif inst[0] != "SET_LINENO":
394 pc = pc + 3
395 opname = inst[0]
396 if opname in self.hasjrel:
397 oparg = inst[1]
398 offset = begin[oparg] - pc
399 insts[i] = opname, offset
400 elif opname in self.hasjabs:
401 insts[i] = opname, begin[inst[1]]
402 self.stage = FLAT
403
404 hasjrel = set()
405 for i in dis.hasjrel:
406 hasjrel.add(dis.opname[i])
407 hasjabs = set()
408 for i in dis.hasjabs:
409 hasjabs.add(dis.opname[i])
410
411 def convertArgs(self):
412 """Convert arguments from symbolic to concrete form"""
413 assert self.stage == FLAT
414 self.consts.insert(0, self.docstring)
415 self.sort_cellvars()
416 for i in range(len(self.insts)):
417 t = self.insts[i]
418 if len(t) == 2:
419 opname, oparg = t
420 conv = self._converters.get(opname, None)
421 if conv:
422 self.insts[i] = opname, conv(self, oparg)
423 self.stage = CONV
424
425 def sort_cellvars(self):
426 """Sort cellvars in the order of varnames and prune from freevars.
427 """
428 cells = {}
429 for name in self.cellvars:
430 cells[name] = 1
431 self.cellvars = [name for name in self.varnames
432 if name in cells]
433 for name in self.cellvars:
434 del cells[name]
435 self.cellvars = self.cellvars + cells.keys()
436 self.closure = self.cellvars + self.freevars
437
438 def _lookupName(self, name, L):
439 """Return index of name in list, appending if necessary
440
441 This routine uses a list instead of a dictionary, because a
442 dictionary can't store two different keys if the keys have the
443 same value but different types, e.g. 2 and 2L. The compiler
444 must treat these two separately, so it does an explicit type
445 comparison before comparing the values.
446 """
447 t = type(name)
448 for i in xrange(len(L)):
449 if t == type(L[i]) and L[i] == name:
450 return i
451 end = len(L)
452 L.append(name)
453 return end
454
455 _converters = {}
456 def _convert_LOAD_CONST(self, arg):
457 if hasattr(arg, 'getCode'):
458 arg = arg.getCode()
459 return self._lookupName(arg, self.consts)
460
461 def _convert_LOAD_FAST(self, arg):
462 self._lookupName(arg, self.names)
463 return self._lookupName(arg, self.varnames)
464 _convert_STORE_FAST = _convert_LOAD_FAST
465 _convert_DELETE_FAST = _convert_LOAD_FAST
466
467 def _convert_LOAD_NAME(self, arg):
468 if self.klass is None:
469 self._lookupName(arg, self.varnames)
470 return self._lookupName(arg, self.names)
471
472 def _convert_NAME(self, arg):
473 if self.klass is None:
474 self._lookupName(arg, self.varnames)
475 return self._lookupName(arg, self.names)
476 _convert_STORE_NAME = _convert_NAME
477 _convert_DELETE_NAME = _convert_NAME
478 _convert_IMPORT_NAME = _convert_NAME
479 _convert_IMPORT_FROM = _convert_NAME
480 _convert_STORE_ATTR = _convert_NAME
481 _convert_LOAD_ATTR = _convert_NAME
482 _convert_DELETE_ATTR = _convert_NAME
483 _convert_LOAD_GLOBAL = _convert_NAME
484 _convert_STORE_GLOBAL = _convert_NAME
485 _convert_DELETE_GLOBAL = _convert_NAME
486
487 def _convert_DEREF(self, arg):
488 self._lookupName(arg, self.names)
489 self._lookupName(arg, self.varnames)
490 return self._lookupName(arg, self.closure)
491 _convert_LOAD_DEREF = _convert_DEREF
492 _convert_STORE_DEREF = _convert_DEREF
493
494 def _convert_LOAD_CLOSURE(self, arg):
495 self._lookupName(arg, self.varnames)
496 return self._lookupName(arg, self.closure)
497
498 _cmp = list(dis.cmp_op)
499 def _convert_COMPARE_OP(self, arg):
500 return self._cmp.index(arg)
501
502 # similarly for other opcodes...
503
504 for name, obj in locals().items():
505 if name[:9] == "_convert_":
506 opname = name[9:]
507 _converters[opname] = obj
508 del name, obj, opname
509
510 def makeByteCode(self):
511 assert self.stage == CONV
512 self.lnotab = lnotab = LineAddrTable()
513 for t in self.insts:
514 opname = t[0]
515 if len(t) == 1:
516 lnotab.addCode(self.opnum[opname])
517 else:
518 oparg = t[1]
519 if opname == "SET_LINENO":
520 lnotab.nextLine(oparg)
521 continue
522 hi, lo = twobyte(oparg)
523 try:
524 lnotab.addCode(self.opnum[opname], lo, hi)
525 except ValueError:
526 print(opname, oparg)
527 print(self.opnum[opname], lo, hi)
528 raise
529 self.stage = DONE
530
531 opnum = {}
532 for num in range(len(dis.opname)):
533 opnum[dis.opname[num]] = num
534 del num
535
536 def newCodeObject(self):
537 assert self.stage == DONE
538 if (self.flags & CO_NEWLOCALS) == 0:
539 nlocals = 0
540 else:
541 nlocals = len(self.varnames)
542 argcount = self.argcount
543 if self.flags & CO_VARKEYWORDS:
544 argcount = argcount - 1
545 return types.CodeType(argcount, nlocals, self.stacksize, self.flags,
546 self.lnotab.getCode(), self.getConsts(),
547 tuple(self.names), tuple(self.varnames),
548 self.filename, self.name, self.lnotab.firstline,
549 self.lnotab.getTable(), tuple(self.freevars),
550 tuple(self.cellvars))
551
552 def getConsts(self):
553 """Return a tuple for the const slot of the code object
554
555 Must convert references to code (MAKE_FUNCTION) to code
556 objects recursively.
557 """
558 l = []
559 for elt in self.consts:
560 if isinstance(elt, PyFlowGraph):
561 elt = elt.getCode()
562 l.append(elt)
563 return tuple(l)
564
565def isJump(opname):
566 if opname[:4] == 'JUMP':
567 return 1
568
569class TupleArg:
570 """Helper for marking func defs with nested tuples in arglist"""
571 def __init__(self, count, names):
572 self.count = count
573 self.names = names
574 def __repr__(self):
575 return "TupleArg(%s, %s)" % (self.count, self.names)
576 def getName(self):
577 return ".%d" % self.count
578
579def getArgCount(args):
580 argcount = len(args)
581 if args:
582 for arg in args:
583 if isinstance(arg, TupleArg):
584 numNames = len(misc.flatten(arg.names))
585 argcount = argcount - numNames
586 return argcount
587
588def twobyte(val):
589 """Convert an int argument into high and low bytes"""
590 assert isinstance(val, int)
591 return divmod(val, 256)
592
593class LineAddrTable:
594 """lnotab
595
596 This class builds the lnotab, which is documented in compile.c.
597 Here's a brief recap:
598
599 For each SET_LINENO instruction after the first one, two bytes are
600 added to lnotab. (In some cases, multiple two-byte entries are
601 added.) The first byte is the distance in bytes between the
602 instruction for the last SET_LINENO and the current SET_LINENO.
603 The second byte is offset in line numbers. If either offset is
604 greater than 255, multiple two-byte entries are added -- see
605 compile.c for the delicate details.
606 """
607
608 def __init__(self):
609 self.code = []
610 self.codeOffset = 0
611 self.firstline = 0
612 self.lastline = 0
613 self.lastoff = 0
614 self.lnotab = []
615
616 def addCode(self, *args):
617 for arg in args:
618 self.code.append(chr(arg))
619 self.codeOffset = self.codeOffset + len(args)
620
621 def nextLine(self, lineno):
622 if self.firstline == 0:
623 self.firstline = lineno
624 self.lastline = lineno
625 else:
626 # compute deltas
627 addr = self.codeOffset - self.lastoff
628 line = lineno - self.lastline
629 # Python assumes that lineno always increases with
630 # increasing bytecode address (lnotab is unsigned char).
631 # Depending on when SET_LINENO instructions are emitted
632 # this is not always true. Consider the code:
633 # a = (1,
634 # b)
635 # In the bytecode stream, the assignment to "a" occurs
636 # after the loading of "b". This works with the C Python
637 # compiler because it only generates a SET_LINENO instruction
638 # for the assignment.
639 if line >= 0:
640 push = self.lnotab.append
641 while addr > 255:
642 push(255); push(0)
643 addr -= 255
644 while line > 255:
645 push(addr); push(255)
646 line -= 255
647 addr = 0
648 if addr > 0 or line > 0:
649 push(addr); push(line)
650 self.lastline = lineno
651 self.lastoff = self.codeOffset
652
653 def getCode(self):
654 return ''.join(self.code)
655
656 def getTable(self):
657 return ''.join(map(chr, self.lnotab))
658
659class StackDepthTracker:
660 # XXX 1. need to keep track of stack depth on jumps
661 # XXX 2. at least partly as a result, this code is broken
662
663 def findDepth(self, insts, debug=0):
664 depth = 0
665 maxDepth = 0
666 for i in insts:
667 opname = i[0]
668 if debug:
669 print(i, end=' ')
670 delta = self.effect.get(opname, None)
671 if delta is not None:
672 depth = depth + delta
673 else:
674 # now check patterns
675 for pat, pat_delta in self.patterns:
676 if opname[:len(pat)] == pat:
677 delta = pat_delta
678 depth = depth + delta
679 break
680 # if we still haven't found a match
681 if delta is None:
682 meth = getattr(self, opname, None)
683 if meth is not None:
684 depth = depth + meth(i[1])
685 if depth > maxDepth:
686 maxDepth = depth
687 if debug:
688 print(depth, maxDepth)
689 return maxDepth
690
691 effect = {
692 'POP_TOP': -1,
693 'DUP_TOP': 1,
694 'LIST_APPEND': -1,
695 'SET_ADD': -1,
696 'MAP_ADD': -2,
697 'SLICE+1': -1,
698 'SLICE+2': -1,
699 'SLICE+3': -2,
700 'STORE_SLICE+0': -1,
701 'STORE_SLICE+1': -2,
702 'STORE_SLICE+2': -2,
703 'STORE_SLICE+3': -3,
704 'DELETE_SLICE+0': -1,
705 'DELETE_SLICE+1': -2,
706 'DELETE_SLICE+2': -2,
707 'DELETE_SLICE+3': -3,
708 'STORE_SUBSCR': -3,
709 'DELETE_SUBSCR': -2,
710 # PRINT_EXPR?
711 'PRINT_ITEM': -1,
712 'RETURN_VALUE': -1,
713 'YIELD_VALUE': -1,
714 'EXEC_STMT': -3,
715 'BUILD_CLASS': -2,
716 'STORE_NAME': -1,
717 'STORE_ATTR': -2,
718 'DELETE_ATTR': -1,
719 'STORE_GLOBAL': -1,
720 'BUILD_MAP': 1,
721 'COMPARE_OP': -1,
722 'STORE_FAST': -1,
723 'IMPORT_STAR': -1,
724 'IMPORT_NAME': -1,
725 'IMPORT_FROM': 1,
726 'LOAD_ATTR': 0, # unlike other loads
727 # close enough...
728 'SETUP_EXCEPT': 3,
729 'SETUP_FINALLY': 3,
730 'FOR_ITER': 1,
731 'WITH_CLEANUP': -1,
732 }
733 # use pattern match
734 patterns = [
735 ('BINARY_', -1),
736 ('LOAD_', 1),
737 ]
738
739 def UNPACK_SEQUENCE(self, count):
740 return count-1
741 def BUILD_TUPLE(self, count):
742 return -count+1
743 def BUILD_LIST(self, count):
744 return -count+1
745 def BUILD_SET(self, count):
746 return -count+1
747 def CALL_FUNCTION(self, argc):
748 hi, lo = divmod(argc, 256)
749 return -(lo + hi * 2)
750 def CALL_FUNCTION_VAR(self, argc):
751 return self.CALL_FUNCTION(argc)-1
752 def CALL_FUNCTION_KW(self, argc):
753 return self.CALL_FUNCTION(argc)-1
754 def CALL_FUNCTION_VAR_KW(self, argc):
755 return self.CALL_FUNCTION(argc)-2
756 def MAKE_FUNCTION(self, argc):
757 return -argc
758 def MAKE_CLOSURE(self, argc):
759 # XXX need to account for free variables too!
760 return -argc
761 def BUILD_SLICE(self, argc):
762 if argc == 2:
763 return -1
764 elif argc == 3:
765 return -2
766 def DUP_TOPX(self, argc):
767 return argc
768
769findDepth = StackDepthTracker().findDepth