# This file is part of Androguard.
#
# Copyright (c) 2012 Geoffroy Gueguen <geoffroy.gueguen@gmail.com>
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import defaultdict
from tools.modified.androguard.decompiler.dad.basic_blocks import (build_node_from_block,
                                                    StatementBlock, CondBlock)
from tools.modified.androguard.decompiler.dad.util import get_type
from tools.modified.androguard.decompiler.dad.instruction import Variable

logger = logging.getLogger('dad.graph')


class Graph(object):
    def __init__(self):
        self.entry = None
        self.exit = None
        self.nodes = list()
        self.rpo = []
        self.edges = defaultdict(list)
        self.catch_edges = defaultdict(list)
        self.reverse_edges = defaultdict(list)
        self.reverse_catch_edges = defaultdict(list)
        self.loc_to_ins = None
        self.loc_to_node = None

    def sucs(self, node):
        return self.edges.get(node, [])

    def all_sucs(self, node):
        return self.edges.get(node, []) + self.catch_edges.get(node, [])

    def preds(self, node):
        return [n for n in self.reverse_edges.get(node, [])
                if not n.in_catch]

    def all_preds(self, node):
        return (self.reverse_edges.get(node, []) +
                self.reverse_catch_edges.get(node, []))

    def add_node(self, node):
        self.nodes.append(node)

    def add_edge(self, e1, e2):
        lsucs = self.edges[e1]
        if e2 not in lsucs:
            lsucs.append(e2)
        lpreds = self.reverse_edges[e2]
        if e1 not in lpreds:
            lpreds.append(e1)

    def add_catch_edge(self, e1, e2):
        lsucs = self.catch_edges[e1]
        if e2 not in lsucs:
            lsucs.append(e2)
        lpreds = self.reverse_catch_edges[e2]
        if e1 not in lpreds:
            lpreds.append(e1)

    def remove_node(self, node):
        preds = self.reverse_edges.get(node, [])
        for pred in preds:
            self.edges[pred].remove(node)

        succs = self.edges.get(node, [])
        for suc in succs:
            self.reverse_edges[suc].remove(node)

        exc_preds = self.reverse_catch_edges.pop(node, [])
        for pred in exc_preds:
            self.catch_edges[pred].remove(node)

        exc_succs = self.catch_edges.pop(node, [])
        for suc in exc_succs:
            self.reverse_catch_edges[suc].remove(node)

        self.nodes.remove(node)
        if node in self.rpo:
            self.rpo.remove(node)
        del node

    def number_ins(self):
        self.loc_to_ins = {}
        self.loc_to_node = {}
        num = 0
        for node in self.rpo:
            start_node = num
            num = node.number_ins(num)
            end_node = num - 1
            self.loc_to_ins.update(node.get_loc_with_ins())
            self.loc_to_node[start_node, end_node] = node

    def get_ins_from_loc(self, loc):
        return self.loc_to_ins.get(loc)

    def get_node_from_loc(self, loc):
        for (start, end), node in self.loc_to_node.iteritems():
            if start <= loc <= end:
                return node

    def remove_ins(self, loc):
        ins = self.get_ins_from_loc(loc)
        self.get_node_from_loc(loc).remove_ins(loc, ins)
        self.loc_to_ins.pop(loc)

    def compute_rpo(self):
        '''
        Number the nodes in reverse post order.
        An RPO traversal visit as many predecessors of a node as possible
        before visiting the node itself.
        '''
        nb = len(self.nodes) + 1
        for node in self.post_order():
            node.num = nb - node.po
        self.rpo = sorted(self.nodes, key=lambda n: n.num)

    def post_order(self):
        '''
        Return the nodes of the graph in post-order i.e we visit all the
        children of a node before visiting the node itself.
        '''
        def _visit(n, cnt):
            visited.add(n)
            for suc in self.all_sucs(n):
                if not suc in visited:
                    for cnt, s in _visit(suc, cnt):
                        yield cnt, s
            n.po = cnt
            yield cnt + 1, n
        visited = set()
        for _, node in _visit(self.entry, 1):
            yield node

    def draw(self, name, dname, draw_branches=True):
        from pydot import Dot, Edge
        g = Dot()
        g.set_node_defaults(color='lightgray', style='filled', shape='box',
                            fontname='Courier', fontsize='10')
        for node in sorted(self.nodes, key=lambda x: x.num):
            if draw_branches and node.type.is_cond:
                g.add_edge(Edge(str(node), str(node.true), color='green'))
                g.add_edge(Edge(str(node), str(node.false), color='red'))
            else:
                for suc in self.sucs(node):
                    g.add_edge(Edge(str(node), str(suc), color='blue'))
            for except_node in self.catch_edges.get(node, []):
                g.add_edge(Edge(str(node), str(except_node),
                                color='black', style='dashed'))

        g.write_png('%s/%s.png' % (dname, name))

    def immediate_dominators(self):
        return dom_lt(self)

    def __len__(self):
        return len(self.nodes)

    def __repr__(self):
        return str(self.nodes)

    def __iter__(self):
        for node in self.nodes:
            yield node


