# coding=utf-8
# Copyright 2020 The Tensor2Robot Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Reused modules for building actors/critics for grasping task.
"""

import gin
import tensorflow.compat.v1 as tf
from tensorflow.contrib import slim


@gin.configurable
def argscope(is_training=None, normalizer_fn=slim.layer_norm):
  """Default TF argscope used for convnet-based grasping models.

  Args:
    is_training: Whether this argscope is for training or inference.
    normalizer_fn: Which conv/fc normalizer to use.
  Returns:
    Dictionary of argument overrides.
  """
  with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
    with slim.arg_scope(
        [slim.conv2d, slim.fully_connected],
        weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
        activation_fn=tf.nn.relu,
        normalizer_fn=normalizer_fn):
      with slim.arg_scope(
          [slim.conv2d, slim.max_pool2d], stride=2, padding='VALID') as scope:
        return scope


def tile_to_match_context(net, context):
  """Tiles net along a new axis=1 to match context.

  Repeats minibatch elements of `net` tensor to match multiple corresponding
  minibatch elements from `context`.
  Args:
    net: Tensor of shape [num_batch_net, ....].
    context: Tensor of shape [num_batch_net, num_examples, context_size].
  Returns:
    Tensor of shape [num_batch_net, num_examples, ...], where each minibatch
    element of net has been tiled M times where M = num_batch_context /
    num_batch_net.
  """
  with tf.name_scope('tile_to_context'):
    num_samples = tf.shape(context)[1]
    net_examples = tf.expand_dims(net, 1)  # [batch_size, 1, ...]

    net_ndim = len(net_examples.get_shape().as_list())
    # Tile net by num_samples in axis=1.
    multiples = [1]*net_ndim
    multiples[1] = num_samples
    net_examples = tf.tile(net_examples, multiples)
  return net_examples


def add_context(net, context):
  """Merges visual perception with context using elementwise addition.

  Actions are reshaped to match net dimension depth-wise, and are added to
  the conv layers by broadcasting element-wise across H, W extent.

  Args:
    net: Tensor of shape [batch_size, H, W, C].
    context: Tensor of shape [batch_size * num_examples, C].
  Returns:
    Tensor with shape [batch_size * num_examples, H, W, C]
  """
  num_batch_net = tf.shape(net)[0]
  _, h, w, d1 = net.get_shape().as_list()
  _, d2 = context.get_shape().as_list()
  assert d1 == d2
  context = tf.reshape(context, [num_batch_net, -1, d2])
  net_examples = tile_to_match_context(net, context)
  # Flatten first two dimensions.
  net = tf.reshape(net_examples, [-1, h, w, d1])
  context = tf.reshape(context, [-1, 1, 1, d2])
  context = tf.tile(context, [1, h, w, 1])
  net = tf.add_n([net, context])
  return net