"""
Copyright 2018 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import shutil
import numpy as np
import tensorflow as tf
from sklearn import linear_model
from six.moves import range
from tcav.cav import CAV, get_or_train_cav
from tensorflow.python.platform import flags
from tensorflow.python.platform import googletest

FLAGS = flags.FLAGS
flags.DEFINE_string(name='tcav_test_tmpdir', default='/tmp',
                    help='Temporary directory for test files')

class CavTest(googletest.TestCase):

  def setUp(self):
    """Makes a cav instance and writes it to tmp direcotry.

    The cav instance uses preset values.
    """
    self.hparams = tf.contrib.training.HParams(
      model_type='linear', alpha=.01, max_iter=1000, tol=1e-3)
    self.concepts = ['concept1', 'concept2']
    self.bottleneck = 'bottleneck'
    self.accuracies = {'concept1': 0.8, 'concept2': 0.5, 'overall': 0.65}
    self.cav_vecs = [[1, 2, 3], [4, 5, 6]]

    self.test_subdirectory = os.path.join(FLAGS.tcav_test_tmpdir, 'test')
    self.cav_dir = self.test_subdirectory
    self.cav_file_name = CAV.cav_key(self.concepts, self.bottleneck,
                                         self.hparams.model_type,
                                         self.hparams.alpha) + '.pkl'
    self.save_path = os.path.join(self.cav_dir, self.cav_file_name)
    self.cav = CAV(self.concepts, self.bottleneck, self.hparams)
    # pretend that it was trained and cavs are stored
    self.cav.cavs = np.array(self.cav_vecs)
    shape = (1, 3)
    self.acts = {
        concept: {
            self.bottleneck: np.tile(i * np.ones(shape), (4, 1))
        } for i, concept in enumerate(self.concepts)
    }

    if os.path.exists(self.cav_dir):
      shutil.rmtree(self.cav_dir)
    os.mkdir(self.cav_dir)
    with tf.io.gfile.GFile(self.save_path, 'w') as pkl_file:
      pickle.dump({
          'concepts': self.concepts,
          'bottleneck': self.bottleneck,
          'hparams': self.hparams,
          'accuracies': self.accuracies,
          'cavs': self.cav_vecs,
          'saved_path': self.save_path
      }, pkl_file)

  def test_default_hparams(self):
    hparam = CAV.default_hparams()
    self.assertEqual(hparam.alpha, 0.01)
    self.assertEqual(hparam.model_type, 'linear')

  def test_load_cav(self):
    """Load up the cav file written in setup function and check values.
    """
    cav_instance = CAV.load_cav(self.save_path)
    self.assertEqual(cav_instance.concepts, self.concepts)
    self.assertEqual(cav_instance.cavs, self.cav_vecs)

  def test_cav_key(self):
    self.assertEqual(
        self.cav.cav_key(self.concepts, self.bottleneck,
                         self.hparams.model_type, self.hparams.alpha),
        '-'.join(self.concepts) + '-' + self.bottleneck + '-' +
        self.hparams.model_type + '-' + str(self.hparams.alpha))

  def test_check_cav_exists(self):
    exists = self.cav.check_cav_exists(self.cav_dir, self.concepts,
                                       self.bottleneck, self.hparams)
    self.assertTrue(exists)

  def test__create_cav_training_set(self):
    x, labels, labels2text = self.cav._create_cav_training_set(
        self.concepts, self.bottleneck, self.acts)
    # check values of some elements.
    self.assertEqual(x[0][0], 0)
    self.assertEqual(x[5][0], 1)
    self.assertEqual(labels[0], 0)
    self.assertEqual(labels[5], 1)
    self.assertEqual(labels2text[0], 'concept1')

  def test_perturb_act(self):
    perturbed = self.cav.perturb_act(
        np.array([1., 0, 1.]), 'concept1', operation=np.add, alpha=1.0)
    self.assertEqual(2., perturbed[0])
    self.assertEqual(2., perturbed[1])
    self.assertEqual(4., perturbed[2])

  def test_get_key(self):
    self.assertEqual(
        CAV.cav_key(self.concepts, self.bottleneck, self.hparams.model_type,
                        self.hparams.alpha),
        '-'.join([str(c) for c in self.concepts]) + '-' + self.bottleneck + '-'
        + self.hparams.model_type + '-' + str(self.hparams.alpha))

  def test_get_direction(self):
    idx_concept1 = self.cav.concepts.index('concept1')
    cav_directly_from_member = self.cav.cavs[idx_concept1]
    cav_via_get_direction = self.cav.get_direction('concept1')
    for i in range(len(cav_directly_from_member)):
      self.assertEqual(cav_directly_from_member[i], cav_via_get_direction[i])

  def test_train(self):
    self.cav.train({c: self.acts[c] for c in self.concepts})
    # check values of some elements.
    # the two coefficients of the classifier must be negative.
    self.assertLess(self.cav.cavs[0][0] * self.cav.cavs[1][0], 0)

  def test__train_lm(self):
    lm = linear_model.SGDClassifier(alpha=self.hparams.alpha)
    acc = self.cav._train_lm(lm, np.array([[0], [0], [0], [1], [1], [1]]),
                             np.array([0, 0, 0, 1, 1, 1]), {
                                 0: 0,
                                 1: 1
                             })
    # the given data is so easy it should get this almost perfect.
    self.assertGreater(acc[0], 0.99)
    self.assertGreater(acc[1], 0.99)

  def test_get_or_train_cav_save_test(self):
    cav_instance = get_or_train_cav(
        self.concepts,
        self.bottleneck,
        self.acts,
        cav_dir=self.cav_dir,
        cav_hparams=self.hparams)
    # check values of some elements.
    self.assertEqual(cav_instance.cavs[0][0], self.cav_vecs[0][0])
    self.assertEqual(cav_instance.cavs[1][2], self.cav_vecs[1][2])


if __name__ == '__main__':
  googletest.main()