import gzip, math, os, os.path, re, signal, struct, sys, yaml
from cmd import Cmd
import lz4.block

import colorama
from colorama import Fore, Back, Style

from unicorn import *
from unicorn.arm64_const import *

from capstone import *

from ceval import ceval, compile
import util
from util import *
import inlines, relocation
from svc import SvcHandler
from threadmanager import ThreadManager
import mmio

TRACE_NONE = 0
TRACE_INSTRUCTION = 1
TRACE_BLOCK = 2
TRACE_FUNCTION = 4
TRACE_MEMORY = 8
TRACE_MEMCHECK = 16

def colorDepth(depth):
	colors = [Fore.RED, Fore.WHITE, Fore.GREEN, Fore.YELLOW, Style.BRIGHT + Fore.BLUE, Fore.MAGENTA, Fore.CYAN]

	return colors[depth % len(colors)]

INSN_PER_SLICE = 100000000 # How many instructions to execute per thread slice

class HandleJar(object):
	def __init__(self, ctu):
		self.ctu = ctu
		self.jar = {}

	def __setitem__(self, handle, obj):
		self.jar[handle] = obj

	def __getitem__(self, handle):
		if handle in self.jar:
			return self.jar[handle]
		print '~~ Unknown handle 0x%08x ~~' % handle
		self.ctu.debugbreak()
		return None

	def __delitem__(self, handle):
		del self.jar[handle]

	def __contains__(self, handle):
		return handle in self.jar

	def items(self):
		return self.jar.items()

	def replace(self, old, new):
		self.jar = {k:v if v is not old else new for k, v in self.jar.items()}

