from __future__ import print_function
import os
import torch
import torch.nn.functional as F
import numpy as np
import math 


#for single depth [stable version]
def depthToNormal(depthmap,
                  device,
                  K,
                  thres,
                  img_size
                 ):
    # default
    K_torch = torch.tensor([K[0], K[1], K[2], K[3]]).to(device)

    C, H, W = depthmap.size()
    assert( C==1 and H == img_size[0] and W == img_size[1])
    
    depthmap = torch.reshape(depthmap, [H,W]) #resize to H W 
    
    X_grid, Y_grid = torch.meshgrid( [torch.arange(H, out=torch.FloatTensor().to(device)), torch.arange(W, out=torch.FloatTensor().to(device))] )
    
    X = (X_grid - K_torch[2]) * depthmap / K_torch[0] 
    Y = (Y_grid - K_torch[3]) * depthmap / K_torch[1]
    
    DepthPoints = torch.stack([X, Y, depthmap], dim=2) # all 3D point 
    
    delta_right = DepthPoints[2:,1:-1,:] - DepthPoints[1:-1,1:-1,:]
    delta_down  = DepthPoints[1:-1,2:,:] - DepthPoints[1:-1,1:-1,:]
    
    delta_left = DepthPoints[0:-2,1:-1,:] - DepthPoints[1:-1,1:-1,:]
    delta_up   = DepthPoints[1:-1,0:-2,:] - DepthPoints[1:-1,1:-1,:]
    
    
    normal_crop1 = torch.cross(delta_down, delta_right)
    normal_crop1 = F.normalize(normal_crop1, p=2, dim=2)
    
    normal_crop2 = torch.cross(delta_up,   delta_left )
    normal_crop2 = F.normalize(normal_crop2, p=2, dim=2)
    
    normal_crop  = normal_crop1 + normal_crop2
    normal_crop  = F.normalize(normal_crop, p=2, dim=2)
    
    
    normal = torch.zeros(H,W,3).to(device)
    normal[1:-1,1:-1,:] = normal_crop
    
    confidence_map_crop = torch.ones(H-2,W-2).to(device)
    
    delta_right_norm = torch.norm(delta_right, p=2, dim=2)
    delta_down_norm  = torch.norm(delta_down, p=2, dim=2)
    confidence_map_crop[ delta_right_norm > thres] = 0.0
    confidence_map_crop[ delta_down_norm > thres] = 0.0
    
    confidence_map = torch.zeros(H,W)
    confidence_map[1:-1,1:-1] = confidence_map_crop
    confidence_map[depthmap == 0] = 0
    
    # change to CHW
    normal = normal.permute(2, 0, 1)
    
    # return the  normal [3,H,W] and confidence map
    return normal, confidence_map
  
    
def normalToSH(normal, 
               device):
    
    #here is the SH (SH Basis order=2)
    c =torch.zeros(9).to(device)
    
    c[0] =  1/(2* math.sqrt(math.pi) );
    c[1] = - math.sqrt(3) / ( 2* math.sqrt(math.pi) )
    c[2] = - math.sqrt(3) / ( 2* math.sqrt(math.pi) )
    c[3] =  math.sqrt(3) / ( 2* math.sqrt(math.pi) )
    c[4] =  math.sqrt(15) / ( 2* math.sqrt(math.pi) )
    c[5] = - math.sqrt(15) / ( 2* math.sqrt(math.pi) ) 
    c[6] = - math.sqrt(15) / ( 2* math.sqrt(math.pi) ) 
    c[7] =  math.sqrt(15) / ( 4* math.sqrt(math.pi) )
    c[8] =  math.sqrt(5) / ( 4* math.sqrt(math.pi) )    
    
    C, H, W = normal.size()
    
    spherical_harmonics = torch.zeros(9,H,W).to(device)
    
    spherical_harmonics[0,:,:] = 1 * c[0];  
    spherical_harmonics[1,:,:] = normal[0,:,:] * c[1];
    spherical_harmonics[2,:,:] = normal[1,:,:] * c[2];
    spherical_harmonics[3,:,:] = normal[2,:,:] * c[3];
    spherical_harmonics[4,:,:] = normal[0,:,:] * normal[1,:,:]  * c[4];
    spherical_harmonics[5,:,:] = normal[0,:,:] * normal[2,:,:]  * c[5];
    spherical_harmonics[6,:,:] = normal[1,:,:] * normal[2,:,:]  * c[6];
    spherical_harmonics[7,:,:] = (normal[0,:,:] * normal[0,:,:] -  normal[1,:,:] * normal[1,:,:]) * c[7];
    spherical_harmonics[8,:,:] = (3 * normal[2,:,:] * normal[2,:,:] - 1) *c[8];
    
    return spherical_harmonics


