# -*- coding: utf-8 -*-

import argparse
import os
import time
import torch
from torch.autograd import Variable as V
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as trn
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
import numpy as np
from resnext_50_32x4d import resnext_50_32x4d
from resnext_101_32x4d import resnext_101_32x4d
from resnext_101_64x4d import resnext_101_64x4d
from densenet_cosine_264_k48 import densenet_cosine_264_k48
from condensenet_converted import CondenseNet

parser = argparse.ArgumentParser(description='Evaluates robustness of various nets on ImageNet',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Architecture
parser.add_argument('--model-name', '-m', type=str,
                    choices=['alexnet', 'squeezenet1.0', 'squeezenet1.1', 'condensenet4', 'condensenet8',
                             'vgg11', 'vgg', 'vggbn',
                             'densenet121', 'densenet169', 'densenet201', 'densenet161', 'densenet264',
                             'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
                             'resnext50', 'resnext101', 'resnext101_64'])
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
args = parser.parse_args()
print(args)

# /////////////// Model Setup ///////////////

if args.model_name == 'alexnet':
    net = models.AlexNet()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 256

elif args.model_name == 'squeezenet1.0':
    net = models.SqueezeNet(version=1.0)
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 256

elif args.model_name == 'squeezenet1.1':
    net = models.SqueezeNet(version=1.1)
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 256

elif args.model_name == 'condensenet4':
    args.evaluate = True
    args.stages = [4,6,8,10,8]
    args.growth = [8,16,32,64,128]
    args.data = 'imagenet'
    args.num_classes = 1000
    args.bottleneck = 4
    args.group_1x1 = 4
    args.group_3x3 = 4
    args.reduction = 0.5
    args.condense_factor = 4
    net = CondenseNet(args)
    state_dict = torch.load('./converted_condensenet_4.pth')['state_dict']
    for i in range(len(state_dict)):
        name, v = state_dict.popitem(False)
        state_dict[name[7:]] = v     # remove 'module.' in key beginning
    net.load_state_dict(state_dict)
    args.test_bs = 256

elif args.model_name == 'condensenet8':
    args.evaluate = True
    args.stages = [4,6,8,10,8]
    args.growth = [8,16,32,64,128]
    args.data = 'imagenet'
    args.num_classes = 1000
    args.bottleneck = 4
    args.group_1x1 = 8
    args.group_3x3 = 8
    args.reduction = 0.5
    args.condense_factor = 8
    net = CondenseNet(args)
    state_dict = torch.load('./converted_condensenet_8.pth')['state_dict']
    for i in range(len(state_dict)):
        name, v = state_dict.popitem(False)
        state_dict[name[7:]] = v     # remove 'module.' in key beginning
    net.load_state_dict(state_dict)
    args.test_bs = 256

elif 'vgg' in args.model_name:
    if 'bn' not in args.model_name:
        net = models.vgg19()
        net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
                                               model_dir='/share/data/lang/users/dan/.torch/models'))
    elif '11' in args.model_name:
        net = models.vgg11()
        net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
                                               model_dir='/share/data/lang/users/dan/.torch/models'))
    else:
        net = models.vgg19_bn()
        net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
                                               model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 64

elif args.model_name == 'densenet121':
    net = models.densenet121()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet121-a639ec97.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 128

elif args.model_name == 'densenet169':
    net = models.densenet169()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet169-6f0f7f60.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 128

elif args.model_name == 'densenet201':
    net = models.densenet201()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet201-c1103571.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 64

elif args.model_name == 'densenet161':
    net = models.densenet161()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet161-8d451a50.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 64

elif args.model_name == 'densenet264':
    net = densenet_cosine_264_k48
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet_cosine_264_k48.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 64

elif args.model_name == 'resnet18':
    net = models.resnet18()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 256

elif args.model_name == 'resnet34':
    net = models.resnet34()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet34-333f7ec4.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 128

elif args.model_name == 'resnet50':
    net = models.resnet50()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 128

