import os
from rex.exploit import CannotExploit, ExploitFactory
import tempfile

import logging
l = logging.getLogger("rex.exploit.cgc.cgc_exploit_factory")

class CGCExploitFactory(ExploitFactory):
    '''
    Exploit factory for CGC challenges
    '''

    cgc_registers = ["eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi"]

    def __init__(self, crash, blacklist_techniques=None):
        '''
        :param crash: an exploitable crash object
        :param blacklist_techniques: a set of techniques to skip
        '''
        super(CGCExploitFactory, self).__init__(crash, blacklist_techniques=blacklist_techniques)

        # mapping of register names to type1 exploit objects
        self.register_setters = [ ]

        # a type2 exploit object
        self.leakers = [ ]

        # interesting inputs generated by techniques which
        # dont directly create exploits
        self.manipulations = [ ]

        # type1 exploit with the highest chance of success
        self.best_type1 = None

        # type2 exploit with the highest chance of success
        self.best_type2 = None

    def has_type1(self):
        '''
        Tests if an exploit factory has a type1 exploit.
        '''

        return self.best_type1 is not None

    def has_type2(self):
        '''
        Tests if an exploit factory has a type2 exploit.
        '''

        return self.best_type2 is not None

    def generate_report(self, register_setters, leakers):

        stat_name = tempfile.mktemp(dir=".", prefix='rex-results-')

        l.info("exploitation report being written to '%s'", stat_name)

        f = open(stat_name, 'w')
        f.write("Binary %s:\n" % os.path.basename(self.crash.project.filename))
        f.write("Register setting exploits:\n")
        for register_setter in register_setters:
            f.write("\t%s\n" % str(register_setter))
        f.write("\n")
        f.write("Leaker exploits:\n")
        for leaker in leakers:
            f.write("\t%s\n" % str(leaker))

        f.close()

    @staticmethod
    def _grade_exploit(exploit):
        '''
        grade an exploit based on whether it can bypass nx and aslr
        the higher the score the better
        '''

        grade = int(exploit.bypasses_nx) * 2
        grade += int(exploit.bypasses_aslr)

        return grade

    def initialize(self):
        os = 'unix' if self.os.startswith('UNIX') else self.os
        for technique in Techniques[os]:
            p = technique(self.crash, self.rop, self.shellcode)
            if p.name in self.blacklist_techniques:
                continue
            try:
                l.debug("applying technique %s", p.name)
                result = p.apply()
                if result is not None:
                    if p.pov_type == 1:
                        self.register_setters.append(result)
                    elif p.pov_type == 2:
                        self.leakers.append(result)
            except CannotExploit as e:
                l.debug("technique failed: %s", e)

        if len(self.register_setters + self.leakers) == 0:
            l.debug("no exploits, running blacklisted")
            os = 'unix' if self.os.startswith('UNIX') else self.os
            for technique in Techniques[os]:
                p = technique(self.crash, self.rop, self.shellcode)
                if p.name not in self.blacklist_techniques:
                    continue
                try:
                    l.debug("applying technique %s", p.name)
                    result = p.apply()
                    if result is not None:
                        if p.pov_type == 1:
                            self.register_setters.append(result)
                        elif p.pov_type == 2:
                            self.leakers.append(result)
                except CannotExploit as e:
                    l.debug("technique failed: %s", e)

        # pick the best register setting exploits based on grading
        l.debug("done applying techniques, grading exploits...")
        ordered_exploits = sorted(self.register_setters, key=self._grade_exploit)[::-1]
        if len(ordered_exploits) > 0:
            self.best_type1 = ordered_exploits[0]

        # pick the best leaker exploits based on grading
        ordered_exploits = sorted(self.leakers, key=self._grade_exploit)[::-1]
        if len(ordered_exploits) > 0:
            self.best_type2 = ordered_exploits[0]

    def yield_exploits(self):

        os = 'unix' if self.os.startswith('UNIX') else self.os
        for technique in Techniques[os]:
            p = technique(self.crash, self.rop, self.shellcode)
            if p.name in self.blacklist_techniques:
                continue
            try:
                l.debug("applying technique %s", p.name)
                result = p.apply()
                if result is not None:
                    yield result
                    if p.pov_type == 1:
                        self.register_setters.append(result)
                    elif p.pov_type == 2:
                        self.leakers.append(result)
            except CannotExploit as e:
                l.debug("technique failed: %s", e)

        if len(self.register_setters + self.leakers) == 0:
            l.debug("no exploits, running blacklisted")
            os = 'unix' if self.os.startswith('UNIX') else self.os
            for technique in Techniques[os]:
                p = technique(self.crash, self.rop, self.shellcode)
                if p.name not in self.blacklist_techniques:
                    continue
                try:
                    l.debug("applying technique %s", p.name)
                    result = p.apply()
                    if result is not None:
                        yield result
                        if p.pov_type == 1:
                            self.register_setters.append(result)
                        elif p.pov_type == 2:
                            self.leakers.append(result)
                except CannotExploit as e:
                    l.debug("technique failed: %s", e)

from .techniques import Techniques