import future
import builtins
import past
import six

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd

from functools import reduce

try:
    from . import helpers as h
except:
    import helpers as h



def catNonNullErrors(op, ref_errs=None): # the way of things is ugly
    def doop(er1, er2):
        erS, erL = (er1, er2)
        sS, sL = (erS.size()[0], erL.size()[0])

        if sS == sL: # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding).
            return op(erS,erL)

        if ref_errs is not None:
            sz = ref_errs.size()[0]
        else:
            sz = min(sS, sL)
    
        p1 = op(erS[:sz], erL[:sz])
        erSrem = erS[sz:]
        erLrem = erS[sz:]
        p2 = op(erSrem, h.zeros(erSrem.shape))
        p3 = op(h.zeros(erLrem.shape), erLrem)
        return torch.cat((p1,p2,p3), dim=0)
    return doop

def creluBoxy(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta 
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)

    aber = torch.abs(dom.errors)

    sm = torch.sum(aber, 0) 

    if not dom.beta is None:
        sm += dom.beta

    mx = dom.head + sm
    mn = dom.head - sm

    should_box = mn.lt(0) * mx.gt(0)
    gtz = dom.head.gt(0).to_dtype()
    mx /= 2
    newhead = h.ifThenElse(should_box, mx, gtz * dom.head)
    newbeta = h.ifThenElse(should_box, mx, gtz * (dom.beta if not dom.beta is None else 0))
    newerr = (1 - should_box.to_dtype()) * gtz * dom.errors

    return dom.new(newhead, newbeta , newerr)


def creluBoxySound(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta 
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2 + 2e-6 , None)

    aber = torch.abs(dom.errors)

    sm = torch.sum(aber, 0) 

    if not dom.beta is None:
        sm += dom.beta

    mx = dom.head + sm
    mn = dom.head - sm

    should_box = mn.lt(0) * mx.gt(0)
    gtz = dom.head.gt(0).to_dtype()
    mx /= 2
    newhead = h.ifThenElse(should_box, mx, gtz * dom.head)
    newbeta = h.ifThenElse(should_box, mx + 2e-6, gtz * (dom.beta if not dom.beta is None else 0))
    newerr = (1 - should_box.to_dtype()) * gtz * dom.errors

    return dom.new(newhead, newbeta, newerr)


def creluSwitch(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta 
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)

    aber = torch.abs(dom.errors)

    sm = torch.sum(aber, 0) 

    if not dom.beta is None:
        sm += dom.beta

    mn = dom.head - sm
    mx = sm
    mx += dom.head

    should_box = mn.lt(0) * mx.gt(0)
    gtz = dom.head.gt(0)

    mn.neg_()
    should_boxer = mn.gt(mx)

    mn /= 2
    newhead = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, dom.head + mn ), gtz.to_dtype() * dom.head)
    zbet =  dom.beta if not dom.beta is None else 0
    newbeta = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, mn + zbet), gtz.to_dtype() * zbet)
    newerr  = h.ifThenElseL(should_box, 1 - should_boxer, gtz).to_dtype() * dom.errors

    return dom.new(newhead, newbeta , newerr)

def creluSmooth(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta 
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)

    aber = torch.abs(dom.errors)

    sm = torch.sum(aber, 0) 

    if not dom.beta is None:
        sm += dom.beta

    mn = dom.head - sm
    mx = sm
    mx += dom.head


    nmn = F.relu(-1 * mn)

    zbet =  (dom.beta if not dom.beta is None else 0)
    newheadS = dom.head + nmn / 2
    newbetaS = zbet + nmn / 2
    newerrS = dom.errors

    mmx = F.relu(mx)

    newheadB = mmx / 2
    newbetaB = newheadB
    newerrB = 0

    eps = 0.0001
    t = nmn / (mmx + nmn + eps) # mn.lt(0).to_dtype() * F.sigmoid(nmn - nmx)

    shouldnt_zero = mx.gt(0).to_dtype()

    newhead = shouldnt_zero * ( (1 - t) * newheadS + t * newheadB)
    newbeta = shouldnt_zero * ( (1 - t) * newbetaS + t * newbetaB)
    newerr =  shouldnt_zero * ( (1 - t) * newerrS  + t * newerrB)

    return dom.new(newhead, newbeta , newerr)


