#!/usr/bin/python
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CNN definition helpers.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from lsi.nnutils import helpers as nn_helpers
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.layers.python.layers import utils


def encoder_simple(inp_img, nz=1000, is_training=True, reuse=False):
  """Creates a simple encoder CNN.

  Args:
    inp_img: TensorFlow node for input with size B X H X W X C
    nz: number of units in last layer, default=1000
    is_training: whether batch_norm should be in train mode
    reuse: Whether to reuse weights from an already defined net
  Returns:
    An encoder CNN which computes a final representation with nz
    units.
  """
  batch_norm_params = {'is_training': is_training}
  with tf.variable_scope('encoder', reuse=reuse) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    with slim.arg_scope(
        [slim.conv2d, slim.fully_connected],
        normalizer_fn=slim.batch_norm,
        normalizer_params=batch_norm_params,
        weights_regularizer=slim.l2_regularizer(0.05),
        activation_fn=tf.nn.relu,
        outputs_collections=end_points_collection):
      cnv1 = slim.conv2d(inp_img, 32, [7, 7], stride=2, scope='cnv1')
      cnv1b = slim.conv2d(cnv1, 32, [7, 7], stride=1, scope='cnv1b')
      cnv2 = slim.conv2d(cnv1b, 64, [5, 5], stride=2, scope='cnv2')
      cnv2b = slim.conv2d(cnv2, 64, [5, 5], stride=1, scope='cnv2b')
      cnv3 = slim.conv2d(cnv2b, 128, [3, 3], stride=2, scope='cnv3')
      cnv3b = slim.conv2d(cnv3, 128, [3, 3], stride=1, scope='cnv3b')
      cnv4 = slim.conv2d(cnv3b, 256, [3, 3], stride=2, scope='cnv4')
      cnv4b = slim.conv2d(cnv4, 256, [3, 3], stride=1, scope='cnv4b')
      cnv5 = slim.conv2d(cnv4b, 512, [3, 3], stride=2, scope='cnv5')
      cnv5b = slim.conv2d(cnv5, 512, [3, 3], stride=1, scope='cnv5b')
      cnv6 = slim.conv2d(cnv5b, 512, [3, 3], stride=2, scope='cnv6')
      cnv6b = slim.conv2d(cnv6, 512, [3, 3], stride=1, scope='cnv6b')
      cnv7 = slim.conv2d(cnv6b, 512, [3, 3], stride=2, scope='cnv7')
      cnv7b = slim.conv2d(cnv7, 512, [3, 3], stride=1, scope='cnv7b')
      cnv7b_flat = slim.flatten(cnv7b, scope='cnv7b_flat')
      enc = slim.stack(
          cnv7b_flat, slim.fully_connected, [2 * nz, nz, nz], scope='fc')

    end_points = utils.convert_collection_to_dict(end_points_collection)
    return enc, end_points


def decoder_simple(feat, nconv=7, is_training=True, skip_feat=None,
                   reuse=False):
  """Creates a simple encoder CNN.

  Args:
    feat: Input geatures with size B X nz or B X H X W X nz
    nconv: number of deconv layers
    is_training: whether batch_norm should be in train mode
    skip_feat: additional skip-features per upconv layer
    reuse: Whether to reuse weights from an already defined net
  Returns:
    A decoder CNN which adds nconv upsampling layers
    units.
  """
  batch_norm_params = {'is_training': is_training}
  n_filters = [32, 64, 128, 256]
  if nconv > 4:
    for _ in range(nconv - 4):
      n_filters.append(512)

  with tf.variable_scope('decoder', reuse=reuse) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    with slim.arg_scope(
        [slim.conv2d, slim.conv2d_transpose],
        normalizer_fn=slim.batch_norm,
        normalizer_params=batch_norm_params,
        weights_regularizer=slim.l2_regularizer(0.05),
        activation_fn=tf.nn.relu,
        outputs_collections=end_points_collection):
      if feat.get_shape().ndims == 2:
        feat = tf.expand_dims(tf.expand_dims(feat, 1), 1)
      for nc in range(nconv, 0, -1):
        n_filt = n_filters[nc - 1]
        feat = slim.conv2d_transpose(
            feat, n_filt, [4, 4], stride=2, scope='upcnv' + str(nc))
        if (nc > 1) and (skip_feat is not None):
          feat = tf.concat([feat, skip_feat[-nc + 1]], axis=3)
        feat = slim.conv2d(
            feat, n_filt, [3, 3], stride=1, scope='upcnv' + str(nc) + 'b')

    end_points = utils.convert_collection_to_dict(end_points_collection)
    return feat, end_points


