# coding=utf-8
# Copyright 2018 Google LLC & Hwalsuk Lee.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for GANs with different regularizers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags
from absl.testing import parameterized
from compare_gan import datasets
from compare_gan import test_utils
from compare_gan.gans import consts as c
from compare_gan.gans import loss_lib
from compare_gan.gans import penalty_lib
from compare_gan.gans.modular_gan import ModularGAN
import gin
import tensorflow as tf


FLAGS = flags.FLAGS
TEST_ARCHITECTURES = [c.RESNET5_ARCH, c.RESNET_BIGGAN_ARCH, c.RESNET_CIFAR_ARCH]
TEST_LOSSES = [loss_lib.non_saturating, loss_lib.wasserstein,
               loss_lib.least_squares, loss_lib.hinge]
TEST_PENALTIES = [penalty_lib.no_penalty, penalty_lib.dragan_penalty,
                  penalty_lib.wgangp_penalty, penalty_lib.l2_penalty]


class ModularGANConditionalTest(parameterized.TestCase,
                                test_utils.CompareGanTestCase):

  def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
                             labeled_dataset):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        conditional=True,
        model_dir=model_dir)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)

  @parameterized.parameters(TEST_ARCHITECTURES)
  def testSingleTrainingStepArchitectures(self, architecture):
    self._runSingleTrainingStep(architecture, loss_lib.hinge,
                                penalty_lib.no_penalty, True)

  @parameterized.parameters(TEST_LOSSES)
  def testSingleTrainingStepLosses(self, loss_fn):
    self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_fn,
                                penalty_lib.no_penalty, labeled_dataset=True)

  @parameterized.parameters(TEST_PENALTIES)
  def testSingleTrainingStepPenalties(self, penalty_fn):
    self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_lib.hinge, penalty_fn,
                                labeled_dataset=True)

  def testUnlabledDatasetRaisesError(self):
    parameters = {
        "architecture": c.RESNET_CIFAR_ARCH,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("loss.fn", loss_lib.hinge)
    # Use dataset without labels.
    dataset = datasets.get_dataset("celeb_a")
    model_dir = self._get_empty_model_dir()
    with self.assertRaises(ValueError):
      gan = ModularGAN(
          dataset=dataset,
          parameters=parameters,
          conditional=True,
          model_dir=model_dir)
      del gan


if __name__ == "__main__":
  tf.test.main()