import argparse

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn

from models import modules, net, resnet, densenet, senet
import net_mask
import loaddata
import util
import numpy as np

import os
import matplotlib
import matplotlib.image
matplotlib.rcParams['image.cmap'] = 'viridis'

import pdb

parser = argparse.ArgumentParser(description='single depth estimation')
parser.add_argument('--epochs', default=60, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
                    help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    help='weight decay (default: 1e-4)')
parser.add_argument('--name', default='train2_2', type=str,
                    help='name of experiment')

def define_model(encoder='resnet'):
    if encoder is 'resnet':
        original_model = resnet.resnet50(pretrained = True)
        Encoder = modules.E_resnet(original_model) 
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
    if encoder is 'densenet':
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
    if encoder is 'senet':
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

    return model
   

def main():
    global args
    args = parser.parse_args()

    model_selection = 'resnet'
    model = define_model(encoder = model_selection)
    original_model2 = net_mask.drn_d_22(pretrained=True)
    model2 = net_mask.AutoED(original_model2)  
 
    model = torch.nn.DataParallel(model).cuda()
    model2 = torch.nn.DataParallel(model2).cuda()

    model.load_state_dict(torch.load('./pretrained_model/model_' + model_selection))
    model2.load_state_dict(torch.load('./net_mask/mask_' + model_selection))

    test_loader = loaddata.getTestingData(1)
    test(test_loader, model, model2,'mask_'+model_selection)

 



def test(train_loader, model, model2, dir):
    totalNumber = 0
    errorSum = {'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0,
                'MAE': 0,  'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0}
    model.eval()
    model2.eval()

    # if not os.path.exists(dir):
    #     os.mkdir(dir)

    for i, sample_batched in enumerate(train_loader):
        image, depth_ = sample_batched['image'], sample_batched['depth']

        image = torch.autograd.Variable(image, volatile=True).cuda()
        depth_ = torch.autograd.Variable(depth_, volatile=True).cuda(async=True)
 
        depth = model(image)

        mask = model2(image)
        output = model(image*mask)

        batchSize = depth.size(0)
        errors = util.evaluateError(output,depth_)
        errorSum = util.addErrors(errorSum, errors, batchSize)
        totalNumber = totalNumber + batchSize
        averageError = util.averageErrors(errorSum, totalNumber)

        # mask = mask.squeeze().view(228,304).data.cpu().float().numpy()
        # matplotlib.image.imsave(dir+'/mask'+str(i)+'.png', mask)
 
    print('rmse:',np.sqrt(averageError['MSE']))

 



if __name__ == '__main__':
    main()