def RGBalbedoSHToLight(colorImg, albedoImg, SH, confidence_map):
    
    #remove non-zeros [now confidence_map is the more clean] 
    confidence_map[colorImg==0] = 0
    confidence_map[albedoImg==0] = 0
    
    id_non_not = confidence_map.nonzero()
    idx_non = torch.unbind(id_non_not, 1) # this only works for two dimesion
    
    colorImg_non = colorImg[idx_non]
    albedoImg_non = albedoImg[idx_non]
    
    #get the shadingImg element-wise divide
    shadingImg_non = torch.div(colorImg_non, albedoImg_non)    
    shadingImg_non2 = shadingImg_non.view(-1,1)

    #:means 9 channels [get the shading image]
    SH0 = SH[0,:,:]; SH0_non = SH0[idx_non]
    SH1 = SH[1,:,:]; SH1_non = SH1[idx_non]
    SH2 = SH[2,:,:]; SH2_non = SH2[idx_non]
    SH3 = SH[3,:,:]; SH3_non = SH3[idx_non]
    SH4 = SH[4,:,:]; SH4_non = SH4[idx_non]
    SH5 = SH[5,:,:]; SH5_non = SH5[idx_non]
    SH6 = SH[6,:,:]; SH6_non = SH6[idx_non]
    SH7 = SH[7,:,:]; SH7_non = SH7[idx_non]
    SH8 = SH[8,:,:]; SH8_non = SH8[idx_non]
           
    SH_NON = torch.stack([SH0_non, SH1_non, SH2_non, SH3_non, SH4_non, SH5_non, SH6_non, SH7_non, SH8_non], dim=-1)
    
    ## only use the first N soultions if M>N  A(M*N) B(N*K) X should (N*K)[use N if M appears] 
    ## torch.gels(B, A, out=None)  Tensor
    ## https://pytorch.org/docs/stable/torch.html#torch.gels  
    light, _ = torch.gels(shadingImg_non2, SH_NON)  
    light_9 = light[0:9] # use first 9

    return (light_9, SH)


def RGBDalbedoToLight(colorImg, 
                      depthImg, 
                      albedoImg, 
                      device,
                      K = [400.0, 400.0, 224.0, 224.0],
                      thres = 30, 
                      img_size= [448,448]):
    
    normal, confidence_map = depthToNormal(depthImg, device, K, thres, img_size)
    SH = normalToSH(normal, device)
    lighting_est = RGBalbedoSHToLight(colorImg, albedoImg, SH, confidence_map)
    
    return lighting_est


#for Batch depth to Normal [stable]
def depthToNormalBatch(depthmap,
                  device,
                  K = [400.0, 400.0, 224.0, 224.0],
                  thres = 30,
                  img_size =[448, 448]):
    
    # default
    K_torch = torch.tensor([K[0], K[1], K[2], K[3]]).to(device)

    N, C, H, W = depthmap.size()
    assert( C==1 and H == img_size[0] and W == img_size[1])
    
    depthmap = torch.reshape(depthmap, [N,H,W]) #resize to (N H W) 
    
    X_grid, Y_grid = torch.meshgrid( [torch.arange(H, out=torch.FloatTensor().to(device)), torch.arange(W, out=torch.FloatTensor().to(device))] )
    X_grid = X_grid.repeat(N, 1, 1) # repeat to N H W 
    Y_grid = Y_grid.repeat(N, 1, 1) # repeat to N H W
    
    
    X = (X_grid - K_torch[2]) *depthmap / K_torch[0]
    Y = (Y_grid - K_torch[3]) *depthmap / K_torch[1]
    
    DepthPoints = torch.stack([X, Y, depthmap], dim=3) # all 3D point 
    
    delta_right = DepthPoints[:, 2:,1:-1,:] - DepthPoints[:, 1:-1,1:-1,:]
    delta_down  = DepthPoints[:, 1:-1,2:,:] - DepthPoints[:, 1:-1,1:-1,:]
    
    delta_left = DepthPoints[:, 0:-2,1:-1,:] - DepthPoints[:, 1:-1,1:-1,:]
    delta_up   = DepthPoints[:, 1:-1,0:-2,:] - DepthPoints[:, 1:-1,1:-1,:]
    
    
    normal_crop1 = torch.cross(delta_down, delta_right)
    normal_crop1 = F.normalize(normal_crop1, p=2, dim=3)
    
    normal_crop2 = torch.cross(delta_up,   delta_left )
    normal_crop2 = F.normalize(normal_crop2, p=2, dim=3)
    
    normal_crop  = normal_crop1 #+ normal_crop2
    normal_crop  = F.normalize(normal_crop, p=2, dim=3)
    
    
