'''
IDA Pro plugin that emulates WebAssembly instructions and shows their effects.

Usage:
  1. Select a sequence of instructions in the disassembly view.
  2. Run this script.
  3. Review the output view. You'll see results like this:

      globals:
        $frame_stack: ($frame_stack - 0x40)
      locals:
        $local5: $frame_stack
        $local6: 0x40
        $frame_pointer: ($frame_stack - 0x40)
        $local8: 0x0

Notes:

This plugin supports a subset of the WebAssembly opcodes.
Currently it does not implement branching or comparison instructions.
Therefore, you should select instructions within a single basic block.
'''
import logging
import collections

import wasm
import wasm.decode
import wasm.opcodes
import netnode

import idaapi
import ida_bytes


logger = logging.getLogger('wasm-emu')



def get_type_sort_order(v):
    '''
    when ordering a list of (possibly complex) values, prefer:
      binary-op < memory < global < local-var < i32
    '''
    if isinstance(v, BinaryOperation):
        return 1
    elif isinstance(v, Memory):
        return 2
    elif isinstance(v, GlobalVariable):
        return 3
    elif isinstance(v, LocalVariable):
        return 4
    elif isinstance(v, I32):
        return 5
    else:
        raise ValueError('unexpected value type')


def cmp(a, b):
    '''
    define a general purpose ordering for (possibly complex) values.

    useful when rendering the memory map, which has signature Dict[Any, Any],
     where the key is any of our types or nodes (I32, LocalVarable, etc).
    '''
    at = get_type_sort_order(a)
    bt = get_type_sort_order(b)

    if at != bt:
        return at < bt

    if isinstance(a, I32):
        return a.value < b.value

    elif isinstance(a, LocalVariable):
        return a.local_index < b.local_index

    elif isinstance(a, GlobalVariable):
        return a.global_index < b.global_index

    elif isinstance(a, Memory):
        return cmp(a.address, b.address)

    elif isinstance(a, BinaryOperation):
        if a.operation != b.operation:
            return a.operation < b.operation

        # a.lhs != b.lhs
        if a.lhs < b.lhs or b.lhs < a.lhs:
            return cmp(a.lhs, b.lhs)
        # a.rhs != b.rhs
        elif a.rhs != b.rhs or b.rhs < a.rhs:
            return cmp(a.rhs, b.rhs)

        return False

    else:
        raise ValueError('unexpected value type: ' + str(type(self)))


class I32(object):
    def __init__(self, value):
        self.value = value

    def render(self, ctx={}):
        return '0x{self.value:X}'.format(**locals())

    def __lt__(self, other):
        return cmp(self, other)


class LocalVariable(object):
    def __init__(self, local_index):
        self.local_index = local_index

    def render(self, ctx={}):
        return '{local}'.format(local=render_local(self.local_index, ctx=ctx))

    def __lt__(self, other):
        return cmp(self, other)


class GlobalVariable(object):
    def __init__(self, global_index):
        self.global_index = global_index

    def render(self, ctx={}):
        return '{g}'.format(g=render_global(self.global_index, ctx=ctx))

    def __lt__(self, other):
        return cmp(self, other)


def is_frame_pointer(value, ctx={}):
    return render(value, ctx=ctx) == '$frame_pointer'


class Memory(object):
    def __init__(self, address):
        self.address = address

    def render(self, ctx={}):
        addr = reduce(self.address)
        return 'memory[{addr}]'.format(addr=render(self.address, ctx=ctx))

    def __lt__(self, other):
        return cmp(self, other)


class BinaryOperation(object):
    def __init__(self, operation, lhs, rhs):
        self.operation = operation
        self.lhs = lhs
        self.rhs = rhs

    def render(self, ctx={}):
        return '({lhs} {op} {rhs})'.format(
            lhs=render(self.lhs, ctx=ctx),
            op=self.operation,
            rhs=render(self.rhs, ctx=ctx))

    def __lt__(self, other):
        return cmp(self, other)


class AddOperation(BinaryOperation):
    def __init__(self, lhs, rhs):
        super(AddOperation, self).__init__('+', lhs, rhs)


class SubOperation(BinaryOperation):
    def __init__(self, lhs, rhs):
        super(SubOperation, self).__init__('-', lhs, rhs)


class AndOperation(BinaryOperation):
    def __init__(self, lhs, rhs):
        super(AndOperation, self).__init__('&', lhs, rhs)


