import argparse
import os
import torch
import torch.nn as nn
from copy import deepcopy
from experiment.engine import MultiLabelMAPEngine
from experiment.models import vgg16_sp
from experiment.voc import Voc2007Classification

parser = argparse.ArgumentParser(description='Model Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset (e.g. ../data/')
parser.add_argument('--image-size', '-i', default='224', type=str,
                    metavar='N', help='image size (default: 224)')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=16, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0.0005, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--resume', default=None, type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
 
def main_voc2007():
    global args, best_prec1, use_gpu 
    args = parser.parse_args()

    use_gpu = torch.cuda.is_available()

    # define dataset
    train_dataset = Voc2007Classification(args.data, 'trainval')
    val_dataset = Voc2007Classification(args.data, 'test')
    num_classes = 20

    # load model
    model = vgg16_sp(num_classes, pretrained=True)
    
    print(model)

    criterion = nn.MultiLabelSoftMarginLoss()

    state = {'batch_size': args.batch_size, 'max_epochs': args.epochs, 
            'image_size': args.image_size, 'evaluate': args.evaluate, 'resume': args.resume,
             'lr':args.lr, 'momentum':args.momentum, 'weight_decay':args.weight_decay}
    state['difficult_examples'] = True
    state['save_model_path'] = 'logs/voc2007/'

    engine = MultiLabelMAPEngine(state)
    engine.multi_learning(model, criterion, train_dataset, val_dataset)


if __name__ == '__main__':
    main_voc2007()