#!/usr/bin/env python
# coding: utf-8
#
# Author:   Kazuto Nakashima
# URL:      http://kazuto1011.github.io
# Created:  _list-11-01
# Modified by: Subhabrata Choudhury

from __future__ import absolute_import, division, print_function

import pickle
import os
import sys
import time
import inspect
import shutil
import os.path as osp
import click
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from addict import Dict
from tensorboardX import SummaryWriter
from torchnet.meter import MovingAverageValueMeter
import torchvision.models.resnet
from tqdm import tqdm
from libs.datasets import get_dataset
from libs.models import DeepLabV2_ResNet101_MSC
from libs.utils.loss import CrossEntropyLoss2d
import json

def get_params(model, key):
    # For Dilated FCN
    if key == "1x":
        for m in model.named_modules():
            if "layer" in m[0] or "vgg" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    for p in m[1].parameters():
                        yield p
    # For conv weight in the ASPP module
    if key == "10x":
        for m in model.named_modules():
            if "aspp" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    yield m[1].weight
    # For conv bias in the ASPP module
    if key == "20x":
        for m in model.named_modules():
            if "aspp" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    yield m[1].bias



class RandomImageSampler(torch.utils.data.Sampler):
    r"""Samples classes randomly, then returns images corresponding to those classes.
    """

    def __init__(self, seenset, novelset):
        self.data_index = []
        for v in seenset:
            self.data_index.append([v, 0])
        for v,i in novelset:
            self.data_index.append([v, i+1])

    def __iter__(self):
        return iter([ self.data_index[i] for i in np.random.permutation(len(self.data_index))])

    def __len__(self):
        return len(self.data_index)

def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter, max_iter, power):
    if iter % lr_decay_iter or iter > max_iter:
        return None
    new_lr = init_lr * (1 - float(iter) / max_iter) ** power
    optimizer.param_groups[0]["lr"] = new_lr
    optimizer.param_groups[1]["lr"] = 10 * new_lr
    optimizer.param_groups[2]["lr"] = 20 * new_lr


def resize_target(target, size):
    new_target = np.zeros((target.shape[0], size, size), np.int32)
    for i, t in enumerate(target.numpy()):
        new_target[i, ...] = cv2.resize(t, (size,) * 2, interpolation=cv2.INTER_NEAREST)
    return torch.from_numpy(new_target).long()



