# noqa: ignore=F811

import idaapi

import os


from bap.utils import trace
from PyQt5 import QtWidgets
from PyQt5 import QtGui

from PyQt5.QtCore import (
    Qt,
    QFile,
    QIODevice,
    QCryptographicHash as QCrypto,
    QRegExp,
    QTimer,
    QAbstractItemModel,
    QModelIndex,
    QVariant,
    pyqtSignal,
    QSortFilterProxyModel)


def add_insn_to_trace_view(ea, tid=1):
    idaapi.dbg_add_tev(1, tid, ea)


@trace.handler('pc-changed', requires=['machine-id', 'pc'])
def tev_insn(state, ev):
    "stores each visited instruction to the IDA Trace Window"
    add_insn_to_trace_view(state['pc'], tid=state['machine-id'])


@trace.handler('pc-changed', requires=['pc'])
def tev_insn0(state, ev):
    """stores each visited instruction to the IDA Trace Window.

    But doesn't set the pid/tid field, and keep it equal to 0
    (This enables interoperation with the debugger)
    """
    add_insn_to_trace_view(state['pc'])


@trace.handler('call', requires=['machine-id', 'pc'])
def tev_call(state, call):
    "stores call events to the IDA Trace Window"
    caller = state['pc']
    callee = idaapi.get_name_ea(0, call[0])
    idaapi.dbg_add_call_tev(state['machine-id'], caller, callee)


incidents = []
locations = {}


@trace.handler('incident')
def incident(state, data):
    incidents.append(Incident(data[0], [int(x) for x in data[1:]]))


@trace.handler('incident-location')
def incident_location(state, data):
    id = int(data[0])
    locations[id] = [parse_point(p) for p in data[1]]


# we are using PyQt5 here, because IDAPython relies on a system
# openssl 0.9.8 which is quite outdated and not available on most
# modern installations
def md5sum(filename):
    """computes md5sum of a file with the given ``filename``

    The return value is a 32 byte hexadecimal ASCII representation of
    the md5 sum (same value as returned by the ``md5sum filename`` command)
    """
    stream = QFile(filename)
    if not stream.open(QIODevice.ReadOnly | QIODevice.Text):
        raise IOError("Can't open file: " + filename)
    hasher = QCrypto(QCrypto.Md5)
    if not hasher.addData(stream):
        raise ValueError('Unable to hash file: ' + filename)
    stream.close()
    return str(hasher.result().toHex())


class HandlerSelector(QtWidgets.QGroupBox):
    def __init__(self, parent=None):
        super(HandlerSelector, self).__init__("Trace Event Processors", parent)
        self.setFlat(True)
        box = QtWidgets.QVBoxLayout(self)
        self.options = {}
        for name in trace.handlers:
            btn = QtWidgets.QCheckBox(name)
            btn.setToolTip(trace.handlers[name].__doc__)
            box.addWidget(btn)
            self.options[name] = btn
        box.addStretch(1)
        self.setCheckable(True)
        self.setChecked(True)
        self.setLayout(box)


