import os
import sys
import h5py
import torch
import torch.nn as nn
import argparse
import numpy as np
from tqdm import tqdm
from plyfile import PlyData, PlyElement
import math
from imageio import imread
from PIL import Image
import torchvision.transforms as transforms

sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder
from lib.config import CONF
from lib.projection import ProjectionHelper

SCANNET_LIST = CONF.SCANNETV2_LIST
SCANNET_DATA = CONF.PREP_SCANS
SCANNET_FRAME_ROOT = CONF.SCANNET_FRAMES
SCANNET_FRAME_PATH = os.path.join(SCANNET_FRAME_ROOT, "{}") # name of the file

ENET_FEATURE_PATH = CONF.ENET_FEATURES_PATH
ENET_FEATURE_DATABASE = CONF.MULTIVIEW

# projection
INTRINSICS = [[37.01983, 0, 20, 0],[0, 38.52470, 15.5, 0],[0, 0, 1, 0],[0, 0, 0, 1]]
PROJECTOR = ProjectionHelper(INTRINSICS, 0.1, 4.0, [41, 32], 0.05)

def get_scene_list():
    with open(SCANNET_LIST, 'r') as f:
        return sorted(list(set(f.read().splitlines())))

def to_tensor(arr):
    return torch.Tensor(arr).cuda()

def resize_crop_image(image, new_image_dims):
    image_dims = [image.shape[1], image.shape[0]]
    if image_dims == new_image_dims:
        return image
    resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1])))
    image = transforms.Resize([new_image_dims[1], resize_width], interpolation=Image.NEAREST)(Image.fromarray(image))
    image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image)
    image = np.array(image)
    
    return image

def load_image(file, image_dims):
    image = imread(file)
    # preprocess
    image = resize_crop_image(image, image_dims)
    if len(image.shape) == 3: # color image
        image =  np.transpose(image, [2, 0, 1])  # move feature to front
        image = transforms.Normalize(mean=[0.496342, 0.466664, 0.440796], std=[0.277856, 0.28623, 0.291129])(torch.Tensor(image.astype(np.float32) / 255.0))
    elif len(image.shape) == 2: # label image
#         image = np.expand_dims(image, 0)
        pass
    else:
        raise
        
    return image

def load_pose(filename):
    lines = open(filename).read().splitlines()
    assert len(lines) == 4
    lines = [[x[0],x[1],x[2],x[3]] for x in (x.split(" ") for x in lines)]

    return np.asarray(lines).astype(np.float32)

def load_depth(file, image_dims):
    depth_image = imread(file)
    # preprocess
    depth_image = resize_crop_image(depth_image, image_dims)
    depth_image = depth_image.astype(np.float32) / 1000.0

    return depth_image

def get_scene_data(scene_list):
    scene_data = {}
    for scene_id in scene_list:
        scene_data[scene_id] = np.load(os.path.join(SCANNET_DATA, scene_id)+".npy")[:, :3]
    
    return scene_data

def compute_projection(points, depth, camera_to_world):
    """
        :param points: tensor containing all points of the point cloud (num_points, 3)
        :param depth: depth map (size: proj_image)
        :param camera_to_world: camera pose (4, 4)
        
        :return indices_3d (array with point indices that correspond to a pixel),
        :return indices_2d (array with pixel indices that correspond to a point)

        note:
            the first digit of indices represents the number of relevant points
            the rest digits are for the projection mapping
    """
    num_points = points.shape[0]
    num_frames = depth.shape[0]
    indices_3ds = torch.zeros(num_frames, num_points + 1).long().cuda()
    indices_2ds = torch.zeros(num_frames, num_points + 1).long().cuda()

    for i in range(num_frames):
        indices = PROJECTOR.compute_projection(to_tensor(points), to_tensor(depth[i]), to_tensor(camera_to_world[i]))
        if indices:
            indices_3ds[i] = indices[0].long()
            indices_2ds[i] = indices[1].long()
        
    return indices_3ds, indices_2ds

if __name__ == "__main__":
    scene_list = get_scene_list()
    scene_data = get_scene_data(scene_list)
    with h5py.File(ENET_FEATURE_DATABASE, "w", libver="latest") as database:
        print("projecting multiview features to point cloud...")
        for scene_id in tqdm(scene_list):
            scene = scene_data[scene_id]
            # load frames
            frame_list = list(map(lambda x: x.split(".")[0], os.listdir(SCANNET_FRAME_ROOT.format(scene_id, "color"))))
            scene_images = np.zeros((len(frame_list), 3, 256, 328))
            scene_depths = np.zeros((len(frame_list), 32, 41))
            scene_poses = np.zeros((len(frame_list), 4, 4))
            for i, frame_id in enumerate(frame_list):
                scene_images[i] = load_image(SCANNET_FRAME_PATH.format(scene_id, "color", "{}.jpg".format(frame_id)), [328, 256])
                scene_depths[i] = load_depth(SCANNET_FRAME_PATH.format(scene_id, "depth", "{}.png".format(frame_id)), [41, 32])
                scene_poses[i] = load_pose(SCANNET_FRAME_PATH.format(scene_id, "pose", "{}.txt".format(frame_id)))

            # compute projections for each chunk
            projection_3d, projection_2d = compute_projection(scene, scene_depths, scene_poses)
            _, inds = torch.sort(projection_3d[:, 0], descending=True)
            projection_3d, projection_2d = projection_3d[inds], projection_2d[inds]
            
            # compute valid projections
            projections = []
            for i in range(projection_3d.shape[0]):
                num_valid = projection_3d[i, 0]
                if num_valid == 0:
                    continue

                projections.append((frame_list[inds[i].long().item()], projection_3d[i], projection_2d[i]))

            # project
            point_features = to_tensor(scene).new(scene.shape[0], 128).fill_(0)
            for i, projection in enumerate(projections):
                frame_id = projection[0]
                projection_3d = projection[1]
                projection_2d = projection[2]
                feat = to_tensor(np.load(ENET_FEATURE_PATH.format(scene_id, frame_id)))
                proj_feat = PROJECTOR.project(feat, projection_3d, projection_2d, scene.shape[0]).transpose(1, 0)
                if i == 0:
                    point_features = proj_feat
                else:
                    mask = ((point_features == 0).sum(1) == 128).nonzero().squeeze(1)
                    point_features[mask] = proj_feat[mask]

            # save
            database.create_dataset(scene_id, data=point_features.cpu().numpy())

    print("done!")