import torch.utils.data as data
import torch
import numpy as np
import os
from os import listdir
from os.path import join
from PIL import Image, ImageOps
import random
import pyflow
from skimage import img_as_float
from skimage import color
from random import randrange
import os.path
import cv2

max_flow = 150.0 

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath, scale):
    list=os.listdir(filepath)
    list.sort()
    
    rate = 1
    #for vimeo90k-setuplet (multiple temporal scale)
    #if random.random() < 0.5:
    #    rate = 2
    
    index = randrange(0, len(list)-(2*rate))
    
    target = [modcrop(Image.open(filepath+'/'+list[i]).convert('RGB'), scale) for i in range(index, index+3*rate, rate)]
    
    h,w = target[0].size
    h_in,w_in = int(h//scale), int(w//scale)
    
    target_l = target[1].resize((h_in,w_in), Image.BICUBIC)
    input = [target[j].resize((h_in,w_in), Image.BICUBIC) for j in [0,2]]
    
    return input, target, target_l, list

def load_img_test(filepath, scale):
    list=os.listdir(filepath)
    list.sort()
    
    target = [modcrop(Image.open(filepath+'/'+list[i]).convert('RGB'), scale) for i in range(len(list))]
    h,w = target[0].size
    h_in,w_in = int(h//scale), int(w//scale)
    
    input = [target[j].resize((h_in,w_in), Image.BICUBIC) for j in [0,len(list)-1]]
    
    return input, list

def load_img_nodown(filepath):
    list=os.listdir(filepath)
    list.sort()
    
    input = [Image.open(filepath+'/'+list[i]).convert('RGB') for i in [0,len(list)-1]]
    
    return input, list
    
def get_flow(im1, im2):
    im1 = np.array(im1)
    im2 = np.array(im2)
    im1 = im1.astype(float) / 255.
    im2 = im2.astype(float) / 255.
    
    # Flow Options:
    alpha = 0.012
    ratio = 0.75 #0.95 #0.75
    minWidth = 20 #50 #20
    nOuterFPIterations = 7
    nInnerFPIterations = 1
    nSORIterations = 30
    colType = 0  # 0 or default:RGB, 1:GRAY (but pass gray image with shape (h,w,1))
    
    u, v, im2W = pyflow.coarse2fine_flow(im1, im2, alpha, ratio, minWidth, nOuterFPIterations, nInnerFPIterations,nSORIterations, colType)
    flow = np.concatenate((u[..., None], v[..., None]), axis=2)
    
    #Rescale
    flow = rescale_flow(flow,-1,1)
    return flow

def rescale_flow(x,max_range,min_range):
    #remove noise
    x[x > max_flow] = max_flow
    x[x < -max_flow] = -max_flow
    
    max_val = max_flow 
    min_val = -max_flow 
    return (max_range-min_range)/(max_val-min_val)*(x-max_val)+max_range

def modcrop(img, modulo):
    (ih, iw) = img.size
    ih = ih - (ih%modulo);
    iw = iw - (iw%modulo);
    img = img.crop((0, 0, ih, iw))
    return img

def get_patch(img_in, img_tar, img_tar_l, patch_size, scale, ix=-1, iy=-1):
    (ih, iw) = img_in[0].size
    (th, tw) = (scale * ih, scale * iw)

    patch_mult = scale #if len(scale) > 1 else 1
    tp = patch_mult * patch_size
    ip = tp // scale

    if ix == -1:
        ix = random.randrange(0, iw - ip + 1)
    if iy == -1:
        iy = random.randrange(0, ih - ip + 1)

    (tx, ty) = (scale * ix, scale * iy)

    img_in = [j.crop((iy,ix,iy + ip, ix + ip)) for j in img_in] 
    img_tar = [j.crop((ty,tx,ty + tp, tx + tp)) for j in img_tar] 
    img_tar_l = img_tar_l.crop((iy,ix,iy + ip, ix + ip)) 
                
    info_patch = {
        'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp}

    return img_in, img_tar, img_tar_l, info_patch

def augment(img_in, img_tar, img_tar_l, flip_h=True, rot=True):
    info_aug = {'flip_h': False, 'flip_v': False, 'trans': False}
    
    if random.random() < 0.5 and flip_h:
        img_in = [ImageOps.flip(j) for j in img_in]
        img_tar = [ImageOps.flip(j) for j in img_tar]
        img_tar_l = ImageOps.flip(img_tar_l)
        info_aug['flip_h'] = True

    if rot:
        if random.random() < 0.5:
            img_in = [ImageOps.mirror(j) for j in img_in]
            img_tar = [ImageOps.mirror(j) for j in img_tar]
            img_tar_l = ImageOps.mirror(img_tar_l)
            info_aug['flip_v'] = True
        if random.random() < 0.5:
            img_in = [j.rotate(180) for j in img_in]
            img_tar = [j.rotate(180) for j in img_tar]
            img_tar_l = img_tar_l.rotate(180)
            info_aug['trans'] = True

    return img_in, img_tar, img_tar_l, info_aug
    
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, upscale_factor, data_augmentation, file_list, patch_size, transform=None):
        super(DatasetFromFolder, self).__init__()
        alist = [line.rstrip() for line in open(join(image_dir,file_list))]
        self.image_filenames = [join(image_dir,x) for x in alist]
        self.upscale_factor = upscale_factor
        self.transform = transform
        self.data_augmentation = data_augmentation
        self.patch_size = patch_size

    def __getitem__(self, index):
        input, target, target_l, file_list = load_img(self.image_filenames[index], self.upscale_factor)

        if self.patch_size != 0:
            input, target, target_l, _ = get_patch(input,target,target_l,self.patch_size, self.upscale_factor)
        
        if self.data_augmentation:
            input, target, target_l, _ = augment(input, target, target_l)
            
        flow_f = get_flow(input[0],input[1])
        flow_b = get_flow(input[1],input[0])
                    
        if self.transform:
            input = [self.transform(j) for j in input]
            target = [self.transform(j) for j in target]
            target_l = self.transform(target_l)
            flow_f = torch.from_numpy(flow_f.transpose(2,0,1)) 
            flow_b = torch.from_numpy(flow_b.transpose(2,0,1)) 

        return input, target, target_l, flow_f, flow_b, file_list, self.image_filenames[index]

    def __len__(self):
        return len(self.image_filenames)

class DatasetFromFolderFlow(data.Dataset):
    def __init__(self, image_dir, upscale_factor, data_augmentation, file_list, patch_size, transform=None):
        super(DatasetFromFolderFlow, self).__init__()
        alist = [line.rstrip() for line in open(join(image_dir,file_list))]
        self.image_filenames = [join(image_dir,x) for x in alist]
        self.upscale_factor = upscale_factor
        self.transform = transform
        self.data_augmentation = data_augmentation
        self.patch_size = patch_size

    def __getitem__(self, index):
        input, target, target_l, file_list = load_img(self.image_filenames[index], self.upscale_factor)

        if self.patch_size != 0:
            input, target, target_l, _ = get_patch(input,target,target_l,self.patch_size, self.upscale_factor)
        
        if self.data_augmentation:
            input, target, target_l, _ = augment(input, target, target_l)
            
        flow_f = get_flow(input[0],input[1])
        flow_b = get_flow(input[1],input[0])
        
        gt_flow_f = get_flow(input[0],target_l) + get_flow(target_l,input[1])
        gt_flow_b = get_flow(input[1],target_l) + get_flow(target_l,input[0])
                    
        if self.transform:
            input = [self.transform(j) for j in input]
            target = [self.transform(j) for j in target]
            target_l = self.transform(target_l)
            flow_f = torch.from_numpy(flow_f.transpose(2,0,1)) 
            flow_b = torch.from_numpy(flow_b.transpose(2,0,1)) 
            gt_flow_f = torch.from_numpy(gt_flow_f.transpose(2,0,1)) 
            gt_flow_b = torch.from_numpy(gt_flow_b.transpose(2,0,1)) 

        return input, target, target_l, flow_f, flow_b, gt_flow_f, gt_flow_b,file_list, self.image_filenames[index]

    def __len__(self):
        return len(self.image_filenames)
        
class DatasetFromFolderFlowLR(data.Dataset):
    def __init__(self, image_dir, upscale_factor, data_augmentation, file_list, patch_size, transform=None):
        super(DatasetFromFolderFlowLR, self).__init__()
        alist = [line.rstrip() for line in open(join(image_dir,file_list))]
        self.image_filenames = [join(image_dir,x) for x in alist]
        self.upscale_factor = upscale_factor
        self.transform = transform
        self.data_augmentation = data_augmentation
        self.patch_size = patch_size

    def __getitem__(self, index):
        input, target, target_l, file_list = load_img(self.image_filenames[index], self.upscale_factor)

        if self.patch_size != 0:
            input, target, target_l, _ = get_patch(input,target,target_l,self.patch_size, self.upscale_factor)
        
        if self.data_augmentation:
            input, target, target_l, _ = augment(input, target, target_l)
            
        flow_f = get_flow(target[0],target[2])
        flow_b = get_flow(target[2],target[0])
        
        gt_flow_f = get_flow(target[0],target[1]) + get_flow(target[1],target[2])
        gt_flow_b = get_flow(target[2],target[1]) + get_flow(target[1],target[0])
                    
        if self.transform:
            target = [self.transform(j) for j in target]
            flow_f = torch.from_numpy(flow_f.transpose(2,0,1)) 
            flow_b = torch.from_numpy(flow_b.transpose(2,0,1)) 
            gt_flow_f = torch.from_numpy(gt_flow_f.transpose(2,0,1)) 
            gt_flow_b = torch.from_numpy(gt_flow_b.transpose(2,0,1)) 
            

        return target, flow_f, flow_b, gt_flow_f, gt_flow_b, file_list, self.image_filenames[index]

    def __len__(self):
        return len(self.image_filenames)
    
class DatasetFromFolderTest(data.Dataset):
    def __init__(self, image_dir, upscale_factor, file_list, transform=None):
        super(DatasetFromFolderTest, self).__init__()
        alist = [line.rstrip() for line in open(join(image_dir,file_list))]
        self.image_filenames = [join(image_dir,x) for x in alist]
        self.upscale_factor = upscale_factor
        self.transform = transform

    def __getitem__(self, index):
        input, file_list = load_img_test(self.image_filenames[index], self.upscale_factor)
            
        flow_f = get_flow(input[0],input[1])
        flow_b = get_flow(input[1],input[0])
        
        if self.transform:
            input = [self.transform(j) for j in input]
            flow_f = torch.from_numpy(flow_f.transpose(2,0,1)) 
            flow_b = torch.from_numpy(flow_b.transpose(2,0,1)) 
            
        return input, flow_f, flow_b, file_list, self.image_filenames[index]
      
    def __len__(self):
        return len(self.image_filenames)

class DatasetFromFolderInterp(data.Dataset):
    def __init__(self, image_dir, file_list, transform=None):
        super(DatasetFromFolderInterp, self).__init__()
        alist = [line.rstrip() for line in open(join(image_dir,file_list))]
        self.image_filenames = [join(image_dir,x) for x in alist]
        self.transform = transform

    def __getitem__(self, index):
        input, file_list = load_img_nodown(self.image_filenames[index])
            
        flow_f = get_flow(input[0],input[1])
        flow_b = get_flow(input[1],input[0])
        
        if self.transform:
            input = [self.transform(j) for j in input]
            flow_f = torch.from_numpy(flow_f.transpose(2,0,1)) 
            flow_b = torch.from_numpy(flow_b.transpose(2,0,1)) 
            
        return input, flow_f, flow_b, file_list, self.image_filenames[index]
      
    def __len__(self):
        return len(self.image_filenames)