#!/usr/bin/env python

from __future__ import print_function
import argparse
import os

import chainer
from chainer import training
from chainer.training import extensions

import dataset
from models.vgg16 import VGG16
from models.generators import FCN32s, FCN16s, FCN8s 
from models.discriminators import (
    LargeFOV, LargeFOVLight, SmallFOV, SmallFOVLight, SPPDiscriminator)
from updater import GANUpdater, NonAdversarialUpdater
from extensions import TestModeEvaluator
import utils


def parse_args(generators, discriminators, updaters):
    parser = argparse.ArgumentParser(description='Semantic Segmentation using Adversarial Networks')
    parser.add_argument('--generator', choices=generators.keys(), default='fcn32s',
                        help='Generator(segmentor) architecture')
    parser.add_argument('--discriminator', choices=discriminators.keys(), default='largefov',
                        help='Discriminator architecture')
    parser.add_argument('--updater', choices=updaters.keys(), default='gan',
                        help='Updater')
    parser.add_argument('--initgen_path', default='pretrained_model/vgg16.npz',
                        help='Pretrained model of generator')
    parser.add_argument('--initdis_path', default=None,
                        help='Pretrained model of discriminator')
    parser.add_argument('--batchsize', '-b', type=int, default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--iteration', '-i', type=int, default=100000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='snapshot',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--evaluate_interval', type=int, default=1000,
                        help='Interval of evaluation')
    parser.add_argument('--snapshot_interval', type=int, default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval', type=int, default=10,
                        help='Interval of displaying log to console')
    return parser.parse_args()

def load_pretrained_model(initmodel_path, initmodel, model, n_class, device):
    print('Initializing the model')
    chainer.serializers.load_npz(initmodel_path, initmodel)
    utils.copy_chainermodel(initmodel, model)
    return model

def make_optimizer(model, lr=1e-10, momentum=0.99):
    optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005), 'hook_dec')
    return optimizer

def main():
    generators = {
        'fcn32s': (FCN32s, VGG16, 1e-10), # (model, initmodel, learning_rate)
        'fcn16s': (FCN16s, FCN32s, 1e-12),
        'fcn8s': (FCN8s, FCN16s, 1e-14),
    }
    discriminators = {
        'largefov': (LargeFOV, LargeFOV, 0.1, 1.0), # (model, initmodel, learning_rate, L_bce_weight)
        'largefov-light': (LargeFOVLight, LargeFOVLight, 0.1, 1.0),
        'smallfov': (SmallFOV, SmallFOV, 0.1, 0.1),
        'smallfov-light': (SmallFOVLight, SmallFOVLight, 0.2, 1.0),
        'sppdis': (SPPDiscriminator, SPPDiscriminator, 0.1, 1.0),
    }
    updaters = {
        'gan': GANUpdater,
        'standard': NonAdversarialUpdater
    }

    args = parse_args(generators, discriminators, updaters)

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# iteration: {}'.format(args.iteration))

    # dataset
    train = dataset.PascalVOC2012Dataset('train')
    val = dataset.PascalVOC2012Dataset('val')
    n_class = len(train.label_names)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    val_iter = chainer.iterators.SerialIterator(val, args.batchsize, repeat=False, shuffle=False)

    # Set up a neural network to train and an optimizer
    if args.updater=='gan':
        gen_cls, initgen_cls, gen_lr = generators[args.generator]
        dis_cls, initdis_cls, dis_lr, L_bce_weight = discriminators[args.discriminator]
        print('# generator: {}'.format(gen_cls.__name__))
        print('# discriminator: {}'.format(dis_cls.__name__))
        print('')

        # Initialize generator
        if args.initgen_path:
            gen, initgen = gen_cls(n_class), initgen_cls(n_class)
            gen = load_pretrained_model(args.initgen_path, initgen, gen, n_class, args.gpu)
        else:
            gen = gen_cls(n_class)
        # Initialize discriminator
        if args.initdis_path:
            dis, initdis = dis_cls(n_class), initdis_cls(n_class)
            dis = load_pretrained_model(args.initdis_path, initdis, dis, n_class, args.gpu)
        else:
            dis = dis_cls(n_class)
        if args.gpu >= 0:
            chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
            gen.to_gpu()  # Copy the model to the GPU
            dis.to_gpu()
        opt_gen = make_optimizer(gen, gen_lr)
        opt_dis = make_optimizer(dis, dis_lr)
        model={'gen':gen,'dis':dis}
        optimizer={'gen': opt_gen, 'dis': opt_dis}
    elif args.updater=='standard':
        model_cls, initmodel_cls, lr = generators[args.generator]
        L_bce_weight = None
        print('# model: {}'.format(model_cls.__name__))
        print('')
        if args.initgen_path:
            model, initmodel = model_cls(n_class), initmodel_cls(n_class)
            model = load_pretrained_model(args.initgen_path, initmodel, model, n_class, args.gpu)
        else:
            model = model_cls(n_class)
        if args.gpu >= 0:
            chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
            model.to_gpu()  # Copy the model to the GPU
        optimizer = make_optimizer(model, lr)

    # Set up a trainer
    updater = updaters[args.updater](
        model=model,
        iterator=train_iter,
        optimizer=optimizer,
        device=args.gpu,
        L_bce_weight=L_bce_weight,
        n_class=n_class,)

    trainer = training.Trainer(updater, (args.iteration, 'iteration'), out=args.out)

    evaluate_interval = (args.evaluate_interval, 'iteration') 
    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')

    trainer.extend(
        TestModeEvaluator(
            val_iter, updater, device=args.gpu),
        trigger=snapshot_interval,
        invoke_before_training=False)
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)

    if args.updater=='gan':
        trainer.extend(extensions.snapshot_object(
            gen, 'gen_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
        trainer.extend(extensions.snapshot_object(
            dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
        trainer.extend(extensions.LogReport(trigger=display_interval))
        trainer.extend(extensions.PrintReport([
            'iteration',
            'gen/loss', 'validation/gen/loss',
            'dis/loss',
            'gen/accuracy', 'validation/gen/accuracy',
            'gen/iu', 'validation/gen/iu',
            'elapsed_time',
        ]), trigger=display_interval)
    elif args.updater=='standard':
        trainer.extend(extensions.snapshot_object(
            model, 'model_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
        trainer.extend(extensions.LogReport(trigger=display_interval))
        trainer.extend(extensions.PrintReport([
            'iteration',
            'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy',
            'main/iu', 'validation/main/iu',
            'elapsed_time',
        ]), trigger=display_interval)

    trainer.extend(extensions.ProgressBar(update_interval=1))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    print('\nRun the training')
    trainer.run()

if __name__ == '__main__':
    main()