import os

import argparse

from torch.backends import cudnn
from config import config, dataset_config, merge_cfg_arg

from dataloder import get_loader
from solver_cycle import Solver_cycleGAN
from solver_makeup import Solver_makeupGAN

def parse_args():
    parser = argparse.ArgumentParser(description='Train GAN')
    # general
    parser.add_argument('--data_path', default='makeup/makeup_final/', type=str, help='training and test data path')
    parser.add_argument('--dataset', default='MAKEUP', type=str, help='dataset name, MAKEUP means two domain, MMAKEUP means multi-domain')
    parser.add_argument('--gpus', default='0', type=str, help='GPU device to train with')
    parser.add_argument('--batch_size', default='1', type=int, help='batch_size')
    parser.add_argument('--vis_step', default='1260', type=int, help='steps between visualization')
    parser.add_argument('--task_name', default='', type=str, help='task name')
    parser.add_argument('--checkpoint', default='', type=str, help='checkpoint to load')
    parser.add_argument('--ndis', default='1', type=int, help='train discriminator steps')
    parser.add_argument('--LR', default="2e-4", type=float, help='Learning rate')
    parser.add_argument('--decay', default='0', type=int, help='epochs number for training')
    parser.add_argument('--model', default='makeupGAN', type=str, help='which model to use: cycleGAN/ makeupGAN')
    parser.add_argument('--epochs', default='300', type=int, help='nums of epochs')
    parser.add_argument('--whichG', default='branch', type=str, help='which Generator to choose, normal/branch, branch means two input branches')
    parser.add_argument('--norm', default='SN', type=str, help='normalization of discriminator, SN means spectrum normalization, none means no normalization')
    parser.add_argument('--d_repeat', default='3', type=int, help='the repeat Res-block in discriminator')
    parser.add_argument('--g_repeat', default='6', type=int, help='the repeat Res-block in Generator')
    parser.add_argument('--lambda_cls', default='1', type=float, help='the lambda_cls weight')
    parser.add_argument('--lambda_rec', default='10', type=int, help='lambda_A and lambda_B')
    parser.add_argument('--lambda_his', default='1', type=float, help='histogram loss on lips')
    parser.add_argument('--lambda_skin_1', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin')
    parser.add_argument('--lambda_skin_2', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin')
    parser.add_argument('--lambda_eye', default='1', type=float, help='histogram loss on eyes equals to lambda_his*lambda_eye')
    parser.add_argument('--content_layer', default='r41', type=str, help='vgg layer we use to output features')
    parser.add_argument('--lambda_vgg', default='5e-3', type=float, help='the param of vgg loss')
    parser.add_argument('--cls_list', default='SYMIX,MAKEMIX', type=str, help='the classes of makeup to train')
    parser.add_argument('--direct', action="store_true", default=True, help='direct means to add local cosmetic loss at the first, unified training')
    parser.add_argument('--lips', action="store_true", default=True, help='whether to finetune lips color')
    parser.add_argument('--skin', action="store_true", default=True, help='whether to finetune foundation color')
    parser.add_argument('--eye', action="store_true", default=True, help='whether to finetune eye shadow color')
    args = parser.parse_args()
    return args

def train_net():
    # enable cudnn
    cudnn.benchmark = True

    data_loaders = get_loader(dataset_config, config, mode="train")    # return train&test
    #get the solver
    if args.model == 'cycleGAN':
        solver = Solver_cycleGAN(data_loaders, config, dataset_config)
    elif args.model =='makeupGAN':
        solver = Solver_makeupGAN(data_loaders, config, dataset_config)
    else:
        print("model that not support")
        exit()
    solver.train()



if __name__ == '__main__':
    args = parse_args()
    print("Call with args:")
    print(args)
    config = merge_cfg_arg(config, args)

    dataset_config.name = args.dataset

    print("The config is:")
    print(config)

    # Create the directories if not exist
    if not os.path.exists(config.data_path):
        print("No datapath!!")
        exit()

    if args.data_path != '':
        dataset_config.dataset_path = os.path.join(config.data_path, args.data_path)

    train_net()