import os
import sys
import time
import glob
import math
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torch.distributions.categorical as cate
import torchvision.utils as vutils

from torch.autograd import Variable
from model_search import Network
from architect import Architect
from tensorboardX import SummaryWriter

parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--batch_increase', default=8, type=int, help='how much does the batch size increase after making a decision')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--warmup_dec_epoch', type=int, default=9, help='warmup decision epoch')
parser.add_argument('--decision_freq', type=int, default=5, help='decision freq epoch')
parser.add_argument('--use_history', action='store_true', help='use history for decision')
parser.add_argument('--history_size', type=int, default=4, help='number of stored epoch scores')
parser.add_argument('--post_val', action='store_true', default=False, help='validate after each decision')
args = parser.parse_args()

args.save = 'search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

writer = SummaryWriter(log_dir=args.save, max_queue=50)

CIFAR_CLASSES = 10

def histogram_average(history, probs):
  histogram_inter = torch.zeros(probs.shape[0], dtype=torch.float).cuda()
  if not history:
    return histogram_inter
  for hist in history:
    histogram_inter += utils.histogram_intersection(hist, probs)
  histogram_inter /= len(history)
  return histogram_inter

def score_image(type, score, epoch):
  score_img = vutils.make_grid(
    torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(score, 1), 2), 3),
    nrow=7,
    normalize=True,
    pad_value=0.5)
  writer.add_image(type + '_score', score_img, epoch)


def edge_decision(type, alphas, selected_idxs, candidate_flags, probs_history, epoch, model, args):
  mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach()
  print(mat)
  importance = torch.sum(mat[:, 1:], dim=-1)
  # logging.info(type + " importance {}".format(importance))

  probs = mat[:, 1:] / importance[:, None]
  # print(type + " probs", probs)
  entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])
  # logging.info(type + " entropy {}".format(entropy))

  if args.use_history: # SGAS Cri.2
    # logging.info(type + " probs history {}".format(probs_history))
    histogram_inter = histogram_average(probs_history, probs)
    # logging.info(type + " histogram intersection average {}".format(histogram_inter))
    probs_history.append(probs)
    if (len(probs_history) > args.history_size):
      probs_history.pop(0)

    score = utils.normalize(importance) * utils.normalize(
      1 - entropy) * utils.normalize(histogram_inter)
    # logging.info(type + " score {}".format(score))
  else: # SGAS Cri.1
    score = utils.normalize(importance) * utils.normalize(1 - entropy)
    # logging.info(type + " score {}".format(score))

  if torch.sum(candidate_flags.int()) > 0 and \
      epoch >= args.warmup_dec_epoch and \
      (epoch - args.warmup_dec_epoch) % args.decision_freq == 0:
    masked_score = torch.min(score,
                              (2 * candidate_flags.float() - 1) * np.inf)
    selected_edge_idx = torch.argmax(masked_score)
    selected_op_idx = torch.argmax(probs[selected_edge_idx]) + 1 # add 1 since none op
    selected_idxs[selected_edge_idx] = selected_op_idx

    candidate_flags[selected_edge_idx] = False
    alphas[selected_edge_idx].requires_grad = False
    if type == 'normal':
      reduction = False
    elif type == 'reduce':
      reduction = True
    else:
      raise Exception('Unknown Cell Type')
    candidate_flags, selected_idxs = model.check_edges(candidate_flags,
                                                       selected_idxs,
                                                       reduction=reduction)
    logging.info("#" * 30 + " Decision Epoch " + "#" * 30)
    logging.info("epoch {}, {}_selected_idxs {}, added edge {} with op idx {}".format(epoch,
                                                      type,
                                                      selected_idxs,
                                                      selected_edge_idx,
                                                      selected_op_idx))
    print(type + "_candidate_flags {}".format(candidate_flags))
    score_image(type, score, epoch)
    return True, selected_idxs, candidate_flags

  else:
    logging.info("#" * 30 + " Not a Decision Epoch " + "#" * 30)
    logging.info("epoch {}, {}_selected_idxs {}".format(epoch,
                                                 type,
                                                 selected_idxs))
    print(type + "_candidate_flags {}".format(candidate_flags))
    score_image(type, score, epoch)
    return False, selected_idxs, candidate_flags

