#!/usr/bin/env python # Jay Smith # jay.smith@fireeye.com # ######################################################################## # Copyright 2015 FireEye # Copyright 2012 Mandiant # # FireEye licenses this file to you under the Apache License, Version # 2.0 (the "License"); you may not use this file except in compliance with the # License. You may obtain a copy of the License at: # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ######################################################################## # # New lib that helps with tracking arguments to functions. # New version that uses vivisect for analysis # ######################################################################## import sys import copy import struct import pprint import logging import binascii import idc import idaapi import idautils import jayutils import vivisect import vivisect.impemu as viv_imp import vivisect.impemu.monitor as viv_imp_monitor from visgraph import pathcore as vg_path ######################################################################## # # ######################################################################## class RegMonitor(viv_imp_monitor.EmulationMonitor): ''' This tracks all register changes, even if it's not currently an interesting reg because we need to trace register changes backwards. ''' def __init__(self, regs): viv_imp_monitor.EmulationMonitor.__init__(self) self.logger = jayutils.getLogger('argracker.RegMonitor') self.regs = regs[:] self.reg_map = {} def prehook(self, emu, op, starteip): try: #self.logger.debug('prehook: 0x%08x', starteip) self.cachedRegs = emu.getRegisters() self.startEip = starteip except Exception, err: self.logger.exception('Error in prehook: %s', str(err)) def posthook(self, emu, op, endeip): try: #self.logger.debug('posthook: 0x%08x', endeip) curRegs = emu.getRegisters() curDict = {} for name, val in curRegs.items(): if (self.cachedRegs[name] != val): curDict[name] = val if len(curDict) != 0: self.reg_map[self.startEip] = curDict except Exception, err: self.logger.exception('Error in posthook: %s', str(err)) ######################################################################## # # ######################################################################## #maps a va to the vg_path node that contains it in an emu run def build_emu_va_map(node, **kwargs): res = kwargs.get('res') emu = kwargs.get('emu') logtype = kwargs.get('logtype') if (res is None) or (emu is None) or (logtype is None): return #for va in vg_path.getNodeProp(node, 'valist'): # res[va] = node #for pc, va, bytes in vg_path.getNodeProp(node, 'writelog'): for entry in vg_path.getNodeProp(node, logtype): pc, va, bytes = entry res[pc] = entry def formatWriteLogEntry(entry): pc, va, bytes = entry return '0x%08x: 0x%08x: %s' % (pc, va, binascii.hexlify(bytes)) def transformWriteLogEntry(entry, bigend=False): ''' Tranforms a writelog entry to a (pc, value) tuple ''' pc, va, bytes = entry blen = len(bytes) if blen == 1: return (pc, struct.unpack_from('<B', bytes)[0]) elif blen == 2: return (pc, struct.unpack_from('<H', bytes)[0]) elif blen == 4: return (pc, struct.unpack_from('<I', bytes)[0]) elif blen == 8: return (pc, struct.unpack_from('<Q', bytes)[0]) elif blen == 16: t0,t1 = struct.unpack_from('<QQ', bytes)[0] return (pc, (t1<<64) | t0) else: raise RuntimeError('Unexpected len of writelog bytes: %d' % blen) class TrackerState(object): def __init__(self, tracker, baseEntry, num, regs): ''' desiredState: list of stackArgNums and register names ''' self.tracker = tracker self.baseEntry = baseEntry self.num = num self.regs = regs[:] self.ptrsize = tracker.ptrsize self.resultArgs = {} #stackArgLocs: pre-calculated locations we're looking for for stack writes self.stackArgLocs = [] #tempMapping: used to follow data movement backwards. if a value we're interested # in is loaded in another register or a memory location, maps the current self.tempMapping = {} #desiredState: list of stackArgNums and register names self.desiredState = [] self.setDesiredState(baseEntry, num, regs) self.setStackArgLocs(baseEntry, num, regs) def copy(self): cp = TrackerState(self.tracker, self.baseEntry, self.num, self.regs) cp.resultArgs = copy.deepcopy(self.resultArgs) cp.tempMapping = copy.deepcopy(self.tempMapping) cp.desiredState = copy.deepcopy(self.desiredState) return cp def __str__(self): info = '\n'.join([ '%s: %s,' % (self.getArgNameRep(k), repr(self.resultArgs.get(k))) for k in self.desiredState]) return info def processWriteLog(self, tracker, cVa): wlogEntry = tracker.va_write_map.get(cVa, None) if (wlogEntry is None): return pc, writeVa, bytes = wlogEntry if (writeVa in self.stackArgLocs) and (self.getStackArgNum(writeVa) not in self.resultArgs.keys()): #it's a stack arg value pc, value = transformWriteLogEntry(wlogEntry) #self.tracker.logger.debug('writelog 0x%08x: Found stack arg %d: 0x%08x', pc, self.getStackArgNum(writeVa), value) self.saveResult(writeVa, pc, value) return if writeVa not in self.tempMapping.keys(): #self.tracker.logger.debug('writelog 0x%08x: not interesting', pc) return #argName: the actual value we're tracing back argName = self.tempMapping.pop(writeVa) pc, value = transformWriteLogEntry(wlogEntry) #we found a temp value tracing backwards, but need to determine if it's a constant # or if we need to continue tracing backwards. basically as long as it's not # a register, we stop? mnem = idc.GetMnem(pc) srcOpIdx = 0 if mnem.startswith('push'): srcOpIdx = 0 elif mnem.startswith('mov'): srcOpIdx = 1 else: #TODO: any other data movement instructions need to be traced rahter # than using the observed write log value? #self.tracker.logger.debug('writelog 0x%08x: found (default): 0x%08x', pc, value) self.saveResult(argName, pc, value) return #process data movements instructions: optype = idc.GetOpType(pc, srcOpIdx) if optype == idc.o_reg: #need to trace the new reg now newReg = idc.GetOpnd(pc, srcOpIdx) #self.tracker.logger.debug('writelog 0x%08x tracing: (%s): %s', pc, self.getArgNameRep(argName), newReg) self.tempMapping[newReg] = argName else: #not a register, so currently assuming we can use the stored value #self.tracker.logger.debug('writelog 0x%08x: found (non-reg): 0x%08x', pc, value) self.saveResult(argName, pc, value) def getArgNameRep(self, argName): if isinstance(argName, int) or isinstance(argName, long): return '0x%08x' % argName return argName def getStackArgNum(self, writeVa): return (writeVa - self.startSp)/self.ptrsize def saveResult(self, argName, pc, value): ''' Saves a tuple (pc, value) to the found argument. Assumes if argName is an integer, it's the address of an expected stack argument. If argName is a string, it's a register name for an expected argument. ''' if isinstance(argName, int) or isinstance(argName, long) : #argNum = (wlogEntry[1] - self.startSp)/tracker.ptrsize argNum = self.getStackArgNum(argName) self.resultArgs[argNum] = (pc, value) elif isinstance(argName, str): self.resultArgs[argName] = (pc, value) else: raise RuntimeError('Unknown argName type: %s' % type(argName)) def processRegMon(self, tracker, cVa): if tracker.regMon is None: #tracker.logger.debug('regmon: regMon is empty') return regMods = tracker.regMon.reg_map.get(cVa) if regMods is None: #tracker.logger.debug('regmon 0x%08x: no entry in reg_map', cVa) return #figure out if one of the monitored regs is modified in this instruction # and if has not already been stored -> just want the first reg value regMods = self.tracker.regMon.reg_map[cVa] #self.tracker.logger.debug('regmon 0x%08x: examining %d items: %r', cVa, len(regMods), regMods) for reg in regMods: interesting1 = (reg in self.regs) and (reg not in self.resultArgs.keys()) interesting2 = (reg in self.tempMapping.keys()) if (not interesting1) and (not interesting2): #modified reg isn't interesting: either a function arg or a temp traced value #self.tracker.logger.debug('regmon 0x%08x: not interesting: %s', cVa, reg) continue mnem = idc.GetMnem(cVa) argName = reg if interesting1: self.regs.remove(reg) if interesting2: argName = self.tempMapping.pop(reg) if mnem.startswith('pop'): #add the current stack read address to the temporary tracking rlogEntry = tracker.va_read_map.get(cVa, None) if rlogEntry is None: raise RuntimeError('readlog entry does not exist for a pop') pc, readVa, bytes = rlogEntry #self.tracker.logger.debug('regmon 0x%08x tracing (pop): %s (%s): 0x%x', cVa, argName, reg, readVa) self.tempMapping[readVa] = argName elif mnem.startswith('mov'): if idc.GetOpType(cVa, 1) == idc.o_reg: #change to track this reg backwards newReg = idc.GetOpnd(cVa, 1) #self.tracker.logger.debug('regmon 0x%08x tracing (mov): %s (%s)', cVa, argName, newReg) self.tempMapping[newReg] = argName else: #not a register, use the modified result otherwise? #self.tracker.logger.debug('regmon 0x%08x found (mov): %s (%s): 0x%x', cVa, argName, reg, regMods[reg]) self.saveResult(argName, cVa, regMods[reg]) else: #TODO: any other data movement instructions that should be traced back? #self.tracker.logger.debug('regmon 0x%08x found (default): %s (%s): 0x%x', cVa, argName, reg, regMods[reg]) self.saveResult(argName, cVa, regMods[reg]) def setStackArgLocs(self, baseEntry, num, regs): self.startSp = baseEntry[1] # desiredSp: the stack write addressses that correspond to the arguments we want self.stackArgLocs = [self.startSp + self.ptrsize*(1+i) for i in range(num)] def setDesiredState(self, baseEntry, num, regs): desiredState = [(i+1) for i in range(num)] desiredState.extend(regs) self.desiredState = sorted(desiredState) def isComplete(self): if len(self.desiredState) == len(self.resultArgs): if self.desiredState == sorted(self.resultArgs.keys()): return True else: raise RuntimeError('Matching len of resultArgs, but not equal!') return False class ArgTracker(object): def __init__(self, vw, maxIters=1000): self.logger = jayutils.getLogger('argracker.ArgTracker') self.logger.debug('Starting up here') self.vw = vw self.lastFunc = 0 self.va_write_map = None self.codesize = jayutils.getx86CodeSize() self.ptrsize = self.codesize/8 self.queue = [] self.maxIters = maxIters def printWriteLog(self, wlog): for ent in wlog: self.logger.debug(formatWriteLogEntry(ent)) def isCargsComplete(self, cargs, num, regs): return all([cargs.has_key(i+1) for i in range(num)]) and all([cargs.has_key(i) for i in regs]) def getPushArgs(self, va, num, regs=None): ''' num -> first arg is 1, 2nd is 2, ... Returns a list of dicts whose key is the arg number (starting at 1, 2.. num) Each dict for a stack argument is a write log tuple (pc, va bytes) Each dict for a registry is a tuple (pc, value) ''' if regs is None: regs = [] count = 0 touched = [] #func = self.vw.getFunction(va) #if func is None: # self.logger.error('Could not get function start from vw 0x%08x -> has analysis been done???', va) # return [] funcStart = idc.GetFunctionAttr(va, idc.FUNCATTR_START) #if func != funcStart: # self.logger.error('IDA & vivisect disagree over function start. Needs to be addressed before process') # self.logger.error(' IDA: 0x%08x. vivisect: 0x%08x', funcStart, func) # return [] #map a every (?) va in a function to the pathnode it was found in if funcStart != self.lastFunc: emu = self.vw.getEmulator(True, True) self.logger.debug('Generating va_write_map for function 0x%08x', funcStart) self.regMon = RegMonitor(regs) emu.setEmulationMonitor(self.regMon) emu.runFunction(funcStart, maxhit=1, maxloop=1) #cache the last va_write_map for a given function self.va_write_map = {} self.va_read_map = {} self.lastFunc = funcStart jayutils.path_bfs(emu.path, build_emu_va_map, res=self.va_write_map, emu=emu, logtype='writelog') jayutils.path_bfs(emu.path, build_emu_va_map, res=self.va_read_map, emu=emu, logtype='readlog') else: self.logger.debug('Using cached va_write_map') #self.logger.debug('Len va_write_map: %d', len(self.va_write_map)) #for cVa, wlog in self.va_write_map.items(): # self.logger.debug('0x%08x: %s', cVa, formatWriteLogEntry(wlog)) baseEntry = self.va_write_map.get(va, None) if baseEntry is None: self.logger.error('Node does not have write log. Requires a call instruction (which writes to the stack) for this to work: 0x%08x', va) return [] self.startSp = baseEntry[1] return self.analyzeTracker(baseEntry, va, num, regs) def analyzeTracker(self, baseEntry, va, num, regs): funcStart = idc.GetFunctionAttr(va, idc.FUNCATTR_START) initState = TrackerState(self, baseEntry, num, regs) count = 0 ret = [] touched = set() self.queue = [ (va, initState) ] while len(self.queue) != 0: if count > self.maxIters: self.logger.error('Max graph traveral iterations reached: (0x%08x) %d. Stopping early. Consider increasing ArgTracker maxIters (unless this is a bug)', va, count) break cVa, cState = self.queue.pop(0) touched.add(cVa) #self.logger.debug('Examining 0x%08x: %s', cVa, str(cState)) #self.logger.debug('Current tempMapping: 0x%08x %s', cVa, pprint.pformat(cState.tempMapping)) try: cState.processWriteLog(self, cVa) #self.logger.debug('writelog 0x%08x done', cVa) cState.processRegMon(self, cVa) #self.logger.debug('regmon 0x%08x done', cVa) except Exception, err: self.logger.exception('Error in process: %s', str(err)) return [] if cState.isComplete(): #self.logger.debug('Yep, appending') ret.append(cState.resultArgs) else: if cVa == funcStart: #self.logger.debug('Skipping xref queueing: hit function start') pass else: #self.logger.debug('Not complete: queuing prev items') for ref in idautils.CodeRefsTo(cVa, True): if ref in touched: #self.logger.debug('Skip queueing (touched) 0x%08x -> 0x%08x', cVa, ref) pass else: #self.logger.debug('Queueing 0x%08x -> 0x%08x', cVa, ref) self.queue.append( (ref, cState.copy()) ) count += 1 return ret def main(): #jayutils.configLogger(__name__, logging.DEBUG) jayutils.configLogger(__name__, logging.INFO) logger = jayutils.getLogger('') logger.debug('Starting up in main') #name = idc.AskStr('CreateThread', 'Enter function to find args for') #argNum = idc.AskLong(6) filePath = jayutils.getInputFilepath() if filePath is None: self.logger.info('No input file provided. Stopping') return vw = jayutils.loadWorkspace(filePath) logger.debug('Loaded workspace') tracker = ArgTracker(vw) import idautils funcEa = idc.LocByName('CreateThread') if funcEa == idc.BADADDR: logger.info('CreateThread not found. Returning now') return for xref in idautils.XrefsTo(funcEa): argsList = tracker.getPushArgs(xref.frm, 6) for argDict in argsList: print '-'*60 pc, value = argDict[3] print '0x%08x: 0x%08x: 0x%08x' % (xref.frm, pc, value) if __name__ == '__main__': main()