"""CSE script written by Serge!
"""

import ast
import astunparse
import functools
from copy import deepcopy
import sys
import os
import itertools

from sspam.tools import asttools


COMMUTATIVE_OPERATORS = ast.Add, ast.Mult, ast.BitOr, ast.BitXor, ast.BitAnd
ASSOCIATIVE_OPERATORS = COMMUTATIVE_OPERATORS
BINARY_OPERATORS = COMMUTATIVE_OPERATORS + (ast.Sub, ast.LShift,
                                            ast.RShift, ast.Div)
# some operators are not considered: pow, matmult etc


class ForwardSubstitute(ast.NodeTransformer):
    """
    Perform the typical Forward substitution transformation,
    based on the following strong assumptions:

    * all values are integers (no aliasing)
    * the input consist in a sequence of assignment/ expressions
    * the assignment are in the form id = expr
    * the assigments are in SSA form
    * expressions only consist in binary operators, names and num
    """

    def __init__(self):
        self.substitutions = {}

    def visit_Assign(self, node):
        """
        Check if the assignment can be forward subsituted

        In that case register the substitution and prune the statement
        """
        assert isinstance(node.targets[0], ast.Name) and len(node.targets) == 1
        targetid = node.targets[0].id
        if targetid not in self.use_count:
            return
        # literals can always be propagated
        if isinstance(node.value, ast.Num):
            self.substitutions[targetid] = node.value
            return None
        # identifiers copy are always propagated
        # with extra care to keep the substitution table valid in case of
        # assignment chaining (a = b ; c = a)
        elif isinstance(node.value, ast.Name):
            self.use_count[node.value.id] -= 1
            if self.use_count[node.value.id] > 0:
                # there is at least another use
                self.use_count[targetid] += self.use_count[node.value.id]
                self.use_count[node.value.id] = self.use_count[targetid]

            if node.value.id in self.substitutions:
                sub = self.substitutions[node.value.id]
                cond = (self.use_count[targetid] == 1 or
                        isinstance(sub, (ast.Name, ast.Num)))
                if cond:
                    self.substitutions[targetid] = sub
                    return None
                else:
                    node.value = sub
                    return node
            else:
                self.substitutions[targetid] = node.value
                return None
        # other assignment are only propagated if they are used once
        elif self.use_count[targetid] == 1:
            self.substitutions[targetid] = self.generic_visit(node.value)
            return None
        else:
            return self.generic_visit(node)

    def visit_Name(self, node):
        'Substitute if needed'
        sub = self.substitutions.get(node.id)
        if sub:
            return deepcopy(sub)
        else:
            return node

    def run(self, node):
        'Entry point: perform the needed analyse and run the transformation'
        uc = UseCount()
        self.use_count = uc.run(node)
        self.visit(node)


class UseCount(ast.NodeVisitor):
    """
    Basic value usage analysis

    Register, for each variable, the number of times it is used,
    based on the same assumptions as in ForwardSubstitute
    """

    def __init__(self):
        self.result = {}

    def visit_Name(self, node):
        'If name is read, add it to the count'
        if isinstance(node.ctx, ast.Load):
            self.result[node.id] = self.result.get(node.id, 0) + 1

    def run(self, node):
        'Return result of visitor'
        self.visit(node)
        return self.result


def node_hash(node):
    """
    Helper function to compute a unique hashable representation of a node
    """
    if isinstance(node, ast.Name):
        return node.id,
    if isinstance(node, ast.Num):
        return str(node.n),
    if isinstance(node, ast.BinOp):
        children = node_hash(node.left), node_hash(node.right)
        return (type(node.op).__name__,) + children
    assert False, 'unhandled node type' + ast.dump(node)


class HandleCommutativity(ast.NodeTransformer):
    """
    Used to handle commutativity of some operators
    """

    def visit_BinOp(self, node):
        'Check commutativity and order children if commutative'
        node = self.generic_visit(node)
        if isinstance(node.op, COMMUTATIVE_OPERATORS):  # commutative
            hash_left = node_hash(node.left)
            hash_right = node_hash(node.right)
            if hash_right < hash_left:
                node.left, node.right = node.right, node.left
        return node


