import argparse, os, sys, subprocess
from tqdm import tqdm
from glob import glob
from os.path import *
import importlib

import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pickle as pkl

from models import VGG_graph_matching
from dataloader import MpiSintelClean, MpiSintelFinal, ImagesFromFolder
from logger import Logger
from utils import flow_utils, tools


def init_config():
    """
    INITALIZE  EVERYTHING (Argument Parser, Logging, GPU)
    """
    parser = argparse.ArgumentParser(description='Deep Learning of Graph matching')

    # dataset
    parser.add_argument('--dataset', type=str, default='sintel', help='dataset to use: middlebury/sintel')
    parser.add_argument('--data_path', type=str, default='data/', help='Path to dataset root directory')

    # select mode
    parser.add_argument('--eval', action='store_true', default=False, help='Perform Inference')
    parser.add_argument('--load_path', type=str, default='')
    parser.add_argument('--save_flow', action='store_true', default=False, help='Save flow files during evaluation')

    # others
    parser.add_argument('--seed', type=int, default=7, metavar='S', help='random seed')
    parser.add_argument('--number_workers', '-nw', '--num_workers', type=int, default=8)
    parser.add_argument('--number_gpus', '-ng', type=int, default=2, help='number of GPUs to use')
    parser.add_argument('--cuda', action='store_true', default=True, help='Use GPU')
    parser.add_argument('--use_vgg', action='store_true', default=True, help='Use VGG weights')


    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()

    save_dir = "models/%s" % args.dataset
    log_dir = "logs/%s" % args.dataset

    config_file = "config_%s" % args.dataset
    params = importlib.import_module(config_file).params

    args = argparse.Namespace(**vars(args), **params)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)


    id_ = "%s_seed-%d" % \
            (args.dataset, args.seed)

    save_path = os.path.join(save_dir, id_ + '.pt')

    args.save_path = save_path

    args.log_path = os.path.join(log_dir, id_ + ".log")
    print("log path", args.log_path)



    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True


    args.effective_number_workers = args.number_workers * args.number_gpus
    args.device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    return args

def _apply_loss(d, d_gt):
    """
    LOSS CALCULATION OF THE BATCH

    Arguments:
    ----------
        - d: Computed displacements
        - d_gt: Ground truth displacements

    Returns:
    --------
        - loss: calculate loss according to the specified loss function
    
    """

    # Set all pixel entries to 0 whose displacement magnitude is bigger than 10px
    pixel_thresh = 10
    dispMagnitude = torch.sqrt(torch.pow(d_gt[:,:,0],2) + torch.pow(d_gt[:,:,1], 2)).unsqueeze(-1).expand(-1,-1,2)
    idx = dispMagnitude > pixel_thresh
    z = torch.zeros(dispMagnitude.shape)
    d = torch.where(idx, z, d)
    d_gt = torch.where(idx, z, d_gt)

    # Calculate loss according to formula in paper
    return torch.sum(torch.sqrt(torch.diagonal(torch.bmm(d - d_gt, (d-d_gt).permute(0,2,1)), dim1=-2, dim2=-1)), dim = 1)