class ShlOperation(BinaryOperation):
    def __init__(self, lhs, rhs):
        super(ShlOperation, self).__init__('<<', lhs, rhs)


class ShruOperation(BinaryOperation):
    def __init__(self, lhs, rhs):
        super(ShruOperation, self).__init__('>>', lhs, rhs)


class XorOperation(BinaryOperation):
    def __init__(self, lhs, rhs):
        super(XorOperation, self).__init__('^', lhs, rhs)


def reduce(value):
    if isinstance(value, BinaryOperation):
        rhs = reduce(value.rhs)
        lhs = reduce(value.lhs)

        if isinstance(value, AddOperation):
            # A + 0 = A
            if isinstance(rhs, I32) and rhs.value == 0:
                return lhs

            # 0 + A = A
            elif isinstance(lhs, I32) and lhs.value == 0:
                return rhs

            # A + B = B + A
            # and we prefer integers on the rhs
            if isinstance(lhs, I32) and not isinstance(rhs, I32):
                lhs, rhs = rhs, lhs

            # (A + B) + C = A + (B + C)
            # and reduce the B + C if constant
            if (isinstance(rhs, I32)
                  and isinstance(lhs, AddOperation)
                  and isinstance(lhs.rhs, I32)):
                return AddOperation(lhs.lhs, I32(lhs.rhs.value + rhs.value))

        return type(value)(lhs, rhs)

    else:
        return value


def render_local(index, ctx={}):

    name = '$local{index:d}'.format(**locals())
    if name in ctx.get('regvars', {}):
        return ctx['regvars'][name]
    else:
        return name


def render_global(index, ctx={}):
    name = '$global{index:d}'.format(**locals())
    if name in ctx.get('globals', {}):
        return ctx['globals'][name]
    else:
        return name


def render(value, ctx={}):
    value = reduce(value)
    if isinstance(value, (I32, LocalVariable, GlobalVariable, Memory)):
        return reduce(value).render(ctx=ctx)
    # render `(frame_pointer + struct_offset)`
    # as `frame_pointer.fieldname`
    elif (isinstance(value, AddOperation)
          and is_frame_pointer(value.lhs, ctx=ctx)
          and isinstance(value.rhs, I32)
          and value.rhs.value in ctx.get('frame', {})):
        return '$frame.{field}'.format(field=ctx['frame'][value.rhs.value])
    elif isinstance(value, BinaryOperation):
        return reduce(value).render(ctx=ctx)
    else:
        raise NotImplementedError('value type: ' + str(type(value)))


