import torch
import torch.nn as nn
import numpy as np

from tqdm import tqdm
from torchvision import transforms as T
from torchvision import datasets
from torch.utils.data import DataLoader
from scratchai.utils import freeze, Topk
from scratchai.learners.metrics import accuracy
from scratchai.attacks.attacks import *
from scratchai.imgutils import get_trf
from scratchai.utils import name_from_object
from scratchai._config import CIFAR10, MNIST, IMGNET12


def optimize_linear(grads, eps, ordr):
  """
  Solves for optimal input to a linear function under a norm constraint.

  Arguments
  ---------
  grads : torch.Tensor
           The gradients of the input.
  eps : float
        Scalar specifying the constraint region.
  ordr : [np.inf, 1, 2] 
         Order of norm constraint.

  Returns
  -------
  opt_pert : torch.Tensor
             Optimal Perturbation.
  """

  red_ind = list(range(1, len(grads.size())))
  azdiv = torch.tensor(1e-12, dtype=grads.dtype, device=grads.device)

  if ordr == np.inf:
    opt_pert = torch.sign(grads)

  elif ordr == 1:
    abs_grad = torch.abs(grads)
    sign = torch.sign(grads)
    ori_shape = [1] * len(grads.size())
    ori_shape[0] = grads.size(0)

    max_abs_grad, _ = torch.max(abs_grad.view(grads.size(0), -1), 1)
    max_mask = abs_grad.eq(max_abs_grad.view(ori_shape)).float()
    num_ties = max_mask
    for red_scalar in red_ind:
      num_ties = torch.sum(num_ties, red_scalar, keepdims=True)
    opt_pert = sign * max_mask / num_ties
    # TODO tests

  elif ordr == 2:
    # TODO
    square = torch.max(azdiv, torch.sum(grads ** 2, red_ind, keepdim=True))
    opt_pert = grads / torch.sqrt(square)
    # TODO tests
  else:
    raise NotImplementedError('Only L-inf, L1 and L2 norms are '
                              'currently implemented.')

  scaled_pert = eps * opt_pert
  return scaled_pert


def clip_eta(eta, ord, eps):
  """
  Helper fucntion to clip the perturbation to epsilon norm ball.

  Args:
      eta: A tensor with the current perturbation
      ord: Order of the norm (mimics Numpy)
           Possible values: np.inf, 1 or 2.
      eps: Epsilon, bound of the perturbation.
  """

  # Clipping perturbation eta to self.ord norm ball
  if ord not in [np.inf, 1, 2]:
    raise ValueError('ord must be np.inf, 1, or 2.')

  reduce_ind = list(range(1, len(eta.shape)))
  azdiv = torch.tensor(1e-12)

  if ord == np.inf:
    eta = torch.clamp(eta, -eps, eps)
  else:
    if ord == 1:
      raise NotImplementedError("The expression below is not the correct way"
                                    " to project onto the L1 norm ball.")
      norm = torch.max(azdiv, torch.mean(torch.abs(eta), reduce_ind))

    elif ord == 2:
      # azdiv(avoid_zero_div) must go inside sqrt to avoid a divide by zero
      # in the gradient through this operation.
      norm = torch.sqrt(torch.max(azdiv, torch.mean(eta**2, reduce_ind)))

    # We must clip to within the norm ball, not 'normalize' onto the
    # surface of the ball
    factor = min(1., eps / norm)
    eta *= factor

  return eta



##################################################################
######### Functions to help benchmark attacks ####################
##################################################################

def benchmark_atk(atk, net:nn.Module, **kwargs):
  """
  Helper function to benchmark using a particular attack
  on a particular dataset. All benchmarks that are present in this
  repository are created using this function.

  Arguments
  ---------
  atk : scratchai.attacks.attacks
        The attack on which to use.
  net : nn.Module
        The net which is to be attacked.
  root : str
         The root directory of the dataset.
  dfunc : function
          The function that can take the root and torchvision.transforms
          and return a torchvision.Datasets object
          Defaults to datasets.ImageFolder
  trf : torchvision.Transforms
        The transforms that you want to apply.
        Defaults to (get_trf('rz256_cc224_tt_normimgnet')
  bs : int
       The batch size. Defaults to 4.
  
  """
  
  loader, topk, kwargs = pre_benchmark_atk(**kwargs)

  freeze(net)
  print ('[INFO] Net Frozen!')
  atk = atk(net, **kwargs)
  atk_name = name_from_object(atk)
  net_name = name_from_object(net)

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  loss = 0; adv_loss = 0
  oatopk = Topk('original accuracy', topk)
  aatopk = Topk('adversarial accuracy', topk)
  net.to(device); net.eval()
  crit = nn.CrossEntropyLoss()

  for ii, (data, labl) in enumerate(tqdm(loader)):
    adv_data, data = atk(data.to(device).clone()), data.to(device)
    labl = labl.to(device)
    adv_out = net(adv_data); out = net(data)
    loss += crit(out, labl).item()
    adv_loss += crit(adv_out, labl).item()
    acc = accuracy(out, labl, topk)
    adv_acc = accuracy(adv_out, labl, topk)
    oatopk.update(acc, data.size(0)); aatopk.update(adv_acc, data.size(0))
  loss /= len(loader)
  adv_loss /= len(loader)
  
  print ('\nAttack Summary on {} with {} attack:'.format(net_name, atk_name))
  print ('-'*45)
  print (oatopk)
  print ('-'*35)
  print (aatopk)


def pre_benchmark_atk(**kwargs):
  """
  Helper function that sets all the defaults while performing checks
  for all the options passed before benchmarking attacks.
  """

  # Set the Default options if nothing explicit provided
  def_dict = {
    'bs'       : 4,
    'trf'      : get_trf('rz256_cc224_tt_normimgnet'),
    'dset'     : 'NA',
    'root'     : './',
    'topk'     : (1, 5),
    'dfunc'    : datasets.ImageFolder,
    'download' : True,
  }

  for key, val in def_dict.items(): 
    if key not in kwargs: kwargs[key] = val
  
  if kwargs['dset'] == 'NA':
    if 'loader' not in kwargs:
      dset = kwargs['dfunc'](kwargs['root'], transform=kwargs['trf'])
      loader = DataLoader(dset, batch_size=kwargs['bs'], num_workers=2)
    else:
      loader = kwargs['loader']
  
  # Set dataset specific functions here
  else:
    if kwargs['dset'] == IMGNET12:
      dset = datasets.ImageNet(kwargs['root'], split='test',
                      download=kwargs['download'], transform=kwargs['trf'])
    elif kwargs['dset'] == MNIST:
      kwargs['trf'] = get_trf('tt_normmnist')
      kwargs['dfunc'] = datasets.MNIST
      dset = kwargs['dfunc'](kwargs['root'], train=False, 
                     download=kwargs['download'], transform=kwargs['trf'])
    else: raise

    loader = DataLoader(dset, shuffle=False, batch_size=kwargs['bs'])
  topk = kwargs['topk']
    
  for key, val in kwargs.items():
    print ('[INFO] Setting {} to {}.'.format(key, kwargs[key]))

  # Deleting keys that is used just for benchmark_atk() function is 
  # important as the same kwargs dict is passed to initialize the attack
  # So, otherwise the attack will throw an exception
  for key in def_dict:
    del kwargs[key]
  if 'loader' in kwargs: del kwargs['loader']

  return loader, topk, kwargs