#     normal_crop = torch.cross(delta_down, delta_right)
    normal = torch.zeros(N,H,W,3).to(device)
    normal[:, 1:-1, 1:-1, :] = normal_crop
    
    confidence_map_crop = torch.ones(N, H-2, W-2).to(device)
    
    delta_right_norm = torch.norm(delta_right, p=2, dim=3)
    delta_down_norm = torch.norm(delta_down, p=2, dim=3)
    confidence_map_crop[ delta_right_norm > thres ] =0.0
    confidence_map_crop[ delta_down_norm > thres ] =0.0
    
    delta_left_norm = torch.norm(delta_left, p=2, dim=3)
    delta_up_norm = torch.norm(delta_up, p=2, dim=3)
    confidence_map_crop[ delta_left_norm > thres ] =0.0
    confidence_map_crop[ delta_up_norm > thres ] =0.0
    
    confidence_map = torch.zeros(N,H,W)
    confidence_map[:, 1:-1, 1:-1] = confidence_map_crop
    confidence_map[depthmap == 0] = 0 # [N, H, W]
    
   
    # change to CHW    
    normal = normal.permute(0, 3, 1, 2)  # return the  normal [N,C,H,W] and confidence map
   
    return normal, confidence_map
    #return normal

    
def normalToSHBatch(normal, 
               device):
    
    N, CC, H ,W= normal.size()
    #here is the SH (SH Basis order=2)
    c =torch.zeros(9).to(device)
    
    c[0] =  1/(2* math.sqrt(math.pi) );
    c[1] = - math.sqrt(3) / ( 2* math.sqrt(math.pi) )
    c[2] = - math.sqrt(3) / ( 2* math.sqrt(math.pi) )
    c[3] =  math.sqrt(3) / ( 2* math.sqrt(math.pi) )
    c[4] =  math.sqrt(15) / ( 2* math.sqrt(math.pi) )
    c[5] = - math.sqrt(15) / ( 2* math.sqrt(math.pi) ) 
    c[6] = - math.sqrt(15) / ( 2* math.sqrt(math.pi) ) 
    c[7] =  math.sqrt(15) / ( 4* math.sqrt(math.pi) )
    c[8] =  math.sqrt(5) / ( 4* math.sqrt(math.pi) )    
    
    
    spherical_harmonics = torch.zeros(N,H,W,9).to(device)
    
    spherical_harmonics[:,:,:,0] = 1 * c[0];  
    spherical_harmonics[:,:,:,1] = normal[:,0,:,:] * c[1];
    spherical_harmonics[:,:,:,2] = normal[:,1,:,:] * c[2];
    spherical_harmonics[:,:,:,3] = normal[:,2,:,:] * c[3];
    spherical_harmonics[:,:,:,4] = normal[:,0,:,:] * normal[:,1,:,:]  * c[4];
    spherical_harmonics[:,:,:,5] = normal[:,0,:,:] * normal[:,2,:,:]  * c[5];
    spherical_harmonics[:,:,:,6] = normal[:,1,:,:] * normal[:,2,:,:]  * c[6];
    spherical_harmonics[:,:,:,7] = (normal[:,0,:,:] * normal[:,0,:,:] -  normal[:,1,:,:] * normal[:,1,:,:]) * c[7];
    spherical_harmonics[:,:,:,8] = (3 * normal[:,2,:,:] * normal[:,2,:,:] - 1) *c[8];
    
    return spherical_harmonics