import torch
from .. import diffsat
from ..util import lmap
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import scipy.optimize as spo
import time
import sys
import signal
import re

def make_comp(op, a, b):
    a_is_inf = (isinstance(a, Fn) and a.t == 'normInf') and a.args[0].is_shape_preserving_arithmetic_in_var_const()
    b_is_inf = (isinstance(b, Fn) and b.t == 'normInf') and b.args[0].is_shape_preserving_arithmetic_in_var_const()
    if (((a.is_var() or a.is_const()) and (b.is_var() or b.is_const())) or
            (a_is_inf and op in ['le', 'lt'] and (b.is_var() or b.is_const())) or
            (b_is_inf and op in ['ge', 'gt'] and (a.is_var() or a.is_const()))):
        if a_is_inf:
            a = Fn('abs', lambda a: a.abs(), a.args[0])
        if b_is_inf:
            b = Fn('abs', lambda b: b.abs(), b.args[0])
        a = Fn('view(-1)', lambda a: a.view(-1), a)
        ashape = a.shape()
        b = Fn('view(-1)', lambda b: b.view(-1), b)
        bshape = b.shape()
        if ashape == bshape:
            return And(*[Comp(op, a[i], b[i]) for i in range(ashape[0])])
        elif ashape[0] == 1:
            return And(*[Comp(op, a, b[i]) for i in range(bshape[0])])
        elif bshape[0] == 1:
            return And(*[Comp(op, a[i], b) for i in range(ashape[0])])
        else:
            assert False, f"Shape mismatch: {ashape} {bshape}"
    else:
        return Comp(op, a, b)

class DL2Tensor:

    def __init__(self):
        pass

    def upgrade_other(self, other):
        if type(other) in [float, int, torch.Tensor, np.array, np.ndarray]:
            return Constant(other, self.cuda)
        return other
    
    def __add__(self, other):
        other = self.upgrade_other(other)
        return Fn('+', lambda a, b: a + b, self, other)

    def __mul__(self, other):
        other = self.upgrade_other(other)
        return Fn('*', lambda a, b: a * b, self, other)

    def __sub__(self, other):
        other = self.upgrade_other(other)
        return Fn('-', lambda a, b: a - b, self, other)

    def sum(self):
        return Fn('sum', lambda a: a.sum(), self)
    
    def __lt__(self, other):
        other = self.upgrade_other(other)
        return make_comp('lt', self, other)

    def __le__(self, other):
        other = self.upgrade_other(other)
        return make_comp('le', self, other)

    def __gt__(self, other):
        other = self.upgrade_other(other)
        return make_comp('gt', self, other)

    def __ge__(self, other):
        other = self.upgrade_other(other)
        return make_comp('ge', self, other)

    def eq_(self, other):
        other = self.upgrade_other(other)
        return make_comp('eq', self, other)
    
    def __neg__(self):
        return Fn('neg', lambda a: -a, self)

    def shape(self):
        if self.is_var() or self.is_const() or self.is_shape_preserving_arithmetic_in_var_const():
            with torch.no_grad():
                return self.to_diffsat(cache=False).shape
        else:
            return None

    def reset_cache(self):
        pass
        
    # [] operator
    def __getitem__(self, key):
        return Fn('[]', lambda a, b: a.__getitem__(b), self, key)

    def is_var(self):
        return isinstance(self, Variable) or (isinstance(self, Fn) and (self.t == '[]' or 'view' in self.t) and self.args[0].is_var())

    def is_const(self):
        return isinstance(self, Constant) or (isinstance(self, Fn) and (self.t == '[]' or 'view' in self.t) and self.args[0].is_const())

    def is_shape_preserving_arithmetic_in_var_const(self):
        if self.is_var or self.is_const():
            return True
        if isinstance(self, Fn):
            return self.t in ['+', '-', 'abs'] and all([a.is_shape_preserving_arithmetic_in_var_const() for a in self.args])
        return False
        return isinstance(self, Constant) or (isinstance(self, Fn) and (self.t == '[]' or 'view' in self.t) and self.args[0].is_const())
    
    # in as function name, not the operator
    def in_(self, interval):
        assert isinstance(interval, Interval)
        return And(interval.a <= self, self <= interval.b)

    def init(self, other):
        assert self.is_var()
        other = self.upgrade_other(other)
        value = other.tensor.clone().view(-1)
        var = self.to_diffsat().view(-1)
        with torch.no_grad():
            var[:] = value[:]

    def simplify(self, delete_box_constraints=False):
        return self        

class DL2Logic:

    def __init__(self):
        pass

    def get_variables(self):
        variables = []
        for arg in self.args:
            if hasattr(arg, 'get_variables'):
                variables.extend(arg.get_variables())
        return variables

    def simplify(self, delete_box_constraints=False):
        return self