elif args.model_name == 'resnet101':
    net = models.resnet101()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 128

elif args.model_name == 'resnet152':
    net = models.resnet152()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet152-b121ed2d.pth',
                                           model_dir='/share/data/lang/users/dan/.torch/models'))
    args.test_bs = 64

elif args.model_name == 'resnext50':
    net = resnext_50_32x4d
    net.load_state_dict(torch.load('/share/data/lang/users/dan/.torch/models/resnext_50_32x4d.pth'))
    args.test_bs = 64

elif args.model_name == 'resnext101':
    net = resnext_101_32x4d
    net.load_state_dict(torch.load('/share/data/lang/users/dan/.torch/models/resnext_101_32x4d.pth'))
    args.test_bs = 64

elif args.model_name == 'resnext101_64':
    net = resnext_101_64x4d
    net.load_state_dict(torch.load('/share/data/lang/users/dan/.torch/models/resnext_101_64x4d.pth'))
    args.test_bs = 64

args.prefetch = 4

for p in net.parameters():
    p.volatile = True

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()

torch.manual_seed(1)
np.random.seed(1)
if args.ngpu > 0:
    torch.cuda.manual_seed(1)

net.eval()
cudnn.benchmark = True  # fire on all cylinders

print('Model Loaded')

# /////////////// Data Loader ///////////////

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

clean_loader = torch.utils.data.DataLoader(dset.ImageFolder(
    root="/share/data/vision-greg/ImageNet/clsloc/images/val",
    transform=trn.Compose([trn.Resize(256), trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)])),
    batch_size=args.test_bs, shuffle=False, num_workers=args.prefetch, pin_memory=True)


# /////////////// Further Setup ///////////////

def auc(errs):  # area under the distortion-error curve
    area = 0
    for i in range(1, len(errs)):
        area += (errs[i] + errs[i - 1]) / 2
    area /= len(errs) - 1
    return area


# correct = 0
# for batch_idx, (data, target) in enumerate(clean_loader):
#     data = V(data.cuda(), volatile=True)
#
#     output = net(data)
#
#     pred = output.data.max(1)[1]
#     correct += pred.eq(target.cuda()).sum()
#
# clean_error = 1 - correct / len(clean_loader.dataset)
# print('Clean dataset error (%): {:.2f}'.format(100 * clean_error))


def show_performance(distortion_name):
    errs = []

    for severity in range(1, 6):
        distorted_dataset = dset.ImageFolder(
            root='/share/data/vision-greg/DistortedImageNet/JPEG/' + distortion_name + '/' + str(severity),
            transform=trn.Compose([trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)]))

        distorted_dataset_loader = torch.utils.data.DataLoader(
            distorted_dataset, batch_size=args.test_bs, shuffle=False, num_workers=args.prefetch, pin_memory=True)

        correct = 0
        for batch_idx, (data, target) in enumerate(distorted_dataset_loader):
            data = V(data.cuda(), volatile=True)

            output = net(data)

            pred = output.data.max(1)[1]
            correct += pred.eq(target.cuda()).sum()

        errs.append(1 - 1.*correct / len(distorted_dataset))

    print('\n=Average', tuple(errs))
    return np.mean(errs)


# /////////////// End Further Setup ///////////////


# /////////////// Display Results ///////////////
import collections

print('\nUsing ImageNet data')

distortions = [
    'gaussian_noise', 'shot_noise', 'impulse_noise',
    'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
    'snow', 'frost', 'fog', 'brightness',
    'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
    'speckle_noise', 'gaussian_blur', 'spatter', 'saturate'
]

error_rates = []
for distortion_name in distortions:
    rate = show_performance(distortion_name)
    error_rates.append(rate)
    print('Distortion: {:15s}  | CE (unnormalized) (%): {:.2f}'.format(distortion_name, 100 * rate))


print('mCE (unnormalized by AlexNet errors) (%): {:.2f}'.format(100 * np.mean(error_rates)))