@click.command()
@click.option("-c", "--config", type=str, required=True)
@click.option("--cuda/--no-cuda", default=True)
@click.option("--excludeval/--no-excludeval", default=False)
@click.option("--embedding", default='fastnvec')
@click.option("--continue-from", type=int)
@click.option("--nolog", is_flag=True)
@click.option("--inputmix", type=str, default='seen')
@click.option("--imagedataset", default='cocostuff')
@click.option("--experimentid", type=str)
@click.option("--nshot", type=int)
@click.option("--ishot", type=int, default=0)
def main(config, cuda, excludeval,  embedding, continue_from, nolog, inputmix, imagedataset, experimentid, nshot, ishot ):
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    #print(values)

    #in case you want to save to the location of script you're running
    datadir = os.path.join('data/datasets', imagedataset)
    if not nolog:
        #name the savedir, might add logs/ before the datetime for clarity
        if experimentid is None:
            savedir = time.strftime('%Y%m%d%H%M%S')
        else:
            savedir = experimentid
        #the full savepath is then:
        savepath = os.path.join('logs', imagedataset, savedir)
        #in case the folder has not been created yet / except already exists error:
        try:
            os.makedirs(savepath)
            print("Log dir:", savepath)
        except:
            pass
        if continue_from is None:
            #now join the path in save_screenshot:
            shutil.copytree('./libs/', savepath+'/libs')
            shutil.copy2(osp.abspath(inspect.stack()[0][1]), savepath)
            shutil.copy2(config, savepath)
            args_dict = {}
            for a in args:
                args_dict[a] = values[a]
            with open(savepath+'/args.json', 'w') as fp:
                json.dump(args_dict, fp)

    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config), Loader=yaml.FullLoader))
    visibility_mask = {}
    if excludeval:
        seen_classes = np.load(datadir+'/split/seen_cls.npy')
    else:
        seen_classes = np.asarray(np.concatenate([np.load(datadir+'/split/seen_cls.npy'), np.load(datadir+'/split/val_cls.npy')]),dtype=int)

    novel_classes = np.load(datadir+'/split/novel_cls.npy')
    seen_novel_classes = np.concatenate([seen_classes, novel_classes])

    seen_map = np.array([-1]*256)
    for i,n in enumerate(list(seen_classes)):
        seen_map[n] = i

    visibility_mask[0] = seen_map.copy()
    for i, n in enumerate(list(novel_classes)):
        visibility_mask[i+1] = seen_map.copy()
        visibility_mask[i+1][n] = seen_classes.shape[0]+i
    if excludeval:
        train = np.load(datadir+'/split/train_list.npy')[:-CONFIG.VAL_SIZE]
    else:
        train = np.load(datadir+'/split/train_list.npy')

    novelset = []
    seenset = []

    if inputmix == 'novel' or inputmix == 'both':
        inverse_dict = pickle.load(open(datadir+'/split/inverse_dict_train.pkl', 'rb'))
        for icls, key in enumerate(novel_classes):
            if(inverse_dict[key].size >0):
                for v in inverse_dict[key][ishot*20: ishot*20+nshot]:
                    novelset.append((v, icls))
                    #print((v, icls))

    if inputmix == 'both':
        seenset = []
        inverse_dict = pickle.load(open(datadir+'/split/inverse_dict_train.pkl', 'rb'))
        for icls, key in enumerate(seen_classes):
            if(inverse_dict[key].size >0):
                for v in inverse_dict[key][ishot*20: ishot*20+nshot]:
                    seenset.append(v)

    if inputmix == 'seen':
        seenset = range(train.shape[0])

    sampler = RandomImageSampler(seenset, novelset)

    if inputmix == 'novel':
        visible_classes = seen_novel_classes
        if nshot is not None:
            nshot = str(nshot)+'n'
    elif inputmix == 'seen':
        visible_classes = seen_classes
        if nshot is not None:
            nshot = str(nshot)+'s'
    elif inputmix == 'both':
        visible_classes = seen_novel_classes
        if nshot is not None:
            nshot = str(nshot)+'b'

    
    print("Visible classes:", visible_classes.size, " \nClasses are: ", visible_classes, "\nTrain Images:", train.shape[0])

    #a Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET)(train=train, test=None,
            root=CONFIG.ROOT,
            split=CONFIG.SPLIT.TRAIN,
            base_size=513,
            crop_size=CONFIG.IMAGE.SIZE.TRAIN,
            mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
            warp=CONFIG.WARP_IMAGE,
            scale=(0.5, 1.5),
            flip=True,
            visibility_mask=visibility_mask
        )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.NUM_WORKERS,
        sampler = sampler
    )
    
    if  embedding == 'word2vec':
        class_emb = pickle.load(open(datadir+'/word_vectors/word2vec.pkl', "rb"))
    elif embedding == 'fasttext':
        class_emb = pickle.load(open(datadir+'/word_vectors/fasttext.pkl', "rb"))
    elif embedding == 'fastnvec':
        class_emb = np.concatenate([pickle.load(open(datadir+'/word_vectors/fasttext.pkl', "rb")), pickle.load(open(datadir+'/word_vectors/word2vec.pkl', "rb"))], axis = 1)
    else:
        print("invalid emb ", embedding)
        sys.exit() 

    print((class_emb.shape))
    class_emb = F.normalize(torch.tensor(class_emb), p=2, dim=1).cuda()

    loader_iter = iter(loader)
    DeepLab = DeepLabV2_ResNet101_MSC
    #import ipdb; ipdb.set_trace()
    state_dict = torch.load(CONFIG.INIT_MODEL)

    # Model load
    model = DeepLab(class_emb.shape[1], class_emb[visible_classes]) 
    if continue_from is not None and continue_from > 0:
        print("Loading checkpoint: {}".format(continue_from))
        #import ipdb; ipdb.set_trace()
        model = nn.DataParallel(model)
        state_file = osp.join(savepath, "checkpoint_{}.pth".format(continue_from))
        if osp.isfile(state_file+'.tar') :
            state_dict = torch.load(state_file+'.tar')
            model.load_state_dict(state_dict['state_dict'], strict=True)
        elif osp.isfile(state_file) :
            state_dict = torch.load(state_file)
            model.load_state_dict(state_dict, strict=True)
        else:
            print("Checkpoint {} not found".format(continue_from))
            sys.exit()

    else:
        model.load_state_dict(state_dict, strict=False)  # make strict=True to debug if checkpoint is loaded correctly or not if performance is low
        model = nn.DataParallel(model)
    model.to(device)
    # Optimizer

    optimizer = {
        "sgd": torch.optim.SGD(
            # cf lr_mult and decay_mult in train.prototxt
            params=[
                {
                    "params": get_params(model.module, key="1x"),
                    "lr": CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_params(model.module, key="10x"),
                    "lr": 10 * CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_params(model.module, key="20x"),
                    "lr": 20 * CONFIG.LR,
                    "weight_decay": 0.0,
                }
            ],
            momentum=CONFIG.MOMENTUM,
        ),
        "adam": torch.optim.Adam(
            # cf lr_mult and decay_mult in train.prototxt
            params=[
                {
                    "params": get_params(model.module, key="1x"),
                    "lr": CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_params(model.module, key="10x"),
                    "lr": 10 * CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_params(model.module, key="20x"),
                    "lr": 20 * CONFIG.LR,
                    "weight_decay": 0.0,
                }
            ]
        )
        # Add any other optimizer
    }.get(CONFIG.OPTIMIZER)

    if 'optimizer' in state_dict:
        optimizer.load_state_dict(state_dict['optimizer'])
    print("Learning rate:",  CONFIG.LR )
    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    criterion.to(device)

    if not nolog:
        # TensorBoard Logger
        if continue_from is not None:
            writer = SummaryWriter(savepath+'/runs/fs_{}_{}_{}'.format(continue_from, nshot, ishot))
        else:
            writer = SummaryWriter(savepath+'/runs')
        loss_meter = MovingAverageValueMeter(20)

    model.train()
    model.module.scale.freeze_bn()

    pbar =  tqdm(
        range(1, CONFIG.ITER_MAX + 1),
        total=CONFIG.ITER_MAX,
        leave=False,
        dynamic_ncols=True,
    )
    for iteration in pbar:

        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            try:
                data, target = next(loader_iter)
            except:
                loader_iter = iter(loader)
                data, target = next(loader_iter)

            # Image
            data = data.to(device)

            # Propagate forward
            outputs = model(data)
            # Loss
            loss = 0
            for output in outputs:
                # Resize target for {100%, 75%, 50%, Max} outputs
                target_ = resize_target(target, output.size(2))
                target_ = torch.tensor(target_).to(device)
                loss += criterion.forward(output, target_)

            # Backpropagate (just compute gradients wrt the loss)
            #print(loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)

        #print(iter_loss)
        pbar.set_postfix(loss = "%.3f" % iter_loss)

        # Update weights with accumulated gradients
        optimizer.step()
        if not nolog:
            loss_meter.add(iter_loss)
            # TensorBoard
            if iteration % CONFIG.ITER_TB == 0:
                writer.add_scalar("train_loss", loss_meter.value()[0], iteration)
                for i, o in enumerate(optimizer.param_groups):
                    writer.add_scalar("train_lr_group{}".format(i), o["lr"], iteration)
                if False:  # This produces a large log file
                    for name, param in model.named_parameters():
                        name = name.replace(".", "/")
                        writer.add_histogram(name, param, iteration, bins="auto")
                        if param.requires_grad:
                            writer.add_histogram(
                                name + "/grad", param.grad, iteration, bins="auto"
                            )

            # Save a model
            if continue_from is not None:
                if iteration in CONFIG.ITER_SAVE:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                        },
                        osp.join(savepath, "checkpoint_{}_{}_{}_{}.pth.tar".format(continue_from, nshot, ishot, iteration)),
                    )

                # Save a model (short term) [unnecessary for fewshot]
                if False and iteration % 100 == 0:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                        },
                        osp.join(savepath, "checkpoint_{}_{}_{}_current.pth.tar".format(continue_from, nshot, ishot)),
                    )
                    print(osp.join(savepath, "checkpoint_{}_{}_{}_current.pth.tar".format(continue_from, nshot, ishot)))
            else:
                if iteration % CONFIG.ITER_SAVE == 0:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                            'optimizer' : optimizer.state_dict(),
                        },
                        osp.join(savepath, "checkpoint_{}.pth.tar".format(iteration)),
                    )

                # Save a model (short term)
                if iteration % 100 == 0:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                            'optimizer' : optimizer.state_dict(),
                        },
                        osp.join(savepath, "checkpoint_current.pth.tar"),
                    )

    if not nolog:
        if continue_from is not None:
            torch.save(
                {
                    'iteration': iteration,
                    'state_dict': model.state_dict(),
                },
                osp.join(savepath, "checkpoint_{}_{}_{}_{}.pth.tar".format(continue_from, nshot, ishot, CONFIG.ITER_MAX))
            )
        else:
            torch.save(
                {
                    'iteration': iteration,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                },
                osp.join(savepath, "checkpoint_{}.pth.tar".format(CONFIG.ITER_MAX))
            )

if __name__ == "__main__":
    main()