class PromoteUnaryOp(ast.NodeTransformer):
    """
    Transform UnaryOp if needed.
    """

    def visit_UnaryOp(self, node):
        'Change USub and Invert'
        operand = self.visit(node.operand)
        if isinstance(node.op, ast.UAdd):
            return operand
        if isinstance(node.op, ast.USub):
            return ast.BinOp(ast.Num(-1), ast.Mult(), operand)
        if isinstance(node.op, ast.Invert):
            return ast.BinOp(ast.Num(-1), ast.BitXor(), operand)
        assert False, 'unhandled node type: ' + ast.dump(node)


class Substitute(ast.NodeTransformer):
    """
    Perform node substitution after cse

    I.e. add extra assigments and use the assigned value in the
    original expression
    """

    def __init__(self, prefix, op, term_to_node,
                 rewritten_terms, result_nodes):
        # pylint: disable=too-many-arguments
        self.prefix = prefix
        self.op = op
        self.rewrite = dict()
        self.assigned_values = set()

        assigns = {}
        for terms_part, result_node in zip(rewritten_terms, result_nodes):
            if result_node in self.rewrite:
                continue
            new_node = None
            ordered_assign_keys = []
            ordered_assign_values = []
            if len(terms_part) == 1 and isinstance(terms_part[0], ast.Name):
                new_node = terms_part[0]
            else:
                for term in terms_part:
                    if term not in assigns:
                        new_id = self.prefix.format(len(assigns))
                        new_subnode = ast.Name(new_id, ast.Load())
                        assigns[term] = new_subnode
                        ordered_assign_keys.append(term)
                        ordered_assign_values.append(new_subnode)
                    else:
                        new_subnode = deepcopy(assigns[term])
                    if new_node:
                        new_node = ast.BinOp(new_node, self.op(), new_subnode)
                    else:
                        new_node = new_subnode
            binlamb = lambda x, y: ast.BinOp(x, self.op(), y)
            new_assigns = [ast.Assign([ast.Name(target.id, ast.Store())], val)
                           for target, val in
                           zip(ordered_assign_values,
                               [functools.reduce(binlamb, [term_to_node[term]
                                                           for term in terms])
                                for terms in ordered_assign_keys])]
            self.rewrite[result_node] = new_node, new_assigns

    def visit_TopLevelStmt(self, node):
        'Visitor for top level statement'
        self.new_assigns = []
        candidate_assigns = self.new_assigns + [self.generic_visit(node)]
        new_node = []
        for candidate_assign in candidate_assigns:
            if isinstance(candidate_assign, ast.Assign):
                if candidate_assign.targets[0].id not in self.assigned_values:
                    self.assigned_values.add(candidate_assign.targets[0].id)
                    new_node.append(candidate_assign)
            else:
                new_node.append(candidate_assign)
        del self.new_assigns
        return new_node

    visit_Expr = visit_Assign = visit_TopLevelStmt

    def visit_BinOp(self, node):
        'Rewrite node if needed'
        node = self.generic_visit(node)
        rewrite = self.rewrite.get(node)
        if rewrite:
            new_node, new_assigns = rewrite
            for new_assign in new_assigns:
                self.new_assigns.append(new_assign)
            return new_node
        else:
            return node


class GatherOpClasses(ast.NodeVisitor):
    """
    Builds the sets of associative operations found in the input node
    """

    def __init__(self, op):
        self.op = op
        self.result = []
        self.result_nodes = []
        self.hash_to_node = {}
        self.hash_to_term = {}
        self.term_to_node = {}

    def to_terms(self):
        'Transform result list in terms'
        terms = []
        for part in self.result:
            terms_part = []
            for node in part:
                nhash = node_hash(node)
                if nhash in self.hash_to_node:
                    term = self.hash_to_term[nhash]
                else:
                    self.hash_to_node[nhash] = node
                    term = len(self.hash_to_term)
                    self.hash_to_term[nhash] = term
                    self.term_to_node[term] = node
                terms_part.append(term)
            terms.append(tuple(terms_part))
        return terms

    def from_terms(self, terms):
        'Get nodes from terms list'
        nodes = []
        for terms_part in terms:
            nodes_part = []
            for term in terms_part:
                nodes_part.append(deepcopy(self.term_to_node[term]))
            nodes.append(nodes_part)
        return nodes

    def visit_BinOp(self, node, partial=False):
        'Regroup associative operators'
        if isinstance(node.op, self.op):
            operands = []
            for child in node.left, node.right:
                cond = (isinstance(child, ast.BinOp)
                        and isinstance(child.op, self.op)
                        and self.op in ASSOCIATIVE_OPERATORS)
                if cond:
                    operands.extend(self.visit_BinOp(child, partial=True))
                else:
                    self.visit(child)
                    operands.append(child)
            if not partial:
                self.result_nodes.append(node)
                self.result.append(operands)
            return operands
        else:
            self.generic_visit(node)
            return []


