# # coding=utf-8 # __author__ = 'Anatoli Kalysch' # import operator from _collections import defaultdict from copy import deepcopy from TraceOptimizations import * from dynamic.TraceRepresentation import Trace, Traceline from idaapi import * from idautils import * from idc import * from lib.VMRepresentation import get_vmr, VMContext ############################################# ### VISUALIZATION AND RESULT PRESENTATION ### ############################################# def visualize_cli(cluster): """ Visualize the cluster on the console or IDA Output window :param cluster: clustered trace or trace """ print print for line in range(len(cluster)): if isinstance(cluster[line], Traceline): print '- single:' + cluster[line].to_str_line() elif isinstance(cluster[line], list): print print "+ cluster %s - %s:" % (hex(cluster[line][0].addr), hex(cluster[line][-1].addr)) for num in range(len(cluster[line])): print ' ' + cluster[line][num].to_str_line() # ####################################### # ### CLUSTERING AND RELATED FUNC ### # ####################################### def len_check(cluster): """ length check for a clustered trace of trace lines :param cluster: clustered trace :return: length of the trace """ l = 0 for line in cluster: if isinstance(line, Traceline): l += 1 elif isinstance(line, list): l += len(line) return l def get_addr(op): """ Get the address of a trace line of list of addresses for lists of trace lines recursively :param op: :return: """ if isinstance(op, Traceline): # simple element return op.addr elif isinstance(op, list): # recursive in case list of lists return [get_addr(elem) for elem in op] def address_count(trace): """ Count the diffetent occurences of the addresses in the trace and return a sorted(highest->lowest) occurence list :param trace: execution trace of the binary :return: sorted list starting with the highest address count and ending with the lowest """ trace = [line.addr for line in trace] analysis_result = {} for addr in trace: # for heuristic analysis the count of address count = trace.count(addr) if addr not in analysis_result.keys(): analysis_result[addr] = count # sort the analysis result by most common addresses sorted_result = sorted(analysis_result.items(), key=operator.itemgetter(1)) sorted_result.reverse() return sorted_result def repetition_cluster_round(cluster_list): """ One round of repetition cluster analysis. :param cluster_list: list of clusters :return: cluster list """ vmr = get_vmr() assert isinstance(cluster_list, list) test_length = len_check(cluster_list) temp_cluster = [[cluster_list[cluster], cluster_list[cluster + 1]] for cluster in range(0, len(cluster_list) - 1, 2)] # each tupel is tested for validity for cluster in temp_cluster: if cluster_list.count(cluster[0]) == cluster_list.count(cluster[1]): occurence = 0 pop_indexes = [] try: for j in range(len(cluster_list) - 1): # if they are adjacent, they are labeled valid if get_addr(cluster_list[j]) == get_addr(cluster[0]) and get_addr(cluster_list[j + 1]) == get_addr( cluster[1]): pop_indexes.append(j + 1) occurence += 1 if occurence > vmr.cluster_magic: # if validity occured more than once we have a new cluster pop_ctr = 0 for ind in pop_indexes: addition = cluster_list.pop(ind - pop_ctr) pop_ctr += 1 base = cluster_list[ind - pop_ctr] if isinstance(base, Traceline): if isinstance(addition, Traceline): cluster_list[ind - pop_ctr] = [base, addition] elif isinstance(addition, list): cluster_list[ind - pop_ctr] = [base] + addition elif isinstance(base, list): if isinstance(addition, Traceline): cluster_list[ind - pop_ctr].append(addition) elif isinstance(addition, list): cluster_list[ind - pop_ctr].extend(addition) except Exception, e: print e.message pass # clean up clusterlist for cluster in cluster_list: if isinstance(cluster, list): # [Traceline, Traceline, ...] for j in cluster: if j.addr == BADADDR: cluster.remove(j) elif not cluster or cluster.addr is BADADDR: cluster_list.remove(cluster) # if we are missing a trace element something went wrong assert test_length == len_check(cluster_list) return cluster_list def create_bb_diff(bb, ctx_reg_size, prev_line_ctx): """ Addr and thread id irrelevant; ctx shown as: before -> after; disasm (and comment) is chosen by heuristic. :param ctx_reg_size: :param prev_line_ctx: :param bb: return """ first = bb[0] last = bb[-1] keys_f = prev_line_ctx.keys() keys_l = last.ctx.keys() context = {} disasm = [] comment = [] if keys_f == keys_l: for key in keys_f: if first.ctx[key] != last.ctx[key]: context[key] = first.ctx[key] + ' -> ' + last.ctx[key] else: context[key] = last.ctx[key] elif len(keys_l) > len(keys_f): for key in keys_f: if first.ctx[key] != last.ctx[key]: context[key] = first.ctx[key] + ' -> ' + last.ctx[key] else: context[key] = last.ctx[key] for key in list(set(keys_l) - set(keys_f)): context[key] = last.ctx[key] else: # means keys_l < keys_f and if that happens sth went wrong. Should not be possible by normal execution. raise Exception('[*] Keys at the end of basic block %s-%s were LESS than at the beginning!' % (first.addr, last.addr)) last_ctx = prev_line_ctx for line in bb: if line.comment is not None: comment.append(line.comment) if line.disasm[0].startswith('mov'): try: if bb[bb.index(line) + 1].disasm[0].startswith('mov') and get_reg_class( bb[bb.index(line) + 1].disasm[1]) == get_reg_class(line.disasm[1]): continue except: pass if line.disasm[1].startswith('[') and line.disasm[1].endswith(']'): comment.append(line.disasm[1] + '=' + line.disasm[2]) elif get_reg_class(line.disasm[1]) is not None: continue elif line.disasm[0].startswith('j'): continue elif line.comment is not None and len(line.disasm) == 3 and line.disasm[1].startswith('['): if get_reg_class(line.disasm[2]) is not None: comment[-1] = comment[-1] + ' ' + line.disasm[0] + ' ' + last_ctx[get_reg(line.disasm[2], ctx_reg_size)] else: comment[-1] = comment[-1] + ' ' + line.disasm[0] + ' ' + line.disasm[2] elif line.comment is not None and len(line.disasm) == 3 and line.disasm[2].startswith('['): if get_reg_class(line.disasm[1]) is not None: comment[-1] = comment[-1] + ' ' + line.disasm[0] + ' ' + last_ctx[get_reg(line.disasm[1], ctx_reg_size)] else: comment[-1] = comment[-1] + ' ' + line.disasm[0] + ' ' + line.disasm[1] disasm.append(line.disasm) last_ctx = line.ctx result = Traceline(addr=last.addr, thread_id=last.thread_id, ctx=context, disasm=disasm, comment=comment) return result def extract_stack_change(line, stack_changes): """ Extracts the stack changes(=stack comments) from the line and inputs them into the stack_changes dict. :param line: a trace line :param stack_changes: the stack_changes dict :return: updated stack_changes """ for comment in filter(None, line.comment): try: addr, value = ''.join(c for c in comment if c not in '[]').split('=') if stack_changes[addr] != 0 and not stack_changes[addr].endswith(value): stack_changes[addr] = stack_changes[addr] + '->' + value else: stack_changes[addr] = value except Exception, e: print e.message print e.args return line, stack_changes def create_cluster_gist(cluster, ctx_reg_size, prev_line_ctx, stack_changes): """ Function takes a cluster, subdivides it into basic blocs (if any). For each bb a representative traceline is created which consists of relevant instructions, relevant stack changes and shows the difference in the registers between fist ans last bb line. :param cluster: list of Tracelines :return: appeared stack_changes """ bbs = [] bb = [] # subdivide the clusters by basic blocks for line in cluster: if is_basic_block_end(line.addr): bb.append(line) bbs.append(bb) bb = [] else: bb.append(line) for bb in bbs: bb_gist = create_bb_diff(bb, ctx_reg_size, prev_line_ctx) bb_gist, stack_changes = extract_stack_change(bb_gist, stack_changes) prev_line_ctx = bb[-1].ctx return stack_changes def repetition_clustering(trace, **kwargs): """ Cluster the trace into groups of repeating instructions(=clusters) and non-repeating instructions(=singles) :param trace: instruction trace :param kwargs: rounds=clustering_rounds :return: list where an element is either another list(=cluster of addrs) or an int (=single addr) """ rounds = kwargs.get('rounds', None) if trace is None: raise Exception("[*] Empty trace, nothing to cluster!") clusters_final = trace if rounds: for j in range(int(rounds)): clusters_final = repetition_cluster_round(clusters_final) else: # assuming greedy, since it produces best results im most cases pre = 1 post = 0 while pre != post: pre = len(clusters_final) clusters_final = repetition_cluster_round(clusters_final) post = len(clusters_final) return clusters_final def cluster_removal(trace, **kwargs): # remove the *threshold* most common basic blocks -> often the vm handler routine to get next vm_instruction """ Remove certain amount of clusters. Clusters to be removed are determined dynamically by the frequency of their occurrence. :param trace: instruction trace :param kwargs: threshold='how many clusters to remove' :return: clustered trace without *threshold* clusters """ threshold = kwargs.get('threshold', 1) kill_addrs = [] addr_list = address_count(trace) # how often are *threshold* basic blocks repeated? temp = set() for tupel in addr_list: temp.add(tupel[1]) if len(temp) >= threshold: break # fill the kill index with to remove addresses for tupel in addr_list: if tupel[1] in temp: kill_addrs.append(tupel[0]) kill_index = [line for line in trace if line.addr in kill_addrs] for line in kill_index: trace.pop(trace.index(line)) return trace ##################################### ### VM ANALYSIS FUNCTIONS ### ##################################### def find_vm_addr(trace): """ Find the virtual machine addr :param trace: instruction trace :return: virtual function start addr """ push_dict = defaultdict(lambda: 0) vm_func_dict = defaultdict(lambda: 0) # try to find the vm Segment via series of push commands, which identify the vm_addr also for line in trace: try: if line.disasm[0] == 'push': push_dict[GetFunctionAttr(line.addr, FUNCATTR_START)] += 1 except: pass vm_func = max(push_dict, key=push_dict.get) vm_seg_start = SegStart(vm_func) vm_seg_end = SegEnd(vm_func) # test wheather the vm_func is the biggest func in the Segment vm_funcs = Functions(vm_seg_start, vm_seg_end) for f in vm_funcs: vm_func_dict[f] = GetFunctionAttr(f, FUNCATTR_END) - GetFunctionAttr(f, FUNCATTR_START) if max(vm_func_dict, key=vm_func_dict.get) != vm_func: return AskAddr(vm_func, "Found two possible addresses for the VM function start address: %s and %s. Choose one!" % (vm_func, max(vm_func_dict, key=vm_func_dict.get))) else: return vm_func def extract_vm_segment(trace): """ Identify the VM Segment, Extract only the VM part of the trace and return the cleaned trace and start/end addr. :param trace: instruction trace :return: cleaned trace, start addr of vm segment, end addr of vm segment, vm_addr_candidate """ vm_seg_start = None vm_seg_end = None # try to find the vm Segment via name -> easiest case but also easy to foil for addr in Segments(): if SegName(addr).startswith('.vmp'): vm_seg_start = SegStart(addr) vm_seg_end = SegEnd(addr) break # if that fails, find the vm_function and use its segment if not vm_seg_start or not vm_seg_end: vm_addr = find_vm_addr(trace) vm_seg_start = SegStart(vm_addr) vm_seg_end = SegEnd(vm_addr) return [line for line in trace if vm_seg_start < line.addr and vm_seg_end > line.addr], vm_seg_start, vm_seg_end def dynamic_vm_values(trace, code_start=BADADDR, code_end=BADADDR, silent=False): """ Find the virtual machine context necessary for an automated static analysis. code_start = the bytecode start -> often the param for vm_func and usually starts right after vm_func code_end = the bytecode end -> bytecode usually a big chunk, so if we identify several x86/x64 inst in a row we reached the end base_addr = startaddr of the jmp table -> most often used offset in the vm_trace vm_addr = startaddr of the vm function -> biggest function in .vmp segment, :param trace: instruction trace :return: vm_ctx -> [code_start, code_end, base_addr, vm_func_addr, vm_funcs] """ base_addr = defaultdict(lambda: 0) vm_addr = find_vm_addr(deepcopy(trace)) trace, vm_seg_start, vm_seg_end = extract_vm_segment(trace) code_addrs = [] # try finding code_start if code_start == BADADDR: code_start = GetFunctionAttr(vm_addr, FUNCATTR_END)#NextHead(GetFunctionAttr(vm_addr, FUNCATTR_END), vm_seg_end) code_start = NextHead(code_start, BADADDR) while isCode(code_start): code_start = NextHead(code_start, BADADDR) for line in trace: # construct base addr dict of offsets -> jmp table should be the one most used if len(line.disasm) == 2: try: offset = re.findall(r'.*:off_([0123456789abcdefABCDEF]*)\[.*\]', line.disasm[1])[0] base_addr[offset] += 1 except: pass # code_start additional search of vm_func params if line.addr == vm_addr: for l in trace[:trace.index(line)]: if l.disasm[0] == 'push': try: arg = re.findall(r'.*_([0123456789ABCDEFabcdef]*)', l.disasm[1]) if len(arg) == 1: code_addrs.append(int(arg[0], 16)) except Exception, e: print e.message # finalize base_addr max_addr = int(max(base_addr, key=base_addr.get), 16) # now we have the base_addr used for offset computation - this will probably be the top of the table but to be sure we need to take its relative position into account base_addr = max_addr while GetMnem(PrevHead(base_addr)) == '': base_addr = PrevHead(base_addr) # finalize code_start if not silent: if code_start not in code_addrs: code_start = AskAddr(code_start, "Start of bytecode mismatch! Found %x but parameter for vm seem to be %s" % (code_start, [hex(c) for c in code_addrs])) # code_end -> follow code_start until data becomes code again if code_end == BADADDR: code_end = vm_seg_end # while code_end < vm_seg_end: # code_end = NextHead(code_end, vm_seg_end) # if isCode(code_end): # break vm_ctx = VMContext() vm_ctx.code_start = code_start vm_ctx.code_end = code_end vm_ctx.base_addr = base_addr vm_ctx.vm_addr = vm_addr print code_start, code_end, base_addr, vm_addr return vm_ctx def find_virtual_regs(trace, manual=False, update=None): """ Maps the virtual registers on the stack to the actual registers after the vm exit. :param trace: instruction trace :return: virtual registers dict which maps the real regs onto virtual ones via stack addresses """ vmr = get_vmr() assert isinstance(trace, Trace) virt_regs = defaultdict(lambda: False) # trace, vm_seg_start, vm_seg_end = extract_vm_segment(trace) while trace: try: elem = trace.pop(len(trace) - 1) if len(elem.disasm) > 0 and elem.disasm[0] == 'pop': opnd = elem.disasm[1] if get_reg_class(opnd) is None: # if not a register it is a mem_loc pass elif virt_regs[opnd]: pass else: # the context always shows the registers after the execution, so we nee the SP from the instruction before stack_addr = trace[len(trace) - 1].ctx[get_reg('rsp', trace.ctx_reg_size)] virt_regs[opnd] = stack_addr except: pass if update is not None: update.pbar_update(60) vmr.vm_stack_reg_mapping = virt_regs if manual: print ''.join('%s:%s\n' % (c, virt_regs[c]) for c in virt_regs.keys()) return virt_regs def find_ops_callconv(trace, vmp_seg_start, vmp_seg_end): """ find params on stack before function call :param vmp_seg_start: start of vm segment :param vmp_seg_end: end of vm segment :param trace: instruciton trace :return: set of operands """ # call_depth is the number of call instructions the algo goes through to analyze the passed args -> useful if VM consists of more than one function (most have have at least 2) call_depth = 2 ops = [] calls = 0 for line in trace: # we search backwards for call inst and then further for stack push or mov to stack addr if vmp_seg_start <= line.addr <= vmp_seg_end: for i in range(trace.index(line) - 1, 0, -1): if trace[i].disasm[0].startswith('call'): for j in range(i): line = trace[i - j] if line.disasm[0].startswith('call') and not vmp_seg_start <= line.addr <= vmp_seg_end: calls += 1 if calls >= call_depth: break # push reg/const elif line.disasm[0].startswith('push'): ops.append(line.disasm[1]) # mov instructions; only xsp related elif line.disasm[0].startswith('mov'): try: op1 = re.findall(r'.*\[(.*)\].*', line.disasm[1])[0] op2 = line.disasm[2] try: # mov [xsp +/-/* reg/const], reg/const expr = re.findall(r'.*([\+\-\*\/]).*', line.disasm[1])[0] # find math expr + or - or * or / elem = op1.split(expr) for e in elem: if get_reg_class(e) is 7: # 7 is the esp class if get_reg_class(op2) is not None: # mov [*xsp*], reg ops.append(line.ctx[get_reg(op2, trace.ctx_reg_size)]) else: # mov [*xsp*], const ops.append(line.disasm[2]) break except: # mov [xsp]/[mem], reg/const if get_reg_class(op1) is 7 or get_reg_class(op1) is None: # 7 is the esp class if get_reg_class(op2) is not None: # mov [xsp], reg ops.append(line.ctx[get_reg(op2, trace.ctx_reg_size)]) else: # mov [xsp], const ops.append(line.disasm[2]) except: # if no [.*] was found, it means the mov instructions were not onto the stack pass return ops def find_input(trace, manual=False, update=None): """ Find input operands to the vm_function. :param trace: instruciton trace :param manual: console output z/n :return: a set of operands to the vm_function """ vmr = get_vmr() if vmr.func_args: func = GetFunctionName(find_vm_addr(deepcopy(trace))) func_args = vmr.func_args[func] ops = set() if update is not None: update.pbar_update(20) ex_trace, vmp_seg_start, vmp_seg_end = extract_vm_segment(deepcopy(trace)) # use deepcopy trace, since we need the full one for find_ops_callconv if update is not None: update.pbar_update(20) for line in ex_trace: try: # case inst reg, ss:[reg] op = line.disasm[2] # following is ida only if op.startswith('ss:'): # get the reg value from ctx op = line.ctx[get_reg(line.disasm[1], trace.ctx_reg_size)] ops.add(op.upper()) except: pass try: # if we find the .vmp Segment addr or vm-function addr we should check the stack for op in find_ops_callconv(trace, vmp_seg_start, vmp_seg_end): ops.add(op.upper()) # set will eliminate double entries if update is not None: update.pbar_update(30) for op in func_args: ops.add(op.upper()) except: pass if update is not None: update.pbar_update(10) if manual: print 'operands: %s' % ''.join('%s | ' % op for op in ops) return ops def find_output(trace, manual=False, update=None): """ Find output operands to the vm_function. :param trace: instruction trace :param manual: console output y/n :return: set of output operands """ if update is not None: update.pbar_update(20) ex_trace, vmp_seg_start, vmp_seg_end = extract_vm_segment(deepcopy(trace)) ex_trace.reverse() if update is not None: update.pbar_update(20) pop_lines = [] lastline = '' ctx = {} for line in ex_trace: if line.disasm[0].startswith('ret'): ctx = line.ctx lastline = line break elif line.disasm[0].startswith('pop'): ctx = line.ctx lastline = line break if update is not None: update.pbar_update(40) if manual: print ''.join('%s:%s\n' % (c, ctx[c]) for c in ctx.keys() if get_reg_class(c) is not None) return set([ctx[get_reg(reg, trace.ctx_reg_size)].upper() for reg in ctx if get_reg_class(reg) is not None]) def follow_virt_reg(trace, **kwargs): """ Follows the virtual registers and extracts the relevant trace lines to clarify how the final result in a virtual register came to be and what values(=recursively) it consists of. :param trace: instruction trace :param virt_reg_addr: the stack addr of the virtual register :param real_reg_name: reg string :return: trace consisting of relevant tracelines for the virtual register """ assert(isinstance(trace, Trace)) update = kwargs.get('update', None) manual = kwargs.get('manual', False) if manual: real_reg_name = AskStr('eax', 'Which register do you want followed?') if real_reg_name is None: real_reg_name = get_reg('rax', trace.ctx_reg_size) else: real_reg_name = get_reg(real_reg_name, trace.ctx_reg_size) else: real_reg_name = kwargs.get('real_reg_name', get_reg('rax', trace.ctx_reg_size)) virt_reg_addr = kwargs.get('virt_reg_addr', None) if virt_reg_addr is None: vr = find_virtual_regs(deepcopy(trace)) virt_reg_addr = vr[real_reg_name] if update is not None: update.pbar_update(30) backtrace = Trace() watch_addrs = set() reg_vals = set() trace = optimization_const_propagation(trace) trace = optimization_stack_addr_propagation(trace) if update is not None: update.pbar_update(10) # reversing the trace makes the backward tracersal easier trace.reverse() # get reg value at pop reg = get_reg(real_reg_name, trace.ctx_reg_size) for line in trace: if len(line.disasm) == 2: if line.disasm[0] == 'pop' and get_reg_class(line.disasm[1]) == get_reg_class(reg): reg_vals.add(line.ctx[reg]) break watch_addrs.add(virt_reg_addr) for line in trace: assert isinstance(line,Traceline) if line.is_jmp: continue try: # +1 because trace is reversed to get to prev element prev = trace[trace.index(line)+1] for val in reg_vals.copy(): if val in line.ctx.values() and val not in prev.ctx.values(): backtrace.append(line) # if val suddenly appears in the ctx there should be 2 possibilities: # 1. it was moved from mem, so it was on the stack -> append stack addres to be watched out for if line.is_mov and line.is_op2_mem: watch_addrs.add(''.join(c for c in line.disasm[2] if c not in '[]')) #reg_vals.remove(val) # 2. it was computed -> if regs played a role in the computation add them to values to watch out for elif not line.is_mov: if line.disasm_len > 2: if line.is_op1_reg: reg_vals.add(line.ctx[get_reg(line.disasm[1], trace.ctx_reg_size)]) if line.is_op1_mem: watch_addrs.add(''.join(c for c in line.disasm[1] if c not in '[]')) if line.is_op2_reg: # not necessarily the case for lea reg_vals.add(line.ctx[get_reg(line.disasm[2], trace.ctx_reg_size)]) if line.is_op2_mem: watch_addrs.add(''.join(c for c in line.disasm[2] if c not in '[]')) elif line.disasm_len == 2: reg_vals.add(prev.ctx[get_reg('eax', trace.ctx_reg_size)]) if line.is_op1_reg: reg_vals.add(line.ctx[get_reg(line.disasm[1], trace.ctx_reg_size)]) reg_vals.add(prev.ctx[get_reg(line.disasm[1], trace.ctx_reg_size)]) if line.ctx[get_reg('eax', trace.ctx_reg_size)] != prev.ctx[get_reg('eax', trace.ctx_reg_size)]: reg_vals.add(line.ctx[get_reg('eax', trace.ctx_reg_size)]) reg_vals.add(prev.ctx[get_reg('eax', trace.ctx_reg_size)]) if line.disasm[0].startswith('not'): reg_vals.add(line.ctx[get_reg(line.disasm[1], trace.ctx_reg_size)]) reg_vals.add(prev.ctx[get_reg(line.disasm[1], trace.ctx_reg_size)]) backtrace.append(prev) backtrace.append(trace[trace.index(line)-1]) try: reg_vals.add(prev.ctx[get_reg(prev.disasm[1], trace.ctx_reg_size)]) reg_vals.add(trace[trace.index(line)-1].ctx[get_reg(prev.disasm[1], trace.ctx_reg_size)]) except: pass except Exception, e: pass #print 'reg_vals\n',line, e.message if watch_addrs: for addr in watch_addrs.copy(): try: if line.disasm[1].__contains__(addr): backtrace.append(line) reg_vals.add(line.disasm[2]) r = line.ctx.keys()[line.ctx.values().index(line.disasm[2])] for i in range(len(trace)): temp = trace[trace.index(line)+i] if len(temp.disasm) == 3: if temp.disasm[1][-2:] == r[-2:]: if get_reg_class(r[-2:]) is not None: watch_addrs.add(temp.disasm[2][1:-1]) break if line.is_mov: watch_addrs.remove(addr) except Exception, e: #print 'watch_addr\n',line, e.message pass if update is not None: update.pbar_update(40) # reverse the reversed bt backtrace.reverse() backtrace = [line for line in backtrace if line.disasm[1] not in ['esi', 'edi', 'ebp', 'rsi', 'rdi', 'rbp']] # append the previous line of first line to be able to see the contextual difference try: backtrace.append(trace[trace.index(backtrace[0]) - 1]) except: pass if manual: print for line in backtrace: print line.to_str_line() return backtrace