class MachineSelector(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(MachineSelector, self).__init__(parent)
        box = QtWidgets.QHBoxLayout(self)
        label = MonitoringLabel('List of &machines (threads)')
        self.is_ready = label.is_ready
        self.updated = label.updated
        box.addWidget(label)
        self._machines = QtWidgets.QLineEdit('all')
        self._machines.setToolTip('an integer, \
        a comma-separated list of integers, or "all"')
        grammar = QRegExp(r'\s*(all|\d+\s*(,\s*\d+\s*)*)\s*')
        valid = QtGui.QRegExpValidator(grammar)
        self._machines.setValidator(valid)
        label.setBuddy(self._machines)
        box.addWidget(self._machines)
        box.addStretch(1)
        self.setLayout(box)

    def selected(self):
        if not self._machines.hasAcceptableInput():
            raise ValueError('invalid input')
        data = self._machines.text().strip()
        if data == 'all':
            return None
        else:
            return [int(x) for x in data.split(',')]


class MonitoringLabel(QtWidgets.QLabel):
    "a label that will monitors the validity of its buddy"

    updated = pyqtSignal()

    def __init__(self, text='', buddy=None, parent=None):
        super(MonitoringLabel, self).__init__(parent)
        self.setText(text)
        if buddy:
            self.setBuddy(buddy)

    def setText(self, text):
        super(MonitoringLabel, self).setText(text)
        self.text = text

    def setBuddy(self, buddy):
        super(MonitoringLabel, self).setBuddy(buddy)
        buddy.textChanged.connect(lambda x: self.update())
        self.update()

    def is_ready(self):
        return not self.buddy() or self.buddy().hasAcceptableInput()

    def update(self):
        self.updated.emit()
        if self.is_ready():
            super(MonitoringLabel, self).setText(self.text)
        else:
            super(MonitoringLabel, self).setText(
                '<font color=red>'+self.text+'</font>')


class ExistingFileValidator(QtGui.QValidator):
    def __init__(self, parent=None):
        super(ExistingFileValidator, self).__init__(parent)

    def validate(self, name, pos):
        if os.path.isfile(name):
            return (self.Acceptable, name, pos)
        else:
            return (self.Intermediate, name, pos)


class TraceFileSelector(QtWidgets.QWidget):

    def __init__(self, parent=None):
        super(TraceFileSelector, self).__init__(parent)
        box = QtWidgets.QHBoxLayout(self)
        label = MonitoringLabel('Trace &file:')
        self.is_ready = label.is_ready
        self.updated = label.updated
        box.addWidget(label)
        self.location = QtWidgets.QLineEdit('incidents')
        self.text = self.location.text
        must_exist = ExistingFileValidator()
        self.location.setValidator(must_exist)
        label.setBuddy(self.location)
        box.addWidget(self.location)
        openfile = QtWidgets.QPushButton(self)
        openfile.setIcon(self.style().standardIcon(
            QtWidgets.QStyle.SP_DialogOpenButton))
        dialog = QtWidgets.QFileDialog(self)
        openfile.clicked.connect(dialog.open)
        dialog.fileSelected.connect(self.location.setText)
        box.addWidget(openfile)
        box.addStretch(1)
        self.setLayout(box)


class IncidentView(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(IncidentView, self).__init__(parent)
        self.view = QtWidgets.QTreeView()
        self.view.setAllColumnsShowFocus(True)
        self.view.setUniformRowHeights(True)
        box = QtWidgets.QVBoxLayout()
        box.addWidget(self.view)
        self.load_trace = QtWidgets.QPushButton('&Trace')
        self.load_trace.setToolTip('Load into the Trace Window')
        self.load_trace.setEnabled(False)
        for activation_signal in [
                self.view.activated,
                self.view.entered,
                self.view.pressed]:
            activation_signal.connect(lambda _: self.update_controls_state())
        self.load_trace.clicked.connect(self.load_current_trace)
        self.view.doubleClicked.connect(self.jump_to_index)
        hbox = QtWidgets.QHBoxLayout()
        self.filter = QtWidgets.QLineEdit()
        self.filter.textChanged.connect(self.filter_model)
        filter_label = QtWidgets.QLabel('&Search')
        filter_label.setBuddy(self.filter)
        hbox.addWidget(filter_label)
        hbox.addWidget(self.filter)
        hbox.addWidget(self.load_trace)
        box.addLayout(hbox)
        self.setLayout(box)
        self.model = None
        self.proxy = None

    def display(self, incidents, locations):
        self.model = IncidentModel(incidents, locations, self)
        self.proxy = QSortFilterProxyModel(self)
        self.proxy.setSourceModel(self.model)
        self.proxy.setFilterRole(self.model.filter_role)
        self.proxy.setFilterRegExp(QRegExp(self.filter.text()))
        self.view.setModel(self.proxy)

    def filter_model(self, txt):
        if self.proxy:
            self.proxy.setFilterRegExp(QRegExp(txt))

    def update_controls_state(self):
        curr = self.view.currentIndex()
        self.load_trace.setEnabled(curr.isValid() and
                                   curr.parent().isValid())

    def load_current_trace(self):
        idx = self.proxy.mapToSource(self.view.currentIndex())
        if not idx.isValid() or index_level(idx) not in (1, 2):
            raise ValueError('load_current_trace: invalid index')

        if index_level(idx) == 2:
            idx = idx.parent()

        incident = self.model.incidents[idx.parent().row()]
        location = incident.locations[idx.row()]
        backtrace = self.model.locations[location]

        for p in reversed(backtrace):
            self.load_trace_point(p)

    def jump_to_index(self, idx):
        idx = self.proxy.mapToSource(idx)
        if index_level(idx) != 2:
            # don't mess with parents, they are used to create children
            return
        grandpa = idx.parent().parent()
        incident = self.model.incidents[grandpa.row()]
        location = incident.locations[idx.parent().row()]
        trace = self.model.locations[location]
        point = trace[idx.row()]
        self.show_trace_point(point)

    def load_trace_point(self, p):
        add_insn_to_trace_view(p.addr)

    def show_trace_point(self, p):
        idaapi.jumpto(p.addr)


class TraceLoaderController(QtWidgets.QWidget):
    finished = pyqtSignal()

    def __init__(self, parent=None):
        super(TraceLoaderController, self).__init__(parent)
        self.loader = None
        box = QtWidgets.QVBoxLayout(self)
        self.location = TraceFileSelector(self)
        self.handlers = HandlerSelector(self)
        self.machines = MachineSelector(self)
        box.addWidget(self.location)
        box.addWidget(self.handlers)
        box.addWidget(self.machines)
        self.load = QtWidgets.QPushButton('&Load')
        self.load.setDefault(True)
        self.load.setEnabled(self.location.is_ready())
        self.cancel = QtWidgets.QPushButton('&Stop')
        self.cancel.setVisible(False)
        hor = QtWidgets.QHBoxLayout()
        hor.addWidget(self.load)
        hor.addWidget(self.cancel)
        self.progress = QtWidgets.QProgressBar()
        self.progress.setVisible(False)
        hor.addWidget(self.progress)
        hor.addStretch(2)
        box.addLayout(hor)

        def enable_load():
            self.load.setEnabled(self.location.is_ready() and
                                 self.machines.is_ready())
        self.location.updated.connect(enable_load)
        self.machines.updated.connect(enable_load)
        enable_load()
        self.processor = QTimer()
        self.processor.timeout.connect(self.process)
        self.load.clicked.connect(self.processor.start)
        self.cancel.clicked.connect(self.stop)
        self.setLayout(box)

    def start(self):
        self.cancel.setVisible(True)
        self.load.setVisible(False)
        filename = self.location.text()
        self.loader = trace.Loader(file(filename))
        self.progress.setVisible(True)
        stat = os.stat(filename)
        self.progress.setRange(0, stat.st_size)
        machines = self.machines.selected()
        if machines is not None:
            self.loader.enable_filter('filter-machine', id=machines)

        for name in self.handlers.options:
            if self.handlers.options[name].isChecked():
                self.loader.enable_handlers([name])

    def stop(self):
        self.processor.stop()
        self.progress.setVisible(False)
        self.cancel.setVisible(False)
        self.load.setVisible(True)
        self.loader = None
        self.finished.emit()

    def process(self):
        if not self.loader:
            self.start()
        try:
            self.loader.next()
            self.progress.setValue(self.loader.parser.lexer.instream.tell())
        except StopIteration:
            self.stop()


def index_level(idx):
    if idx.isValid():
        return 1 + index_level(idx.parent())
    else:
        return -1


def index_up(idx, level=0):
    if level == 0:
        return idx
    else:
        return index_up(idx.parent(), level=level-1)


class IncidentIndex(object):
    def __init__(self, model, index):
        self.model = model
        self.index = index

    @property
    def incidents(self):
        return self.model.incidents

    @property
    def level(self):
        return index_level(self.index)

    @property
    def column(self):
        return self.index.column()

    @property
    def row(self):
        return self.index.row()

    @property
    def incident(self):
        top = index_up(self.index, self.level)
        return self.incidents[top.row()]

    @property
    def location(self):
        if self.level in (1, 2):
            top = self.index
            if self.level == 2:
                top = index_up(self.index, 1)
            location_id = self.incident.locations[top.row()]
            if self.model.locations is None:
                return None
            else:
                return self.model.locations.get(location_id)

    @property
    def point(self):
        if self.level == 2:
            return self.location[self.index.row()]


class IncidentModel(QAbstractItemModel):
    filter_role = Qt.UserRole
    sort_role = Qt.UserRole + 1

    handlers = []

    def __init__(self, incidents, locations, parent=None):
        super(IncidentModel, self).__init__(parent)
        self.incidents = incidents
        self.locations = locations
        self.parents = {0: QModelIndex()}
        self.child_ids = 0

    def dispatch(self, role, index):
        for handler in self.handlers:
            def sat(c, v):
                if c == 'roles':
                    return role in v
                if c == 'level':
                    return index.level == v
                if c == 'column':
                    return index.column == v

            for (c, v) in handler['constraints'].items():
                if not sat(c, v):
                    break
            else:
                return handler['accept'](index)

    def index(self, row, col, parent):
        if parent.isValid():
            self.child_ids += 1
            index = self.createIndex(row, col, self.child_ids)
            self.parents[self.child_ids] = parent
            return index
        else:
            return self.createIndex(row, col, 0)

    def parent(self, child):
        return self.parents[child.internalId()]

    def rowCount(self, parent):
        n = self.dispatch('row-count', IncidentIndex(self, parent))
        return 0 if n is None else n

    def columnCount(self, parent):
        return 2 if not parent.isValid() or parent.column() == 0 else 0

    def data(self, index, role):
        role = {
            Qt.DisplayRole: 'display',
            self.sort_role: 'sort',
            self.filter_role: 'filter'
        }.get(role)

        if role:
            return QVariant(self.dispatch(role, IncidentIndex(self, index)))
        else:
            return QVariant()


def defmethod(*args, **kwargs):
    def register(method):
        kwargs['roles'] = args
        IncidentModel.handlers.append({
            'name': method.__name__,
            'constraints': kwargs,
            'accept': method})
    return register


@defmethod('display', level=2, column=0)
def display_point(msg):
    return '{:x}'.format(msg.point.addr)


@defmethod('display', level=2, column=1)
def display_point_machine(msg):
    return msg.point.machine


@defmethod('display', level=1, column=0)
def display_incident_location(msg):
    return 'location-{}'.format(msg.row)


@defmethod('display', level=0, column=0)
def display_incident_name(msg):
    return msg.incident.name


@defmethod('display', level=0, column=1)
def display_incident_id(msg):
    return msg.row


@defmethod('sort', 'filter', column=0)
def incident_name(msg):
    return msg.incident.name


@defmethod('row-count', level=-1)
def number_of_incidents(msg):
    return len(msg.incidents)


@defmethod('row-count', level=0, column=0)
def number_of_locations(msg):
    return len(msg.incident.locations)


@defmethod('row-count', level=1, column=0)
def backtrace_length(msg):
    return 0 if msg.location is None else len(msg.location)


class Incident(object):
    __slots__ = ['name', 'locations']

    def __init__(self, name, locations):
        self.name = name
        self.locations = locations

    def __repr__(self):
        return 'Incident({}, {})'.format(repr(self.name),
                                         repr(self.locations))


class Point(object):
    __slots__ = ['addr', 'machine']

    def __init__(self, addr, machine=None):
        self.addr = addr
        self.machine = machine

    def __str__(self):
        if self.machine:
            return '{}:{}'.format(self.machine, self.addr)
        else:
            return str(self.addr)

    def __repr__(self):
        if self.machine:
            return 'Point({},{})'.format(self.machine, self.addr)
        else:
            return 'Point({})'.format(repr(self.addr))


def parse_point(data):
    parts = data.split(':')
    if len(parts) == 1:
        return Point(int(data, 16))
    else:
        return Point(int(parts[1], 16), int(parts[0]))


class BapTraceMain(idaapi.PluginForm):
    def OnCreate(self, form):
        form = self.FormToPyQtWidget(form)
        self.control = TraceLoaderController(form)
        self.incidents = IncidentView(form)

        def display():
            self.incidents.display(incidents, locations)
        self.control.finished.connect(display)
        box = QtWidgets.QHBoxLayout()
        split = QtWidgets.QSplitter()
        split.addWidget(self.control)
        split.addWidget(self.incidents)
        box.addWidget(split)
        form.setLayout(box)


class BapTracePlugin(idaapi.plugin_t):
    wanted_name = 'BAP: Load Observations'
    wanted_hotkey = ''
    flags =  idaapi.PLUGIN_FIX
    comment = 'Load Primus Observations'
    help = """
    Loads Primus Observations into IDA for further analysis
    """

    def __init__(self):
        self.form = None
        self.name = 'Primus Observations'

    def init(self):
        return idaapi.PLUGIN_KEEP

    def term(self):
        pass

    def run(self, arg):
        if not self.form:
            self.form = BapTraceMain()
        return self.form.Show(self.name, options=(
            self.form.FORM_PERSIST |
            self.form.FORM_SAVE))


def PLUGIN_ENTRY():
    return BapTracePlugin()


main = None


def bap_trace_test():
    global main
    main = BapTraceMain()
    main.Show('Primus Observations')