def creluNIPS(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta 
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)
    
    sm = torch.sum(torch.abs(dom.errors), 0) 

    if not dom.beta is None:
        sm += dom.beta

    mn = dom.head - sm
    mx =  dom.head + sm

    mngz = mn >= 0.0

    zs = h.zeros(dom.head.shape)

    diff = mx - mn

    lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs)
    mu = lam * mn * (-0.5)

    betaz = zs if dom.beta is None else dom.beta 

    newhead = torch.where(mngz, dom.head , lam * dom.head + mu)
    mngz += diff <= 0.0
    newbeta = torch.where(mngz, betaz    , lam * betaz + mu ) # mu is always positive on this side
    newerr = torch.where(mngz, dom.errors, lam * dom.errors )
    return dom.new(newhead, newbeta, newerr)




class MaxTypes:

    @staticmethod
    def ub(x):
        return x.ub()

    @staticmethod
    def only_beta(x):
        return x.beta if x.beta is not None else x.head * 0

    @staticmethod
    def head_beta(x):
        return MaxTypes.only_beta(x) + x.head

class HybridZonotope:

    def isSafe(self, target):
        od,_ = torch.min(h.preDomRes(self,target).lb(), 1)
        return od.gt(0.0).long()

    def isPoint(self):
        return False

    def labels(self):
        target = torch.max(self.ub(), 1)[1]
        l = list(h.preDomRes(self,target).lb()[0])
        return [target.item()] + [ i for i,v in zip(range(len(l)), l) if v <= 0]

    def relu(self):
        return self.customRelu(self)
    
    def __init__(self, head, beta, errors, customRelu = creluBoxy, **kargs):
        self.head = head
        self.errors = errors
        self.beta = beta
        self.customRelu = creluBoxy if customRelu is None else customRelu

    def new(self, *args, customRelu = None, **kargs):
        return self.__class__(*args, **kargs, customRelu = self.customRelu if customRelu is None else customRelu).checkSizes()

    def zono_to_hybrid(self, *args, **kargs): # we are already a hybrid zono.
        return self.new(self.head, self.beta, self.errors, **kargs)

    def hybrid_to_zono(self, *args, correlate=True, customRelu = None, **kargs):
        beta = self.beta
        errors = self.errors
        if correlate and beta is not None:
            batches = beta.shape[0]
            num_elem = h.product(beta.shape[1:])
            ei = h.getEi(batches, num_elem)
        
            if len(beta.shape) > 2:
                ei = ei.contiguous().view(num_elem, *beta.shape)
            err = ei * beta
            errors = torch.cat((err, errors), dim=0) if errors is not None else err
            beta = None

        return Zonotope(self.head, beta, errors if errors is not None else (self.beta * 0).unsqueeze(0) , customRelu = self.customRelu if customRelu is None else None)



    def abstractApplyLeaf(self, foo, *args, **kargs):
        return getattr(self, foo)(*args, **kargs)

    def decorrelate(self, cc_indx_batch_err): # keep these errors
        if self.errors is None:
            return self

        batch_size = self.head.shape[0]
        num_error_terms = self.errors.shape[0]

        

        beta = h.zeros(self.head.shape).to_dtype() if self.beta is None  else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors

        inds_i = torch.arange(self.head.shape[0], device=h.device).unsqueeze(1).long()
        errors = errors.to_dtype().permute(1,0, *list(range(len(self.errors.shape)))[2:])
        
        sm = errors.clone()
        sm[inds_i, cc_indx_batch_err] = 0
        
        beta = beta.to_dtype() + sm.abs().sum(dim=1)

        errors = errors[inds_i, cc_indx_batch_err]
        errors = errors.permute(1,0, *list(range(len(self.errors.shape)))[2:]).contiguous()
        return self.new(self.head, beta, errors)
    
    def dummyDecorrelate(self, num_decorrelate):
        if num_decorrelate == 0 or self.errors is None:
            return self
        elif num_decorrelate >= self.errors.shape[0]:
            beta = self.beta
            if self.errors is not None:
                errs = self.errors.abs().sum(dim=0)
                if beta is None:
                    beta = errs
                else:
                    beta += errs
            return self.new(self.head, beta, None)
        return None

    def stochasticDecorrelate(self, num_decorrelate, choices = None, num_to_keep=False):
        dummy = self.dummyDecorrelate(num_decorrelate)
        if dummy is not None:
            return dummy
        num_error_terms = self.errors.shape[0]
        batch_size = self.head.shape[0]

        ucc_mask = h.ones([batch_size, self.errors.shape[0]]).long()
        cc_indx_batch_err = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_decorrelate if num_to_keep else num_error_terms - num_decorrelate, replacement=False)) if choices is None else choices
        return self.decorrelate(cc_indx_batch_err)

    def decorrelateMin(self, num_decorrelate, num_to_keep=False):
        dummy = self.dummyDecorrelate(num_decorrelate)
        if dummy is not None:
            return dummy

        num_error_terms = self.errors.shape[0]
        batch_size = self.head.shape[0]

        error_sum_b_e = self.errors.abs().view(self.errors.shape[0], batch_size, -1).sum(dim=2).permute(1,0)
        cc_indx_batch_err = error_sum_b_e.topk(num_decorrelate if num_to_keep else num_error_terms - num_decorrelate)[1]
        return self.decorrelate(cc_indx_batch_err)
      
    def correlate(self, cc_indx_batch_beta): # given in terms of the flattened matrix.
        num_correlate = h.product(cc_indx_batch_beta.shape[1:])
        
        beta = h.zeros(self.head.shape).to_dtype() if self.beta is None  else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors

        batch_size = beta.shape[0]
        new_errors = h.zeros([num_correlate] + list(self.head.shape)).to_dtype()
        
        inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long()

        nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long()

        new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(batch_size, num_correlate, -1)
        new_errors[inds_i, nc.unsqueeze(0).expand([batch_size]+list(nc.shape)).squeeze(2), cc_indx_batch_beta] = beta.view(batch_size,-1)[inds_i, cc_indx_batch_beta]

        new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(num_correlate, batch_size, *beta.shape[1:])
        errors = torch.cat((errors, new_errors), dim=0)
            
        beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0
        
        return self.new(self.head, beta, errors)

    def stochasticCorrelate(self, num_correlate, choices = None):
        if num_correlate == 0:
            return self

        domshape = self.head.shape
        batch_size = domshape[0]
        num_pixs = h.product(domshape[1:])
        num_correlate = min(num_correlate, num_pixs)
        ucc_mask = h.ones([batch_size, num_pixs ]).long()

        cc_indx_batch_beta = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_correlate, replacement=False)) if choices is None else choices
        return self.correlate(cc_indx_batch_beta)


    def correlateMaxK(self, num_correlate):
        if num_correlate == 0:
            return self
        
        domshape = self.head.shape
        batch_size = domshape[0]
        num_pixs = h.product(domshape[1:])
        num_correlate = min(num_correlate, num_pixs)

        concrete_max_image = self.ub().view(batch_size, -1)

        cc_indx_batch_beta = concrete_max_image.topk(num_correlate)[1]
        return self.correlate(cc_indx_batch_beta)

    def correlateMaxPool(self, *args, max_type = MaxTypes.ub , max_pool = F.max_pool2d, **kargs):
        domshape = self.head.shape
        batch_size = domshape[0]
        num_pixs = h.product(domshape[1:])

        concrete_max_image = max_type(self)

        cc_indx_batch_beta = max_pool(concrete_max_image, *args, return_indices=True, **kargs)[1].view(batch_size, -1)

        return self.correlate(cc_indx_batch_beta)

    def checkSizes(self):
        if not self.errors is None:
            if not self.errors.size()[1:] == self.head.size():
                raise Exception("Such bad sizes on error:", self.errors.shape, " head:", self.head.shape)
            if torch.isnan(self.errors).any():
                raise Exception("Such nan in errors")
        if not self.beta is None:
            if not self.beta.size() == self.head.size():
                raise Exception("Such bad sizes on beta")

            if torch.isnan(self.beta).any():
                raise Exception("Such nan in errors")
            if self.beta.lt(0.0).any():
                self.beta = self.beta.abs()
            
        return self

    def __mul__(self, flt):
        return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt)
    
    def __truediv__(self, flt):
        flt = 1. / flt
        return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt)

    def __add__(self, other):
        if isinstance(other, HybridZonotope):
            return self.new(self.head + other.head, h.msum(self.beta, other.beta, lambda a,b: a + b), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a + b)))
        else:
            # other has to be a standard variable or tensor
            return self.new(self.head + other, self.beta, self.errors)

    def addPar(self, a, b):
        return self.new(a.head + b.head, h.msum(a.beta, b.beta, lambda a,b: a + b), h.msum(a.errors, b.errors, catNonNullErrors(lambda a,b: a + b, self.errors)))

    def __sub__(self, other):
        if isinstance(other, HybridZonotope):
            return self.new(self.head - other.head
                            , h.msum(self.beta, other.beta, lambda a,b: a + b)
                            , h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b)))
        else:
            # other has to be a standard variable or tensor
            return self.new(self.head - other, self.beta, self.errors)

    def bmm(self, other):
        hd = self.head.bmm(other)
        bet = None if self.beta is None else self.beta.bmm(other.abs())

        if self.errors is None:
            er = None
        else:
            er = self.errors.matmul(other)
        return self.new(hd, bet, er)


    def getBeta(self):
        return self.head * 0 if self.beta is None else self.beta

    def getErrors(self):
        return (self.head * 0).unsqueeze(0) if self.beta is None else self.errors

    def merge(self, other, ref = None): # the vast majority of the time ref should be none here.  Not for parallel computation with powerset
        s_beta = self.getBeta() # so that beta is never none

        sbox_u = self.head + s_beta
        sbox_l = self.head - s_beta
        o_u = other.ub()
        o_l = other.lb()
        o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l)

        s_err_mx = self.errors.abs().sum(dim=0)

        if not isinstance(other, HybridZonotope):
            new_head = (self.head + other.center()) / 2
            new_beta = torch.max(sbox_u + s_err_mx,o_u) - new_head
            return self.new(torch.where(o_in_s, self.head, new_head), torch.where(o_in_s, self.beta,new_beta), o_in_s.float() * self.errors)
        
        # TODO: could be more efficient if one of these doesn't have beta or errors but thats okay for now.
        s_u = sbox_u + s_err_mx
        s_l = sbox_l - s_err_mx

        obox_u = o_u - other.head
        obox_l = o_l + other.head

        s_in_o = (s_u <= obox_u) & (s_l >= obox_l)
        
        # TODO: could theoretically still do something better when one is contained partially in the other
        new_head = (self.head + other.center()) / 2
        new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0),o_u) - new_head

        return self.new(torch.where(o_in_s, self.head, torch.where(s_in_o, other.head, new_head))
                        , torch.where(o_in_s, s_beta,torch.where(s_in_o, other.getBeta(), new_beta))
                        , h.msum(o_in_s.float() * self.errors, s_in_o.float() * other.errors, catNonNullErrors(lambda a,b: a + b, ref_errs = ref.errors if ref is not None else ref))) # these are both zero otherwise
    

    def conv(self, conv, weight, bias = None, **kargs):
        h = self.errors
        inter = h if h is None else h.view(-1, *h.size()[2:])
        hd = conv(self.head, weight, bias=bias, **kargs)
        res = h if h is None else conv(inter, weight, bias=None, **kargs)

        return self.new( hd
                       , None if self.beta is None else conv(self.beta, weight.abs(), bias = None, **kargs)
                       , h if h is None else res.view(h.size()[0], h.size()[1], *res.size()[1:]))
            

    def conv1d(self, *args, **kargs):
        return self.conv(lambda x, *args, **kargs: x.conv1d(*args,**kargs), *args, **kargs)
                   
    def conv2d(self, *args, **kargs):
        return self.conv(lambda x, *args, **kargs: x.conv2d(*args,**kargs), *args, **kargs)                   

    def conv3d(self, *args, **kargs):
        return self.conv(lambda x, *args, **kargs: x.conv3d(*args,**kargs), *args, **kargs)

    def conv_transpose1d(self, *args, **kargs):
        return self.conv(lambda x, *args, **kargs: x.conv_transpose1d(*args,**kargs), *args, **kargs)
                   
    def conv_transpose2d(self, *args, **kargs):
        return self.conv(lambda x, *args, **kargs: x.conv_transpose2d(*args,**kargs), *args, **kargs)                   

    def conv_transpose3d(self, *args, **kargs):
        return self.conv(lambda x, *args, **kargs: x.conv_transpose3d(*args,**kargs), *args, **kargs)
        
    def matmul(self, other):
        return self.new(self.head.matmul(other), None if self.beta is None else self.beta.matmul(other.abs()), None if self.errors is None else self.errors.matmul(other))

    def unsqueeze(self, i):
        return self.new(self.head.unsqueeze(i), None if self.beta is None else self.beta.unsqueeze(i), None if self.errors is None else self.errors.unsqueeze(i + 1))

    def squeeze(self, dim):
        return self.new(self.head.squeeze(dim),
                        None if self.beta is None else self.beta.squeeze(dim),
                        None if self.errors is None else self.errors.squeeze(dim + 1 if dim >= 0 else dim))    

    def double(self):
        return self.new(self.head.double(), self.beta.double()  if self.beta is not None else None, self.errors.double() if self.errors is not None else None) 

    def float(self):
        return self.new(self.head.float(), self.beta.float()  if self.beta is not None else None, self.errors.float() if self.errors is not None else None) 

    def to_dtype(self):
        return self.new(self.head.to_dtype(), self.beta.to_dtype()  if self.beta is not None else None, self.errors.to_dtype() if self.errors is not None else None) 
    
    def sum(self, dim=1):
        return self.new(torch.sum(self.head,dim=dim), None if self.beta is None else torch.sum(self.beta,dim=dim), None if self.errors is None else torch.sum(self.errors, dim= dim + 1 if dim >= 0 else dim))

    def view(self,*newshape):
        return self.new(self.head.view(*newshape), 
                        None if self.beta is None else self.beta.view(*newshape),
                        None if self.errors is None else self.errors.view(self.errors.size()[0], *newshape))

    def gather(self,dim, index):
        return self.new(self.head.gather(dim, index), 
                        None if self.beta is None else self.beta.gather(dim, index),
                        None if self.errors is None else self.errors.gather(dim + 1, index.expand([self.errors.size()[0]] + list(index.size()))))
    
    def concretize(self):
        if self.errors is None:
            return self

        return self.new(self.head, torch.sum(self.concreteErrors().abs(),0), None) # maybe make a box?
    
    def cat(self,other, dim=0):
        return self.new(self.head.cat(other.head, dim = dim), 
                        h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)),
                        h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1))))


    def split(self, split_size, dim = 0):
        heads = list(self.head.split(split_size, dim))
        betas = list(self.beta.split(split_size, dim)) if not self.beta is None else None
        errorss = list(self.errors.split(split_size, dim + 1)) if not self.errors is None else None
        
        def makeFromI(i):
            return self.new( heads[i], 
                             None if betas is None else betas[i], 
                             None if errorss is None else errorss[i])
        return tuple(makeFromI(i) for i in range(len(heads)))

        
    
    def concreteErrors(self):
        if self.beta is None and self.errors is None:
            raise Exception("shouldn't have both beta and errors be none")
        if self.errors is None:
            return self.beta.unsqueeze(0)
        if self.beta is None:
            return self.errors
        return torch.cat([self.beta.unsqueeze(0),self.errors], dim=0)


    def applyMonotone(self, foo, *args, **kargs):
        if self.beta is None and self.errors is None:
            return self.new(foo(self.head), None , None)

        beta = self.concreteErrors().abs().sum(dim=0)

        tp = foo(self.head + beta, *args, **kargs)
        bt = foo(self.head - beta, *args, **kargs)

        new_hybrid = self.new((tp + bt) / 2, (tp - bt) / 2 , None)


        if self.errors is not None:
            return new_hybrid.correlateMaxK(self.errors.shape[0])
        return new_hybrid

    def avg_pool2d(self, *args, **kargs):
        nhead = F.avg_pool2d(self.head, *args, **kargs)
        return self.new(nhead, 
                        None if self.beta is None else F.avg_pool2d(self.beta, *args, **kargs), 
                        None if self.errors is None else F.avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) 

    def adaptive_avg_pool2d(self, *args, **kargs):
        nhead = F.adaptive_avg_pool2d(self.head, *args, **kargs)
        return self.new(nhead, 
                        None if self.beta is None else F.adaptive_avg_pool2d(self.beta, *args, **kargs), 
                        None if self.errors is None else F.adaptive_avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) 

    def elu(self):
        return self.applyMonotone(F.elu)

    def selu(self):
        return self.applyMonotone(F.selu)

    def sigm(self):
        return self.applyMonotone(F.sigmoid)

    def softplus(self):
        if self.errors is None:
            if self.beta is None:
                return self.new(F.softplus(self.head), None , None)
            tp = F.softplus(self.head + self.beta)
            bt = F.softplus(self.head - self.beta)
            return self.new((tp + bt) / 2, (tp - bt) / 2 , None)

        errors = self.concreteErrors()
        o = h.ones(self.head.size())

        def sp(hd):
            return F.softplus(hd) # torch.log(o + torch.exp(hd))  # not very stable
        def spp(hd):
            ehd = torch.exp(hd)
            return ehd.div(ehd + o)
        def sppp(hd):
            ehd = torch.exp(hd)
            md = ehd + o
            return ehd.div(md.mul(md))

        fa = sp(self.head)
        fpa = spp(self.head)

        a = self.head

        k = torch.sum(errors.abs(), 0) 

        def evalG(r):
            return r.mul(r).mul(sppp(a + r))

        m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k)))
        m = h.ifThenElse( a.abs().lt(k), torch.max(m, torch.max(evalG(a), evalG(-a))), m)
        m /= 2
        
        return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa), None if self.errors is None else self.errors.mul(fpa))

    def center(self):
        return self.head

    def vanillaTensorPart(self):
        return self.head

    def lb(self):
        return self.head - self.concreteErrors().abs().sum(dim=0)

    def ub(self):
        return self.head + self.concreteErrors().abs().sum(dim=0)

    def size(self):
        return self.head.size()

    def diameter(self):
        abal = torch.abs(self.concreteErrors()).transpose(0,1)
        return abal.sum(1).sum(1) # perimeter

    def loss(self, target, **args):
        r = -h.preDomRes(self, target).lb()
        return F.softplus(r.max(1)[0])

    def deep_loss(self, act = F.relu, *args, **kargs):
        batch_size = self.head.shape[0]
        inds = torch.arange(batch_size, device=h.device).unsqueeze(1).long()

        def dl(l,u):
            ls, lsi = torch.sort(l, dim=1)
            ls_u = u[inds, lsi]

            def slidingMax(a): # using maxpool
                k = a.shape[1]
                ml = a.min(dim=1)[0].unsqueeze(1)

                inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1)
                mpl = F.max_pool1d(inp.unsqueeze(1) ,  kernel_size = k, stride=1, padding = 0, return_indices=False).squeeze(1)
                return mpl[:,:-1] + ml
            
            return act(slidingMax(ls_u) - ls).sum(dim=1)

        l = self.lb().view(batch_size, -1)
        u = self.ub().view(batch_size, -1)
        return ( dl(l,u) + dl(-u,-l) ) / (2 * l.shape[1]) # make it easier to regularize against