class Class(DL2Logic):

    def __init__(self, net, c):
        assert isinstance(net, Fn) and net.t == '()'
        self.net = net
        self.c = c

    def __str__(self):
        return f"(class, {self.net}, {self.c})"
        
    def get_variables(self):
        return self.net.get_variables()

    def reset_cache(self):
        self.net.reset_cache()
        self.c.reset_cache()
    
    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        logits = self.net.to_diffsat(cache=cache)
        c = self.c.to_diffsat(cache=cache) if hasattr(self.c, 'to_diffsat') else self.c
        c = int(c)
        batch_size, nr_classes = logits.shape
        assert batch_size == 1
        constraints = []
        for k in range(nr_classes):
            if k == c:
                continue
            constraints.append(diffsat.LT(logits[0, k], logits[0, c]))
        return diffsat.And(constraints)
    
class And(DL2Logic):

    def __init__(self, *args):
        assert all(map(lambda x: isinstance(x, DL2Logic), args))
        self.args = args

    def __str__(self):
        return " and ".join(map(str, self.args))

    def reset_cache(self):
        for arg in self.args:
            arg.reset_cache()
    
    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        return diffsat.And(lmap(lambda x: x.to_diffsat(cache=cache), self.args))

    def get_box_constraints(self):
        boxes = []
        for arg in self.args:
            if hasattr(arg, 'get_box_constraints'):
                boxes.extend(arg.get_box_constraints())
        return boxes

    def simplify(self, delete_box_constraints=False):
        new_args = []
        for arg in self.args:
            is_box = hasattr(arg, 'is_box_constraint') and arg.is_box_constraint()
            if (delete_box_constraints and not is_box) or not delete_box_constraints:
                arg = arg.simplify(delete_box_constraints=delete_box_constraints)
                if (isinstance(arg, And) or isinstance(arg, Or)) and len(arg.args) == 0:
                    continue
                new_args.append(arg)
        return And(*new_args)

    
class Or(DL2Logic):

    def __init__(self, *args):
        assert all(map(lambda x: isinstance(x, DL2Logic), args))
        self.args = args

    def __str__(self):
        return " or ".join(map(str, self.args))

    def reset_cache(self):
        for arg in self.args:
            arg.reset_cache()
    
    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        return diffsat.Or(lmap(lambda x: x.to_diffsat(cache=cache), self.args))

    def get_box_constraints(self):
        return [] # we can't go over "or"
    
        
class Comp(DL2Logic):

    def __init__(self, t, a, b):
        assert isinstance(a, DL2Tensor)
        assert isinstance(b, DL2Tensor)
        self.t = t
        self.a = a
        self.b = b

    def __str__(self):
        return f"({self.t} {self.a} {self.b})"

        return " or ".join(map(str, self.args))

    def reset_cache(self):
        self.a.reset_cache()
        self.b.reset_cache()
    
    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        op = {'eq': diffsat.EQ,
              'lt': diffsat.LT,
              'le': diffsat.LEQ,
              'gt': diffsat.GT,
              'ge': diffsat.GEQ}[self.t]
        a = self.a.to_diffsat(cache=cache)
        b = self.b.to_diffsat(cache=cache)
        if a.shape == torch.Size([1]):
            a = a.view([])
        if b.shape == torch.Size([1]):
            b = b.view([])
        assert a.shape == torch.Size([])
        assert b.shape == torch.Size([])
        return op(a, b)

    def get_variables(self):
        return self.a.get_variables() + self.b.get_variables()

    def is_box_constraint(self):
        return (self.a.is_const() and self.b.is_var()) or (self.a.is_var() and self.b.is_const())
    
    def get_box_constraints(self):
        return [self] if self.is_box_constraint() else []

    # def isNormInf
    
    # def simplify(self):
    #     pass
    
    
class Fn(DL2Tensor):

    def __init__(self, t, fn, *args):
        self.t = t
        self.fn = fn
        self.args = args
        self.cuda = any([hasattr(a, 'cuda') and a.cuda for a in self.args])
        self.cache = None

    def __str__(self):
        return f"({self.t}, {','.join(map(str, self.args))})"

    def reset_cache(self):
        self.cache = None
        for a in self.args:
            if hasattr(a, 'reset_cache'):
                a.reset_cache()
    
    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        if cache and self.cache is not None:
            return self.cache
        args = [a.to_diffsat(cache=cache) if hasattr(a, 'to_diffsat') else a for a in self.args]
        result = self.fn(*args)
        if cache and self.cache is None:
            self.cache = result
        return result

    def get_variables(self):
        variables = []
        for arg in self.args:
            if hasattr(arg, 'get_variables'):
                variables.extend(arg.get_variables())
        return variables
        