def split_if_nodes(graph):
    '''
    Split IfNodes in two nodes, the first node is the header node, the
    second one is only composed of the jump condition.
    '''
    node_map = {n: n for n in graph}
    to_update = set()
    for node in graph.nodes[:]:
        if node.type.is_cond:
            if len(node.get_ins()) > 1:
                pre_ins = node.get_ins()[:-1]
                last_ins = node.get_ins()[-1]
                pre_node = StatementBlock('%s-pre' % node.name, pre_ins)
                cond_node = CondBlock('%s-cond' % node.name, [last_ins])
                node_map[node] = pre_node
                node_map[pre_node] = pre_node
                node_map[cond_node] = cond_node

                pre_node.copy_from(node)
                cond_node.copy_from(node)
                for var in node.var_to_declare:
                    pre_node.add_variable_declaration(var)
                pre_node.type.is_stmt = True
                cond_node.true = node.true
                cond_node.false = node.false

                for pred in graph.all_preds(node):
                    pred_node = node_map[pred]
                    # Verify that the link is not an exception link
                    if node not in graph.sucs(pred):
                        graph.add_catch_edge(pred_node, pre_node)
                        continue
                    if pred is node:
                        pred_node = cond_node
                    if pred.type.is_cond:  # and not (pred is node):
                        if pred.true is node:
                            pred_node.true = pre_node
                        if pred.false is node:
                            pred_node.false = pre_node
                    graph.add_edge(pred_node, pre_node)
                for suc in graph.sucs(node):
                    graph.add_edge(cond_node, node_map[suc])

                # We link all the exceptions to the pre node instead of the
                # condition node, which should not trigger any of them.
                for suc in graph.catch_edges.get(node, []):
                    graph.add_catch_edge(pre_node, node_map[suc])

                if node is graph.entry:
                    graph.entry = pre_node

                graph.add_node(pre_node)
                graph.add_node(cond_node)
                graph.add_edge(pre_node, cond_node)
                pre_node.update_attribute_with(node_map)
                cond_node.update_attribute_with(node_map)
                graph.remove_node(node)
        else:
            to_update.add(node)
    for node in to_update:
        node.update_attribute_with(node_map)


def simplify(graph):
    '''
    Simplify the CFG by merging/deleting statement nodes when possible:
    If statement B follows statement A and if B has no other predecessor
    besides A, then we can merge A and B into a new statement node.
    We also remove nodes which do nothing except redirecting the control
    flow (nodes which only contains a goto).
    '''
    redo = True
    while redo:
        redo = False
        node_map = {}
        to_update = set()
        for node in graph.nodes[:]:
            if node.type.is_stmt and node in graph:
                sucs = graph.all_sucs(node)
                if len(sucs) != 1:
                    continue
                suc = sucs[0]
                if len(node.get_ins()) == 0:
                    if any(pred.type.is_switch
                            for pred in graph.all_preds(node)):
                        continue
                    if node is suc:
                        continue
                    node_map[node] = suc

                    for pred in graph.all_preds(node):
                        pred.update_attribute_with(node_map)
                        if node not in graph.sucs(pred):
                            graph.add_catch_edge(pred, suc)
                            continue
                        graph.add_edge(pred, suc)
                    redo = True
                    if node is graph.entry:
                        graph.entry = suc
                    graph.remove_node(node)
                elif (suc.type.is_stmt and
                      len(graph.all_preds(suc)) == 1 and
                      not (suc in graph.catch_edges) and
                      not ((node is suc) or (suc is graph.entry))):
                    ins_to_merge = suc.get_ins()
                    node.add_ins(ins_to_merge)
                    for var in suc.var_to_declare:
                        node.add_variable_declaration(var)
                    new_suc = graph.sucs(suc)[0]
                    if new_suc:
                        graph.add_edge(node, new_suc)
                    for exception_suc in graph.catch_edges.get(suc, []):
                        graph.add_catch_edge(node, exception_suc)
                    redo = True
                    graph.remove_node(suc)
            else:
                to_update.add(node)
        for node in to_update:
            node.update_attribute_with(node_map)


