import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from model import NetworkImageNet
from genotypes import PNASNet
from operations import *
from utils import preprocess_for_eval

import sys
import os
sys.path.append('../PNASNet.TF')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import tensorflow as tf
from pnasnet import build_pnasnet_large, pnasnet_large_arg_scope
slim = tf.contrib.slim


class ConvertPNASNet(object):

  def __init__(self):
    self.image = Image.open('data/cat.jpg')
    self.read_tf_weight()
    self.write_pytorch_weight()

  def read_tf_weight(self):
    self.weight_dict = {}
    image_ph = tf.placeholder(tf.uint8, (None, None, 3))
    image_proc = preprocess_for_eval(image_ph, 323, 323)
    with slim.arg_scope(pnasnet_large_arg_scope()):
      logits, end_points = build_pnasnet_large(
          tf.expand_dims(image_proc, 0), num_classes=1001, is_training=False)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    ckpt_restorer = tf.train.Saver()
    ckpt_restorer.restore(sess, '../PNASNet.TF/data/model.ckpt')

    weight_keys = [var.name[:-2] for var in tf.global_variables()]
    weight_vals = sess.run(tf.global_variables())
    for weight_key, weight_val in zip(weight_keys, weight_vals):
      self.weight_dict[weight_key] = weight_val

    self.tf_logits, self.tf_end_points, self.tf_image_proc = sess.run(
        [logits, end_points, image_proc], feed_dict={image_ph: self.image})

  def write_pytorch_weight(self):
    model = NetworkImageNet(216, 1001, 12, False, PNASNet)
    model.drop_path_prob = 0
    model.eval()

    self.used_keys = []
    self.convert_conv(model.conv0, 'conv0/weights')
    self.convert_bn(model.conv0_bn, 'conv0_bn/gamma', 'conv0_bn/beta',
        'conv0_bn/moving_mean', 'conv0_bn/moving_variance')
    self.convert_cell(model.stem1, 'cell_stem_0/')
    self.convert_cell(model.stem2, 'cell_stem_1/')

    for i in range(12):
      self.convert_cell(model.cells[i], 'cell_{}/'.format(i))
    
    self.convert_fc(model.classifier, 'final_layer/FC/weights',
        'final_layer/FC/biases')

    print('Conversion complete!')
    print('Check 1: whether all TF variables are used...')
    assert len(self.weight_dict) == len(self.used_keys)
    print('Pass!')

    model = model.cuda()
    image = self.tf_image_proc.transpose((2, 0, 1))
    image = Variable(self.Tensor(image)).cuda()
    logits, _ = model(image.unsqueeze(0))
    self.pytorch_logits = logits.data.cpu().numpy()

    print('Check 2: whether logits have small diff...')
    assert np.max(np.abs(self.tf_logits - self.pytorch_logits)) < 1e-5
    print('Pass!')

    model_path = 'data/PNASNet-5_Large.pth'
    torch.save(model.state_dict(), model_path)
    print('PyTorch model saved to {}'.format(model_path))

  def convert_cell(self, cell, name):
    # cell.preprocess0
    assert isinstance(cell.preprocess0, FactorizedReduce) or isinstance(cell.preprocess0, ReLUConvBN) or isinstance(cell.preprocess0, Identity)
    if isinstance(cell.preprocess0, FactorizedReduce):
      self.convert_conv(cell.preprocess0.conv_1, name + 'path1_conv/weights')
      self.convert_conv(cell.preprocess0.conv_2, name + 'path2_conv/weights')
      self.convert_bn(cell.preprocess0.bn, name + 'final_path_bn/gamma',
          name + 'final_path_bn/beta', name + 'final_path_bn/moving_mean',
          name + 'final_path_bn/moving_variance')
    else:
      if name + 'prev_1x1/weights' in self.weight_dict:
        self.convert_conv(cell.preprocess0.op[1], name + 'prev_1x1/weights')
        self.convert_bn(cell.preprocess0.op[2], name + 'prev_bn/gamma',
            name + 'prev_bn/beta', name + 'prev_bn/moving_mean',
            name + 'prev_bn/moving_variance')
      # else preprocess0 is Identity or = preprocess1; do nothing

    # cell.preprocess1
    assert isinstance(cell.preprocess1, ReLUConvBN)
    self.convert_conv(cell.preprocess1.op[1], name + '1x1/weights')
    self.convert_bn(cell.preprocess1.op[2], name + 'beginning_bn/gamma',
        name + 'beginning_bn/beta', name + 'beginning_bn/moving_mean',
        name + 'beginning_bn/moving_variance')

    # cell._ops
    for i in range(len(cell._ops)):
      side = 'left/' if i % 2 == 0 else 'right/'
      prefix = name + 'comb_iter_{}/'.format(i // 2) + side
      if isinstance(cell._ops[i], SepConv):
        suffix = '{0}x{0}'.format(cell._ops[i].op[1].kernel_size[0])
        
        self.convert_conv(cell._ops[i].op[1], 
            prefix + 'separable_' + suffix + '_1/depthwise_weights', sep=True)
        self.convert_conv(cell._ops[i].op[2],
            prefix + 'separable_' + suffix + '_1/pointwise_weights', sep=False)
        self.convert_bn(cell._ops[i].op[3],
            prefix + 'bn_sep_' + suffix + '_1/gamma',
            prefix + 'bn_sep_' + suffix + '_1/beta',
            prefix + 'bn_sep_' + suffix + '_1/moving_mean',
            prefix + 'bn_sep_' + suffix + '_1/moving_variance')
        self.convert_conv(cell._ops[i].op[5],
            prefix + 'separable_' + suffix + '_2/depthwise_weights', sep=True)
        self.convert_conv(cell._ops[i].op[6],
            prefix + 'separable_' + suffix + '_2/pointwise_weights', sep=False)
        self.convert_bn(cell._ops[i].op[7],
            prefix + 'bn_sep_' + suffix + '_2/gamma',
            prefix + 'bn_sep_' + suffix + '_2/beta',
            prefix + 'bn_sep_' + suffix + '_2/moving_mean',
            prefix + 'bn_sep_' + suffix + '_2/moving_variance')
      elif isinstance(cell._ops[i], ReLUConvBN):
        # skip_connect with stride > 1
        self.convert_conv(cell._ops[i].op[1], prefix + '1x1/weights')
        self.convert_bn(cell._ops[i].op[2],
            prefix + 'bn_1/gamma', prefix + 'bn_1/beta',
            prefix + 'bn_1/moving_mean', prefix + 'bn_1/moving_variance')
      elif isinstance(cell._ops[i], nn.Sequential):
        # max_pool or avg_pool with C_in != C_out
        self.convert_conv(cell._ops[i][1], prefix + '1x1/weights')
        self.convert_bn(cell._ops[i][2], 
            prefix + 'bn_1/gamma', prefix + 'bn_1/beta',
            prefix + 'bn_1/moving_mean', prefix + 'bn_1/moving_variance')

  def convert_conv(self, conv2d, weights_key, sep=False):
    weights = self.weight_dict[weights_key]
    if sep:
      # TF: [filter_height, filter_width, in_channels, channel_multiplier]
      # TF: [1, 1, channel_multiplier * in_channels, channel_multiplier]
      # PyTorch: [out_channels, in_channels // groups, *kernel_size]
      weights = np.transpose(weights, (2, 3, 0, 1))
    else:
      # TF: [filter_height, filter_width, in_channels, out_channels]
      # PyTorch: [out_channels, in_channels, *kernel_size]
      weights = np.transpose(weights, (3, 2, 0, 1))
    assert conv2d.weight.shape == self.Param(weights).shape, '{0} vs {1}'.format(conv2d.weight.shape, self.Param(weights).shape)
    conv2d.weight = self.Param(weights)
    self.used_keys += [weights_key]

  def convert_bn(self, bn, gamma_key, beta_key, moving_mean_key, moving_var_key):
    gamma = self.weight_dict[gamma_key]
    beta = self.weight_dict[beta_key]
    moving_mean = self.weight_dict[moving_mean_key]
    moving_var = self.weight_dict[moving_var_key]
    assert bn.weight.shape == self.Param(gamma).shape
    assert bn.bias.shape == self.Param(beta).shape
    assert bn.running_mean.shape == self.Tensor(moving_mean).shape
    assert bn.running_var.shape == self.Tensor(moving_var).shape
    bn.weight = self.Param(gamma)
    bn.bias = self.Param(beta)
    bn.running_mean = self.Tensor(moving_mean)
    bn.running_var = self.Tensor(moving_var)
    self.used_keys += [gamma_key, beta_key, moving_mean_key, moving_var_key]

  def convert_fc(self, fc, weights_key, biases_key):
    weights = self.weight_dict[weights_key]
    biases = self.weight_dict[biases_key]
    weights = np.transpose(weights)
    assert fc.weight.shape == self.Param(weights).shape
    assert fc.bias.shape == self.Param(biases).shape
    fc.weight = self.Param(weights)
    fc.bias = self.Param(biases)
    self.used_keys += [weights_key, biases_key]

  def Param(self, x):
    return torch.nn.Parameter(torch.from_numpy(x))

  def Tensor(self, x):
    return torch.from_numpy(x)


if __name__ == '__main__':
  ConvertPNASNet()