def simple_cse(node, operators=BINARY_OPERATORS):
    'Simple version of cse'
    def cse_generation(op, generation):
        'Generate subexpressions and substitute'
        # just to avoid infinite recursion
        if generation < 30:
            prefix = 'cse{}{}{{}}'.format(generation, op.__name__)
            goc = GatherOpClasses(op)
            goc.visit(node)
            terms = goc.to_terms()
            frequency = {}
            combinations = {}
            if not terms:
                return
            for term in terms:
                combinations[term] = list(itertools.combinations(term, 2))
                for pair in combinations[term]:
                    frequency[pair] = frequency.get(pair, 0) + 1
            max_pair, _ = max(frequency.items(), key=lambda x: x[1])
            if frequency[max_pair] > 1:
                new_terms = []
                for term in terms:
                    new_term = []
                    if max_pair in combinations[term]:
                        new_term.append(tuple(max_pair))
                        remaining = list(term)
                        for elem in max_pair:
                            remaining.remove(elem)
                        if remaining:
                            new_term.append(tuple(remaining))
                    else:
                        new_term.append(tuple(term))
                    new_terms.append(tuple(new_term))

                Substitute(prefix, op, goc.term_to_node,
                           new_terms, goc.result_nodes).visit(node)
                ForwardSubstitute().run(node)
                cse_generation(op, generation + 1)

    for op in operators:
        cse_generation(op, 0)


class PostProcessing(ast.NodeTransformer):
    """
    Actual cse might need some post-processing:

    - remove constant subexpr
    - change final expr in an assign
    """

    def __init__(self):
        self.replace = {}

    def visit_Module(self, node):
        'Replace elements of the module body if needed'
        new_body = []
        for elem in node.body:
            nodetype = elem.__class__.__name__
            visitor = getattr(self, "visit_%s" % nodetype, None)
            new_node = visitor(elem)
            if new_node:
                new_body.append(new_node)
        node.body = deepcopy(new_body)
        return node

    def visit_Assign(self, node):
        'Register node id if value is a constant expression'
        if len(node.targets) != 1:
            return self.generic_visit(node)
        if asttools.CheckConstExpr().visit(node.value):
            self.replace[node.targets[0].id] = deepcopy(node.value)
            return None
        return self.generic_visit(node)

    def visit_Name(self, node):
        'Replace id if it points to a constant expr'
        if isinstance(node.ctx, ast.Load) and node.id in self.replace:
            return self.replace[node.id]
        else:
            return node

    def visit_Expr(self, node):
        'Change last expression into an assignment'
        return ast.Assign([ast.Name('result', ast.Store())],
                          self.generic_visit(node.value))


def apply_cse(expr, outputfile=None):
    """
    Apply CSE on expression file or string
    """

    if isinstance(expr, str):
        if os.path.isfile(expr):
            exprfile = open(expr, 'r')
            expr_ast = ast.parse(exprfile.read())
        else:
            expr_ast = ast.parse(expr)
    elif isinstance(expr, ast.AST):
        if isinstance(expr, ast.Module):
            expr_ast = deepcopy(expr)
        else:
            expr_ast = ast.Expr(expr)
    PromoteUnaryOp().visit(expr_ast)
    HandleCommutativity().visit(expr_ast)
    simple_cse(expr_ast)
    expr_ast = PostProcessing().visit(expr_ast)
    expr_string = astunparse.unparse(expr_ast).strip('\n')
    if outputfile:
        output_file = open(outputfile, 'w')
        output_file.write(expr_string)
        output_file.close()
    return expr_string, expr_ast


if __name__ == "__main__":
    if len(sys.argv) < 2 or len(sys.argv) > 3:
        print "Usage: %s <input file> [output file]" % sys.argv[0]
        exit(0)

    if len(sys.argv) == 2:
        print apply_cse(sys.argv[1])[0]
    if len(sys.argv) == 3:
        print apply_cse(sys.argv[1], sys.argv[2])[0]