class Zonotope(HybridZonotope):
    def applySuper(self, ret):
        batches = ret.head.size()[0]
        num_elem = h.product(ret.head.size()[1:])
        ei = h.getEi(batches, num_elem)

        if len(ret.head.size()) > 2:
            ei = ei.contiguous().view(num_elem, *ret.head.size())

        ret.errors = torch.cat( (ret.errors, ei * ret.beta) ) if not ret.beta is None else ret.errors
        ret.beta = None
        return ret.checkSizes()

    def zono_to_hybrid(self, *args, customRelu = None, **kargs): # we are already a hybrid zono.
        return HybridZonotope(self.head, self.beta, self.errors, customRelu = self.customRelu if customRelu is None else customRelu)

    def hybrid_to_zono(self, *args, **kargs):
        return self.new(self.head, self.beta, self.errors, **kargs)

    def applyMonotone(self, *args, **kargs):
        return self.applySuper(super(Zonotope,self).applyMonotone(*args, **kargs))

    def softplus(self):
        return self.applySuper(super(Zonotope,self).softplus())

    def relu(self):
        return self.applySuper(super(Zonotope,self).relu())

    def splitRelu(self, *args, **kargs):
        return [self.applySuper(a) for a in super(Zonotope, self).splitRelu(*args, **kargs)]


