#!python
import sys
import argparse
from pathlib import Path

import numpy as np
import torch

from ..utils.logger import logger, set_logfile
from ..utils.audio import AudioDataLoader
from ..utils import misc
from ..dataset.aspire import AspireDataset

from .model import FirstCapsuleNetModel


def parse_options(argv):
    parser = argparse.ArgumentParser(description="First CapsuleNet AM with fully supervised training")
    # for training
    parser.add_argument('--data-path', default='data/aspire', type=str, help="dataset path to use in training")
    parser.add_argument('--num-workers', default=4, type=int, help="number of dataloader workers")
    parser.add_argument('--num-epochs', default=500, type=int, help="number of epochs to run")
    parser.add_argument('--batch-size', default=16, type=int, help="number of images (and labels) to be considered in a batch")
    parser.add_argument('--init-lr', default=0.0001, type=float, help="initial learning rate for Adam optimizer")
    parser.add_argument('--num-iterations', default=3, type=float, help="number of routing iterations")
    # optional
    parser.add_argument('--use-cuda', default=False, action='store_true', help="use cuda")
    parser.add_argument('--seed', default=None, type=int, help="seed for controlling randomness in this example")
    parser.add_argument('--log-dir', default='./logs', type=str, help="filename for logging the outputs")
    parser.add_argument('--model-prefix', default='capsule_aspire', type=str, help="model file prefix to store")
    parser.add_argument('--continue-from', default=None, type=str, help="model file path to make continued from")

    args = parser.parse_args(argv)

    print(f"begins logging to file: {str(Path(args.log_dir).resolve() / 'train.log')}")
    set_logfile(Path(args.log_dir, "train.log"))

    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"Training started with command: {' '.join(sys.argv)}")
    args_str = [f"{k}={v}" for (k, v) in vars(args).items()]
    logger.info(f"args: {' '.join(args_str)}")

    if args.use_cuda:
        logger.info("using cuda")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    if args.seed is not None:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed)

    return args


def train(argv):
    args = parse_options(argv)

    def get_model_file_path(desc):
        return misc.get_model_file_path(args.log_dir, args.model_prefix, desc)

    # batch_size: number of images (and labels) to be considered in a batch
    model = FirstCapsuleNetModel(**vars(args))

    # initializing local variables to maintain the best validation accuracy
    # seen across epochs over the supervised training set
    best_valid_acc = 0.0

    # run inference for a certain number of epochs
    for i in range(model.epoch, args.num_epochs):
        # if you want to limit the datasets' entry size
        sizes = { "train": 4000, "dev": 400 }

        # prepare data loaders
        datasets, data_loaders = dict(), dict()
        for mode in ["train", "dev"]:
            datasets[mode] = AspireDataset(root=args.data_path, mode=mode, data_size=sizes[mode])
            data_loaders[mode] = AudioDataLoader(datasets[mode], batch_size=args.batch_size,
                                                 num_workers=args.num_workers, shuffle=True,
                                                 use_cuda=args.use_cuda, pin_memory=True)
        # train an epoch
        model.train_epoch(data_loaders["train"])
        logger.info(f"epoch {model.epoch:03d}: "
                    f"training loss {model.meter_loss.value()[0]:5.3f} "
                    f"training accuracy {model.meter_accuracy.value()[0]:6.3f}")

        # validate
        model.test(data_loaders["dev"])
        logger.info(f"epoch {model.epoch:03d}: "
                    f"validating loss {model.meter_loss.value()[0]:5.3f} "
                    f"validating accuracy {model.meter_accuracy.value()[0]:6.3f}")

        # update the best validation accuracy and the corresponding
        # testing accuracy and the state of the parent module (including the networks)
        if best_valid_acc < model.meter_accuracy.value()[0]:
            best_valid_acc = model.meter_accuracy.value()[0]
        # save
        model.save(get_model_file_path(f"epoch_{model.epoch:04d}"))
        # increase epoch num
        model.epoch += 1

    # test
    model.test(data_loaders["test"])

    logger.info(f"best validation accuracy {best_valid_acc:6.3f} "
                f"test accuracy {model.meter_accuracy.value()[0]:6.3f}")

    #save final model
    model.save(get_model_file_path("final"), epoch=epoch)


if __name__ == "__main__":
    pass