def pixelwise_predictor(feat,
                        nc=3,
                        n_layers=1,
                        n_layerwise_steps=0,
                        skip_feat=None,
                        reuse=False,
                        is_training=True):
  """Predicts texture images and probilistic masks.

  Args:
    feat: B X H X W X C feature vectors
    nc: number of output channels
    n_layers: number of plane equations to predict (denoted as L)
    n_layerwise_steps: Number of independent per-layer up-conv steps
    skip_feat: List of features useful for skip connections. Used if lws>0.
    reuse: Whether to reuse weights from an already defined net
    is_training: whether batch_norm should be in train mode
  Returns:
    textures : L X B X H X W X nc.
  """
  with tf.variable_scope('pixelwise_pred', reuse=reuse) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    with slim.arg_scope(
        [slim.conv2d],
        normalizer_fn=None,
        weights_regularizer=slim.l2_regularizer(0.05),
        activation_fn=tf.nn.sigmoid,
        outputs_collections=end_points_collection):
      preds = []
      for l in range(n_layers):
        with tf.variable_scope('upsample_' + str(l), reuse=reuse):
          feat_l, _ = decoder_simple(
              feat,
              nconv=n_layerwise_steps,
              skip_feat=skip_feat,
              reuse=reuse,
              is_training=is_training)
          pred = slim.conv2d(
              feat_l, nc, [3, 3], stride=1, scope='pred_' + str(l))
          preds.append(pred)

      end_points = utils.convert_collection_to_dict(end_points_collection)
      preds = tf.stack(preds, axis=0)

      return preds, end_points


def ldi_predictor(feat,
                  n_layers=1,
                  reuse=False,
                  n_layerwise_steps=0,
                  skip_feat=None,
                  pred_masks=False,
                  is_training=True):
  """Predicts ldi : [textures, masks, disps].

  Args:
    feat: B X H X W X C feature vectors
    n_layers: number of layers to predict (denoted as L)
    reuse: Whether to reuse weights from an already defined net
    n_layerwise_steps: Number of independent per-layer up-conv steps
    skip_feat: List of features useful for skip connections. Used if lws>0.
    pred_masks: Whether to predict masks or use all 1s
    is_training: whether batch_norm should be in train mode
  Returns:
    ldi : [textures, masks, disps]
        textures : L X B X H X W X nc.
        masks : L X B X H X W X 1 (all ones)
        textures : L X B X H X W X 1
  """
  with tf.variable_scope('ldi_tex_disp', reuse=reuse):
    nc = 3 + 1
    if pred_masks:
      nc += 1
    tex_disp_pred, _ = pixelwise_predictor(
        feat,
        nc=nc,
        n_layers=n_layers,
        n_layerwise_steps=n_layerwise_steps,
        skip_feat=skip_feat,
        reuse=reuse,
        is_training=is_training)
    if pred_masks:
      tex_pred, masks_ldi, disps_pred = tf.split(
          tex_disp_pred, [3, 1, 1], axis=4)
      masks_ldi = nn_helpers.enforce_bg_occupied(tf.nn.sigmoid(masks_ldi))
    else:
      tex_pred, disps_pred = tf.split(tex_disp_pred, [3, 1], axis=4)
      masks_ldi = tf.ones(disps_pred.get_shape())

    ldi = [tex_pred, masks_ldi, disps_pred]
    return ldi


def encoder_decoder_simple(inp_img,
                           nz=1000,
                           nupconv=8,
                           is_training=True,
                           reuse=False,
                           nl_diff_enc_dec=0):
  """Creates a simple encoder-decoder CNN.

  Args:
    inp_img: TensorFlow node for input with size B X H X W X C
    nz: number of units in last layer, default=1000
    nupconv: number of upconv layers in the deocder
    is_training: whether batch_norm should be in train mode
    reuse: Whether to reuse weights from an already defined net
    nl_diff_enc_dec: Number of dec layers are nupconv - nl_diff_enc_dec
  Returns:
    feat: A bottleneck representation with nz units.
    feat_dec: features of the same size as the image.
    skip_feat: initial layer features useful for layerwise steps
    end_points: intermediate activations
  """
  feat, enc_intermediate = encoder_simple(
      inp_img, is_training=is_training, nz=nz, reuse=reuse)
  feat_dec, dec_intermediate = decoder_simple(
      feat,
      nconv=nupconv - nl_diff_enc_dec,
      is_training=is_training,
      reuse=reuse)
  enc_dec_int = dict(enc_intermediate, **dec_intermediate)
  skip_feat = None
  return feat, feat_dec, skip_feat, enc_dec_int