def mysign(x):
    e = x.eq(0).to_dtype()
    r = x.sign().to_dtype()
    return r + e

def mulIfEq(grad,out,target):
    pred = out.max(1, keepdim=True)[1]
    is_eq = pred.eq(target.view_as(pred)).to_dtype()
    is_eq = is_eq.view([-1] + [1 for _ in grad.size()[1:]]).expand_as(grad)
    return is_eq
    

def stdLoss(out, target):
    if torch.__version__[0] == "0":
        return F.cross_entropy(out, target, reduce = False)
    else:
        return F.cross_entropy(out, target, reduction='none')



class ListDomain(object):

    def __init__(self, al, *args, **kargs):
        self.al = list(al)

    def new(self, *args, **kargs):
        return self.__class__(*args, **kargs)

    def isSafe(self,*args,**kargs):
        raise "Domain Not Suitable For Testing"

    def labels(self):
        raise "Domain Not Suitable For Testing"

    def isPoint(self):
        return all(a.isPoint() for a in self.al)

    def __mul__(self, flt):
        return self.new(a.__mul__(flt) for a in self.al)

    def __truediv__(self, flt):
        return self.new(a.__truediv__(flt) for a in self.al)

    def __add__(self, other):
        if isinstance(other, ListDomain):
            return self.new(a.__add__(o) for a,o in zip(self.al, other.al))
        else:
            return self.new(a.__add__(other) for a in self.al)

    def merge(self, other, ref = None):
        if ref is None:
            return self.new(a.merge(o) for a,o in zip(self.al,other.al) )
        return self.new(a.merge(o, ref = r) for a,o,r in zip(self.al,other.al, ref.al))

    def addPar(self, a, b):
        return self.new(s.addPar(av,bv) for s,av,bv in zip(self.al, a.al, b.al))

    def __sub__(self, other):
        if isinstance(other, ListDomain):
            return self.new(a.__sub__(o) for a,o in zip(self.al, other.al))
        else:
            return self.new(a.__sub__(other) for a in self.al)

    def abstractApplyLeaf(self, *args, **kargs):
        return self.new(a.abstractApplyLeaf(*args, **kargs) for a in self.al)

    def bmm(self, other):
        return self.new(a.bmm(other) for a in self.al)

    def matmul(self, other):
        return self.new(a.matmul(other) for a in self.al)

    def conv(self, *args, **kargs):
        return self.new(a.conv(*args, **kargs) for a in self.al)

    def conv1d(self, *args, **kargs):
        return self.new(a.conv1d(*args, **kargs) for a in self.al)

    def conv2d(self, *args, **kargs):
        return self.new(a.conv2d(*args, **kargs) for a in self.al)

    def conv3d(self, *args, **kargs):
        return self.new(a.conv3d(*args, **kargs) for a in self.al)

    def max_pool2d(self, *args, **kargs):
        return self.new(a.max_pool2d(*args, **kargs) for a in self.al)

    def avg_pool2d(self, *args, **kargs):
        return self.new(a.avg_pool2d(*args, **kargs) for a in self.al)

    def adaptive_avg_pool2d(self, *args, **kargs):
        return self.new(a.adaptive_avg_pool2d(*args, **kargs) for a in self.al)

    def unsqueeze(self, *args, **kargs):
        return self.new(a.unsqueeze(*args, **kargs) for a in self.al)

    def squeeze(self, *args, **kargs):
        return self.new(a.squeeze(*args, **kargs) for a in self.al)

    def view(self, *args, **kargs):
        return self.new(a.view(*args, **kargs) for a in self.al)

    def gather(self, *args, **kargs):
        return self.new(a.gather(*args, **kargs) for a in self.al)

    def sum(self, *args, **kargs):
        return self.new(a.sum(*args,**kargs) for a in self.al)

    def double(self):
        return self.new(a.double() for a in self.al)

    def float(self):
        return self.new(a.float() for a in self.al)

    def to_dtype(self):
        return self.new(a.to_dtype() for a in self.al)

    def vanillaTensorPart(self):
        return self.al[0].vanillaTensorPart()

    def center(self):
        return self.new(a.center() for a in self.al)

    def ub(self):
        return self.new(a.ub() for a in self.al)

    def lb(self):
        return self.new(a.lb() for a in self.al)

    def relu(self):
        return self.new(a.relu() for a in self.al)

    def splitRelu(self, *args, **kargs):
        return self.new(a.splitRelu(*args, **kargs) for a in self.al)

    def softplus(self):
        return self.new(a.softplus() for a in self.al)

    def elu(self):
        return self.new(a.elu() for a in self.al)

    def selu(self):
        return self.new(a.selu() for a in self.al)

    def sigm(self):
        return self.new(a.sigm() for a in self.al)

    def cat(self, other, *args, **kargs):
        return self.new(a.cat(o, *args, **kargs) for a,o in zip(self.al, other.al))


    def split(self, *args, **kargs):
        return [self.new(*z) for z in zip(a.split(*args, **kargs) for a in self.al)]

    def size(self):
        return self.al[0].size()

    def loss(self, *args, **kargs):
        return sum(a.loss(*args, **kargs) for a in self.al)

    def deep_loss(self, *args, **kargs):
        return sum(a.deep_loss(*args, **kargs) for a in self.al)

    def checkSizes(self):
        for a in self.al:
            a.checkSizes()
        return self