def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled = True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
  model = model.cuda()
  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

  num_train = len(train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True)

  valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True)

  num_edges = model._steps * 2
  post_train = 5
  epochs = args.warmup_dec_epoch + args.decision_freq * (num_edges - 1) + post_train + 1
  logging.info("total epochs: %d", epochs)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(epochs), eta_min=args.learning_rate_min)

  architect = Architect(model, args)

  normal_selected_idxs = torch.tensor(len(model.alphas_normal) * [-1], requires_grad=False, dtype=torch.int).cuda()
  reduce_selected_idxs = torch.tensor(len(model.alphas_reduce) * [-1], requires_grad=False, dtype=torch.int).cuda()
  normal_candidate_flags = torch.tensor(len(model.alphas_normal) * [True], requires_grad=False, dtype=torch.bool).cuda()
  reduce_candidate_flags = torch.tensor(len(model.alphas_reduce) * [True], requires_grad=False, dtype=torch.bool).cuda()
  logging.info('normal_selected_idxs: {}'.format(normal_selected_idxs))
  logging.info('reduce_selected_idxs: {}'.format(reduce_selected_idxs))
  logging.info('normal_candidate_flags: {}'.format(normal_candidate_flags))
  logging.info('reduce_candidate_flags: {}'.format(reduce_candidate_flags))
  model.normal_selected_idxs = normal_selected_idxs
  model.reduce_selected_idxs = reduce_selected_idxs
  model.normal_candidate_flags = normal_candidate_flags
  model.reduce_candidate_flags = reduce_candidate_flags

  print(F.softmax(torch.stack(model.alphas_normal, dim=0), dim=-1).detach())
  print(F.softmax(torch.stack(model.alphas_reduce, dim=0), dim=-1).detach())

  count = 0
  normal_probs_history = []
  reduce_probs_history = []

  for epoch in range(epochs):
    scheduler.step()
    lr = scheduler.get_lr()[0]
    logging.info('epoch %d lr %e', epoch, lr)

    # training
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr)
    logging.info('train_acc %f', train_acc)

    # validation
    with torch.no_grad():
      valid_acc, valid_obj = infer(valid_queue, model, criterion)
      logging.info('valid_acc %f', valid_acc)


    saved_memory_normal, model.normal_selected_idxs, \
    model.normal_candidate_flags = edge_decision('normal',
                                                 model.alphas_normal,
                                                 model.normal_selected_idxs,
                                                 model.normal_candidate_flags,
                                                 normal_probs_history,
                                                 epoch,
                                                 model,
                                                 args)

    saved_memory_reduce, model.reduce_selected_idxs, \
    model.reduce_candidate_flags = edge_decision('reduce',
                                                 model.alphas_reduce,
                                                 model.reduce_selected_idxs,
                                                 model.reduce_candidate_flags,
                                                 reduce_probs_history,
                                                 epoch,
                                                 model,
                                                 args)

    if saved_memory_normal or saved_memory_reduce:
      del train_queue, valid_queue
      torch.cuda.empty_cache()

      count += 1
      new_batch_size = args.batch_size + args.batch_increase * count
      logging.info("new_batch_size = {}".format(new_batch_size))
      train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=new_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

      valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=new_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=2)

      # post validation
      if args.post_val:
        with torch.no_grad():
          post_valid_acc, valid_obj = infer(valid_queue, model, criterion)
          logging.info('post_valid_acc %f', post_valid_acc)

    logging.info('genotype = %s', model.get_genotype(force=True))
    writer.add_scalar('stats/train_acc', train_acc, epoch)
    writer.add_scalar('stats/valid_acc', valid_acc, epoch)
    utils.save(model, os.path.join(args.save, 'weights.pt'))

  logging.info("#" * 30 + " Done " + "#" * 30)
  logging.info('genotype = %s', model.get_genotype())


def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr):
  objs = utils.AverageMeter()
  top1 = utils.AverageMeter()
  top5 = utils.AverageMeter()

  for step, (input, target) in enumerate(train_queue):
    model.train()
    n = input.size(0)

    input = Variable(input, requires_grad=False).cuda()
    target = Variable(target, requires_grad=False).cuda(async=True)

    # get a random minibatch from the search queue with replacement
    input_search, target_search = next(iter(valid_queue))
    input_search = Variable(input_search, requires_grad=False).cuda()
    target_search = Variable(target_search, requires_grad=False).cuda(async=True)

    architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)

    optimizer.zero_grad()
    logits = model(input)
    loss = criterion(logits, target)

    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    optimizer.step()

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    objs.update(loss.item(), n)
    top1.update(prec1.item(), n)
    top5.update(prec5.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

  return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AverageMeter()
  top1 = utils.AverageMeter()
  top5 = utils.AverageMeter()
  model.eval()

  for step, (input, target) in enumerate(valid_queue):
    input = Variable(input).cuda()
    target = Variable(target).cuda(async=True)

    logits = model(input)
    loss = criterion(logits, target)

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    n = input.size(0)
    objs.update(loss.item(), n)
    top1.update(prec1.item(), n)
    top5.update(prec5.item(), n)

    if step % args.report_freq == 0:
      logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

  return top1.avg, objs.avg


if __name__ == '__main__':
  main()