#!/usr/bin/env python
# coding=utf-8
from __future__ import division, print_function, unicode_literals

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim

from sacred import Ingredient
from tensorflow.contrib.rnn import RNNCell
from utils import ACTIVATION_FUNCTIONS

net = Ingredient('network')


@net.config
def cfg():
    input = [
        {'name': 'reshape', 'shape': (64, 64, 1)},
        {'name': 'conv', 'size': 16, 'act': 'elu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True},
        {'name': 'conv', 'size': 32, 'act': 'elu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True},
        {'name': 'conv', 'size': 64, 'act': 'elu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True},
        {'name': 'reshape', 'shape': -1},
        {'name': 'fc', 'size': 512, 'act': 'elu', 'ln': True},
    ]
    recurrent = [
        {'name': 'r_nem', 'size': 250, 'act': 'sigmoid', 'ln': True,
         'encoder': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'core': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'context': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'attention': [
             {'name': 'fc', 'size': 100, 'act': 'tanh', 'ln': True},
             {'name': 'fc', 'size': 1, 'act': 'sigmoid'},
         ]}
    ]
    output = [
        {'name': 'fc', 'size': 512, 'act': 'relu', 'ln': True},
        {'name': 'fc', 'size': 8*8*64, 'act': 'relu', 'ln': True},
        {'name': 'reshape', 'shape': (8, 8, 64)},
        {'name': 'r_conv', 'size': 32, 'act': 'relu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True},
        {'name': 'r_conv', 'size': 16, 'act': 'relu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True},
        {'name': 'r_conv', 'size': 1, 'act': 'sigmoid', 'stride': [2, 2], 'kernel': (4, 4)},
        {'name': 'reshape', 'shape': -1},
    ]

# encoder decoder pairs


net.add_named_config('enc_dec_84_atari', {
    'input': [
        {'name': 'reshape', 'shape': (84, 84, 1)},
        {'name': 'conv', 'size': 16, 'act': 'elu', 'stride': [4, 4], 'kernel': (8, 8), 'ln': True},
        {'name': 'conv', 'size': 32, 'act': 'elu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True},
        {'name': 'reshape', 'shape': -1},
        {'name': 'fc', 'size': 250, 'act': 'elu', 'ln': True},
    ],
    'output': [
        {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
        {'name': 'fc', 'size': 10*10*32, 'act': 'relu', 'ln': True},
        {'name': 'reshape', 'shape': (10, 10, 32)},
        {'name': 'r_conv', 'size': 16, 'act': 'relu', 'stride': [2, 2], 'kernel': (4, 4), 'ln': True, 'offset': 1},
        {'name': 'r_conv', 'size': 1, 'act': 'sigmoid', 'stride': [4, 4], 'kernel': (8, 8)},
        {'name': 'reshape', 'shape': -1},
    ]})


# recurrent configurations

net.add_named_config('rnn_250', {'recurrent': [{'name': 'rnn', 'size': 250, 'act': 'sigmoid', 'ln': True}]})
net.add_named_config('lstm_250', {'recurrent': [{'name': 'lstm', 'size': 250, 'act': 'sigmoid', 'ln': True}]})

net.add_named_config('r_nem', {
    'recurrent': [
        {'name': 'r_nem', 'size': 250, 'act': 'sigmoid', 'ln': True,
         'encoder': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'core': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'context': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'attention': [
             {'name': 'fc', 'size': 100, 'act': 'tanh', 'ln': True},
             {'name': 'fc', 'size': 1, 'act': 'sigmoid'},
         ]}
    ]})


net.add_named_config('r_nem_no_attention', {
    'recurrent': [
        {'name': 'r_nem', 'size': 250, 'act': 'sigmoid', 'ln': True,
         'encoder': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'core': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'context': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'attention': []}
    ]})


net.add_named_config('r_nem_actions', {
    'recurrent': [
        {'name': 'r_nem', 'size': 250, 'act': 'sigmoid', 'ln': True,
         'encoder': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'core': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'context': [
             {'name': 'fc', 'size': 250, 'act': 'relu', 'ln': True},
         ],
         'attention': [
             {'name': 'fc', 'size': 100, 'act': 'tanh', 'ln': True},
             {'name': 'fc', 'size': 1, 'act': 'sigmoid'},
         ],
         'actions': [
             {'name': 'fc', 'size': 10, 'act': 'relu', 'ln': True},
         ]}
    ]})


# GENERIC WRAPPERS

class InputWrapper(RNNCell):
    """Adding an input projection to the given cell."""

    def __init__(self, cell, spec, name="InputWrapper"):
        self._cell = cell
        self._spec = spec
        self._name = name

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        projected = None
        with tf.variable_scope(scope or self._name):
            if self._spec['name'] == 'fc':
                projected = slim.fully_connected(inputs, self._spec['size'], activation_fn=None)
            elif self._spec['name'] == 'conv':
                projected = slim.conv2d(inputs, self._spec['size'], self._spec['kernel'], self._spec['stride'], activation_fn=None)
            else:
                raise ValueError('Unknown layer name "{}"'.format(self._spec['name']))

        return self._cell(projected, state)


class OutputWrapper(RNNCell):
    """Adding an output projection to the given cell."""

    def __init__(self, cell, spec, n_out=1, name="OutputWrapper"):
        self._cell = cell
        self._spec = spec
        self._name = name
        self._n_out = n_out

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._spec['size']

    def __call__(self, inputs, state, scope=None):
        output, res_state = self._cell(inputs, state)

        projected = None
        with tf.variable_scope((scope or self._name)):
            if self._spec['name'] == 'fc':
                projected = slim.fully_connected(output, self._spec['size'], activation_fn=None)
            elif self._spec['name'] == 'r_conv':
                offset = self._spec.get('offset', 0)
                resized = tf.image.resize_images(output, (self._spec['stride'][0] * output.get_shape()[1].value + offset,
                                                          self._spec['stride'][1] * output.get_shape()[2].value + offset), method=1)
                projected = slim.layers.conv2d(resized, self._spec['size'], self._spec['kernel'], activation_fn=None)
            else:
                raise ValueError('Unknown layer name "{}"'.format(self._spec['name']))

        return projected, res_state


class ReshapeWrapper(RNNCell):
    def __init__(self, cell, shape='flatten', apply_to='output'):
        self._cell = cell
        self._shape = shape
        self._apply_to = apply_to

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        batch_size = tf.shape(inputs)[0]

        if self._apply_to == 'input':
            inputs = slim.flatten(inputs) if self._shape == -1 else tf.reshape(inputs, [batch_size] + self._shape)
            return self._cell(inputs, state)
        elif self._apply_to == 'output':
            output, res_state = self._cell(inputs, state)
            output = slim.flatten(output) if self._shape == -1 else tf.reshape(output, [batch_size] + self._shape)
            return output, res_state
        elif self._apply_to == 'state':
            output, res_state = self._cell(inputs, state)
            res_state = slim.flatten(res_state) if self._shape == -1 else tf.reshape(res_state, [batch_size] + self._shape)
            return output, res_state
        else:
            raise ValueError('Unknown apply_to: "{}"'.format(self._apply_to))


class ActivationFunctionWrapper(RNNCell):
    def __init__(self, cell, activation='linear', apply_to='output'):
        self._cell = cell
        self._activation = ACTIVATION_FUNCTIONS[activation]
        self._apply_to = apply_to

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        if self._apply_to == 'input':
            inputs = self._activation(inputs)
            return self._cell(inputs, state)
        elif self._apply_to == 'output':
            output, res_state = self._cell(inputs, state)
            output = self._activation(output)
            return output, res_state
        elif self._apply_to == 'state':
            output, res_state = self._cell(inputs, state)
            res_state = self._activation(res_state)
            return output, res_state
        else:
            raise ValueError('Unknown apply_to: "{}"'.format(self._apply_to))


class LayerNormWrapper(RNNCell):
    def __init__(self, cell, apply_to='output', name="LayerNorm"):
        self._cell = cell
        self._name = name
        self._apply_to = apply_to

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        if self._apply_to == 'input':
            with tf.variable_scope(scope or self._name):
                inputs = slim.layer_norm(inputs)
            return self._cell(inputs, state)
        elif self._apply_to == 'output':
            output, res_state = self._cell(inputs, state)
            with tf.variable_scope(scope or self._name):
                output = slim.layer_norm(output)
                return output, res_state
        elif self._apply_to == 'state':
            output, res_state = self._cell(inputs, state)
            with tf.variable_scope(scope or self._name):
                res_state = slim.layer_norm(res_state)
                return output, res_state
        else:
            raise ValueError('Unknown apply_to: "{}"'.format(self._apply_to))
            

# R-NEM CELL

class R_NEM(RNNCell):
    def __init__(self, encoder, core, context, attention, actions, size, K, name='NPE'):
        self._encoder = encoder
        self._core = core
        self._context = context
        self._attention = attention
        self._actions = actions

        assert K > 1
        self._size = size
        self._K = K
        self._name = name

    @property
    def state_size(self):
        return self._size

    @property
    def output_size(self):
        return self._size

    def get_shapes(self, inputs):
        bk = tf.shape(inputs)[0]
        m = tf.shape(inputs)[1]

        return bk // self._K, self._K, m

    def __call__(self, inputs, state, scope=None):
        """
        input: [B X K, M]
        state: [B x K, H]

        b: batch_size
        k: num_groups
        m: input_size
        h: hidden_size
        h1: size of the encoding of focus and context
        h2: size of effect
        o: size of output

        # 0. Encode with RNN: x is [B*K, M], h is [B*K, H] --> both are [B*K, H]
        # 1. Reshape both to [B, K, H]
        # 2. For each of the k \in K copies, extract the K-1 states that are not that k
        # 3. Now you have two tensors of size [B x K x K-1, H]
        #     The first: "focus object": K-1 copies of the state of "k", the focus object
        #     The second: "context objects": K-1 (all unique) states of the context objects
        # 4. Concatenate results of 3
        # 5. Core: Process result of 4 in a feedforward network --> [B x K, H'']
        # 6. Reshape to [B x K, K-1, H''] to isolate the K-1 dimension (because we did for K-1 pairs)
        # 7. Sum in the K-1 dimension --> [B x K, H'']
        #   7.5 weighted by attention
        # 8. Decoder: Concatenate result of 7, the original theta, and the x and process into new state --> [B x K, H]
        # 9. Actions: Optionally embed actions into some representation

        """
        with tf.variable_scope(scope or self._name):
            b, k, m = self.get_shapes(inputs)

            # compute action embedding and concat to state
            if self._actions:
                action = state['action']
                state = state['state']

                # Optionally compute actions
                action_embedding = action[:, 0]

                for i, layer in enumerate(self._actions):
                    action_embedding = self._build_layer(action_embedding, layer)

                action_embedding = tf.tile(action_embedding, [k, 1])  # (b * k, <embed_size>)

                # concat to current state size
                state = tf.concat((state, action_embedding), axis=1)

            # Encode theta
            state1 = state
            for i, layer in enumerate(self._encoder):
                state1 = self._build_layer(state1, layer)

            # Reshape theta to be used for context
            h1 = state1.get_shape().as_list()[1]
            state1r = tf.reshape(state1, [b, k, h1])     # (b, k, h1)

            # Reshape theta to be used for focus
            state1rr = tf.reshape(state1r, [b, k, 1, h1])     # (b, k, 1, h1)

            # Create focus: tile state1rr k-1 times
            fs = tf.tile(state1rr, [1, 1, k-1, 1])   # (b, k, k-1, h1) 

            # Create context
            state1rl = tf.unstack(state1r, axis=1)      # list of length k of (b, h1)

            if k > 1:
                csu = []
                for i in range(k):
                    selector = [j for j in range(k) if j != i]
                    c = list(np.take(state1rl, selector))  # list of length k-1 of (b, h1)
                    c = tf.stack(c, axis=1)     # (b, k-1, h1)
                    csu.append(c)

                cs = tf.stack(csu, axis=1)    # (b, k, k-1, h1)   
            else:
                cs = tf.zeros((b, k, k-1, h1))

            # Reshape focus and context 
            # you will process the k-1 instances through the same network anyways
            fsr, csr = tf.reshape(fs, [b*k*(k-1), h1]), tf.reshape(cs, [b*k*(k-1), h1])     # (b x k x k-1, h1)

            # Concatenate focus and context
            concat = tf.concat([fsr, csr], axis=1)    # (b x k x k-1, 2h1)

            # NPE core
            core_out = concat
            for i, layer in enumerate(self._core):
                core_out = self._build_layer(core_out, layer)

            # Context branch: produces context
            context = core_out
            for i, layer in enumerate(self._context):
                context = self._build_layer(context, layer)

            h2 = self._context[-1]['size'] if len(self._context) > 0 else self._core[-1]['size']
            contextr = tf.reshape(context, [b*k, k-1, h2])    # (b x k, k-1, h2)

            # Attention branch: produces attention coefficients
            if len(self._attention) > 0:
                attention = core_out

                for i, layer in enumerate(self._attention):
                    attention = self._build_layer(attention, layer)

            # produce effect as sum(context * attention)
            # if len(self._attention) > 0:
                attentionr = tf.reshape(attention, [b*k, k-1, 1])
                effectrsum = tf.reduce_sum(contextr * attentionr, axis=1)
            else:
                effectrsum = tf.reduce_sum(contextr, axis=1)

            # 9 calculate new state
            # This is where the input from the encoder comes in 
            # concatenate state1, effectrsum, and input
            if self._actions:
                total = tf.concat([state1, effectrsum, inputs, action_embedding], axis=1)
            else:
                total = tf.concat([state1, effectrsum, inputs], axis=1)     # (b x k, h + h2 + m)

            # produce recurrent update
            new_state = slim.fully_connected(total, self._size, activation_fn=None)  # (b x k, h)

            return new_state, new_state

    @staticmethod
    def _build_layer(inputs, layer):
        # apply transformation
        if layer['name'] == 'fc':
            out = slim.fully_connected(inputs, layer['size'], activation_fn=None)
        else:
            raise KeyError('Unknown layer "{}"'.format(layer['name']))

        # apply layer normalisation
        if layer.get('ln', False):
            out = slim.layer_norm(out)

        # apply activation function
        if layer.get('act', False):
            out = ACTIVATION_FUNCTIONS[layer['act']](out)

        return out


# NETWORK BUILDER

@net.capture
def build_network(K, input, recurrent, output):
    with tf.name_scope('inner_RNN'):
        # build recurrent
        for i, layer in enumerate(recurrent):
            if layer['name'] == 'rnn':
                cell = tf.contrib.rnn.BasicRNNCell(layer['size'], activation=ACTIVATION_FUNCTIONS['linear'])
                cell = LayerNormWrapper(cell, apply_to='output', name='LayerNormR{}'.format(i)) if layer.get('ln') else cell
                cell = ActivationFunctionWrapper(cell, activation=layer['act'], apply_to='state')
                cell = ActivationFunctionWrapper(cell, activation=layer['act'], apply_to='output')

            elif layer['name'] == 'lstm':
                cell = tf.contrib.rnn.LayerNormBasicLSTMCell(layer['size'], layer_norm=layer.get('ln', False))

                if layer.get('act'):
                    print("WARNING: activation function arg for LSTM Cell is ignored. Default (tanh) is used in stead.")

            elif layer['name'] == 'r_nem':
                cell = R_NEM(encoder=layer['encoder'],
                             core=layer['core'],
                             context=layer['context'],
                             attention=layer['attention'],
                             actions=layer.get('actions', None),
                             size=layer['size'],
                             K=K)

                cell = LayerNormWrapper(cell, apply_to='output', name='LayerNormR{}'.format(i)) if layer.get('ln') else cell
                cell = ActivationFunctionWrapper(cell, activation=layer['act'], apply_to='state')
                cell = ActivationFunctionWrapper(cell, activation=layer['act'], apply_to='output')
            else:
                raise ValueError('Unknown recurrent name "{}"'.format(layer['name']))

        # build input
        for i, layer in reversed(list(enumerate(input))):
            if layer['name'] == 'reshape':
                cell = ReshapeWrapper(cell, layer['shape'], apply_to='input')
            else:
                cell = ActivationFunctionWrapper(cell, layer['act'], apply_to='input')
                cell = LayerNormWrapper(cell, apply_to='input', name='LayerNormI{}'.format(i)) if layer.get('ln') else cell
                cell = InputWrapper(cell, layer, name="InputWrapper{}".format(i))

        # build output
        for i, layer in enumerate(output):
            if layer['name'] == 'reshape':
                cell = ReshapeWrapper(cell, layer['shape'])
            else:
                n_out = layer.get('n_out', 1)
                cell = OutputWrapper(cell, layer, n_out=n_out, name="OutputWrapper{}".format(i))
                cell = LayerNormWrapper(cell, apply_to='output', name='LayerNormO{}'.format(i)) if layer.get('ln') else cell
                cell = ActivationFunctionWrapper(cell, layer['act'], apply_to='output')

        return cell