import dis import struct import new import re from dis import opmap, cmp_op, opname class Bytecode(): ''' Class to store individual instruction as a node in the graph ''' def __init__(self, addr, buffer, prev=None, next=None, xrefs=[]): self.opcode = ord(buffer[0]) self.addr = addr if self.opcode >= dis.HAVE_ARGUMENT: self.oparg = ord(buffer[1]) | (ord(buffer[2]) << 8) else: self.oparg = None self.prev = prev self.next = next self.xrefs = [] self.target = None self.co_lnotab = None def len(self): ''' Returns the length of the bytecode 1 for no argument 3 for argument ''' if self.opcode < dis.HAVE_ARGUMENT: return 1 else: return 3 def disassemble(self): ''' Return disassembly of bytecode ''' rvalue = opname[self.opcode].ljust(20) if self.opcode >= dis.HAVE_ARGUMENT: rvalue += " %04x" % (self.oparg) return rvalue def hex(self): ''' Return ASCII hex representation of bytecode ''' rvalue = "%02x" % self.opcode if self.opcode >= dis.HAVE_ARGUMENT: rvalue += "%02x%02x" % \ (self.oparg & 0xff, (self.oparg >> 8) & 0xff) return rvalue def bin(self): ''' Return bytecode string ''' if self.opcode >= dis.HAVE_ARGUMENT: return struct.pack("<BH", self.opcode, self.oparg) else: return struct.pack("<B", self.opcode) def get_target_addr(self): ''' Returns the target address for the current instruction based on the current address. ''' rvalue = None if self.opcode in dis.hasjrel: rvalue = self.addr + self.oparg + self.len() if self.opcode in dis.hasjabs: rvalue = self.oparg return rvalue class BytecodeGraph(): def __init__(self, code, base=0): self.base = base self.code = code self.head = None self.parse_bytecode() self.apply_lineno() def add_node(self, parent, bc, lnotab=None): ''' Adds an instruction node to the graph ''' # setup pointers for new node bc.next = parent.next bc.prev = parent if lnotab is None: bc.co_lnotab = parent.co_lnotab else: bc.co_lnotab = lnotab if parent.next is not None: parent.next.prev = bc parent.next = bc def apply_labels(self, start=None): ''' Find all JMP REL and ABS bytecode sequences and update the target within branch instruction and add xref to the destination. ''' for current in self.nodes(start): current.xrefs = [] current.target = None for current in self.nodes(start): label = -1 if current.opcode >= dis.HAVE_ARGUMENT: if current.opcode in dis.hasjrel: label = current.addr+current.oparg+current.len() elif current.opcode in dis.hasjabs: label = current.oparg if label >= 0: if current not in self.bytecodes[label].xrefs: self.bytecodes[label].xrefs.append(current) current.target = self.bytecodes[label] current = current.next return def apply_lineno(self): ''' Parses the code object co_lnotab list and applies line numbers to bytecode. This is used to create a new co_lnotab list after modifying bytecode. ''' byte_increments = [ord(c) for c in self.code.co_lnotab[0::2]] line_increments = [ord(c) for c in self.code.co_lnotab[1::2]] lineno = self.code.co_firstlineno addr = self.base linenos = [] for byte_incr, line_incr in zip(byte_increments, line_increments): addr += byte_incr lineno += line_incr linenos.append((addr, lineno)) if linenos == []: return current_addr, current_lineno = linenos.pop(0) current_addr, next_lineno = linenos.pop(0) for x in self.nodes(): if x.addr >= current_addr: current_lineno = next_lineno if len(linenos) != 0: current_addr, next_lineno = linenos.pop(0) x.co_lnotab = current_lineno def calc_lnotab(self): ''' Creates a new co_lineno after modifying bytecode ''' rvalue = "" prev_lineno = self.code.co_firstlineno prev_offset = self.head.addr for current in self.nodes(): if current.co_lnotab == prev_lineno: continue new_offset = current.co_lnotab - prev_lineno new_offset = 0xff if new_offset > 0xff else new_offset rvalue += struct.pack("BB", current.addr - prev_offset, (current.co_lnotab - prev_lineno) & 0xff) prev_lineno = current.co_lnotab prev_offset = current.addr return rvalue def delete_node(self, node): ''' Deletes a node from the graph, removing the instruction from the produced bytecode stream ''' # For each instruction pointing to instruction to be delete, # move the pointer to the next instruction for x in node.xrefs: x.target = node.next if node.next is not None: node.next.xrefs.append(x) # Clean up the doubly linked list if node.prev is not None: node.prev.next = node.next if node.next is not None: node.next.prev = node.prev if node == self.head: self.head = node.next del self.bytecodes[node.addr] def disassemble(self, start=None, count=None): ''' Simple disassembly routine for analyzing nodes in the graph ''' rvalue = "" for x in self.nodes(start): rvalue += "[%04d] %04x %-6s %s\n" % \ (x.co_lnotab, x.addr, x.hex(), x.disassemble()) return rvalue def get_code(self, start=None): ''' Produce a new code object based on the graph ''' self.refactor() # generate a new co_lineno new_co_lineno = self.calc_lnotab() # generate new bytecode stream new_co_code = "" for x in self.nodes(start): new_co_code += x.bin() # create a new code object with modified bytecode and updated line numbers # a new code object is necessary because co_code is readonly rvalue = new.code(self.code.co_argcount, self.code.co_nlocals, self.code.co_stacksize, self.code.co_flags, new_co_code, self.code.co_consts, self.code.co_names, self.code.co_varnames, self.code.co_filename, self.code.co_name, self.code.co_firstlineno, new_co_lineno) return rvalue def nodes(self, start=None): ''' Iterator for stepping through bytecodes in order ''' if start is None: current = self.head else: current = start while current is not None: yield current current = current.next raise StopIteration def parse_bytecode(self): ''' Parses the bytecode stream and creates an instruction graph ''' self.bytecodes = {} prev = None offset = 0 targets = [] while offset < len(self.code.co_code): next = Bytecode(self.base + offset, self.code.co_code[offset:offset+3], prev) self.bytecodes[self.base + offset] = next offset += self.bytecodes[offset].len() if prev is not None: prev.next = next prev = next if next.get_target_addr() is not None: targets.append(next.get_target_addr()) for x in targets: if x not in self.bytecodes: print "Nonlinear issue at offset: %08x" % x self.head = self.bytecodes[self.base] self.apply_labels() return def patch_opargs(self, start=None): ''' Updates branch instructions to correct offsets after adding or deleting bytecode ''' for current in self.nodes(start): # No argument, skip to next if current.opcode < dis.HAVE_ARGUMENT: continue # Patch relative offsets if current.opcode in dis.hasjrel: current.oparg = current.target.addr - \ (current.addr+current.len()) # Patch absolute offsets elif current.opcode in dis.hasjabs: current.oparg = current.target.addr def refactor(self): ''' iterates through all bytecodes and determines correct offset position in code sequence after adding or removing bytecode ''' offset = self.base new_bytecodes = {} for current in self.nodes(): new_bytecodes[offset] = current current.addr = offset offset += current.len() current = current.next self.bytecodes = new_bytecodes self.patch_opargs() self.apply_labels()