import unittest
import train
import torch
from datasets.get import get_dataset
import torch.nn as nn
from models.layers.AsyncTFBase import AsyncTFBase
from models.layers.AsyncTFCriterion import AsyncTFCriterion
from opts import parse
import subprocess
subprocess.Popen('find ./exp/.. -iname "*.pyc" -delete'.split())


def opts(opt):
    opt.dataset = 'mock_dataset1'
    opt.lr_decay_rate = 100
    opt.lr = 1e-1
    opt.temporal_weight = 5.
    opt.temporalloss_weight = 0.05
    opt.memory_decay = 1.0
    opt.sigma = 150
    opt.print_freq = 9
    opt.weight_decay = 0
    opt.weight_decay = 5e-4
    opt.memory_size = 20
    opt.nclass = 5
    #opt.adjustment = True
    #opt.nhidden = 2
    opt.nhidden = 10


def simpletest1():
    # test if the code can learn a simple sequence
    opt = parse()
    opts(opt)
    epochs = 40
    train_loader, val_loader, valvideo_loader = get_dataset(opt)
    trainer = train.Trainer()
    basemodel = nn.Linear(100, 5)
    model = AsyncTFBase(basemodel, 100, opt).cuda()
    criterion = AsyncTFCriterion(opt).cuda()
    optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
    epoch = -1
    for i in range(epochs):
        top1, _ = trainer.train(train_loader, model, criterion, optimizer, i, opt)
        print('cls weights: {}, aa weights: {}'.format(
            model.mA.parameters().next().norm().data[0],
            model.mAAa.parameters().next().norm().data[0]))
    top1, _ = trainer.validate(train_loader, model, criterion, epochs, opt)

    for i in range(5):
        top1val, _ = trainer.validate(val_loader, model, criterion, epochs + i, opt)
        print('top1val: {}'.format(top1val))

    ap = trainer.validate_video(valvideo_loader, model, criterion, epoch, opt)
    return top1, top1val, ap


class AsyncTests(unittest.TestCase):
    def test1(self):
        top1, top1val, ap = simpletest1()
        self.failUnless(top1 > 90)
        self.failUnless(top1val > 90)
        self.failUnless(ap > 0.9)

    #def test2(self):
    #    # stresstest
    #    top1s, top1vals, aps = [], [], []
    #    trials = 20
    #    for _ in range(trials):
    #        top1, top1val, ap = simpletest1()
    #        top1s.append(top1)
    #        top1vals.append(top1val)
    #        aps.append(ap)
    #    top1s = [1 if x > 85 else 0 for x in top1s]
    #    top1vals = [1 if x > 85 else 0 for x in top1vals]
    #    aps = [1 if x > .85 else 0 for x in aps]
    #    print('top1s: {}/{} \t top1vals: {}/{} \t aps: {}/{}'.format(sum(top1s), trials, sum(top1vals), trials, sum(aps), trials))
    #    self.failUnless(sum(top1s) > .8 * len(top1s))
    #    self.failUnless(sum(top1vals) > .8 * len(top1vals))


def main():
    unittest.main()


if __name__ == '__main__':
    main()