import torch
from torch.utils.checkpoint import checkpoint

from .adlib import SVD 
svd = SVD.apply
#from .adlib import EigenSolver
#symeig = EigenSolver.apply

def renormalize(*tensors):
    # T(up,left,down,right), u=up, l=left, d=down, r=right
    # C(d,r), EL(u,r,d), EU(l,d,r)

    C, E, T, chi = tensors

    dimT, dimE = T.shape[0], E.shape[0]
    D_new = min(dimE*dimT, chi)

    # step 1: contruct the density matrix Rho
    Rho = torch.tensordot(C,E,([1],[0]))        # C(ef)*EU(fga)=Rho(ega)
    Rho = torch.tensordot(Rho,E,([0],[0]))      # Rho(ega)*EL(ehc)=Rho(gahc)
    Rho = torch.tensordot(Rho,T,([0,2],[0,1]))  # Rho(gahc)*T(ghdb)=Rho(acdb)
    Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT)  # Rho(acdb)->Rho(ab;cd)

    Rho = Rho+Rho.t()
    Rho = Rho/Rho.norm()

    # step 2: Get Isometry P
    U, S, V = svd(Rho)
    truncation_error = S[D_new:].sum()/S.sum()
    P = U[:, :D_new] # projection operator
    
    #can also do symeig since Rho is symmetric 
    #S, U = symeig(Rho)
    #sorted, indices = torch.sort(S.abs(), descending=True)
    #truncation_error = sorted[D_new:].sum()/sorted.sum()
    #S = S[indices][:D_new]
    #P = U[:, indices][:, :D_new] # projection operator

    # step 3: renormalize C and E
    C = (P.t() @ Rho @ P) #C(D_new, D_new)

    ## EL(u,r,d)
    P = P.view(dimE,dimT,D_new)
    E = torch.tensordot(E, P, ([0],[0]))  # EL(def)P(dga)=E(efga)
    E = torch.tensordot(E, T, ([0,2],[1,0]))  # E(efga)T(gehb)=E(fahb)
    E = torch.tensordot(E, P, ([0,2],[0,1]))  # E(fahb)P(fhc)=E(abc)

    # step 4: symmetrize C and E
    C = 0.5*(C+C.t())
    E = 0.5*(E + E.permute(2, 1, 0))

    return C/C.norm(), E, S.abs()/S.abs().max(), truncation_error


def CTMRG(T, chi, max_iter, use_checkpoint=False):
    # T(up, left, down, right)

    threshold = 1E-12 if T.dtype is torch.float64 else 1E-6 # ctmrg convergence threshold

    # C(down, right), E(up,right,down)
    C = T.sum((0,1))  #
    E = T.sum(1).permute(0,2,1)

    truncation_error = 0.0
    sold = torch.zeros(chi, dtype=T.dtype, device=T.device)
    diff = 1E1
    for n in range(max_iter):
        tensors = C, E, T, torch.tensor(chi)
        if use_checkpoint: # use checkpoint to save memory
            C, E, s, error = checkpoint(renormalize, *tensors)
        else:
            C, E, s, error = renormalize(*tensors)

        Enorm = E.norm()
        E = E/Enorm
        truncation_error += error.item()
        if (s.numel() == sold.numel()):
            diff = (s-sold).norm().item()
            #print( s, sold )
        #print( 'n: %d, Enorm: %g, error: %e, diff: %e' % (n, Enorm, error.item(), diff) )
        if (diff < threshold):
            break
        sold = s
    #print ('ctmrg converged at iterations %d to %.5e, truncation error: %.5f'%(n, diff, truncation_error/n))

    return C, E

if __name__=='__main__':
    import time
    torch.manual_seed(42)
    D = 6
    chi = 80
    max_iter = 100
    device = 'cpu'

    # T(u,l,d,r)
    T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True)

    T = (T + T.permute(0, 3, 2, 1))/2.      # left-right symmetry
    T = (T + T.permute(2, 1, 0, 3))/2.      # up-down symmetry
    T = (T + T.permute(3, 2, 1, 0))/2.      # skew-diagonal symmetry
    T = (T + T.permute(1, 0, 3, 2))/2.      # digonal symmetry
    T = T/T.norm()

    C, E = CTMRG(T, chi, max_iter, use_checkpoint=True)
    C, E = CTMRG(T, chi, max_iter, use_checkpoint=False)
    print( 'diffC = ', torch.dist( C, C.t() ) )
    print( 'diffE = ', torch.dist( E, E.permute(2,1,0) ) )