class Emulator:
    def __init__(self, code):
        self.code = code
        self.bc = wasm.decode.decode_bytecode(code)
        self.stack = []
        self.locals = {}
        self.globals = {}
        self.memory = {}

    def push(self, v):
        logger.debug('stack: pushed %s', render(v))
        self.stack.append(v)

    def pop(self):
        v = self.stack[-1]
        logger.debug('stack: popped %s', render(v))
        self.stack = self.stack[:-1]
        return v

    def handle_I32_CONST(self, insn):
        self.push(I32(insn.imm.value))

    def handle_SET_LOCAL(self, insn):
        v = self.pop()
        logger.debug('locals: set %s: %s', render_local(insn.imm.local_index), render(v))
        self.locals[insn.imm.local_index] = v

    def handle_GET_LOCAL(self, insn):
        try:
            v = self.locals[insn.imm.local_index]
        except KeyError:
            v = LocalVariable(insn.imm.local_index)
        self.push(v)

    def handle_SET_GLOBAL(self, insn):
        v = self.pop()
        logger.debug('globals: set %s: %s', render_global(insn.imm.global_index), render(v))
        self.globals[insn.imm.global_index] = v

    def handle_GET_GLOBAL(self, insn):
        try:
            v = self.globals[insn.imm.global_index]
        except KeyError:
            v = GlobalVariable(insn.imm.global_index)
        self.push(v)

    def handle_I32_ADD(self, insn):
        v1 = self.pop()
        v0 = self.pop()

        # V + 0 = V
        if isinstance(v0, I32) and v0.value == 0:
            self.push(v1)
        # 0 + V = V
        elif isinstance(v1, I32) and v1.value == 0:
            self.push(v0)
        if isinstance(v0, I32) and isinstance(v1, I32):
            self.push(I32(v0.value + v1.value))
        else:
            self.push(AddOperation(v0, v1))

    def handle_I32_SUB(self, insn):
        v1 = self.pop()
        v0 = self.pop()

        # V - 0 = V
        if isinstance(v0, I32) and v0.value == 0:
            self.push(v1)
        if isinstance(v0, I32) and isinstance(v1, I32):
            self.push(I32(v0.value - v1.value))
        else:
            self.push(SubOperation(v0, v1))

    def handle_I32_AND(self, insn):
        v1 = self.pop()
        v0 = self.pop()

        # TODO: special case u8 & 0xFF

        if isinstance(v0, I32) and isinstance(v1, I32):
            self.push(I32(v0.value & v1.value))
        else:
            self.push(AndOperation(v0, v1))

    def handle_I32_SHL(self, insn):
        v1 = self.pop()
        v0 = self.pop()

        if isinstance(v0, I32) and isinstance(v1, I32):
            self.push(I32(v0.value << v1.value))
        else:
            self.push(ShlOperation(v0, v1))

    def handle_I32_SHR_U(self, insn):
        v1 = self.pop()
        v0 = self.pop()

        if isinstance(v0, I32) and isinstance(v1, I32):
            self.push(I32(v0.value >> v1.value))
        else:
            self.push(ShruOperation(v0, v1))

    def handle_I32_XOR(self, insn):
        v1 = self.pop()
        v0 = self.pop()

        if isinstance(v0, I32) and isinstance(v1, I32):
            self.push(I32(v0.value ^ v1.value))
        else:
            self.push(XorOperation(v0, v1))

    def handle_I32_LOAD8_U(self, insn):
        base = self.pop()
        offset = insn.imm.offset

        if isinstance(base, I32):
            addr = I32(base.value + offset)
        else:
            addr = AddOperation(base, I32(offset))

        if isinstance(addr, I32) and addr.value in self.memory:
            self.push(self.memory[addr.value])
        else:
            self.push(Memory(addr))

    def handle_I32_LOAD(self, insn):
        base = self.pop()
        offset = insn.imm.offset

        if isinstance(base, I32):
            addr = I32(base.value + offset)
        else:
            addr = AddOperation(base, I32(offset))

        if (isinstance(addr, I32)
              and addr.value in self.memory
              and addr.value + 1 in self.memory
              and addr.value + 2 in self.memory
              and addr.value + 3 in self.memory):

            v = (self.memory[addr.value] +
                 (self.memory[addr.value + 1] << 8) +
                 (self.memory[addr.value + 2] << 16) +
                 (self.memory[addr.value + 3] << 24))
            self.push(I32(v))
        else:
            self.push(Memory(addr))

    def handle_I32_STORE8(self, insn):
        value = self.pop()
        base = self.pop()
        offset = insn.imm.offset

        if isinstance(base, I32):
            addr = I32(base.value + offset)
        else:
            addr = AddOperation(base, I32(offset))

        if isinstance(value, I32):
            v = I32(value.value & 0xFF)
        else:
            v = AndOperation(value, I32(0xFF))

        if isinstance(addr, I32):
            self.memory[addr.value] = v
        else:
            # ew: symbolic address for memory?
            self.memory[addr] = v

    def handle_I32_STORE(self, insn):
        value = self.pop()
        base = self.pop()
        offset = insn.imm.offset

        if isinstance(base, I32):
            addr = I32(base.value + offset)
        else:
            addr = AddOperation(base, I32(offset))

        if isinstance(value, I32):
            v0 = I32(value.value & 0xFF)
            v1 = I32((value.value & 0xFF00) >> 8)
            v2 = I32((value.value & 0xFF0000) >> 16)
            v3 = I32((value.value & 0xFF000000) >> 24)
        else:
            v0 = AndOperation(value, I32(0xFF))
            v1 = ShruOperation(AndOperation(value, I32(0xFF)), I32(8))
            v2 = ShruOperation(AndOperation(value, I32(0xFF00)), I32(16))
            v3 = ShruOperation(AndOperation(value, I32(0xFF0000)), I32(24))

        if isinstance(addr, I32):
            self.memory[addr.value] = v0
            self.memory[addr.value + 1] = v1
            self.memory[addr.value + 2] = v2
            self.memory[addr.value + 3] = v3

        else:
            # ew: symbolic address for memory?
            # TODO: need to reduce here for symbolic addresses to match
            self.memory[AddOperation(addr, I32(0))] = v0
            self.memory[AddOperation(addr, I32(1))] = v1
            self.memory[AddOperation(addr, I32(2))] = v2
            self.memory[AddOperation(addr, I32(3))] = v3

    def handle_DEFAULT(self, insn):
        raise NotImplementedError('instruction: {insn:s}'.format(**locals()))

    def handle_insn(self, insn):
        logger.debug('trace: %s', insn.op.mnemonic)
        handler = {
            wasm.opcodes.OP_I32_CONST: self.handle_I32_CONST,
            wasm.opcodes.OP_I32_ADD: self.handle_I32_ADD,
            wasm.opcodes.OP_I32_SUB: self.handle_I32_SUB,
            wasm.opcodes.OP_I32_AND: self.handle_I32_AND,
            wasm.opcodes.OP_I32_SHL: self.handle_I32_SHL,
            wasm.opcodes.OP_I32_SHR_U: self.handle_I32_SHR_U,
            wasm.opcodes.OP_I32_XOR: self.handle_I32_XOR,
            wasm.opcodes.OP_I32_LOAD: self.handle_I32_LOAD,
            wasm.opcodes.OP_I32_LOAD8_U: self.handle_I32_LOAD8_U,
            wasm.opcodes.OP_I32_STORE: self.handle_I32_STORE,
            wasm.opcodes.OP_I32_STORE8: self.handle_I32_STORE8,
            wasm.opcodes.OP_SET_LOCAL: self.handle_SET_LOCAL,
            wasm.opcodes.OP_GET_LOCAL: self.handle_GET_LOCAL,
            wasm.opcodes.OP_SET_GLOBAL: self.handle_SET_GLOBAL,
            wasm.opcodes.OP_GET_GLOBAL: self.handle_GET_GLOBAL,
        }.get(insn.op.id, self.handle_DEFAULT)
        handler(insn)

    def run(self):
        for insn in self.bc:
            self.handle_insn(insn)

    def render(self, ctx={}):
        ret = []

        if self.globals:
            ret.append('globals:')
            for g in sorted(self.globals.keys()):
                ret.append('  ' + render_global(g, ctx) + ': ' + render(self.globals[g], ctx))

        if self.locals:
            ret.append('locals:')
            for l in sorted(self.locals.keys()):
                ret.append('  ' + render_local(l, ctx) + ': ' + render(self.locals[l], ctx))

        if self.stack:
            ret.append('stack:')
            for index, v in enumerate(reversed(self.stack)):
                ret.append('  {index:d}: '.format(**locals()) + render(v, ctx=ctx))

        if self.memory:
            ret.append('memory:')
            for addr, v in sorted([(k, v) for k, v in self.memory.items()]):
                ret.append('  {addr:s}: '.format(addr=render(addr, ctx=ctx)) + render(v, ctx=ctx))

        return '\n'.join(ret)


