import tensorflow as tf
from functools import reduce
from tensorflow.contrib.layers.python.layers import initializers

'''
Reference: carpedm20 / deep-rl-tensorflow
https://github.com/carpedm20/deep-rl-tensorflow/blob/master/networks/layers.py
'''

def conv2d(x,
           output_dim,
           kernel_size,
           stride,
           weights_initializer=tf.contrib.layers.xavier_initializer(),
           biases_initializer=tf.zeros_initializer,
           activation_fn=tf.nn.relu,
           data_format='NHWC',
           padding='VALID',
           name='conv2d',
           trainable=True):
  with tf.variable_scope(name):
    stride = [1, stride[0], stride[1], 1]
    kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[-1], output_dim]

    w = tf.get_variable('w', kernel_shape, 
        tf.float32, initializer=weights_initializer, trainable=trainable)
    conv = tf.nn.conv2d(x, w, stride, padding, data_format=data_format)

    b = tf.get_variable('b', [output_dim],
        tf.float32, initializer=biases_initializer, trainable=trainable)
    out = tf.nn.bias_add(conv, b, data_format)

  if activation_fn != None:
    out = activation_fn(out)

  return out, w, b

def linear(input_,
           output_size,
           weights_initializer=initializers.xavier_initializer(),
           biases_initializer=tf.zeros_initializer,
           activation_fn=None,
           trainable=True,
           name='linear'):
  shape = input_.get_shape().as_list()

  if len(shape) > 2:
    input_ = tf.reshape(input_, [-1, reduce(lambda x, y: x * y, shape[1:])])
    shape = input_.get_shape().as_list()

  with tf.variable_scope(name):
    w = tf.get_variable('w', [shape[1], output_size], tf.float32,
        initializer=weights_initializer, trainable=trainable)
    b = tf.get_variable('b', [output_size],
        initializer=biases_initializer, trainable=trainable)
    out = tf.nn.bias_add(tf.matmul(input_, w), b)

    if activation_fn != None:
      return activation_fn(out), w, b
    else:
      return out, w, b

def batch_sample(probs, name='batch_sample'):
  with tf.variable_scope(name):
    uniform = tf.random_uniform(tf.shape(probs), minval=0, maxval=1)
    samples = tf.argmax(probs - uniform, dimension=1)
  return samples