# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import torch
import torch_dct


def log_safe(x):
  """The same as torch.log(x), but clamps the input to prevent NaNs."""
  x = torch.as_tensor(x)
  return torch.log(torch.min(x, torch.tensor(33e37).to(x)))


def log1p_safe(x):
  """The same as torch.log1p(x), but clamps the input to prevent NaNs."""
  x = torch.as_tensor(x)
  return torch.log1p(torch.min(x, torch.tensor(33e37).to(x)))


def exp_safe(x):
  """The same as torch.exp(x), but clamps the input to prevent NaNs."""
  x = torch.as_tensor(x)
  return torch.exp(torch.min(x, torch.tensor(87.5).to(x)))


def expm1_safe(x):
  """The same as tf.math.expm1(x), but clamps the input to prevent NaNs."""
  x = torch.as_tensor(x)
  return torch.expm1(torch.min(x, torch.tensor(87.5).to(x)))


def inv_softplus(y):
  """The inverse of tf.nn.softplus()."""
  y = torch.as_tensor(y)
  return torch.where(y > 87.5, y, torch.log(torch.expm1(y)))


def logit(y):
  """The inverse of tf.nn.sigmoid()."""
  y = torch.as_tensor(y)
  return -torch.log(1. / y - 1.)


def affine_sigmoid(logits, lo=0, hi=1):
  """Maps reals to (lo, hi), where 0 maps to (lo+hi)/2."""
  if not lo < hi:
    raise ValueError('`lo` (%g) must be < `hi` (%g)' % (lo, hi))
  logits = torch.as_tensor(logits)
  lo = torch.as_tensor(lo)
  hi = torch.as_tensor(hi)
  alpha = torch.sigmoid(logits) * (hi - lo) + lo
  return alpha


def inv_affine_sigmoid(probs, lo=0, hi=1):
  """The inverse of affine_sigmoid(., lo, hi)."""
  if not lo < hi:
    raise ValueError('`lo` (%g) must be < `hi` (%g)' % (lo, hi))
  probs = torch.as_tensor(probs)
  lo = torch.as_tensor(lo)
  hi = torch.as_tensor(hi)
  logits = logit((probs - lo) / (hi - lo))
  return logits


def affine_softplus(x, lo=0, ref=1):
  """Maps real numbers to (lo, infinity), where 0 maps to ref."""
  if not lo < ref:
    raise ValueError('`lo` (%g) must be < `ref` (%g)' % (lo, ref))
  x = torch.as_tensor(x)
  lo = torch.as_tensor(lo)
  ref = torch.as_tensor(ref)
  shift = inv_softplus(torch.tensor(1.))
  y = (ref - lo) * torch.nn.Softplus()(x + shift) + lo
  return y


def inv_affine_softplus(y, lo=0, ref=1):
  """The inverse of affine_softplus(., lo, ref)."""
  if not lo < ref:
    raise ValueError('`lo` (%g) must be < `ref` (%g)' % (lo, ref))
  y = torch.as_tensor(y)
  lo = torch.as_tensor(lo)
  ref = torch.as_tensor(ref)
  shift = inv_softplus(torch.tensor(1.))
  x = inv_softplus((y - lo) / (ref - lo)) - shift
  return x


def students_t_nll(x, df, scale):
  """The NLL of a Generalized Student's T distribution (w/o including TFP)."""
  x = torch.as_tensor(x)
  df = torch.as_tensor(df)
  scale = torch.as_tensor(scale)
  log_partition = torch.log(torch.abs(scale)) + torch.lgamma(
      0.5 * df) - torch.lgamma(0.5 * df + torch.tensor(0.5)) + torch.tensor(
          0.5 * np.log(np.pi))
  return 0.5 * ((df + 1.) * torch.log1p(
      (x / scale)**2. / df) + torch.log(df)) + log_partition


# A constant scale that makes tf.image.rgb_to_yuv() volume preserving.
_VOLUME_PRESERVING_YUV_SCALE = 1.580227820074


def rgb_to_syuv(rgb):
  """A volume preserving version of tf.image.rgb_to_yuv().

  By "volume preserving" we mean that rgb_to_syuv() is in the "special linear
  group", or equivalently, that the Jacobian determinant of the transformation
  is 1.

  Args:
    rgb: A tensor whose last dimension corresponds to RGB channels and is of
      size 3.

  Returns:
    A scaled YUV version of the input tensor, such that this transformation is
    volume-preserving.
  """
  rgb = torch.as_tensor(rgb)
  kernel = torch.tensor([[0.299, -0.14714119, 0.61497538],
                         [0.587, -0.28886916, -0.51496512],
                         [0.114, 0.43601035, -0.10001026]]).to(rgb)
  yuv = torch.reshape(
      torch.matmul(torch.reshape(rgb, [-1, 3]), kernel), rgb.shape)
  return _VOLUME_PRESERVING_YUV_SCALE * yuv


def syuv_to_rgb(yuv):
  """A volume preserving version of tf.image.yuv_to_rgb().

  By "volume preserving" we mean that rgb_to_syuv() is in the "special linear
  group", or equivalently, that the Jacobian determinant of the transformation
  is 1.

  Args:
    yuv: A tensor whose last dimension corresponds to scaled YUV channels and is
      of size 3 (ie, the output of rgb_to_syuv()).

  Returns:
    An RGB version of the input tensor, such that this transformation is
    volume-preserving.
  """
  yuv = torch.as_tensor(yuv)
  kernel = torch.tensor([[1, 1, 1], [0, -0.394642334, 2.03206185],
                         [1.13988303, -0.58062185, 0]]).to(yuv)
  rgb = torch.reshape(
      torch.matmul(torch.reshape(yuv, [-1, 3]), kernel), yuv.shape)
  return rgb / _VOLUME_PRESERVING_YUV_SCALE


def image_dct(image):
  """Does a type-II DCT (aka "The DCT") on axes 1 and 2 of a rank-3 tensor."""
  image = torch.as_tensor(image)
  dct_y = torch.transpose(torch_dct.dct(image, norm='ortho'), 1, 2)
  dct_x = torch.transpose(torch_dct.dct(dct_y, norm='ortho'), 1, 2)
  return dct_x


def image_idct(dct_x):
  """Inverts image_dct(), by performing a type-III DCT."""
  dct_x = torch.as_tensor(dct_x)
  dct_y = torch_dct.idct(torch.transpose(dct_x, 1, 2), norm='ortho')
  image = torch_dct.idct(torch.transpose(dct_y, 1, 2), norm='ortho')
  return image


def compute_jacobian(f, x):
  """Computes the Jacobian of function `f` with respect to input `x`."""
  vec = lambda z: torch.reshape(z, [-1])
  jacobian = []
  for i in range(np.prod(x.shape)):
    var_x = torch.autograd.Variable(torch.tensor(x), requires_grad=True)
    y = vec(f(var_x))[i]
    y.backward()
    jacobian.append(np.array(vec(var_x.grad)))
  jacobian = np.stack(jacobian, 1)
  return jacobian