def main():
    is_selected, sel_start, sel_end = idaapi.read_selection()
    if not is_selected:
        logger.error('range must be selected')
        return -1

    sel_end = idc.NextHead(sel_end)

    buf = ida_bytes.get_bytes(sel_start, sel_end - sel_start)
    if buf is None:
        logger.error('failed to fetch instruction bytes')
        return -1

    f = idaapi.get_func(sel_start)
    if f != idaapi.get_func(sel_end):
        logger.error('range must be within a single function')
        return -1

    # find mappings from "$localN" to "custom_name"
    regvars = {}
    for i in range(0x1000):
        regvar = idaapi.find_regvar(f, sel_start, '$local%d' % (i))
        if regvar is None:
            continue
        regvars[regvar.canon] = regvar.user

        if len(regvars) >= f.regvarqty:
            break

    globals_ = {}
    for i, offset in netnode.Netnode('$ wasm.offsets').get('globals', {}).items():
        globals_['$global' + i] = ida_name.get_name(offset)

    frame = {}
    if f.frame != idc.BADADDR:
        names = set([])
        for i in range(idc.GetStrucSize(f.frame)):
            name = idc.GetMemberName(f.frame, i)
            if not name:
                continue
            if name in names:
                continue
            frame[i] = name
            names.add(name)

    emu = Emulator(buf)
    emu.run()
    print(emu.render(ctx={
        'regvars': regvars,
        'frame': frame,
        'globals': globals_,
    }))


logging.basicConfig(level=logging.DEBUG)
main()