def encoder_decoder_unet(inp_img,
                         nz=1000,
                         is_training=True,
                         reuse=False,
                         nl_diff_enc_dec=0):
  """Creates a Unet-like CNN with + features extracted from bottleneck.

  Args:
    inp_img: TensorFlow node for input with size B X H X W X C
    nz: number of units in last layer, default=1000
    is_training: whether batch_norm should be in train mode
    reuse: Whether to reuse weights from an already defined net
    nl_diff_enc_dec: Number of dec layers are num_enc_layers - nl_diff_enc_dec
  Returns:
    feat: A bottleneck representation with nz units.
    icnv1: features of the same size as the image / 2^(nl_diff_enc_dec).
    skip_feat: initial layer features useful for layerwise steps
    end_points: intermediate activations
  """
  batch_norm_params = {'is_training': is_training}
  with tf.variable_scope('encoder_decoder_unet', reuse=reuse) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    with slim.arg_scope(
        [slim.conv2d, slim.conv2d_transpose, slim.fully_connected],
        normalizer_fn=slim.batch_norm,
        normalizer_params=batch_norm_params,
        weights_regularizer=slim.l2_regularizer(0.05),
        activation_fn=tf.nn.relu,
        outputs_collections=end_points_collection):
      cnv1 = slim.conv2d(inp_img, 32, [7, 7], stride=2, scope='cnv1')
      cnv1b = slim.conv2d(cnv1, 32, [7, 7], stride=1, scope='cnv1b')
      cnv2 = slim.conv2d(cnv1b, 64, [5, 5], stride=2, scope='cnv2')
      cnv2b = slim.conv2d(cnv2, 64, [5, 5], stride=1, scope='cnv2b')
      cnv3 = slim.conv2d(cnv2b, 128, [3, 3], stride=2, scope='cnv3')
      cnv3b = slim.conv2d(cnv3, 128, [3, 3], stride=1, scope='cnv3b')
      cnv4 = slim.conv2d(cnv3b, 256, [3, 3], stride=2, scope='cnv4')
      cnv4b = slim.conv2d(cnv4, 256, [3, 3], stride=1, scope='cnv4b')
      cnv5 = slim.conv2d(cnv4b, 512, [3, 3], stride=2, scope='cnv5')
      cnv5b = slim.conv2d(cnv5, 512, [3, 3], stride=1, scope='cnv5b')
      cnv6 = slim.conv2d(cnv5b, 512, [3, 3], stride=2, scope='cnv6')
      cnv6b = slim.conv2d(cnv6, 512, [3, 3], stride=1, scope='cnv6b')
      cnv7 = slim.conv2d(cnv6b, 512, [3, 3], stride=2, scope='cnv7')
      cnv7b = slim.conv2d(cnv7, 512, [3, 3], stride=1, scope='cnv7b')

      ## features via fc layers on bottleneck
      cnv7b_flat = slim.flatten(cnv7b, scope='cnv7b_flat')
      feat = slim.stack(
          cnv7b_flat, slim.fully_connected, [2 * nz, nz, nz], scope='fc')

      feats_dec = []  # decoded features at different layers
      skip_feat = []  # initial layer features useful for layerwise steps

      upcnv7 = slim.conv2d_transpose(
          cnv7b, 512, [4, 4], stride=2, scope='upcnv7')
      # There might be dimension mismatch due to uneven down/up-sampling
      # upcnv7 = resize_like(upcnv7, cnv6b)
      i7_in = tf.concat([upcnv7, cnv6b], axis=3)
      icnv7 = slim.conv2d(i7_in, 512, [3, 3], stride=1, scope='icnv7')
      feats_dec.append(icnv7)
      skip_feat.append(cnv6b)

      upcnv6 = slim.conv2d_transpose(
          icnv7, 512, [4, 4], stride=2, scope='upcnv6')
      # upcnv6 = resize_like(upcnv6, cnv5b)
      i6_in = tf.concat([upcnv6, cnv5b], axis=3)
      icnv6 = slim.conv2d(i6_in, 512, [3, 3], stride=1, scope='icnv6')
      feats_dec.append(icnv6)
      skip_feat.append(cnv5b)

      upcnv5 = slim.conv2d_transpose(
          icnv6, 256, [4, 4], stride=2, scope='upcnv5')
      # upcnv5 = resize_like(upcnv5, cnv4b)
      i5_in = tf.concat([upcnv5, cnv4b], axis=3)
      icnv5 = slim.conv2d(i5_in, 256, [3, 3], stride=1, scope='icnv5')
      feats_dec.append(icnv5)
      skip_feat.append(cnv4b)

      upcnv4 = slim.conv2d_transpose(
          icnv5, 128, [4, 4], stride=2, scope='upcnv4')
      i4_in = tf.concat([upcnv4, cnv3b], axis=3)
      icnv4 = slim.conv2d(i4_in, 128, [3, 3], stride=1, scope='icnv4')
      feats_dec.append(icnv4)
      skip_feat.append(cnv3b)

      upcnv3 = slim.conv2d_transpose(
          icnv4, 64, [4, 4], stride=2, scope='upcnv3')
      i3_in = tf.concat([upcnv3, cnv2b], axis=3)
      icnv3 = slim.conv2d(i3_in, 64, [3, 3], stride=1, scope='icnv3')
      feats_dec.append(icnv3)
      skip_feat.append(cnv2b)

      upcnv2 = slim.conv2d_transpose(
          icnv3, 32, [4, 4], stride=2, scope='upcnv2')
      i2_in = tf.concat([upcnv2, cnv1b], axis=3)
      icnv2 = slim.conv2d(i2_in, 32, [3, 3], stride=1, scope='icnv2')
      feats_dec.append(icnv2)
      skip_feat.append(cnv1b)

      upcnv1 = slim.conv2d_transpose(
          icnv2, 32, [4, 4], stride=2, scope='upcnv1')
      icnv1 = slim.conv2d(upcnv1, 32, [3, 3], stride=1, scope='icnv1')
      feats_dec.append(icnv1)

      end_points = utils.convert_collection_to_dict(end_points_collection)
      return feat, feats_dec[-1 - nl_diff_enc_dec], skip_feat, end_points