import re import struct import logging from capstone import Cs, CS_ARCH_X86, CS_MODE_32, CS_MODE_64 from smda.utility.PriorityQueue import PriorityQueue from .definitions import DEFAULT_PROLOGUES, GAP_SEQUENCES from .LanguageAnalyzer import LanguageAnalyzer from .FunctionCandidate import FunctionCandidate LOGGER = logging.getLogger(__name__) class FunctionCandidateManager(object): def __init__(self, config): self.config = config self.lang_analyzer = None self.disassembly = None self.bitness = None self._code_areas = [] self.candidates = {} self.candidate_queue = [] self.cached_candidates = None self._candidate_offsets = [] self.candidate_index = 0 self._all_call_refs = {} self.symbol_addresses = [] self.identified_alignment = 0 # gap filling self.function_gaps = None self.max_function_addr = 0 self.gap_pointer = None self.previously_analyzed_gap = 0 self.capstone = None def init(self, disassembly): if disassembly.binary_info.code_areas: self._code_areas = disassembly.binary_info.code_areas self.disassembly = disassembly self.lang_analyzer = LanguageAnalyzer(disassembly) self.disassembly.language = self.lang_analyzer.identify() self.bitness = disassembly.binary_info.bitness self.capstone = Cs(CS_ARCH_X86, CS_MODE_32) if self.bitness == 64: self.capstone = Cs(CS_ARCH_X86, CS_MODE_64) self.locateCandidates() self.disassembly.identified_alignment = self.identified_alignment self._buildQueue() def _passesCodeFilter(self, addr): if addr is None: return False if self._code_areas: for area in self._code_areas: if area[0] <= addr < area[1]: return True return False return True def getBitMask(self): if self.bitness == 64: return 0xFFFFFFFFFFFFFFFF return 0xFFFFFFFF def setInitialCandidate(self, addr): if addr in self.candidates: self.candidates[addr].setInitialCandidate(True) def isFunctionCandidate(self, addr): return addr in self.candidates def getFunctionCandidate(self, addr): if addr in self.candidates: return self.candidates[addr] return None def getAbortedCandidates(self): aborted = [] for addr, candidate in self.candidates.items(): if candidate.analysis_aborted: aborted.append(addr) return aborted def updateAnalysisAborted(self, addr, reason): LOGGER.debug("function analysis of 0x%08x aborted: %s", addr, reason) if addr in self.candidates: self.candidates[addr].setAnalysisAborted(reason) def updateAnalysisFinished(self, addr): LOGGER.debug("function analysis of 0x%08x successfully completed.", addr) if addr in self.candidates: self.candidates[addr].setAnalysisCompleted() def updateCandidates(self, state): if self.config.HIGH_ACCURACY: conflicts = state.identifyCallConflicts(self._all_call_refs) if conflicts: for candidate_addr, conflict in conflicts.items(): self.candidates[candidate_addr].removeCallRefs(conflict) self.candidate_queue.update() def addCandidate(self, addr, is_gap=False, reference_source=None): if not self._passesCodeFilter(addr): return False self.ensureCandidate(addr) self.candidates[addr].setIsGapCandidate(is_gap) if reference_source: self.candidates[addr].addCallRef(reference_source) self.candidate_queue.add(self.candidates[addr]) self.candidate_queue.update() def getNextFunctionStartCandidate(self): for candidate in self.candidate_queue: if not (candidate.isFinished() or candidate.getScore() == 0): if self.identified_alignment and candidate.alignment < self.identified_alignment: continue yield candidate def _logCandidateStats(self): logging.debug("Candidate Statistics:") try: maxc = max([c.getScore() for c in self.candidates.values()]) minc = min([c.getScore() for c in self.candidates.values()]) candidates_2 = len([c.getScore() for c in self.candidates.values() if c.getScore() == 2]) candidates_1 = len([c.getScore() for c in self.candidates.values() if c.getScore() == 1]) candidates_0 = len([c.getScore() for c in self.candidates.values() if c.getScore() == 0]) logging.debug(" Max: %f, Min: %f", maxc, minc) logging.debug(" 2: %d, 1: %d, 0: %d", candidates_2, candidates_1, candidates_0) except: logging.debug(" No candidates found.") def getFunctionStartCandidates(self): return self._candidate_offsets def updateFunctionGaps(self): gaps = [] prev_ins = 0 min_code = min(self.disassembly.code_map) if self.disassembly.code_map else self.getBitMask() max_code = max(self.disassembly.code_map) if self.disassembly.code_map else 0 for code_area in self._code_areas: if code_area[0] < min_code < code_area[1] and min_code != code_area[0]: gaps.append([code_area[0], min_code, min_code - code_area[0]]) if code_area[0] < max_code < code_area[1] and max_code != code_area[1]: gaps.append([max_code, code_area[1], code_area[1] - max_code]) for ins in sorted(self.disassembly.code_map.keys()): if prev_ins != 0: if ins - prev_ins > 1: gaps.append([prev_ins + 1, ins, ins - prev_ins]) prev_ins = ins self.function_gaps = sorted(gaps) def initGapSearch(self): if self.gap_pointer is None: LOGGER.debug("initGapSearch()") self.gap_pointer = self.getBitMask() self.updateFunctionGaps() if self.function_gaps: self.gap_pointer = self.function_gaps[0][0] LOGGER.debug("initGapSearch() gaps are:") for gap in self.function_gaps: LOGGER.debug("initGapSearch() 0x%08x - 0x%08x == %d", gap[0], gap[1], gap[2]) return def getNextGap(self, dont_skip=False): next_gap = self.getBitMask() for gap in self.function_gaps: if gap[0] > self.gap_pointer: next_gap = gap[0] break LOGGER.debug("getNextGap(%s) for 0x%08x based on gap_map: 0x%08x", dont_skip, self.gap_pointer, next_gap) # we potentially just disassembled a function and want to continue directly behind it in case we would otherwise miss more if dont_skip: if self.gap_pointer in self.disassembly.code_map: function = self.disassembly.ins2fn[self.gap_pointer] next_gap = min(next_gap, self.disassembly.function_borders[function][1]) LOGGER.debug("getNextGap(%s) without skip => after checking versus code map: 0x%08x", dont_skip, next_gap) LOGGER.debug("getNextGap(%s) final gap_ptr: 0x%08x", dont_skip, next_gap) return next_gap def isEffectiveNop(self, byte_sequence): if byte_sequence in GAP_SEQUENCES[len(byte_sequence)]: return True return False def isAlignmentSequence(self, instruction_sequence): is_alignment_sequence = False if len(instruction_sequence) > 0: current_offset = instruction_sequence[0].address for instruction in instruction_sequence: if instruction.bytes in GAP_SEQUENCES[len(instruction.bytes)]: current_offset += len(instruction.bytes) if current_offset % 16 == 0: is_alignment_sequence = True break else: break return is_alignment_sequence def nextGapCandidate(self, start_gap_pointer=None): if self.gap_pointer is None: self.initGapSearch() if start_gap_pointer: self.gap_pointer = start_gap_pointer LOGGER.debug("nextGapCandidate() finding new gap candidate, current gap_ptr: 0x%08x", self.gap_pointer) while True: if self.disassembly.binary_info.base_addr + self.disassembly.binary_info.binary_size < self.gap_pointer: LOGGER.debug("nextGapCandidate() gap_ptr: 0x%08x - finishing", self.gap_pointer) return None gap_offset = self.gap_pointer - self.disassembly.binary_info.base_addr if gap_offset >= self.disassembly.binary_info.binary_size: return None # compatibility with python2/3... try: byte = self.disassembly.getRawByte(gap_offset) except: print("0x%08x" % self.disassembly.binary_info.base_addr, "0x%08x" % self.disassembly.binary_info.binary_size, "0x%08x" % self.gap_pointer, "0x%08x" % gap_offset) # try to find padding symbols and skip them if isinstance(byte, int): byte = struct.pack("B", byte) if byte in GAP_SEQUENCES[1]: LOGGER.debug("nextGapCandidate() found 0xCC / 0x00 - gap_ptr += 1: 0x%08x", self.gap_pointer) self.gap_pointer += 1 continue # try to find instructions that directly encode as NOP and skip them ins_buf = [i for i in self.capstone.disasm(self.disassembly.getRawBytes(gap_offset, 15), gap_offset)] if ins_buf and ins_buf[0].mnemonic == "nop": nop_instruction = ins_buf[0].mnemonic + " " + ins_buf[0].op_str nop_length = len(ins_buf[0].bytes) LOGGER.debug("nextGapCandidate() found nop instruction (%s) - gap_ptr += %d: 0x%08x", nop_instruction, nop_length, self.gap_pointer) self.gap_pointer += nop_length continue # try to find effective NOPs and skip them. found_multi_byte_nop = False for gap_length in range(max(GAP_SEQUENCES.keys()), 1, -1): if self.disassembly.getRawBytes(gap_offset, gap_length) in GAP_SEQUENCES[gap_length]: LOGGER.debug("nextGapCandidate() found %d byte effective nop - gap_ptr += %d: 0x%08x", gap_length, gap_length, self.gap_pointer) self.gap_pointer += gap_length found_multi_byte_nop = True break if found_multi_byte_nop: continue # we know this place from data already if self.gap_pointer in self.disassembly.data_map: LOGGER.debug("nextGapCandidate() gap_ptr is already inside data map: 0x%08x", self.gap_pointer) self.gap_pointer += 1 continue if self.gap_pointer in self.disassembly.code_map: LOGGER.debug("nextGapCandidate() gap_ptr is already inside code map: 0x%08x", self.gap_pointer) self.gap_pointer = self.getNextGap() continue # we may have a candidate here LOGGER.debug("nextGapCandidate() using 0x%08x as candidate", self.gap_pointer) start_byte = self.disassembly.getRawByte(gap_offset) has_common_prologue = True # start_byte in FunctionCandidate(self.gap_pointer, start_byte, self.bitness).common_gap_starts[self.bitness] if self.previously_analyzed_gap == self.gap_pointer: LOGGER.debug("--- HRM, nextGapCandidate() gap_ptr at: 0x%08x was previously analyzed", self.gap_pointer) self.gap_pointer = self.getNextGap(dont_skip=True) elif not has_common_prologue: LOGGER.debug("--- HRM, nextGapCandidate() gap_ptr at: 0x%08x has no common prologue (0x%08x)", self.gap_pointer, ord(start_byte)) self.gap_pointer = self.getNextGap(dont_skip=True) else: self.previously_analyzed_gap = self.gap_pointer self.addGapCandidate(self.gap_pointer) return self.gap_pointer return None def checkFunctionOverlap(self): function_boundaries = [] for function in self.disassembly.functions: min_addr = self.getBitMask() max_addr = 0 for block in self.disassembly.functions[function]: min_addr = min(min_addr, min([instruction[0] for instruction in block])) max_addr = max(max_addr, max([instruction[0] + instruction[1] for instruction in block])) function_boundaries.append((min_addr, max_addr)) current_entry = (0, 0) for entry in sorted(function_boundaries): if current_entry[1] > entry[0]: return True current_entry = entry return False def checkCodePadding(self): pattern_count = 0 pattern_functions = [] for pattern in re.finditer(r"((\xCC){2,}|(\x90){2,})", self.disassembly.binary_info.binary): pattern_count += 1 pattern_functions.append(pattern.span()[1] + 1) def ensureCandidate(self, addr): """ create candidate if it does not exist yet, returns True if newly created, else False """ if addr not in self.candidates: self.candidates[addr] = FunctionCandidate(self.disassembly.binary_info, addr) return True return False def addGapCandidate(self, addr): if not self._passesCodeFilter(addr): return False self.ensureCandidate(addr) self.candidates[addr].setIsGapCandidate(True) def addTailcallCandidate(self, addr): if not self._passesCodeFilter(addr): return False self.ensureCandidate(addr) self.candidates[addr].setIsTailcallCandidate(True) def addReferenceCandidate(self, addr, source_ref): if not self._passesCodeFilter(addr): return False if self.ensureCandidate(addr): self._all_call_refs[source_ref] = addr self.candidates[addr].addCallRef(source_ref) def addLanguageSpecCandidate(self, addr, lang_spec): if not self._passesCodeFilter(addr): return False self.ensureCandidate(addr) self.candidates[addr].setLanguageSpec(lang_spec) def addPrologueCandidate(self, addr): if not self._passesCodeFilter(addr): return False return self.ensureCandidate(addr) def addSymbolCandidate(self, addr): if not self._passesCodeFilter(addr): return False self.ensureCandidate(addr) self.candidates[addr].setIsSymbol(True) self.candidates[addr].setInitialCandidate(True) def resolvePointerReference(self, offset): if self.bitness == 32: addr_block = self.disassembly.getRawBytes(offset + 2, 4) function_pointer = struct.unpack("I", addr_block)[0] return self.disassembly.dereferenceDword(function_pointer) if self.bitness == 64: addr_block = self.disassembly.getRawBytes(offset + 2, 4) function_pointer = struct.unpack("i", addr_block)[0] # we need to calculate RIP + offset + 7 (48 ff 25 ** ** ** **) if self.disassembly.getRawBytes(offset, 2) == "\xFF\x25": function_pointer += offset + 7 elif self.disassembly.getRawBytes(offset, 2) == "\xFF\x15": function_pointer += offset + 6 else: raise Exception("resolvePointerReference: should only be used on call/jmp * ptr") return self.disassembly.binary_info.base_addr + function_pointer raise Exception("resolvePointerReference: undefined bitness") def _identifyAlignment(self): identified_alignment = 0 if self.config.USE_ALIGNMENT: num_candidates = sum([1 for addr, candidate in self.candidates.items() if len(candidate.call_ref_sources) > 1]) num_aligned_16_candidates = sum([1 for addr, candidate in self.candidates.items() if len(candidate.call_ref_sources) > 1 and candidate.alignment == 16]) num_aligned_4_candidates = sum([1 for addr, candidate in self.candidates.items() if len(candidate.call_ref_sources) > 1 and candidate.alignment >= 4]) if num_candidates: alignment_16_ratio = 1.0 * num_aligned_16_candidates / num_candidates alignment_4_ratio = 1.0 * num_aligned_4_candidates / num_candidates if num_candidates > 20 and alignment_4_ratio > 0.95: identified_alignment = 4 if num_candidates > 20 and alignment_16_ratio > 0.95: identified_alignment = 16 return identified_alignment def locateCandidates(self): self.locateSymbolCandidates() self.locateReferenceCandidates() self.locatePrologueCandidates() self.locateLangSpecCandidates() self.locateStubChainCandidates() self.identified_alignment = self._identifyAlignment() def _buildQueue(self): LOGGER.debug("Located %d function candidates", len(self.candidates)) # increase lookup speed with static list self._candidate_offsets = [c.addr for c in self.candidates.values()] self.cached_candidates = list(self.candidates.values()) self.candidate_queue = PriorityQueue(content=self.cached_candidates) def locateSymbolCandidates(self): for symbol_addr in self.symbol_addresses: self.addSymbolCandidate(symbol_addr) def locateReferenceCandidates(self): # check for potential call instructions and check if their destinations have a common function prologue for call_match in re.finditer(b"\xE8", self.disassembly.binary_info.binary): if not self._passesCodeFilter(self.disassembly.binary_info.base_addr + call_match.start()): continue if len(self.disassembly.binary_info.binary) - call_match.start() > 5: packed_call = self.disassembly.getRawBytes(call_match.start() + 1, 4) rel_call_offset = struct.unpack("i", packed_call)[0] # ignore zero offset calls, as they will likely not lead to functions but are rather used for positioning in shellcode etc if rel_call_offset == 0: continue call_destination = (self.disassembly.binary_info.base_addr + rel_call_offset + call_match.start() + 5) & self.getBitMask() if self.disassembly.isAddrWithinMemoryImage(call_destination): self.addReferenceCandidate(call_destination, self.disassembly.binary_info.base_addr + call_match.start()) self.setInitialCandidate(call_destination) # also check for "jmp dword ptr <offset>", as they sometimes point to local functions (i.e. non-API) if self.bitness == 32: for match in re.finditer(b"\xFF\x25", self.disassembly.binary_info.binary): function_addr = self.resolvePointerReference(match.start()) if not self._passesCodeFilter(function_addr): continue if self.disassembly.isAddrWithinMemoryImage(function_addr): self.addReferenceCandidate(function_addr, self.disassembly.binary_info.base_addr + match.start()) self.setInitialCandidate(function_addr) # also check for "call dword ptr <offset>", as they sometimes point to local functions (i.e. non-API) for match in re.finditer(b"\xFF\x15", self.disassembly.binary_info.binary): function_addr = self.resolvePointerReference(match.start()) if not self._passesCodeFilter(function_addr): continue if self.disassembly.isAddrWithinMemoryImage(function_addr): self.addReferenceCandidate(function_addr, self.disassembly.binary_info.base_addr + match.start()) self.setInitialCandidate(function_addr) def locatePrologueCandidates(self): # next check for the default function prologue regardless of references for re_prologue in DEFAULT_PROLOGUES: for prologue_match in re.finditer(re_prologue, self.disassembly.binary_info.binary): if not self._passesCodeFilter(self.disassembly.binary_info.base_addr + prologue_match.start()): continue self.addPrologueCandidate((self.disassembly.binary_info.base_addr + prologue_match.start()) & self.getBitMask()) self.setInitialCandidate((self.disassembly.binary_info.base_addr + prologue_match.start()) & self.getBitMask()) def locateLangSpecCandidates(self): # if the sample is highly likely delphi, extract t-string-objects and use their function-addresses as high-confidence function starts delphi_candidates = set([]) if self.lang_analyzer.checkDelphi(): LOGGER.debug("Programming language recognized as Delphi, adding function start addresses from TObjects") t_objects = self.lang_analyzer.getDelphiObjects() for t_string in t_objects: delphi_candidates.update(set(t_objects[t_string])) LOGGER.debug("delphi candidates based on TObject analysis: %d", len(delphi_candidates)) for obj in delphi_candidates: self.addLanguageSpecCandidate(obj, "delphi") def locateStubChainCandidates(self): # binaries often contain long sequences of stubs, consisting only of jmp dword ptr <offset>, add such chains as candidates for block in re.finditer(b"(?P<block>(\xFF\x25[\S\s]{4}){2,})", self.disassembly.binary_info.binary): for match in re.finditer(b"\xFF\x25(?P<function>[\S\s]{4})", block.group("block")): stub_addr = self.disassembly.binary_info.base_addr + block.start() + match.start() if not self._passesCodeFilter(stub_addr): continue self.addPrologueCandidate(stub_addr & self.getBitMask()) self.setInitialCandidate(stub_addr & self.getBitMask()) self.candidates[stub_addr].setIsStub(True) # structure for plt entries is similar but interleaved with additional code not considered functions for block in re.finditer(b"(?P<block>(\xFF\x25[\S\s]{4}\x68[\S\s]{4}\xE9[\S\s]{4}){2,})", self.disassembly.binary_info.binary): for match in re.finditer(b"\xFF\x25(?P<function>[\S\s]{4})", block.group("block")): stub_addr = self.disassembly.binary_info.base_addr + block.start() + match.start() if not self._passesCodeFilter(stub_addr): continue self.addPrologueCandidate(stub_addr & self.getBitMask()) self.setInitialCandidate(stub_addr & self.getBitMask()) self.candidates[stub_addr].setIsStub(True) # define data bytes inbetween for offset in range(10): self.disassembly.data_map.add(stub_addr + 6 + offset)