def dom_lt(graph):
    '''Dominator algorithm from Lengaeur-Tarjan'''
    def _dfs(v, n):
        semi[v] = n = n + 1
        vertex[n] = label[v] = v
        ancestor[v] = 0
        for w in graph.all_sucs(v):
            if not semi[w]:
                parent[w] = v
                n = _dfs(w, n)
            pred[w].add(v)
        return n

    def _compress(v):
        u = ancestor[v]
        if ancestor[u]:
            _compress(u)
            if semi[label[u]] < semi[label[v]]:
                label[v] = label[u]
            ancestor[v] = ancestor[u]

    def _eval(v):
        if ancestor[v]:
            _compress(v)
            return label[v]
        return v

    def _link(v, w):
        ancestor[w] = v

    parent, ancestor, vertex = {}, {}, {}
    label, dom = {}, {}
    pred, bucket = defaultdict(set), defaultdict(set)

    # Step 1:
    semi = {v: 0 for v in graph.nodes}
    n = _dfs(graph.entry, 0)
    for i in xrange(n, 1, -1):
        w = vertex[i]
    # Step 2:
        for v in pred[w]:
            u = _eval(v)
            y = semi[w] = min(semi[w], semi[u])
        bucket[vertex[y]].add(w)
        pw = parent[w]
        _link(pw, w)
    # Step 3:
        bpw = bucket[pw]
        while bpw:
            v = bpw.pop()
            u = _eval(v)
            dom[v] = u if semi[u] < semi[v] else pw
    # Step 4:
    for i in range(2, n + 1):
        w = vertex[i]
        dw = dom[w]
        if dw != vertex[semi[w]]:
            dom[w] = dom[dw]
    dom[graph.entry] = None
    return dom


def bfs(start):
    to_visit = [start]
    visited = set([start])
    while to_visit:
        node = to_visit.pop(0)
        yield node
        if node.exception_analysis:
            for _, _, exception in node.exception_analysis.exceptions:
                if exception not in visited:
                    to_visit.append(exception)
                    visited.add(exception)
        for _, _, child in node.childs:
            if child not in visited:
                to_visit.append(child)
                visited.add(child)


class GenInvokeRetName(object):
    def __init__(self):
        self.num = 0
        self.ret = None

    def new(self):
        self.num += 1
        self.ret = Variable('tmp%d' % self.num)
        return self.ret

    def set_to(self, ret):
        self.ret = ret

    def last(self):
        return self.ret


def make_node(graph, block, block_to_node, vmap, gen_ret):
    node = block_to_node.get(block)
    if node is None:
        node = build_node_from_block(block, vmap, gen_ret)
        block_to_node[block] = node
    if block.exception_analysis:
        for _type, _, exception_target in block.exception_analysis.exceptions:
            exception_node = block_to_node.get(exception_target)
            if exception_node is None:
                exception_node = build_node_from_block(exception_target,
                                                        vmap, gen_ret, _type)
                exception_node.set_catch_type(_type)
                exception_node.in_catch = True
                block_to_node[exception_target] = exception_node
            graph.add_catch_edge(node, exception_node)
    for _, _, child_block in block.childs:
        child_node = block_to_node.get(child_block)
        if child_node is None:
            child_node = build_node_from_block(child_block, vmap, gen_ret)
            block_to_node[child_block] = child_node
        graph.add_edge(node, child_node)
        if node.type.is_switch:
            node.add_case(child_node)
        if node.type.is_cond:
            if_target = ((block.end / 2) - (block.last_length / 2) +
                          node.off_last_ins)
            child_addr = child_block.start / 2
            if if_target == child_addr:
                node.true = child_node
            else:
                node.false = child_node

    # Check that both branch of the if point to something
    # It may happen that both branch point to the same node, in this case
    # the false branch will be None. So we set it to the right node.
    # TODO: In this situation, we should transform the condition node into
    # a statement node
    if node.type.is_cond and node.false is None:
        node.false = node.true

    return node


def construct(start_block, vmap, exceptions):
    bfs_blocks = bfs(start_block)

    graph = Graph()
    gen_ret = GenInvokeRetName()

    # Construction of a mapping of basic blocks into Nodes
    block_to_node = {}

    exceptions_start_block = []
    for exception in exceptions:
        for _, _, block in exception.exceptions:
            exceptions_start_block.append(block)

    for block in bfs_blocks:
        node = make_node(graph, block, block_to_node, vmap, gen_ret)
        graph.add_node(node)

    graph.entry = block_to_node[start_block]
    del block_to_node, bfs_blocks

    graph.compute_rpo()
    graph.number_ins()

    for node in graph.rpo:
        preds = [pred for pred in graph.all_preds(node)
                 if pred.num < node.num]
        if preds and all(pred.in_catch for pred in preds):
            node.in_catch = True

    # Create a list of Node which are 'return' node
    # There should be one and only one node of this type
    # If this is not the case, try to continue anyway by setting the exit node
    # to the one which has the greatest RPO number (not necessarily the case)
    lexit_nodes = [node for node in graph if node.type.is_return]

    if len(lexit_nodes) > 1:
        # Not sure that this case is possible...
        logger.error('Multiple exit nodes found !')
        graph.exit = graph.rpo[-1]
    elif len(lexit_nodes) < 1:
        # A method can have no return if it has throw statement(s) or if its
        # body is a while(1) whitout break/return.
        logger.debug('No exit node found !')
    else:
        graph.exit = lexit_nodes[0]

    return graph