class TaggedDomain(object):


    def __init__(self, a, tag = None):
        self.tag = tag
        self.a = a

    def isSafe(self,*args,**kargs):
        return self.a.isSafe(*args, **kargs)

    def isPoint(self):
        return self.a.isPoint()

    def labels(self):
        raise "Domain Not Suitable For Testing"

    def __mul__(self, flt):
        return TaggedDomain(self.a.__mul__(flt), self.tag)

    def __truediv__(self, flt):
        return TaggedDomain(self.a.__truediv__(flt), self.tag)

    def __add__(self, other):
        if isinstance(other, TaggedDomain):
            return TaggedDomain(self.a.__add__(other.a), self.tag)
        else:
            return TaggedDomain(self.a.__add__(other), self.tag)

    def addPar(self, a,b):
        return TaggedDomain(self.a.addPar(a.a, b.a), self.tag)

    def __sub__(self, other):
        if isinstance(other, TaggedDomain):
            return TaggedDomain(self.a.__sub__(other.a), self.tag)
        else:
            return TaggedDomain(self.a.__sub__(other), self.tag)

    def bmm(self, other):
        return TaggedDomain(self.a.bmm(other), self.tag)

    def matmul(self, other):
        return TaggedDomain(self.a.matmul(other), self.tag)

    def conv(self, *args, **kargs):
        return TaggedDomain(self.a.conv(*args, **kargs) , self.tag)

    def conv1d(self, *args, **kargs):
        return TaggedDomain(self.a.conv1d(*args, **kargs), self.tag)

    def conv2d(self, *args, **kargs):
        return TaggedDomain(self.a.conv2d(*args, **kargs), self.tag)

    def conv3d(self, *args, **kargs):
        return TaggedDomain(self.a.conv3d(*args, **kargs), self.tag)

    def max_pool2d(self, *args, **kargs):
        return TaggedDomain(self.a.max_pool2d(*args, **kargs), self.tag)

    def avg_pool2d(self, *args, **kargs):
        return TaggedDomain(self.a.avg_pool2d(*args, **kargs), self.tag)

    def adaptive_avg_pool2d(self, *args, **kargs):
        return TaggedDomain(self.a.adaptive_avg_pool2d(*args, **kargs), self.tag)


    def unsqueeze(self, *args, **kargs):
        return TaggedDomain(self.a.unsqueeze(*args, **kargs), self.tag)

    def squeeze(self, *args, **kargs):
        return TaggedDomain(self.a.squeeze(*args, **kargs), self.tag)

    def abstractApplyLeaf(self, *args, **kargs):
        return TaggedDomain(self.a.abstractApplyLeaf(*args, **kargs), self.tag)

    def view(self, *args, **kargs):
        return TaggedDomain(self.a.view(*args, **kargs), self.tag)

    def gather(self, *args, **kargs):
        return TaggedDomain(self.a.gather(*args, **kargs), self.tag)

    def sum(self, *args, **kargs):
        return TaggedDomain(self.a.sum(*args,**kargs), self.tag)

    def double(self):
        return TaggedDomain(self.a.double(), self.tag)

    def float(self):
        return TaggedDomain(self.a.float(), self.tag)

    def to_dtype(self):
        return TaggedDomain(self.a.to_dtype(), self.tag)

    def vanillaTensorPart(self):
        return self.a.vanillaTensorPart()

    def center(self):
        return TaggedDomain(self.a.center(), self.tag)

    def ub(self):
        return TaggedDomain(self.a.ub(), self.tag)

    def lb(self):
        return TaggedDomain(self.a.lb(), self.tag)

    def relu(self):
        return TaggedDomain(self.a.relu(), self.tag)

    def splitRelu(self, *args, **kargs):
        return TaggedDomain(self.a.splitRelu(*args, **kargs), self.tag)

    def diameter(self):
        return self.a.diameter()

    def softplus(self):
        return TaggedDomain(self.a.softplus(), self.tag)

    def elu(self):
        return TaggedDomain(self.a.elu(), self.tag)

    def selu(self):
        return TaggedDomain(self.a.selu(), self.tag)

    def sigm(self):
        return TaggedDomain(self.a.sigm(), self.tag)


    def cat(self, other, *args, **kargs):
        return TaggedDomain(self.a.cat(other.a, *args, **kargs), self.tag)

    def split(self, *args, **kargs):
        return [TaggedDomain(z, self.tag) for z in self.a.split(*args, **kargs)]

    def size(self):
        
        return self.a.size()

    def loss(self, *args, **kargs):
        return self.tag.loss(self.a, *args, **kargs)

    def deep_loss(self, *args, **kargs):
        return self.a.deep_loss(*args, **kargs)

    def checkSizes(self):
        self.a.checkSizes()
        return self

    def merge(self, other, ref = None):
        return TaggedDomain(self.a.merge(other.a, ref = None if ref is None else ref.a), self.tag)