class Variable(DL2Tensor):

    def __init__(self, name, shape, cuda=False):
        super().__init__()
        self.name = name
        self.shape = shape
        self.tensor = torch.zeros(self.shape)
        self.cuda = cuda
        if cuda:
            self.tensor = self.tensor.to('cuda:0')
        self.tensor.requires_grad_()

    def __str__(self):
        return self.name

    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        return self.tensor

    def get_variables(self):
        return [self]

    
class Constant(DL2Tensor):

    def __init__(self, value, cuda=False):
        super().__init__()
        # pytorch does not support bools
        self.value = value
        if isinstance(value, np.ndarray):
            if value.dtype == np.bool_:
                value = value.astype(np.uint8)
            self.tensor = torch.tensor(value)
        else:
            self.tensor = torch.tensor(float(value))

        self.cuda = cuda
        if cuda:
            self.tensor = self.tensor.to('cuda:0')


    def __str__(self):
        if len(str(self.value)) < 10:
            return f"({self.value})"
        else:
            return f"(Constant{list(self.tensor.shape)})"

    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        return self.tensor

    def get_variables(self):
        return []

class Interval:

    def __init__(self, a, b, cuda=False):
        super().__init__()
        self.a = Constant(a, cuda)
        self.b = Constant(b, cuda)

    def __str__(self):
        return f"([{self.a}, {self.b}])"    
    
class Model(DL2Tensor):

    def __init__(self, model):
        self.model = model
        self.cuda = next(model.parameters()).is_cuda

    def __call__(self, *args):
        return Fn('()', lambda a, b: a(b), self, *args)

    def __str__(self):
        return f"(Model)"

    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        return self.model
    
    def get_variables(self):
        return []

    def __getattr__(self, attr):
        if attr.startswith('__'):
            raise AttributeError
        else:
            return ModelLayer(self, attr)
    
class ModelLayer(Model):

    def __init__(self, model, layer):
        assert layer in ['p']
        self.model = model
        self.cuda = model.cuda
        self.layer = layer

    def __str__(self):
        return f"(Model.{self.layer})"

    def to_diffsat(self, cache=True, reset_cache=False):
        if reset_cache:
            self.reset_cache()
        if self.layer == 'p':
            return torch.nn.Sequential(self.model.to_diffsat(cache=cache), torch.nn.Softmax())
    
    def get_variables(self):
        return self.model.get_variables()

    def __getattr__(self, attr):
        assert False
    
            
def simplify(constraint, args):
    if args.opt == 'lbfgsb':
        boxes = constraint.get_box_constraints()
        constraint_s = constraint.simplify(delete_box_constraints=True)
        variables = list(set(constraint_s.get_variables()))
        if len(variables) == 0:
            # if we removed all variables, add dummy constraint
            variables = list(set(constraint.get_variables()))
            c = []
            for v in variables:
                c.append(v.eq_(v))
            constraint_s = And(*c)
        bounds = {}
        for var in variables:
            bounds[var] = (torch.zeros_like(var.tensor).view(-1).cpu().numpy(), torch.zeros_like(var.tensor).view(-1).cpu().numpy())
            bounds[var][0][:] = -np.inf
            bounds[var][1][:] = np.inf
        for box in boxes:
            if box.a.is_const():
                const = box.a
                setop = box.b
                is_upper = box.t in ["eq", "ge", "gt"]
                is_lower = box.t in ["eq", "le", "lt"]
            else:
                const = box.b
                setop = box.a
                is_upper = box.t in ["eq", "le", "lt"]
                is_lower = box.t in ["eq", "ge", "gt"]
            var = setop.get_variables()[0]
            value = const.to_diffsat(cache=False).detach().cpu().numpy()

            if is_lower:
                bounds[var][0].__setitem__(setop.args[1], value)
            if is_upper:
                bounds[var][1].__setitem__(setop.args[1], value)
    else:
        constraint_s = constraint.simplify(delete_box_constraints=False)
        variables = list(set(constraint_s.get_variables()))
        bounds = None
    return constraint_s, variables, bounds

def inner_opt(constraint_solve, constraint_check, variables, bounds, args):
    if args.opt == 'lbfgsb':    
        sgd = optim.SGD([v.tensor for v in variables], lr=0.0)
        for i in range(args.opt_iterations):
            satisfied = constraint_check.to_diffsat(cache=True).satisfy(args)
            if satisfied:
                break
            lbfgsb(variables, bounds, lambda: constraint_solve.to_diffsat(cache=True, reset_cache=True).loss(args), lambda: sgd.zero_grad())
    else:
        optimizer = args.opt([v.tensor for v in variables], lr=args.lr)
        for i in range(args.opt_max_iterations):
            satisfied = constraint_check.to_diffsat(cache=True).satisfy(args)
            if satisfied:
                break
            loss = constraint_solve.to_diffsat(cache=True, reset_cache=True).loss(args)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return satisfied