class CTU(Cmd, object):
	def __init__(self, flags=0):
		Cmd.__init__(self)

		colorama.init()
		self.initialized = False
		self.exiting = False
		self.firstLoad = True

		IPCMessage.ctu = self

		self.flags = 0
		self.sublevel = 0
		self.breakpoints = set()
		self.watchpoints = []

		self.terminateOnFullSleep = False # Terminate when all threads go to sleep

		self.mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
		self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM)

		self.mu.hook_add(UC_HOOK_CODE, self.hook_insn_bytes)
		self.mu.hook_add(UC_HOOK_BLOCK, self.trace_block)
		self.mu.hook_add(UC_HOOK_MEM_READ, self.trace_mem_read)
		self.mu.hook_add(UC_HOOK_MEM_WRITE, self.trace_mem_write)
		self.mu.hook_add(UC_HOOK_MEM_READ_UNMAPPED, self.trace_unmapped)
		self.mu.hook_add(UC_HOOK_MEM_WRITE_UNMAPPED, self.trace_unmapped)
		self.mu.hook_add(UC_HOOK_MEM_FETCH_UNMAPPED, self.trace_unmapped)

		self.insnhooks = {}
		self.fetchhooks = {}

		self.termaddr = 1 << 61 # Pseudoaddress upon which to terminate execution
		self.mu.mem_map(self.termaddr, 0x1000)
		self.mu.mem_write(self.termaddr, '\x1F\x20\x03\xD5') # NOP

		for i in xrange(30):
			self.hookinsn(0xD53BD060 + i, (lambda i: lambda _, __: self.tlshook(i))(i))

		self.svch = SvcHandler(self)

		self.mappings = []

		self.reset()
		self.enableFP()

		self.mu.mem_map(inlines.magicBase, 0x1000)

		self.execfunc = None
		self.initialized = True

	def reset(self):
		self.debugging = False
		self.started = False
		self.restarting = False
		self.singlestep = False
		self.mainGlobalScope = None

		self.skipbp = False
		
		for addr, size in self.mappings:
			self.mu.mem_unmap(addr, size)
		self.mappings = []
		self.checkmaps = {}
		self.checktriggers = []

		self.usHeapSize = 0

		self.mmiobase = 1 << 58
		self.mmiosize = 0
		self.mmiomap = []
		for cls in mmio.mmioClasses:
			self.mmiomap.append((cls.physbase, self.mmiobase + self.mmiosize, cls.size, cls(self)))
			self.mmiosize += cls.size
		self.map(self.mmiobase, self.mmiosize)

		self.writehooks = {}
		self.readhooks = {}

		self.handles = HandleJar(self)
		self.handleIter = 0xd000

		self.handles[0xFFFF8001] = Process(0x1234)
		self.handles[0xDEADBEEF] = Process(0xDEAD)

		self.threads = ThreadManager(self)
		self.threadIter = 0

		self.exports = {}

		self.funcReplacements = {}

		self.loadbase = 0
		self.loadsize = 0

		self.heapbase = 7 << 24
		self.heapsize = 32 * 1024 * 1024 # 32MB
		self.heapoff = 0
		self.map(self.heapbase, self.heapsize)

		self.stacktop = 7 << 24
		self.stacksize = 8 * 1024 * 1024 # 8MB
		self.map(self.stacktop - self.stacksize, self.stacksize)

		self.writemem(self.heapbase, '\0' * self.heapsize, check=False)
		self.writemem(self.stacktop - self.stacksize, '\0' * self.stacksize, check=False)

		for i in xrange(32):
			self.reg(i, 0)

	@property
	def threadId(self):
		if self.threads.current is None:
			return '?'
		else:
			return str(self.threads.current.id)

	def newHandle(self, obj):
		i = self.handleIter
		self.handleIter += 1
		self.handles[i] = obj
		return i

	def replaceHandle(self, old, new):
		self.handles.replace(old, new)

	def closeHandle(self, handle):
		if handle == 0xDEADBEEF or handle == 0xFFFF8001:
			return
		elif handle in self.handles:
			obj = self.handles[handle]
			print 'Closing handle:', obj
			if hasattr(obj, 'close'):
				obj.close()
			del self.handles[handle]

	def map(self, base, size):
		if (base & 0xFFF) != 0:
			off = base & 0xFFF
			base -= off
			size += off
		if (size & 0xFFF) != 0:
			size = (size & 0xFFFFFFFFFFFFF000) + 0x1000
		if (base, size) not in self.mappings:
			self.mappings.append((base, size))
			self.mu.mem_map(base, size)
			if self.flags & TRACE_MEMCHECK:
				self.checkmaps[base] = [0] * (size >> 3)

	def unmap(self, base, size):
		if (base & 0xFFF) != 0:
			off = base & 0xFFF
			base -= off
			size += off
		if (size & 0xFFF) != 0:
			size = (size & 0xFFFFFFFFFFFFF000) + 0x1000
		if (base, size) in self.mappings:
			del self.mappings[self.mappings.index((base, size))]
			self.mu.mem_unmap(base, size)
			if self.flags & TRACE_MEMCHECK:
				del self.checkmaps[base]

	def getmap(self, addr):
		for base, size in self.mappings:
			if base <= addr < base + size:
				return base, size
		return -1, -1

	def checkread(self, addr, size):
		if not (self.flags & TRACE_MEMCHECK):
			return
		miss = None
		base, rsize = self.getmap(addr)
		for i in xrange(size):
			caddr = addr + i
			if not (base <= caddr < base + rsize):
				base, rsize = self.getmap(caddr)
				if base == -1:
					continue
			off = caddr - base
			if (self.checkmaps[base][off >> 3] & (1 << (off & 7))) == 0:
				miss = caddr
				break
		tlsbase = self.threads.current.tlsbase if self.threads.current is not None else 1 << 64
		if miss is not None:
			print '[%s:%s] Read from uninitialized memory at %s (reading %i bytes from %s)' % (self.threadId, raw(self.threads.current.lastinsn), raw(miss), size, raw(addr))
			if tlsbase <= miss < tlsbase + 0x100:
				self.debugbreak()
			else:
				for taddr, tsize in self.checktriggers:
					if taddr <= miss < taddr + tsize:
						self.debugbreak()
		elif addr == tlsbase and size == 4:
			self.checkwrite(addr, size, unset=True)

	def checkwrite(self, addr, size, unset=False, trigger=False):
		if not (self.flags & TRACE_MEMCHECK):
			return
		base, rsize = self.getmap(addr)
		for i in xrange(size):
			caddr = addr + i
			if not (base <= caddr < base + rsize):
				base, rsize = self.getmap(caddr)
				if base == -1:
					continue
			off = caddr - base
			if unset:
				self.checkmaps[base][off >> 3] &= 0xFF ^ (1 << (off & 7))
			else:
				self.checkmaps[base][off >> 3] |= 1 << (off & 7)

		if trigger:
			self.checktriggers.append((addr, size))
			if len(self.checktriggers) == 5:
				self.checktriggers.pop(0)

	def setup(self, func):
		self.execfunc = func

	def load(self, dn):
		load = yaml.load(file(dn + '/load.yaml'))

		if 'nro' in load and not 'nxo' in load:
			load['nxo'] = load['nro']
		elif 'nso' in load and not 'nxo' in load:
			load['nxo'] = load['nso']

		if 'bundle' in load:
			self.loadmemory(dn + '/' + load['bundle'])
		elif 'mod' in load:
			self.loadmod(dn + '/' + load['mod'])
		elif 'nxo' in load:
			if not isinstance(load['nxo'], list):
				load['nxo'] = [load['nxo']]
			ibase = 0x7100000000
			self.loadbase = ibase
			allImports = []
			for name in load['nxo']:
				print 'Loading', name
				fn = dn + '/' + name
				if os.path.exists(fn):
					imports, exports = self.loadnso(fn, loadbase=ibase)
				else:
					imports, exports = self.loadnro(fn + '.nro', loadbase=ibase)
				self.exports.update(exports)
				allImports.append(imports)
				ibase += 0x100000000
			self.loadsize = ibase - self.loadbase
			if True:#self.firstLoad and len(load['nso']) == 1:
				Address.display_specialized = False

			for imports in allImports:
				for name, (addr, addend) in imports.items():
					if name in self.exports:
						self.write64(addr, self.exports[name] + addend)
					else:
						print 'Unresolved import:', name

		if 'maps' in load:
			for name, (base, fn) in load['maps'].items():
				mapLoader(dn + '/' + fn, name, base)

		if self.mainGlobalScope is not None:
			self.mainGlobalScope.update(util.addressTypes)

		self.firstLoad = False

	def runExecFunc(self):
		if self.execfunc is None:
			return

		self.mainGlobalScope = self.execfunc.func_globals
		self.execfunc(self)

	def run(self, flags=0):
		fl = self.flags
		self.reset()
		self.flags = fl | flags
		self.runExecFunc()

	def enableFP(self):
		addr = 0
		self.mu.mem_map(addr, 0x1000)
		self.mu.mem_write(addr, '\x41\x10\x38\xd5\x00\x00\x01\xaa\x40\x10\x18\xd5\x40\x10\x38\xd5\xc0\x03\x5f\xd6')
		assert (self.call(addr, 3 << 20) >> 20) & 3 == 3
		self.mu.mem_unmap(addr, 0x1000)

	def loadmod(self, fn):
		data = file(fn, 'rb').read()

		moff, = struct.unpack('<I', data[4:8])
		assert data[moff:moff+4] == 'MOD0'

		bssStart, bssEnd = struct.unpack('<II', data[moff+0x08:moff+0x10])
		bssStart, bssEnd = bssStart + moff, bssEnd + moff
		moff += struct.unpack('<I', data[moff+0x18:moff+0x1C])[0]
		base, = struct.unpack('<Q', data[moff+0x20:moff+0x28])

		overlength = 0
		if bssStart < len(data):
			data = data[:bssStart]
			overlength = bssEnd - bssStart
		else:
			self.map(base + bssStart, bssEnd - bssStart)

		self.map(base, len(data) + overlength)
		self.writemem(base, data)

		defineAddressClass('Main', base, len(data))

	def loadnso(self, fn, loadbase=0x7100000000, relocate=True):
		data = file(fn, 'rb').read()
		assert data[0:4] == 'NSO0'

		toff, tloc, tsize = struct.unpack('<III', data[0x10:0x1C])
		roff, rloc, rsize = struct.unpack('<III', data[0x20:0x2C])
		doff, dloc, dsize = struct.unpack('<III', data[0x30:0x3C])
		bsssize, = struct.unpack('<I', data[0x3C:0x40])

		text = lz4.block.decompress(data[toff:roff], uncompressed_size=tsize)
		rd = lz4.block.decompress(data[roff:doff], uncompressed_size=rsize)
		data = lz4.block.decompress(data[doff:], uncompressed_size=dsize)

		full = text
		if rloc >= len(full):
			full += '\0' * (rloc - len(full))
			full += rd
		else:
			full = full[:rloc] + rd
		if dloc >= len(full):
			full += '\0' * (dloc - len(full))
			full += data
		else:
			full = full[:dloc] + data

		self.map(loadbase, len(full) + bsssize)
		self.writemem(loadbase, full)
		defineAddressClass(fn.rsplit('/', 1)[-1].split('.', 1)[0].title(), loadbase, len(full))

		if relocate:
			return relocation.relocate(self, loadbase)

	def loadnro(self, fn, loadbase=0x7100000000, relocate=True):
		data = file(fn, 'rb').read()
		assert data[0x10:0x14] == 'NRO0'

		tloc, tsize, rloc, rsize, dloc, dsize = struct.unpack('<IIIIII', data[0x20:0x20 + 6 * 4])
		modoff, = struct.unpack('<I', data[4:8])
		assert data[modoff:modoff+4] == 'MOD0'
		bssoff, bssend = struct.unpack('<II', data[modoff+8:modoff+16])
		bsssize = bssend - bssoff

		text = data[tloc:tloc+tsize]
		rd = data[rloc:rloc+rsize]
		data = data[dloc:dloc+dsize]

		full = text
		if rloc >= len(full):
			full += '\0' * (rloc - len(full))
			full += rd
		else:
			full = full[:rloc] + rd
		if dloc >= len(full):
			full += '\0' * (dloc - len(full))
			full += data
		else:
			full = full[:dloc] + data

		if len(full) < bssoff:
			full += '\0' * (bssoff - len(full))

		if bsssize & 0xFFF:
			bsssize = (bsssize & 0xFFFFF000) + 0x1000

		self.map(loadbase, len(full) + bsssize)
		try:
			self.writemem(loadbase, full)
		except:
			import traceback
			traceback.print_exc()
		defineAddressClass(fn.rsplit('/', 1)[-1].split('.', 1)[0].title(), loadbase, len(full))

		if relocate:
			return relocation.relocate(self, loadbase)

	def loadmemory(self, fn):
		if not os.path.isfile(fn) and os.path.isfile(fn + '.gz'):
			with gzip.GzipFile(fn + '.gz', 'rb') as ifp:
				with file(fn, 'wb') as ofp:
					print 'Decompressing membundle'
					ofp.write(ifp.read())
					print 'Done!'

		with file(fn, 'rb') as fp:
			regions, mainbase, wkcbase = struct.unpack('<IQQ', fp.read(20))
			rmap = []
			for i in xrange(regions):
				addr, dlen = struct.unpack('<QI', fp.read(12))
				data = fp.read(dlen)
				self.map(addr, dlen)
				rmap.append((addr, dlen))
				self.writemem(addr, data)

		mainsize = 0
		wkcsize = 0
		inMain = inWKC = False
		last = 0
		rmap.sort(key=lambda x: x[0])
		for (addr, dlen) in rmap:
			if addr == mainbase:
				inMain = True
				last = addr
			elif addr == wkcbase:
				inWKC = True
				last = addr

			if (inMain or inWKC) and last != addr:
				inMain = inWKC = False
			elif inMain:
				mainsize += dlen
				last = addr + dlen
			elif inWKC:
				wkcsize += dlen
				last = addr + dlen

		defineAddressClass('Main', mainbase, mainsize)
		defineAddressClass('Wkc', wkcbase, wkcsize)

	def findMmioObj(self, virtaddr):
		if self.mmiobase <= virtaddr < self.mmiobase + self.mmiosize:
			for pbase, vbase, size, obj in self.mmiomap:
				if vbase <= virtaddr < vbase + size:
					return obj, virtaddr - vbase + pbase
		return None, 0

	def trace_mem_read(self, mu, access, addr, size, value, user_data):
		obj, paddr = self.findMmioObj(addr)
		if obj is not None:
			nval = obj.read(paddr, size)
			if nval is None:
				nval = 0
			if size == 1:
				self.write8(addr, nval)
			elif size == 2:
				self.write16(addr, nval)
			elif size == 4:
				self.write32(addr, nval)
			elif size == 8:
				self.write64(addr, nval)
		#if addr == 0x710062b698:
		#	if size == 4:
		#		self.write32(addr, 0xdeadbeef)
		if self.flags & TRACE_MEMORY:
			value = None
			if size == 1:
				value = self.read8(addr, check=False)
			elif size == 2:
				value = self.read16(addr, check=False)
			elif size == 4:
				value = self.read32(addr, check=False)
			elif size == 8:
				value = self.read64(addr, check=False)
			print '[%s:%s] %i byte read from %s = %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), '0x%x' % value if value is not None else 'unmapped')

		if addr in self.readhooks:
			val = self.readhooks[addr](self, size)
			if val is not None:
				if size == 1:
					self.write8(addr, val)
				elif size == 2:
					self.write16(addr, val)
				elif size == 4:
					self.write32(addr, val)
				elif size == 8:
					self.write64(addr, val)
				print '[%s:%s] %i detoured byte read from %s = %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), '0x%x' % val if val is not None else 'unmapped')
			#self.debugbreak()
		self.checkread(addr, size)

	def trace_mem_write(self, mu, access, addr, size, value, user_data):
		obj, paddr = self.findMmioObj(addr)
		if obj is not None:
			obj.write(paddr, size, value)
		if self.flags & TRACE_MEMORY:
			print '[%s:%s] %i byte write to %s = 0x%x' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), value)
		if self.flags & TRACE_MEMCHECK:
			self.checkwrite(addr, size)
		if addr in self.writehooks:
			if size == 1:
				self.write8(addr, value)
			elif size == 2:
				self.write16(addr, value)
			elif size == 4:
				self.write32(addr, value)
			elif size == 8:
				self.write64(addr, value)

			if self.writehooks[addr](self, addr, size, value):
				del self.writehooks[addr]

	def trace_unmapped(self, mu, access, addr, size, value, user_data):
		if access == UC_MEM_FETCH_UNMAPPED:
			print '[%s:%s] Unmapped fetch of %s' % (self.threadId, raw(self.threads.current.lastinsn), raw(addr))
		elif access == UC_MEM_READ_UNMAPPED:
			print '[%s:%s] Unmapped %i byte read from %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr))
		elif access == UC_MEM_WRITE_UNMAPPED:
			print '[%s:%s] Unmapped %i byte write to %s = 0x%x' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), value)
		self.debugbreak()

	def trace_insn(self, mu, addr, size, user_data):
		if not self.initialized:
			return
		for ins in self.md.disasm(str(mu.mem_read(addr, size)), addr):
			print "[%s] 0x%08x:    %s  %s" % (self.threadId, ins.address, ins.mnemonic, ins.op_str)
			print 'x0=0x%x' % self.reg(0)

	def trace_block(self, mu, addr, size, user_data):
		if not self.initialized or (self.flags & TRACE_BLOCK) == 0:
			return
		print '[%s] Block at %s' % (self.threadId, raw(addr, pad=True))
		"""if addr == MainAddress(0x1ec928) or addr == MainAddress(0x1ec9e0) or addr == MainAddress(0x1ecab4):
			print '\nFATAL:\n%s\n%s\n%s\n' % (
				self.readmem(self.reg(0), 0x100).split('\0', 1)[0], 
				self.readmem(self.reg(1), 0x100).split('\0', 1)[0], 
				self.readmem(self.reg(2), 0x100).split('\0', 1)[0]
			)
			self.stop()"""

	def trace_func(self, mu, addr, size, user_data):
		thread = self.threads.current
		if not self.initialized or thread is None:
			return
		if thread.blx:
			thread.callstack.append(addr)
			if self.flags & TRACE_FUNCTION:
				plen = len('[%s]' % self.threadId + ' ' + '  ' * len(thread.callstack))
				#print ' ' * plen + '-> X0 -- %s  X1 -- %s' % (raw(self.reg(0)), raw(self.reg(1)))
				print '[%s]' % self.threadId, '  ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '-> %s' % raw(addr), Style.RESET_ALL
			thread.blx = False
		insn = self.read32(addr)

		bl_mask  = 0b11101100 << 24
		bl_match = 0b10000100 << 24

		blr_mask  = 0b011011111010 << 20
		blr_match = 0b010001100010 << 20

		ret_mask = 0b011011110110 << 20
		ret_match = 0b010001100100 << 20

		if (insn & bl_mask) == bl_match or (insn & blr_mask) == blr_match:
			thread.blx = True
		elif (insn & ret_mask) == ret_match:
			if self.flags & TRACE_FUNCTION:
				if len(thread.callstack):
					plen = len('[%s]' % self.threadId + ' ' + '  ' * len(thread.callstack))
					print '[%s]' % self.threadId, '  ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '<- %s' % raw(thread.callstack.pop()), Style.RESET_ALL
					#print ' ' * plen + '<- X0 -- %s' % raw(self.reg(0))
			elif len(thread.callstack):
				thread.callstack.pop()

	def hook_insn_bytes(self, mu, addr, size, user_data):
		self.threads.switched = False
		self.threadIter += 1
		if self.threadIter >= INSN_PER_SLICE:
			self.threadIter = 0
			if self.threads.next(pcOffset=0):
				return
		thread = self.threads.current
		if addr in inlines.reverse:
			inlines.reverse[addr](self)
			self.pc = self.reg(30)
			self.threads.current.blx = False
			return
		elif addr in self.funcReplacements:
			func = self.funcReplacements[addr]
			func()
			if self.pc == addr:
				self.pc = self.reg(30)
			if thread.blx:
				thread.blx = False
			elif self.flags & TRACE_FUNCTION:
				if len(thread.callstack):
					print '[%s]' % self.threadId, '  ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '<- %s' % raw(thread.callstack.pop()), Style.RESET_ALL
			return

		if self.restarting:
			return

		if addr in self.fetchhooks:
			self.fetchhooks[addr]()

		if self.skipbp and not self.singlestep:
			self.skipbp = False
		elif self.singlestep or addr in self.breakpoints:
			if self.singlestep:
				self.singlestep = False
			else:
				print 'Breakpoint at %s' % raw(addr)
				self.skipbp = True
			self.debugbreak()
		else:
			for code, func in self.watchpoints:
				if func(self):
					print 'Watchpoint %s triggered at %s' % (code, raw(addr))
					self.skipbp = True
					self.debugbreak()
					break

		if self.flags & TRACE_INSTRUCTION and self.flags & TRACE_FUNCTION:
			if self.threads.current is not None and self.threads.current.blx:
				self.trace_func(mu, addr, size, user_data)
				self.trace_insn(mu, addr, size, user_data)
			else:
				self.trace_insn(mu, addr, size, user_data)
				self.trace_func(mu, addr, size, user_data)
		elif self.flags & TRACE_INSTRUCTION:
			self.trace_insn(mu, addr, size, user_data)
			self.trace_func(mu, addr, size, user_data)
		else:
			self.trace_func(mu, addr, size, user_data)

		self.threads.current.lastinsn = addr

		insn, = struct.unpack('<I', self.mu.mem_read(addr, 4))

		if insn in self.insnhooks:
			if self.insnhooks[insn](self, addr) == False:
				self.pc += 4

	def hookinsn(self, insn, func=None):
		def sub(func):
			assert insn not in self.insnhooks
			self.insnhooks[insn] = func
		if func is None:
			return sub
		sub(func)

	def hookfetch(self, addr, func=None):
		addr = native(addr)
		def sub(func):
			assert addr not in self.fetchhooks
			self.fetchhooks[addr] = func
		if func is None:
			return sub
		sub(func)

	def hookread(self, addr):
		addr = native(addr)
		def sub(func):
			assert addr not in self.readhooks
			self.readhooks[addr] = func
		return sub

	def hookwrite(self, addr, func=None):
		def sub(func):
			assert addr not in self.fetchhooks
			self.writehooks[addr] = func
		if func is None:
			return sub
		sub(func)

	def replaceFunction(self, addr):
		addr = native(addr)

		def sub(func):
			regcount = func.__code__.co_argcount - 1

			def dsub():
				args = [self.reg(i) for i in xrange(regcount)]
				ret = func(self, *args)
				if isinstance(ret, tuple) or isinstance(ret, list):
					for i, v in enumerate(ret):
						self.reg(i, v)
				elif ret is not None:
					self.reg(0, ret)

			self.funcReplacements[addr] = dsub
			dsub.original = func

			return func

		return sub

	def tlshook(self, reg):
		self.reg(reg, self.threads.current.tlsbase)
		return False

	def call(self, pc, *args, **kwargs):
		_start = kwargs['_start'] if '_start' in kwargs else False

		if pc in self.exports:
			print 'Calling', pc
			pc = self.exports[pc]
		thread = self.threads.create(native(pc), native(self.stacktop), *map(native, args))
		if _start:
			thread.regs[0+2] = 0
			thread.regs[1+2] = thread.handle

		if not self.started:
			self.started = True
			first = True
			while first or (not self.exiting and (self.threads.switched or len(self.threads.running))):
				first = False
				self.threads.current.thaw()
				try:
					self.mu.emu_start(native(pc), self.termaddr + 4)
				except:
					import traceback
					traceback.print_exc()
					print 'Exception at %s' % raw(self.threads.current.lastinsn)
					self.dumpregs()
					break
				if self.threads.current is not None:
					pc = self.threads.current.regs[0]

			if self.exiting:
				sys.exit(0)

			self.threads.clear()
			self.started = False

		if self.restarting:
			self.restarting = False
			raise Restart()

		return self.mu.reg_read(UC_ARM64_REG_X0)

	def stop(self):
		self.mu.reg_write(UC_ARM64_REG_PC, self.termaddr)
		self.exiting = True

	def malloc(self, size):
		self.heapoff += size
		assert self.heapoff <= self.heapsize
		return self.heapbase + self.heapoff - size

	def free(self, ptr):
		pass # Lol.

	def writemem(self, addr, data, check=True):
		try:
			addr = native(addr)
			self.mu.mem_write(addr, data)

			if check:
				self.checkwrite(addr, len(data))
			return True
		except unicorn.UcError:
			return False

	def write8(self, addr, data, check=True):
		return self.writemem(addr, struct.pack('<B', data), check=check)
	def write16(self, addr, data, check=True):
		return self.writemem(addr, struct.pack('<H', data), check=check)
	def write32(self, addr, data, check=True):
		return self.writemem(addr, struct.pack('<I', data), check=check)
	def write64(self, addr, data, check=True):
		return self.writemem(addr, struct.pack('<Q', data), check=check)

	def readmem(self, addr, size, check=True):
		try:
			addr = native(addr)
			if check and self.flags & TRACE_MEMCHECK:
				self.checkread(addr, size)
			return str(self.mu.mem_read(addr, size))
		except unicorn.UcError:
			return None

	def read8(self, addr, check=True):
		v = self.readmem(addr, 1, check=check)
		return struct.unpack('<B', v)[0] if v is not None else None
	def readS8(self, addr, check=True):
		v = self.readmem(addr, 1, check=check)
		return struct.unpack('<b', v)[0] if v is not None else None
	def read16(self, addr, check=True):
		v = self.readmem(addr, 2, check=check)
		return struct.unpack('<H', v)[0] if v is not None else None
	def readS16(self, addr, check=True):
		v = self.readmem(addr, 2, check=check)
		return struct.unpack('<h', v)[0] if v is not None else None
	def read32(self, addr, check=True):
		v = self.readmem(addr, 4, check=check)
		return struct.unpack('<I', v)[0] if v is not None else None
	def readS32(self, addr, check=True):
		v = self.readmem(addr, 4, check=check)
		return struct.unpack('<i', v)[0] if v is not None else None
	def read64(self, addr, check=True):
		v = self.readmem(addr, 8, check=check)
		return struct.unpack('<Q', v)[0] if v is not None else None
	def readS64(self, addr, check=True):
		v = self.readmem(addr, 8, check=check)
		return struct.unpack('<q', v)[0] if v is not None else None

	def readstring(self, addr):
		if addr is None:
			return None
		ret = ''
		while True:
			c = self.read8(addr)
			addr += 1
			if c == 0 or c is None:
				return ret
			ret += chr(c)

	def memregions(self):
		lastend = 0
		for begin, end, perms in sorted(self.mu.mem_regions(), key=lambda x: x[0]):
			if begin > lastend:
				yield lastend, begin, -1
			yield begin, end + 1, perms
			lastend = end + 1
		if lastend != 1 << 64:
			yield lastend, 1 << 64, -1

	def reg(self, i, val=None):
		sr = {'LR': 30, 'SP': 31}
		for ri in xrange(32):
			sr['X%i' % ri] = ri

		if isinstance(i, str) and i.upper() in sr:
			i = sr[i.upper()]

		if i <= 28:
			c = UC_ARM64_REG_X0 + i
		elif i == 29 or i == 30:
			c = UC_ARM64_REG_X29 + i - 29
		elif i == 31:
			c = UC_ARM64_REG_SP
		else:
			return None

		if val is None:
			return self.mu.reg_read(c)
		else:
			self.mu.reg_write(c, native(val))
			return True

	@property
	def pc(self):
		return self.mu.reg_read(UC_ARM64_REG_PC)

	@pc.setter
	def pc(self, val):
		self.mu.reg_write(UC_ARM64_REG_PC, val)

	def dumpregs(self):
		sr = {30: 'LR', 31: 'SP'}
		print '-' * 52
		for i in xrange(0, 32, 2):
			an = sr[i] if i in sr else 'X%i' % i
			bn = sr[i + 1] if i + 1 in sr else 'X%i' % (i + 1)
			an += ' ' * (3 - len(an))
			bn += ' ' * (3 - len(bn))
			print '%s - 0x%016x    %s - 0x%016x' % (
				an, self.reg(i),
				bn, self.reg(i + 1)
			)
		print '-' * 52
		print

	def dumpmem(self, addr, size, check=False):
		addr = native(addr)
		data = self.readmem(addr, size, check=check)
		if data is None:
			print 'Unmapped memory at %s' % raw(addr)
			return
		data = map(ord, data)

		fmt = '%%0%ix |' % (int(math.log(addr + size, 16)) + 1)
		for i in xrange(0, len(data), 16):
			print fmt % (addr + i),
			ascii = ''
			for j in xrange(16):
				if i + j < len(data):
					print '%02x' % data[i + j], 
					if 0x20 <= data[i+j] <= 0x7E:
						ascii += chr(data[i+j])
					else:
						ascii += '.'
				else:
					print '  ', 
					ascii += ' '
				if j == 7:
					print '',
					ascii += ' '
			print '|', ascii

	def reprompt(self):
		if self.started:
			self.prompt = '[%s] ctu %s> ' % (self.threadId, raw(self.mu.reg_read(UC_ARM64_REG_PC)))
		else:
			self.prompt = 'ctu> '

	def debug(self, sub=False):
		self.debugging = True
		self.reprompt()
		try:
			self.sublevel += 1
			while True:
				try:
					self.cmdloop()
					break
				except KeyboardInterrupt:
					print
		finally:
			self.sublevel -= 1
		if self.sublevel == 1:
			self.prompt = 'ctu> '

	def debugbreak(self):
		try:
			self.debug(sub=True)
		except Restart:
			self.restarting = True
			return self.stop()

	def print_topics(self, header, cmds, cmdlen, maxcol):
		nix = 'EOF', 'b', 'c', 's', 'r', 't'
		if header is not None:
			Cmd.print_topics(self, header, [cmd for cmd in cmds if cmd not in nix], cmdlen, maxcol)

	def do_EOF(self, line):
		print
		try:
			if raw_input('Really exit? y/n: ').startswith('y'):
				self.exiting = True
				sys.exit()
		except EOFError:
			print
			self.exiting = True
			sys.exit()
	def do_exit(self, line):
		"""exit
		Exit the debugger."""
		sys.exit()

	def do_start(self, line):
		"""s/start
		Start or restart the code."""
		if self.sublevel != 1:
			raise Restart()

		while True:
			self.reset()
			try:
				self.runExecFunc()
				break
			except Restart:
				print 'got restart at', self.sublevel
				continue
	do_s = do_start

	def do_trace(self, line):
		"""t/trace (i/instruction | b/block | f/function | m/memory)
		Toggles tracing of instructions, blocks, functions, or memory."""
		if line.startswith('i'):
			self.flags ^= TRACE_INSTRUCTION
			print 'Instruction tracing', 'on' if self.flags & TRACE_INSTRUCTION else 'off'
		elif line.startswith('b'):
			self.flags ^= TRACE_BLOCK
			print 'Block tracing', 'on' if self.flags & TRACE_BLOCK else 'off'
		elif line.startswith('f'):
			self.flags ^= TRACE_FUNCTION
			print 'Function tracing', 'on' if self.flags & TRACE_FUNCTION else 'off'
		elif line.startswith('m'):
			self.flags ^= TRACE_MEMORY
			print 'Memory tracing', 'on' if self.flags & TRACE_MEMORY else 'off'
		else:
			print 'Unknown trace flag'
	do_t = do_trace

	def do_memcheck(self, line):
		"""mc/memcheck
		Toggles memory access validations."""
		self.flags ^= TRACE_MEMCHECK
		print 'Memcheck', 'on' if self.flags & TRACE_MEMCHECK else 'off'
	do_mc = do_memcheck

	def do_break(self, addr):
		"""b/break [name]
		Without `name`, list breakpoints.
		Given a symbol name or address, toggle breakpoint."""
		if addr == '':
			print 'Breakpoints:'
			for addr in self.breakpoints:
				print '*', addr
			return

		try:
			addr = raw(addr)
		except BadAddr:
			print 'Invalid address/symbol'
			return

		if addr in self.breakpoints:
			print 'Removing breakpoint at %s' % addr
			self.breakpoints.remove(addr)
		else:
			print 'Breaking at %s' % addr
			self.breakpoints.add(addr)
	do_b = do_break
	def complete_break(self, text, line, begidx, endidx):
		ftext = line.split(' ', 1)[1] if ' ' in line else ''
		cut = len(ftext) - len(text)
		return [sym[cut:] for sym in symbols.keys() if sym.startswith(ftext)]
	complete_b = complete_break

	def do_bt(self, line):
		"""bt
		Prints the call stack."""
		print 'Call stack:'
		for i, x in enumerate(self.threads.current.callstack[::-1]):
			print '%03i: %s' % (i, raw(x))

	def do_sym(self, name):
		"""sym <name>
		Prints the address of a given symbol."""
		try:
			print raw(name)
		except BadAddr:
			print 'Invalid address/symbol'
	complete_sym = complete_break

	def do_continue(self, line):
		"""c/continue
		Continues execution of the code."""
		if self.sublevel == 1:
			print 'Not running'
		else:
			return True
	do_c = do_continue

	def do_next(self, line):
		"""n/next
		Step to the next instruction."""
		if self.sublevel == 1:
			print 'Not running'
		else:
			self.singlestep = True
			return True
	do_n = do_next

	def do_regs(self, line):
		"""r/reg/regs [reg [value]]
		No parameters: Display registers.
		Reg parameter: Display one register.
		Otherwise: Assign a value (always hex, or a symbol) to a register."""
		if line == '':
			return self.dumpregs()
		elif ' ' in line:
			r, v = line.split(' ', 1)
			try:
				v = raw(v)
				if self.reg(r, v) is None:
					print 'Invalid register'
			except BadAddr:
				print 'Invalid address/Symbol'
		else:
			v = self.reg(line)
			if v is False:
				print 'Invalid register'
			else:
				print '0x%016x' % v
	do_r = do_reg = do_regs

	def do_exec(self, line):
		"""x/exec <code>
		Evaluates a given line of C."""
		try:
			val = ceval(line, self)
		except:
			import traceback
			traceback.print_exc()
			print 'Execution failed'
			return

		if val is not None:
			print '0x%x' % val
	do_x = do_exec

	def do_dump(self, line):
		"""dump <address> [size]
		Dumps `size` (default: 0x100) bytes of memory at an address.
		If the address takes the form `*register` (e.g. `*X1`) then the value of that register will be used."""
		line = list(line.split(' '))
		if len(line[0]) == 0:
			print 'No address'
		elif len(line) <= 2:
			if len(line[0]) and line[0][0] == '*':
				line[0] = self.reg(line[0][1:])
				if line[0] is None:
					print 'Invalid register'
					return
			else:
				try:
					line[0] = raw(line[0])
				except BadAddr:
					print 'Invalid address/symbol'
					return
			if len(line) == 2:
				line[1] = parseInt(line[1])
				if line[1] is None or line[1] >= 0x10000:
					print 'Invalid size'
					return
			self.dumpmem(line[0], 0x100 if len(line) == 1 else line[1])
		else:
			print 'Too many parameters'

	def do_save(self, line):
		"""save <address> <size> <fn>
		Writes `size` bytes of memory to a file.
		If the address or size takes the form `*register` (e.g. `*X1`) then the value of that register will be used."""

		line = list(line.split(' '))
		addr, size, fn = line
		addr = self.reg(addr[1:]) if addr.startswith('*') else parseInt(addr)
		size = self.reg(size[1:]) if size.startswith('*') else parseInt(size)
		with file(fn, 'wb') as fp:
			fp.write(self.readmem(addr, size))
		print 'Wrote to file'

	def do_ad(self, line):
		"""ad
		Toggle address display specialization."""
		Address.display_specialized = not Address.display_specialized
		print '%s specialized address display' % ('Enabled' if Address.display_specialized else 'Disabled')
		self.reprompt()

	def do_watch(self, line):
		"""w/watch [expression]
		Breaks when expression evaluates to true.
		Without an expression, list existing watchpoints."""
		if line == '':
			print 'Watchpoints:'
			for code, _ in self.watchpoints:
				print '*', code
			return

		if line in [code for code, _ in self.watchpoints]:
			self.watchpoints = [(code, func) for code, func in self.watchpoints if code != line]
			print 'Watchpoint deleted'
		else:
			self.watchpoints.append((line, compile(line)))
			print 'Watchpoint added'
	do_w = do_watch

	def do_memregions(self, line):
		"""mr/memregions
		Displays mapped memory regions."""
		print 'Mapped memory regions'
		print '---------------------'
		for begin, end, perms in self.memregions():
			if perms != -1:
				print '%016x - %016x' % (begin, end)
	do_mr = do_memregions

def debug(*flags):
	def sub(func):
		ctu = CTU()
		ctu.setup(func)
		ctu.flags |= reduce(lambda a, x: a | x, flags, TRACE_NONE)
		ctu.debug()
		return func

	if len(flags) == 1 and callable(flags[0]):
		func = flags[0]
		flags = [TRACE_NONE]
		return sub(func)
	else:
		return sub

def run(*flags):
	def sub(func):
		ctu = CTU()
		ctu.setup(func)
		ctu.flags |= reduce(lambda a, x: a | x, flags, TRACE_NONE)
		ctu.run()
		return func

	if len(flags) == 1 and callable(flags[0]):
		func = flags[0]
		flags = [TRACE_NONE]
		return sub(func)
	else:
		return sub