import torch
import torchnet as tnt
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
from tqdm import tqdm

import config
import utils
from capsnet import CapsuleNet
from loss import CapsuleLoss


def processor(sample):
    data, labels, training = sample

    data = utils.augmentation(data.unsqueeze(1).float() / 255.0)
    labels = torch.eye(config.NUM_CLASSES).index_select(dim=0, index=labels)

    data = Variable(data)
    labels = Variable(labels)
    if torch.cuda.is_available():
        data = data.cuda()
        labels = labels.cuda()

    if training:
        classes, reconstructions = model(data, labels)
    else:
        classes, reconstructions = model(data)

    loss = capsule_loss(data, labels, classes, reconstructions)

    return loss, classes


def on_sample(state):
    state['sample'].append(state['train'])


def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()


def on_forward(state):
    meter_accuracy.add(state['output'].data, state['sample'][1])
    confusion_meter.add(state['output'].data, state['sample'][1])
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])


def on_end_epoch(state):
    print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

    train_loss_logger.log(state['epoch'], meter_loss.value()[0])
    train_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])

    reset_meters()

    engine.test(processor, utils.get_iterator(False))
    test_loss_logger.log(state['epoch'], meter_loss.value()[0])
    test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
    confusion_logger.log(confusion_meter.value())

    print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

    torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])

    # reconstruction visualization

    test_sample = next(iter(utils.get_iterator(False)))

    ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0)
    if torch.cuda.is_available():
        _, reconstructions = model(Variable(ground_truth).cuda())
    else:
        _, reconstructions = model(Variable(ground_truth))
    reconstruction = reconstructions.cpu().view_as(ground_truth).data

    ground_truth_logger.log(
        make_grid(ground_truth, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
    reconstruction_logger.log(
        make_grid(reconstruction, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())


if __name__ == "__main__":
    model = CapsuleNet()
    if torch.cuda.is_available():
        model.cuda()

    print("# parameters:", sum(param.numel() for param in model.parameters()))

    optimizer = Adam(model.parameters())

    engine = Engine()
    meter_loss = tnt.meter.AverageValueMeter()
    meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
    confusion_meter = tnt.meter.ConfusionMeter(config.NUM_CLASSES, normalized=True)

    train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
    train_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
    test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
    test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
    confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion Matrix',
                                                     'columnnames': list(range(config.NUM_CLASSES)),
                                                     'rownames': list(range(config.NUM_CLASSES))})
    ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
    reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

    capsule_loss = CapsuleLoss()

    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch

    engine.train(processor, utils.get_iterator(True), maxepoch=config.NUM_EPOCHS, optimizer=optimizer)