def get_mask(height, width, grid_size = 10):
    """
    Get the location based on the image size corresponding to relu_4_2
    and relu_5_1 layer for a desired grid size.
    """
    print(height, width)
    x_jump = int(width/grid_size)
    y_jump = int(height/grid_size)
    x_idx = np.linspace(int(x_jump/2),int(width - x_jump/2), grid_size, dtype = np.int32)
    y_idx = np.linspace(int(y_jump/2), int(height - y_jump/2), grid_size, dtype = np.int32)
    f_mask = torch.zeros((height//(2**4),width//2**4)).byte()
    u_mask = torch.zeros((height//(2**3),width//2**3)).byte()
    for i in x_idx:
        for j in y_idx:
            f_mask[j//(2**4),i//(2**4)] = 1
            u_mask[j//(2**3),i//(2**3)] = 1
    return(u_mask, f_mask)


def test(args, epoch, model, data_loader):
    """
    TESTING PROCEDURE

    Parameters:
    -----------
        - args: various arguments
        - epoch: number of epochs 
        - model: specified model to test
        - data_loader: specified test data_loader

    Returns:
    --------
        - average_loss: average loss per batch
        - pck: Percentage of Correct Keypoints metric

    """
    
    statistics = []
    total_loss = 0

    model.eval()
    title = 'Validating Epoch {}'.format(epoch)
    progress = tqdm(tools.IteratorTimer(data_loader), ncols=120, total=len(data_loader), smoothing=.9, miniters=1, leave=True, desc=title)
    predictions = []
    gt = []

    sys.stdout.flush()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(progress):

            d = model(data[0].to(args.device), im_2 = data[1].to(args.device))
            loss = _apply_loss(d, target).mean()
            total_loss += loss.item()
            predictions.extend(d.numpy())
            gt.extend(target.numpy())

            # Print out statistics
            statistics.append(loss.item())
            title = '{} Epoch {}'.format('Validating', epoch)

            progress.set_description(title + '\tLoss:\t'+ str(statistics[-1]))
            sys.stdout.flush()


    progress.close()
    pck = tools.calc_pck(np.asarray(predictions), np.asarray(gt))
    print('PCK for epoch %d is %f'%(epoch, pck))

    return total_loss / float(batch_idx + 1), pck


def train(args, epoch, model, data_loader, optimizer):
    """
    TRAINING PROCEDURE

     Parameters:
    -----------
        - args: various arguments
        - epoch: number of epochs 
        - model: specified model to test
        - data_loader: specified train data_loader
        - optimizer: specified optimizer to use

    Returns:
    --------
        - average_loss: average loss per batch

    """
    
    statistics = []
    total_loss = 0

    model.train()
    title = 'Training Epoch {}'.format(epoch)
    progress = tqdm(tools.IteratorTimer(data_loader), ncols=120, total=len(data_loader), smoothing=.9, miniters=1, leave=True, desc=title)

    sys.stdout.flush()

    for batch_idx, (data, target) in enumerate(progress):

        #data, target = data.to(args.device), target.to(args.device)

        optimizer.zero_grad()
        d = model(data[0].to(args.device), im_2 = data[1].to(args.device))
        loss = _apply_loss(d, target).mean()

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        assert not np.isnan(total_loss)

        # Print out statistics
        statistics.append(loss.item())
        title = '{} Epoch {}'.format('Training', epoch)

        progress.set_description(title + '\tLoss:\t'+ str(statistics[-1]))
        sys.stdout.flush()


    progress.close()

    return total_loss / float(batch_idx + 1)

# ====================================================================================================================================
# MAIN PROCEDURE
# =========================

if __name__ == '__main__':
    args = init_config()
    if not args.eval:
        sys.stdout = Logger(args.log_path)


    gpuargs = {'num_workers': args.effective_number_workers,
               'pin_memory': True,
               'drop_last' : True} if args.cuda else {}

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transforms = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])



    # Define path to selected dataset
    if args.dataset.lower() == 'sintel':    

        # Check if specified path exists
        if not os.path.exists(os.path.join(args.data_path,'sintel/training')):
            raise Exception("Could not find specified dataset => Use argument --data_path to specify path")

        # Setup data loader    
        train_dataset = MpiSintelFinal(os.path.join(args.data_path, 'sintel/training'), transforms = transforms)
        val_dataset = MpiSintelFinal(os.path.join(args.data_path, 'sintel/training'),  train = False, sequence_list = train_dataset.sequence_list, transforms = transforms)

    else:
        raise Exception('Dataset not supported yet.')
        sys.stdout.flush()


    # Create Traiing and Validation data loaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size*torch.cuda.device_count(), shuffle=True, **gpuargs)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size_test*torch.cuda.device_count(), shuffle=False, **gpuargs)

    # Load VGG graph matching model
    model = VGG_graph_matching()

    # Load vgg parameters if specified
    if args.use_vgg:
        model.copy_params_from_vgg16()

    # Setup GPU use
    if torch.cuda.device_count() > 1 and args.number_gpus > 1:
        model = nn.DataParallel(model)
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = model.to(args.device)

    optimizer = optim.Adam(model.parameters(), lr = 1e-5)
    best_pck = 0.

    # Perform specified trainings epochs
    for i in range(1, args.n_epochs+1):
        train(args, i, model, train_loader, optimizer)
        loss, pck = test(args, i, model, val_loader)
        if pck > best_pck:
            best_pck = pck
            torch.save(model.state_dict(), args.save_path)