#
# KTH Royal Institute of Technology
#

import torch
from torchvision.transforms import CenterCrop
from os.path import join
from src.model import Net
from src.interpolate import interpolate
from src.extract_frames import extract_frames
from src.data_manager import load_img
from src.dataset import pil_to_tensor, get_validation_set
from src.utilities import psnr
from src.loss import ssim
import src.config as config


def test_metrics(model, video_path=None, frames=None, output_folder=None):

    if video_path is not None and frames is None:
        frames, _ = extract_frames(video_path)

    total_ssim = 0
    total_psnr = 0
    stride = 30
    iters = 1 + (len(frames) - 3) // stride

    triplets = []
    for i in range(iters):
        tup = (frames[i*stride], frames[i*stride + 1], frames[i*stride + 2])
        triplets.append(tup)

    iters = len(triplets)

    for i in range(iters):
        x1, gt, x2 = triplets[i]
        pred = interpolate(model, x1, x2)
        if output_folder is not None:
            frame_path = join(output_folder, f'wiz_{i}.jpg')
            pred.save(frame_path)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1}/{iters} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')


def test_wiz(model, output_folder=None):
    video_path = '/project/videos/see_you_again_540.mp4'
    test_metrics(model, video_path=video_path, output_folder=output_folder)


def test_on_validation_set(model, validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for i, tup in enumerate(validation_set.tuples):
        x1, gt, x2, = [crop(load_img(p)) for p in tup]
        pred = interpolate(model, x1, x2)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')


def test_linear_interp(validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for tup in validation_set.tuples:
        x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup]
        pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')


def test_all():

    print('===> Loading pure L1...')
    # pure_l1 = Net.from_file('./trained_models/last_pure_l1.pth')

    print('===> Testing latest pure L1...')
    # test_on_validation_set(pure_l1)
    print('avg_ssim: 0.8197908288240433, avg_psnr: 29.126618137359618')

    print('===> Testing linear interp...')
    # test_linear_interp()
    print('avg_ssim: 0.6868560968339443, avg_psnr: 26.697076902389526')

    print('===> Loading best models...')
    # best_model_qualitative = Net.from_file('./trained_models/best_model_qualitative.pth')
    # best_model_quantitative = Net.from_file('./trained_models/best_model_quantitative.pth')

    print('===> Testing Wiz (qualitative)...')
    # test_wiz(best_model_qualitative, output_folder='/project/exp/wiz_qual/')
    print('avg_ssim: 0.9658980375842044, avg_psnr: 37.27564642554835')

    print('===> Testing Wiz (quantitative)...')
    # test_wiz(best_model_quantitative, output_folder='/project/exp/wiz_quant/')
    print('avg_ssim: 0.9638479389642415, avg_psnr: 36.52394056822124')

if __name__ == '__main__':
    test_all()