def x_to_vars(x, variables, shapes_flat, shapes):
    running_shape = 0
    with torch.no_grad():
        for i, var in enumerate(variables):
            val = x[running_shape:(running_shape + shapes_flat[i])]
            running_shape += shapes_flat[i]
            var.tensor[:] = torch.from_numpy(val.reshape(shapes[i]))

def vars_to_x(variables):
    np_vars = [var.tensor.detach().cpu().numpy() for var in variables]
    shapes = [var.shape for var in np_vars]
    shapes_flat = [var.size for var in np_vars]
    x = np.stack([var.ravel() for var in np_vars]).astype(np.float64)
    return x, shapes, shapes_flat


def basinhopping(constraint_solve, constraint_check, variables, bounds, args):
    x0, shapes, shapes_flat = vars_to_x(variables)
    
    def loss_fn(x):
        x_to_vars(x, variables, shapes_flat, shapes)
        return constraint_solve.to_diffsat(cache=True).loss(args)

    def local_optimization_step(fun, x0, *losargs, **loskwargs):
        loss_before = loss_fn(x0)
        inner_opt(constraint_solve, constraint_check, variables, bounds, args)
        r = spo.OptimizeResult()
        r.x, _, _ = vars_to_x(variables)
        loss_after = constraint_solve.to_diffsat(cache=True).loss(args)
        r.success = not (loss_before == loss_after and not constraint_check.to_diffsat(cache=True).satisfy(args))
        r.fun = loss_after
        return r

    def check_basinhopping(x, f, accept):
        if abs(f) <= 10 * args.eps_check:
            x_, _, _ = vars_to_x(variables)
            x_to_vars(x, variables, shapes_flat, shapes)
            if constraint_check.to_diffsat(cache=True).satisfy(args):
                return True
            else:
                x_to_vars(x_, variables, shapes_flat, shapes)
        return False
    
    minimizer_kwargs = {}
    minimizer_kwargs['method'] = local_optimization_step

    satisfied = constraint_check.to_diffsat(cache=True).satisfy(args)
    if satisfied:
        return True
    spo.basinhopping(loss_fn, x0, niter=1000, minimizer_kwargs=minimizer_kwargs, callback=check_basinhopping,
                     T=args.basinhopping_T, stepsize=args.basinhopping_stepsize)
    return constraint_check.to_diffsat(cache=True).satisfy(args)

class TimeoutException(Exception):
    pass

def solve(constraint, args, return_values=None):
    def solve_(constraint, args, return_values=None):
        t0 = time.time()
        if constraint is not None:
            constraint_s, variables, bounds = simplify(constraint, args)
            if args.use_basinhopping:
                satisfied = basinhopping(constraint_s, constraint, variables, bounds, args)
            else:
                satisfied = inner_opt(constraint_s, constraint, variables, bounds, args)
        else:
            satisfied = True

        if return_values is None:
            if constraint is not None:
                variables = list(set(constraint.get_variables()))
                ret = dict([(v.name, v.tensor.detach().cpu().numpy()) for v in variables])
            else:
                ret = dict()
        else:
            ret = [(str(r), r.to_diffsat(cache=True).detach().cpu().numpy()) for r in return_values]
            if len(ret) == 1:
                ret = ret[0][1]
            else:
                ret = dict(ret)
        t1 = time.time()
        return satisfied, ret, t1 - t0

    def timeout(signum, frame):
        raise TimeoutException()

    signal.signal(signal.SIGALRM, timeout)
    signal.alarm(args.timeout)
    try:
        solved, results, t = solve_(constraint, args, return_values=None)
    except TimeoutException:
        solved, results, t = False, None, args.timeout
    signal.alarm(0)  # cancel alarms
    torch.cuda.empty_cache()
    return solved, results, t

def lbfgsb(variables, bounds, loss_fn, zero_grad_fn):
    x, shapes, shapes_flat = vars_to_x(variables)
    bounds_list = []
    for var in variables:
        lower, upper = bounds[var]
        lower = lower.ravel()
        upper = upper.ravel()
        for i in range(lower.size):
            bounds_list.append((lower[i], upper[i]))

    def f(x):
        x_to_vars(x, variables, shapes_flat, shapes)
        loss = loss_fn()
        zero_grad_fn()
        loss.backward()
        with torch.no_grad():
            f = loss.detach().cpu().numpy().astype(np.float64)
            g = np.stack([var.tensor.grad.detach().cpu().numpy().ravel() for var in variables]).astype(np.float64)
        return f, g
    x, f, d = spo.fmin_l_bfgs_b(f, x, bounds=bounds_list)
    x_to_vars(x, variables, shapes_flat, shapes)