"""
Author: Gurkirt Singh
 https://github.com/Gurkirt

 Copyright (c) 2017, Gurkirt Singh

    This code and is available
    under the terms of MIT License provided in LICENSE.
    Please retain this notice and LICENSE if you use
    this file (or any portion of it) in your project.
    ---------------------------------------------------------

purpose: of this file is to define Kinetics dataset class so it can be used with
torch.util.dataloader class

"""

import os, torch, pdb
import numpy as np
import json
from PIL import Image
from PIL import ImageFile
import torch.utils.data as data
import random, cv2
import collections
from numpy import random as nprandom
ImageFile.LOAD_TRUNCATED_IMAGES = True

def pilresize(img, size, interpolation=Image.BILINEAR):

    """Resize the input PIL Image to the given size.
    Args:
        img (PIL Image): Image to be resized.
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    Returns:
        PIL Image: Resized image.
    """

    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * float(h) / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * float(w) / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size[::-1], interpolation)


def pil_random_crop(img, scale_size, output_size, params=None):
    img = pilresize(img, scale_size)
    th = output_size
    tw = output_size
    if params is None:
        w, h = img.size
        if w == tw and h == th:
            return img
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        flip = random.random()<0.5
    else:
        i,j,flip = params
    img = img.crop((j, i, j + tw, i + th))
    if flip:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)

    return img, [i, j, flip]

def cv_random_crop(img, scale_size, output_size, params=None):

    if params is None:
        height, width, _ = img.shape
        w = nprandom.uniform(0.6 * width, width)
        h = nprandom.uniform(0.6 * height, height)
        left = nprandom.uniform(width - w)
        top = nprandom.uniform(height - h)
        # convert to integer rect x1,y1,x2,y2
        rect = np.array([int(left), int(top), int(left + w), int(top + h)])
        flip = random.random()<0.5
    else:
        rect,flip = params

    img = img[rect[1]:rect[3], rect[0]:rect[2], :]

    return img, [rect, flip]

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def default_loader(path):
    return pil_loader(path)

def cv_loader(path):
    #print('Going to use cv loader')
    return  cv2.imread(path)

def make_lists(annot_file, subsets, frame_step, seq_len=1, gap=1):
    with open(annot_file, 'r') as f:
        annoData = json.load(f)
    database = annoData["database"]
    classes = annoData["classes"]
    video_list = []
    video_labels = []
    vcount = -1
    image_list = []
    totalcount = 0
    cc = 0
    # pdb.set_trace()
    for vid,videoname in enumerate(sorted(database.keys())):
        video_info = database[videoname]
        isthere = video_info['isthere']
        if isthere and video_info['subset'] in subsets:
            video_list.append(videoname)
            label = 0
            vcount += 1
            numf = video_info['numf']
            if numf<64:
                cc += 1
                # print(videoname,vid,cc,numf)
            if numf > seq_len * 2:
                if 'test' not in subsets:
                    label = video_info['cls']
                maxf = numf-(seq_len//2)*gap-1
                indexs = np.arange((seq_len//2)*gap, maxf, frame_step)
                if indexs.shape[0] > 0:
                    for fid in indexs:
                        totalcount += 1
                        image_list.append([vcount, int(fid + 1), label])
            video_labels.append(label)
    print('{} Images loaded from {} videso'.format(totalcount, vcount))
    print('cc = ',cc)
    return image_list, video_list, classes, video_labels


class KINETICS(data.Dataset):
    """Kinetics
    input is image, target is annotation
    Arguments:
        root (string): path base dirctory of kinectics dataset
        input_type (string): input tuep for example rgb, farneback, brox etc

    """
    def __init__(self, root, input_type, transform=None, target_transform=None,
                 dataset_name='kinectics', datasubset='200', subsets=['train',], exp_name='',
                 netname='inceptionv3', scale_size=321, input_size=299,
                 frame_step=6, seq_len=1, gap=1):

        assert seq_len%2==1, 'seq len can only be a odd integer'
        self.root = root
        self.mode = 'train' in subsets
        self.scale_size = scale_size
        self.input_size = input_size
        self.exp_name = exp_name
        self.seq_len = seq_len
        self.gap = gap

        self.input_type = input_type
        self.subsets = subsets
        self.transform = transform
        self.target_transform = target_transform
        self.name = dataset_name
        # pdb.set_trace()
        self.loader = pil_loader
        self.random_crop = pil_random_crop
        # pdb.set_trace()
        if netname.find('vgg')>-1:
            self.loader = cv_loader
            self.random_crop = cv_random_crop

        self.annot_file = self.root + "hfiles/Annots.json"

        assert len(datasubset) > 1

        self.datasubset = datasubset
        self.annot_file = self.root + "hfiles/Annots_{}.json".format(datasubset)
        self.gtval_file = self.root + "hfiles/Annots_{}.json".format(datasubset)

        print('Annot File: ', self.annot_file, ' Mode is set to ', self.mode)

        # self.img_path = os.path.join('/mnt/mars-fast/datasets/kinetics/', input_type+'-images', '%s.jpg')
        self.img_path = os.path.join(root, input_type + '-images', '%s.jpg')

        image_list, video_list, classes, video_labels = make_lists(self.annot_file, subsets, frame_step, seq_len=self.seq_len,gap=self.gap)

        #self.video_labels = video_labels
        self.classes = classes.keys()
        self.num_classes = len(self.classes)
        self.video_list = video_list
        self.image_list = image_list
        print('Inistliased Kinetics date for ', subsets,' set with ', len(image_list),' images')

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """

        imids = self.image_list[index]
        vid_num = int(imids[0])
        videoname = self.video_list[vid_num]
        frame_num = int(imids[1])
        target = np.int64(imids[2])
        half_len = self.seq_len//2
        gap = self.gap
        frame_nums = np.arange(frame_num-half_len*gap,frame_num+half_len*gap+1,gap)
        #print(frame_nums)
        assert len(frame_nums) == self.seq_len, ' frame indexs length should be the same as frame_nums'
        # if self.mode != 'test':
        imgs = []
        for fn in frame_nums:
            path = self.img_path % '{:s}/{:05d}'.format(videoname, fn)
            imgs.append(self.loader(path))
        # pdb.set_trace()
        #input_imgs = torch.FloatTensor(self.seq_len*3,input_size,input_size)
        params = None
        if self.transform is not None:
            for ind in range(self.seq_len):
                if self.mode:
                    imgs[ind], params = self.random_crop(imgs[ind], self.scale_size, self.input_size, params=params)
                imgs[ind] = self.transform(imgs[ind])
                imgs[ind] = imgs[ind].squeeze()
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.seq_len == 1:
            input_imgs = imgs[0]
        else:
            input_imgs = torch.cat(imgs, 0)

        return input_imgs, target, vid_num, frame_num

    def __len__(self):
        return len(self.image_list)