import gzip, math, os, os.path, re, signal, struct, sys, yaml from cmd import Cmd import lz4.block import colorama from colorama import Fore, Back, Style from unicorn import * from unicorn.arm64_const import * from capstone import * from ceval import ceval, compile import util from util import * import inlines, relocation from svc import SvcHandler from threadmanager import ThreadManager import mmio TRACE_NONE = 0 TRACE_INSTRUCTION = 1 TRACE_BLOCK = 2 TRACE_FUNCTION = 4 TRACE_MEMORY = 8 TRACE_MEMCHECK = 16 def colorDepth(depth): colors = [Fore.RED, Fore.WHITE, Fore.GREEN, Fore.YELLOW, Style.BRIGHT + Fore.BLUE, Fore.MAGENTA, Fore.CYAN] return colors[depth % len(colors)] INSN_PER_SLICE = 100000000 # How many instructions to execute per thread slice class HandleJar(object): def __init__(self, ctu): self.ctu = ctu self.jar = {} def __setitem__(self, handle, obj): self.jar[handle] = obj def __getitem__(self, handle): if handle in self.jar: return self.jar[handle] print '~~ Unknown handle 0x%08x ~~' % handle self.ctu.debugbreak() return None def __delitem__(self, handle): del self.jar[handle] def __contains__(self, handle): return handle in self.jar def items(self): return self.jar.items() def replace(self, old, new): self.jar = {k:v if v is not old else new for k, v in self.jar.items()} class CTU(Cmd, object): def __init__(self, flags=0): Cmd.__init__(self) colorama.init() self.initialized = False self.exiting = False self.firstLoad = True IPCMessage.ctu = self self.flags = 0 self.sublevel = 0 self.breakpoints = set() self.watchpoints = [] self.terminateOnFullSleep = False # Terminate when all threads go to sleep self.mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM) self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM) self.mu.hook_add(UC_HOOK_CODE, self.hook_insn_bytes) self.mu.hook_add(UC_HOOK_BLOCK, self.trace_block) self.mu.hook_add(UC_HOOK_MEM_READ, self.trace_mem_read) self.mu.hook_add(UC_HOOK_MEM_WRITE, self.trace_mem_write) self.mu.hook_add(UC_HOOK_MEM_READ_UNMAPPED, self.trace_unmapped) self.mu.hook_add(UC_HOOK_MEM_WRITE_UNMAPPED, self.trace_unmapped) self.mu.hook_add(UC_HOOK_MEM_FETCH_UNMAPPED, self.trace_unmapped) self.insnhooks = {} self.fetchhooks = {} self.termaddr = 1 << 61 # Pseudoaddress upon which to terminate execution self.mu.mem_map(self.termaddr, 0x1000) self.mu.mem_write(self.termaddr, '\x1F\x20\x03\xD5') # NOP for i in xrange(30): self.hookinsn(0xD53BD060 + i, (lambda i: lambda _, __: self.tlshook(i))(i)) self.svch = SvcHandler(self) self.mappings = [] self.reset() self.enableFP() self.mu.mem_map(inlines.magicBase, 0x1000) self.execfunc = None self.initialized = True def reset(self): self.debugging = False self.started = False self.restarting = False self.singlestep = False self.mainGlobalScope = None self.skipbp = False for addr, size in self.mappings: self.mu.mem_unmap(addr, size) self.mappings = [] self.checkmaps = {} self.checktriggers = [] self.usHeapSize = 0 self.mmiobase = 1 << 58 self.mmiosize = 0 self.mmiomap = [] for cls in mmio.mmioClasses: self.mmiomap.append((cls.physbase, self.mmiobase + self.mmiosize, cls.size, cls(self))) self.mmiosize += cls.size self.map(self.mmiobase, self.mmiosize) self.writehooks = {} self.readhooks = {} self.handles = HandleJar(self) self.handleIter = 0xd000 self.handles[0xFFFF8001] = Process(0x1234) self.handles[0xDEADBEEF] = Process(0xDEAD) self.threads = ThreadManager(self) self.threadIter = 0 self.exports = {} self.funcReplacements = {} self.loadbase = 0 self.loadsize = 0 self.heapbase = 7 << 24 self.heapsize = 32 * 1024 * 1024 # 32MB self.heapoff = 0 self.map(self.heapbase, self.heapsize) self.stacktop = 7 << 24 self.stacksize = 8 * 1024 * 1024 # 8MB self.map(self.stacktop - self.stacksize, self.stacksize) self.writemem(self.heapbase, '\0' * self.heapsize, check=False) self.writemem(self.stacktop - self.stacksize, '\0' * self.stacksize, check=False) for i in xrange(32): self.reg(i, 0) @property def threadId(self): if self.threads.current is None: return '?' else: return str(self.threads.current.id) def newHandle(self, obj): i = self.handleIter self.handleIter += 1 self.handles[i] = obj return i def replaceHandle(self, old, new): self.handles.replace(old, new) def closeHandle(self, handle): if handle == 0xDEADBEEF or handle == 0xFFFF8001: return elif handle in self.handles: obj = self.handles[handle] print 'Closing handle:', obj if hasattr(obj, 'close'): obj.close() del self.handles[handle] def map(self, base, size): if (base & 0xFFF) != 0: off = base & 0xFFF base -= off size += off if (size & 0xFFF) != 0: size = (size & 0xFFFFFFFFFFFFF000) + 0x1000 if (base, size) not in self.mappings: self.mappings.append((base, size)) self.mu.mem_map(base, size) if self.flags & TRACE_MEMCHECK: self.checkmaps[base] = [0] * (size >> 3) def unmap(self, base, size): if (base & 0xFFF) != 0: off = base & 0xFFF base -= off size += off if (size & 0xFFF) != 0: size = (size & 0xFFFFFFFFFFFFF000) + 0x1000 if (base, size) in self.mappings: del self.mappings[self.mappings.index((base, size))] self.mu.mem_unmap(base, size) if self.flags & TRACE_MEMCHECK: del self.checkmaps[base] def getmap(self, addr): for base, size in self.mappings: if base <= addr < base + size: return base, size return -1, -1 def checkread(self, addr, size): if not (self.flags & TRACE_MEMCHECK): return miss = None base, rsize = self.getmap(addr) for i in xrange(size): caddr = addr + i if not (base <= caddr < base + rsize): base, rsize = self.getmap(caddr) if base == -1: continue off = caddr - base if (self.checkmaps[base][off >> 3] & (1 << (off & 7))) == 0: miss = caddr break tlsbase = self.threads.current.tlsbase if self.threads.current is not None else 1 << 64 if miss is not None: print '[%s:%s] Read from uninitialized memory at %s (reading %i bytes from %s)' % (self.threadId, raw(self.threads.current.lastinsn), raw(miss), size, raw(addr)) if tlsbase <= miss < tlsbase + 0x100: self.debugbreak() else: for taddr, tsize in self.checktriggers: if taddr <= miss < taddr + tsize: self.debugbreak() elif addr == tlsbase and size == 4: self.checkwrite(addr, size, unset=True) def checkwrite(self, addr, size, unset=False, trigger=False): if not (self.flags & TRACE_MEMCHECK): return base, rsize = self.getmap(addr) for i in xrange(size): caddr = addr + i if not (base <= caddr < base + rsize): base, rsize = self.getmap(caddr) if base == -1: continue off = caddr - base if unset: self.checkmaps[base][off >> 3] &= 0xFF ^ (1 << (off & 7)) else: self.checkmaps[base][off >> 3] |= 1 << (off & 7) if trigger: self.checktriggers.append((addr, size)) if len(self.checktriggers) == 5: self.checktriggers.pop(0) def setup(self, func): self.execfunc = func def load(self, dn): load = yaml.load(file(dn + '/load.yaml')) if 'nro' in load and not 'nxo' in load: load['nxo'] = load['nro'] elif 'nso' in load and not 'nxo' in load: load['nxo'] = load['nso'] if 'bundle' in load: self.loadmemory(dn + '/' + load['bundle']) elif 'mod' in load: self.loadmod(dn + '/' + load['mod']) elif 'nxo' in load: if not isinstance(load['nxo'], list): load['nxo'] = [load['nxo']] ibase = 0x7100000000 self.loadbase = ibase allImports = [] for name in load['nxo']: print 'Loading', name fn = dn + '/' + name if os.path.exists(fn): imports, exports = self.loadnso(fn, loadbase=ibase) else: imports, exports = self.loadnro(fn + '.nro', loadbase=ibase) self.exports.update(exports) allImports.append(imports) ibase += 0x100000000 self.loadsize = ibase - self.loadbase if True:#self.firstLoad and len(load['nso']) == 1: Address.display_specialized = False for imports in allImports: for name, (addr, addend) in imports.items(): if name in self.exports: self.write64(addr, self.exports[name] + addend) else: print 'Unresolved import:', name if 'maps' in load: for name, (base, fn) in load['maps'].items(): mapLoader(dn + '/' + fn, name, base) if self.mainGlobalScope is not None: self.mainGlobalScope.update(util.addressTypes) self.firstLoad = False def runExecFunc(self): if self.execfunc is None: return self.mainGlobalScope = self.execfunc.func_globals self.execfunc(self) def run(self, flags=0): fl = self.flags self.reset() self.flags = fl | flags self.runExecFunc() def enableFP(self): addr = 0 self.mu.mem_map(addr, 0x1000) self.mu.mem_write(addr, '\x41\x10\x38\xd5\x00\x00\x01\xaa\x40\x10\x18\xd5\x40\x10\x38\xd5\xc0\x03\x5f\xd6') assert (self.call(addr, 3 << 20) >> 20) & 3 == 3 self.mu.mem_unmap(addr, 0x1000) def loadmod(self, fn): data = file(fn, 'rb').read() moff, = struct.unpack('<I', data[4:8]) assert data[moff:moff+4] == 'MOD0' bssStart, bssEnd = struct.unpack('<II', data[moff+0x08:moff+0x10]) bssStart, bssEnd = bssStart + moff, bssEnd + moff moff += struct.unpack('<I', data[moff+0x18:moff+0x1C])[0] base, = struct.unpack('<Q', data[moff+0x20:moff+0x28]) overlength = 0 if bssStart < len(data): data = data[:bssStart] overlength = bssEnd - bssStart else: self.map(base + bssStart, bssEnd - bssStart) self.map(base, len(data) + overlength) self.writemem(base, data) defineAddressClass('Main', base, len(data)) def loadnso(self, fn, loadbase=0x7100000000, relocate=True): data = file(fn, 'rb').read() assert data[0:4] == 'NSO0' toff, tloc, tsize = struct.unpack('<III', data[0x10:0x1C]) roff, rloc, rsize = struct.unpack('<III', data[0x20:0x2C]) doff, dloc, dsize = struct.unpack('<III', data[0x30:0x3C]) bsssize, = struct.unpack('<I', data[0x3C:0x40]) text = lz4.block.decompress(data[toff:roff], uncompressed_size=tsize) rd = lz4.block.decompress(data[roff:doff], uncompressed_size=rsize) data = lz4.block.decompress(data[doff:], uncompressed_size=dsize) full = text if rloc >= len(full): full += '\0' * (rloc - len(full)) full += rd else: full = full[:rloc] + rd if dloc >= len(full): full += '\0' * (dloc - len(full)) full += data else: full = full[:dloc] + data self.map(loadbase, len(full) + bsssize) self.writemem(loadbase, full) defineAddressClass(fn.rsplit('/', 1)[-1].split('.', 1)[0].title(), loadbase, len(full)) if relocate: return relocation.relocate(self, loadbase) def loadnro(self, fn, loadbase=0x7100000000, relocate=True): data = file(fn, 'rb').read() assert data[0x10:0x14] == 'NRO0' tloc, tsize, rloc, rsize, dloc, dsize = struct.unpack('<IIIIII', data[0x20:0x20 + 6 * 4]) modoff, = struct.unpack('<I', data[4:8]) assert data[modoff:modoff+4] == 'MOD0' bssoff, bssend = struct.unpack('<II', data[modoff+8:modoff+16]) bsssize = bssend - bssoff text = data[tloc:tloc+tsize] rd = data[rloc:rloc+rsize] data = data[dloc:dloc+dsize] full = text if rloc >= len(full): full += '\0' * (rloc - len(full)) full += rd else: full = full[:rloc] + rd if dloc >= len(full): full += '\0' * (dloc - len(full)) full += data else: full = full[:dloc] + data if len(full) < bssoff: full += '\0' * (bssoff - len(full)) if bsssize & 0xFFF: bsssize = (bsssize & 0xFFFFF000) + 0x1000 self.map(loadbase, len(full) + bsssize) try: self.writemem(loadbase, full) except: import traceback traceback.print_exc() defineAddressClass(fn.rsplit('/', 1)[-1].split('.', 1)[0].title(), loadbase, len(full)) if relocate: return relocation.relocate(self, loadbase) def loadmemory(self, fn): if not os.path.isfile(fn) and os.path.isfile(fn + '.gz'): with gzip.GzipFile(fn + '.gz', 'rb') as ifp: with file(fn, 'wb') as ofp: print 'Decompressing membundle' ofp.write(ifp.read()) print 'Done!' with file(fn, 'rb') as fp: regions, mainbase, wkcbase = struct.unpack('<IQQ', fp.read(20)) rmap = [] for i in xrange(regions): addr, dlen = struct.unpack('<QI', fp.read(12)) data = fp.read(dlen) self.map(addr, dlen) rmap.append((addr, dlen)) self.writemem(addr, data) mainsize = 0 wkcsize = 0 inMain = inWKC = False last = 0 rmap.sort(key=lambda x: x[0]) for (addr, dlen) in rmap: if addr == mainbase: inMain = True last = addr elif addr == wkcbase: inWKC = True last = addr if (inMain or inWKC) and last != addr: inMain = inWKC = False elif inMain: mainsize += dlen last = addr + dlen elif inWKC: wkcsize += dlen last = addr + dlen defineAddressClass('Main', mainbase, mainsize) defineAddressClass('Wkc', wkcbase, wkcsize) def findMmioObj(self, virtaddr): if self.mmiobase <= virtaddr < self.mmiobase + self.mmiosize: for pbase, vbase, size, obj in self.mmiomap: if vbase <= virtaddr < vbase + size: return obj, virtaddr - vbase + pbase return None, 0 def trace_mem_read(self, mu, access, addr, size, value, user_data): obj, paddr = self.findMmioObj(addr) if obj is not None: nval = obj.read(paddr, size) if nval is None: nval = 0 if size == 1: self.write8(addr, nval) elif size == 2: self.write16(addr, nval) elif size == 4: self.write32(addr, nval) elif size == 8: self.write64(addr, nval) #if addr == 0x710062b698: # if size == 4: # self.write32(addr, 0xdeadbeef) if self.flags & TRACE_MEMORY: value = None if size == 1: value = self.read8(addr, check=False) elif size == 2: value = self.read16(addr, check=False) elif size == 4: value = self.read32(addr, check=False) elif size == 8: value = self.read64(addr, check=False) print '[%s:%s] %i byte read from %s = %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), '0x%x' % value if value is not None else 'unmapped') if addr in self.readhooks: val = self.readhooks[addr](self, size) if val is not None: if size == 1: self.write8(addr, val) elif size == 2: self.write16(addr, val) elif size == 4: self.write32(addr, val) elif size == 8: self.write64(addr, val) print '[%s:%s] %i detoured byte read from %s = %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), '0x%x' % val if val is not None else 'unmapped') #self.debugbreak() self.checkread(addr, size) def trace_mem_write(self, mu, access, addr, size, value, user_data): obj, paddr = self.findMmioObj(addr) if obj is not None: obj.write(paddr, size, value) if self.flags & TRACE_MEMORY: print '[%s:%s] %i byte write to %s = 0x%x' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), value) if self.flags & TRACE_MEMCHECK: self.checkwrite(addr, size) if addr in self.writehooks: if size == 1: self.write8(addr, value) elif size == 2: self.write16(addr, value) elif size == 4: self.write32(addr, value) elif size == 8: self.write64(addr, value) if self.writehooks[addr](self, addr, size, value): del self.writehooks[addr] def trace_unmapped(self, mu, access, addr, size, value, user_data): if access == UC_MEM_FETCH_UNMAPPED: print '[%s:%s] Unmapped fetch of %s' % (self.threadId, raw(self.threads.current.lastinsn), raw(addr)) elif access == UC_MEM_READ_UNMAPPED: print '[%s:%s] Unmapped %i byte read from %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr)) elif access == UC_MEM_WRITE_UNMAPPED: print '[%s:%s] Unmapped %i byte write to %s = 0x%x' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), value) self.debugbreak() def trace_insn(self, mu, addr, size, user_data): if not self.initialized: return for ins in self.md.disasm(str(mu.mem_read(addr, size)), addr): print "[%s] 0x%08x: %s %s" % (self.threadId, ins.address, ins.mnemonic, ins.op_str) print 'x0=0x%x' % self.reg(0) def trace_block(self, mu, addr, size, user_data): if not self.initialized or (self.flags & TRACE_BLOCK) == 0: return print '[%s] Block at %s' % (self.threadId, raw(addr, pad=True)) """if addr == MainAddress(0x1ec928) or addr == MainAddress(0x1ec9e0) or addr == MainAddress(0x1ecab4): print '\nFATAL:\n%s\n%s\n%s\n' % ( self.readmem(self.reg(0), 0x100).split('\0', 1)[0], self.readmem(self.reg(1), 0x100).split('\0', 1)[0], self.readmem(self.reg(2), 0x100).split('\0', 1)[0] ) self.stop()""" def trace_func(self, mu, addr, size, user_data): thread = self.threads.current if not self.initialized or thread is None: return if thread.blx: thread.callstack.append(addr) if self.flags & TRACE_FUNCTION: plen = len('[%s]' % self.threadId + ' ' + ' ' * len(thread.callstack)) #print ' ' * plen + '-> X0 -- %s X1 -- %s' % (raw(self.reg(0)), raw(self.reg(1))) print '[%s]' % self.threadId, ' ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '-> %s' % raw(addr), Style.RESET_ALL thread.blx = False insn = self.read32(addr) bl_mask = 0b11101100 << 24 bl_match = 0b10000100 << 24 blr_mask = 0b011011111010 << 20 blr_match = 0b010001100010 << 20 ret_mask = 0b011011110110 << 20 ret_match = 0b010001100100 << 20 if (insn & bl_mask) == bl_match or (insn & blr_mask) == blr_match: thread.blx = True elif (insn & ret_mask) == ret_match: if self.flags & TRACE_FUNCTION: if len(thread.callstack): plen = len('[%s]' % self.threadId + ' ' + ' ' * len(thread.callstack)) print '[%s]' % self.threadId, ' ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '<- %s' % raw(thread.callstack.pop()), Style.RESET_ALL #print ' ' * plen + '<- X0 -- %s' % raw(self.reg(0)) elif len(thread.callstack): thread.callstack.pop() def hook_insn_bytes(self, mu, addr, size, user_data): self.threads.switched = False self.threadIter += 1 if self.threadIter >= INSN_PER_SLICE: self.threadIter = 0 if self.threads.next(pcOffset=0): return thread = self.threads.current if addr in inlines.reverse: inlines.reverse[addr](self) self.pc = self.reg(30) self.threads.current.blx = False return elif addr in self.funcReplacements: func = self.funcReplacements[addr] func() if self.pc == addr: self.pc = self.reg(30) if thread.blx: thread.blx = False elif self.flags & TRACE_FUNCTION: if len(thread.callstack): print '[%s]' % self.threadId, ' ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '<- %s' % raw(thread.callstack.pop()), Style.RESET_ALL return if self.restarting: return if addr in self.fetchhooks: self.fetchhooks[addr]() if self.skipbp and not self.singlestep: self.skipbp = False elif self.singlestep or addr in self.breakpoints: if self.singlestep: self.singlestep = False else: print 'Breakpoint at %s' % raw(addr) self.skipbp = True self.debugbreak() else: for code, func in self.watchpoints: if func(self): print 'Watchpoint %s triggered at %s' % (code, raw(addr)) self.skipbp = True self.debugbreak() break if self.flags & TRACE_INSTRUCTION and self.flags & TRACE_FUNCTION: if self.threads.current is not None and self.threads.current.blx: self.trace_func(mu, addr, size, user_data) self.trace_insn(mu, addr, size, user_data) else: self.trace_insn(mu, addr, size, user_data) self.trace_func(mu, addr, size, user_data) elif self.flags & TRACE_INSTRUCTION: self.trace_insn(mu, addr, size, user_data) self.trace_func(mu, addr, size, user_data) else: self.trace_func(mu, addr, size, user_data) self.threads.current.lastinsn = addr insn, = struct.unpack('<I', self.mu.mem_read(addr, 4)) if insn in self.insnhooks: if self.insnhooks[insn](self, addr) == False: self.pc += 4 def hookinsn(self, insn, func=None): def sub(func): assert insn not in self.insnhooks self.insnhooks[insn] = func if func is None: return sub sub(func) def hookfetch(self, addr, func=None): addr = native(addr) def sub(func): assert addr not in self.fetchhooks self.fetchhooks[addr] = func if func is None: return sub sub(func) def hookread(self, addr): addr = native(addr) def sub(func): assert addr not in self.readhooks self.readhooks[addr] = func return sub def hookwrite(self, addr, func=None): def sub(func): assert addr not in self.fetchhooks self.writehooks[addr] = func if func is None: return sub sub(func) def replaceFunction(self, addr): addr = native(addr) def sub(func): regcount = func.__code__.co_argcount - 1 def dsub(): args = [self.reg(i) for i in xrange(regcount)] ret = func(self, *args) if isinstance(ret, tuple) or isinstance(ret, list): for i, v in enumerate(ret): self.reg(i, v) elif ret is not None: self.reg(0, ret) self.funcReplacements[addr] = dsub dsub.original = func return func return sub def tlshook(self, reg): self.reg(reg, self.threads.current.tlsbase) return False def call(self, pc, *args, **kwargs): _start = kwargs['_start'] if '_start' in kwargs else False if pc in self.exports: print 'Calling', pc pc = self.exports[pc] thread = self.threads.create(native(pc), native(self.stacktop), *map(native, args)) if _start: thread.regs[0+2] = 0 thread.regs[1+2] = thread.handle if not self.started: self.started = True first = True while first or (not self.exiting and (self.threads.switched or len(self.threads.running))): first = False self.threads.current.thaw() try: self.mu.emu_start(native(pc), self.termaddr + 4) except: import traceback traceback.print_exc() print 'Exception at %s' % raw(self.threads.current.lastinsn) self.dumpregs() break if self.threads.current is not None: pc = self.threads.current.regs[0] if self.exiting: sys.exit(0) self.threads.clear() self.started = False if self.restarting: self.restarting = False raise Restart() return self.mu.reg_read(UC_ARM64_REG_X0) def stop(self): self.mu.reg_write(UC_ARM64_REG_PC, self.termaddr) self.exiting = True def malloc(self, size): self.heapoff += size assert self.heapoff <= self.heapsize return self.heapbase + self.heapoff - size def free(self, ptr): pass # Lol. def writemem(self, addr, data, check=True): try: addr = native(addr) self.mu.mem_write(addr, data) if check: self.checkwrite(addr, len(data)) return True except unicorn.UcError: return False def write8(self, addr, data, check=True): return self.writemem(addr, struct.pack('<B', data), check=check) def write16(self, addr, data, check=True): return self.writemem(addr, struct.pack('<H', data), check=check) def write32(self, addr, data, check=True): return self.writemem(addr, struct.pack('<I', data), check=check) def write64(self, addr, data, check=True): return self.writemem(addr, struct.pack('<Q', data), check=check) def readmem(self, addr, size, check=True): try: addr = native(addr) if check and self.flags & TRACE_MEMCHECK: self.checkread(addr, size) return str(self.mu.mem_read(addr, size)) except unicorn.UcError: return None def read8(self, addr, check=True): v = self.readmem(addr, 1, check=check) return struct.unpack('<B', v)[0] if v is not None else None def readS8(self, addr, check=True): v = self.readmem(addr, 1, check=check) return struct.unpack('<b', v)[0] if v is not None else None def read16(self, addr, check=True): v = self.readmem(addr, 2, check=check) return struct.unpack('<H', v)[0] if v is not None else None def readS16(self, addr, check=True): v = self.readmem(addr, 2, check=check) return struct.unpack('<h', v)[0] if v is not None else None def read32(self, addr, check=True): v = self.readmem(addr, 4, check=check) return struct.unpack('<I', v)[0] if v is not None else None def readS32(self, addr, check=True): v = self.readmem(addr, 4, check=check) return struct.unpack('<i', v)[0] if v is not None else None def read64(self, addr, check=True): v = self.readmem(addr, 8, check=check) return struct.unpack('<Q', v)[0] if v is not None else None def readS64(self, addr, check=True): v = self.readmem(addr, 8, check=check) return struct.unpack('<q', v)[0] if v is not None else None def readstring(self, addr): if addr is None: return None ret = '' while True: c = self.read8(addr) addr += 1 if c == 0 or c is None: return ret ret += chr(c) def memregions(self): lastend = 0 for begin, end, perms in sorted(self.mu.mem_regions(), key=lambda x: x[0]): if begin > lastend: yield lastend, begin, -1 yield begin, end + 1, perms lastend = end + 1 if lastend != 1 << 64: yield lastend, 1 << 64, -1 def reg(self, i, val=None): sr = {'LR': 30, 'SP': 31} for ri in xrange(32): sr['X%i' % ri] = ri if isinstance(i, str) and i.upper() in sr: i = sr[i.upper()] if i <= 28: c = UC_ARM64_REG_X0 + i elif i == 29 or i == 30: c = UC_ARM64_REG_X29 + i - 29 elif i == 31: c = UC_ARM64_REG_SP else: return None if val is None: return self.mu.reg_read(c) else: self.mu.reg_write(c, native(val)) return True @property def pc(self): return self.mu.reg_read(UC_ARM64_REG_PC) @pc.setter def pc(self, val): self.mu.reg_write(UC_ARM64_REG_PC, val) def dumpregs(self): sr = {30: 'LR', 31: 'SP'} print '-' * 52 for i in xrange(0, 32, 2): an = sr[i] if i in sr else 'X%i' % i bn = sr[i + 1] if i + 1 in sr else 'X%i' % (i + 1) an += ' ' * (3 - len(an)) bn += ' ' * (3 - len(bn)) print '%s - 0x%016x %s - 0x%016x' % ( an, self.reg(i), bn, self.reg(i + 1) ) print '-' * 52 print def dumpmem(self, addr, size, check=False): addr = native(addr) data = self.readmem(addr, size, check=check) if data is None: print 'Unmapped memory at %s' % raw(addr) return data = map(ord, data) fmt = '%%0%ix |' % (int(math.log(addr + size, 16)) + 1) for i in xrange(0, len(data), 16): print fmt % (addr + i), ascii = '' for j in xrange(16): if i + j < len(data): print '%02x' % data[i + j], if 0x20 <= data[i+j] <= 0x7E: ascii += chr(data[i+j]) else: ascii += '.' else: print ' ', ascii += ' ' if j == 7: print '', ascii += ' ' print '|', ascii def reprompt(self): if self.started: self.prompt = '[%s] ctu %s> ' % (self.threadId, raw(self.mu.reg_read(UC_ARM64_REG_PC))) else: self.prompt = 'ctu> ' def debug(self, sub=False): self.debugging = True self.reprompt() try: self.sublevel += 1 while True: try: self.cmdloop() break except KeyboardInterrupt: print finally: self.sublevel -= 1 if self.sublevel == 1: self.prompt = 'ctu> ' def debugbreak(self): try: self.debug(sub=True) except Restart: self.restarting = True return self.stop() def print_topics(self, header, cmds, cmdlen, maxcol): nix = 'EOF', 'b', 'c', 's', 'r', 't' if header is not None: Cmd.print_topics(self, header, [cmd for cmd in cmds if cmd not in nix], cmdlen, maxcol) def do_EOF(self, line): print try: if raw_input('Really exit? y/n: ').startswith('y'): self.exiting = True sys.exit() except EOFError: print self.exiting = True sys.exit() def do_exit(self, line): """exit Exit the debugger.""" sys.exit() def do_start(self, line): """s/start Start or restart the code.""" if self.sublevel != 1: raise Restart() while True: self.reset() try: self.runExecFunc() break except Restart: print 'got restart at', self.sublevel continue do_s = do_start def do_trace(self, line): """t/trace (i/instruction | b/block | f/function | m/memory) Toggles tracing of instructions, blocks, functions, or memory.""" if line.startswith('i'): self.flags ^= TRACE_INSTRUCTION print 'Instruction tracing', 'on' if self.flags & TRACE_INSTRUCTION else 'off' elif line.startswith('b'): self.flags ^= TRACE_BLOCK print 'Block tracing', 'on' if self.flags & TRACE_BLOCK else 'off' elif line.startswith('f'): self.flags ^= TRACE_FUNCTION print 'Function tracing', 'on' if self.flags & TRACE_FUNCTION else 'off' elif line.startswith('m'): self.flags ^= TRACE_MEMORY print 'Memory tracing', 'on' if self.flags & TRACE_MEMORY else 'off' else: print 'Unknown trace flag' do_t = do_trace def do_memcheck(self, line): """mc/memcheck Toggles memory access validations.""" self.flags ^= TRACE_MEMCHECK print 'Memcheck', 'on' if self.flags & TRACE_MEMCHECK else 'off' do_mc = do_memcheck def do_break(self, addr): """b/break [name] Without `name`, list breakpoints. Given a symbol name or address, toggle breakpoint.""" if addr == '': print 'Breakpoints:' for addr in self.breakpoints: print '*', addr return try: addr = raw(addr) except BadAddr: print 'Invalid address/symbol' return if addr in self.breakpoints: print 'Removing breakpoint at %s' % addr self.breakpoints.remove(addr) else: print 'Breaking at %s' % addr self.breakpoints.add(addr) do_b = do_break def complete_break(self, text, line, begidx, endidx): ftext = line.split(' ', 1)[1] if ' ' in line else '' cut = len(ftext) - len(text) return [sym[cut:] for sym in symbols.keys() if sym.startswith(ftext)] complete_b = complete_break def do_bt(self, line): """bt Prints the call stack.""" print 'Call stack:' for i, x in enumerate(self.threads.current.callstack[::-1]): print '%03i: %s' % (i, raw(x)) def do_sym(self, name): """sym <name> Prints the address of a given symbol.""" try: print raw(name) except BadAddr: print 'Invalid address/symbol' complete_sym = complete_break def do_continue(self, line): """c/continue Continues execution of the code.""" if self.sublevel == 1: print 'Not running' else: return True do_c = do_continue def do_next(self, line): """n/next Step to the next instruction.""" if self.sublevel == 1: print 'Not running' else: self.singlestep = True return True do_n = do_next def do_regs(self, line): """r/reg/regs [reg [value]] No parameters: Display registers. Reg parameter: Display one register. Otherwise: Assign a value (always hex, or a symbol) to a register.""" if line == '': return self.dumpregs() elif ' ' in line: r, v = line.split(' ', 1) try: v = raw(v) if self.reg(r, v) is None: print 'Invalid register' except BadAddr: print 'Invalid address/Symbol' else: v = self.reg(line) if v is False: print 'Invalid register' else: print '0x%016x' % v do_r = do_reg = do_regs def do_exec(self, line): """x/exec <code> Evaluates a given line of C.""" try: val = ceval(line, self) except: import traceback traceback.print_exc() print 'Execution failed' return if val is not None: print '0x%x' % val do_x = do_exec def do_dump(self, line): """dump <address> [size] Dumps `size` (default: 0x100) bytes of memory at an address. If the address takes the form `*register` (e.g. `*X1`) then the value of that register will be used.""" line = list(line.split(' ')) if len(line[0]) == 0: print 'No address' elif len(line) <= 2: if len(line[0]) and line[0][0] == '*': line[0] = self.reg(line[0][1:]) if line[0] is None: print 'Invalid register' return else: try: line[0] = raw(line[0]) except BadAddr: print 'Invalid address/symbol' return if len(line) == 2: line[1] = parseInt(line[1]) if line[1] is None or line[1] >= 0x10000: print 'Invalid size' return self.dumpmem(line[0], 0x100 if len(line) == 1 else line[1]) else: print 'Too many parameters' def do_save(self, line): """save <address> <size> <fn> Writes `size` bytes of memory to a file. If the address or size takes the form `*register` (e.g. `*X1`) then the value of that register will be used.""" line = list(line.split(' ')) addr, size, fn = line addr = self.reg(addr[1:]) if addr.startswith('*') else parseInt(addr) size = self.reg(size[1:]) if size.startswith('*') else parseInt(size) with file(fn, 'wb') as fp: fp.write(self.readmem(addr, size)) print 'Wrote to file' def do_ad(self, line): """ad Toggle address display specialization.""" Address.display_specialized = not Address.display_specialized print '%s specialized address display' % ('Enabled' if Address.display_specialized else 'Disabled') self.reprompt() def do_watch(self, line): """w/watch [expression] Breaks when expression evaluates to true. Without an expression, list existing watchpoints.""" if line == '': print 'Watchpoints:' for code, _ in self.watchpoints: print '*', code return if line in [code for code, _ in self.watchpoints]: self.watchpoints = [(code, func) for code, func in self.watchpoints if code != line] print 'Watchpoint deleted' else: self.watchpoints.append((line, compile(line))) print 'Watchpoint added' do_w = do_watch def do_memregions(self, line): """mr/memregions Displays mapped memory regions.""" print 'Mapped memory regions' print '---------------------' for begin, end, perms in self.memregions(): if perms != -1: print '%016x - %016x' % (begin, end) do_mr = do_memregions def debug(*flags): def sub(func): ctu = CTU() ctu.setup(func) ctu.flags |= reduce(lambda a, x: a | x, flags, TRACE_NONE) ctu.debug() return func if len(flags) == 1 and callable(flags[0]): func = flags[0] flags = [TRACE_NONE] return sub(func) else: return sub def run(*flags): def sub(func): ctu = CTU() ctu.setup(func) ctu.flags |= reduce(lambda a, x: a | x, flags, TRACE_NONE) ctu.run() return func if len(flags) == 1 and callable(flags[0]): func = flags[0] flags = [TRACE_NONE] return sub(func) else: return sub