# ------------------------------------------------------------------
# Tensorflow implementation of
#  "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018
# Licensed under The MIT License [see LICENSE for details]
# Written by Tianyu Yang (tianyu-yang.com)
# ------------------------------------------------------------------

import config
import tensorflow as tf
import numpy as np
from feature import extract_feature
from input import generate_labels_overlap, generate_labels_dist
from memnet.memnet import MemNet
from memnet.rnn import rnn
import collections

class ModeKeys():
  TRAIN = 'train'
  EVAL = 'eval'
  PREDICT = 'predict'

EstimatorSpec = collections.namedtuple('EstimatorSpec', ['predictions', 'loss', 'dist_error', 'train', 'summary', 'saver'])

def get_cnn_feature(input, reuse, mode):

    input_shape = input.get_shape().as_list()
    if len(input_shape) > 4:
        input = tf.reshape(input, [-1] + input_shape[2:])

    is_train = True if mode == ModeKeys.TRAIN else False
    with tf.variable_scope('feature_extraction', reuse=reuse):
        cnn_feature = extract_feature(is_train, input)

    if len(input_shape) > 4:
        cnn_feature_shape = cnn_feature.get_shape().as_list()
        cnn_feature = tf.reshape(cnn_feature, input_shape[0:2]+cnn_feature_shape[1:])

    return cnn_feature

def batch_conv(A, B, mode):

    a_shape = A.get_shape().as_list()
    if len(a_shape) > 4:
        A = tf.reshape(A, [-1] + a_shape[2:])
    b_shape = B.get_shape().as_list()
    if len(b_shape) > 4:
        B = tf.reshape(B, [-1] + b_shape[2:])
    batch_size = A.get_shape().as_list()[0]

    output = tf.map_fn(lambda inputs: tf.nn.conv2d(tf.expand_dims(inputs[0], 0), tf.expand_dims(inputs[1], 3), [1,1,1,1], 'VALID'),
                       elems=[A, B],
                       dtype=tf.float32,
                       parallel_iterations=batch_size)
    is_train = True if mode == ModeKeys.TRAIN else False
    output = tf.layers.batch_normalization(tf.squeeze(output, [1]), training=is_train, name='bn_response')
    return tf.squeeze(output, [3])

def get_predictions(query_feature, search_feature, mode):

    with tf.variable_scope('mann'):
       mann_cell = MemNet(config.hidden_size, config.memory_size, config.slot_size, True)

    initial_state = mann_cell.initial_state(query_feature[:, 0])

    inputs = (search_feature, query_feature)
    outputs, final_state = rnn(cell=mann_cell, inputs=inputs, initial_state=initial_state)

    response = batch_conv(search_feature, outputs, mode)

    return response


def focal_loss(labels, predictions, gamma=2, epsilon=1e-7, scope=None):

    with tf.name_scope(scope, "focal_loss", (predictions, labels)) as scope:
        predictions = tf.to_float(predictions)
        labels = tf.to_float(labels)
        predictions.get_shape().assert_is_compatible_with(labels.get_shape())
        preds = tf.where(
            tf.equal(labels, 1), predictions, 1. - predictions)
        losses = -(1. - preds) ** gamma * tf.log(preds + epsilon)
        return losses

def get_loss(outputs, labels, mode):

    if mode == tf.estimator.ModeKeys.PREDICT:
        return None
    outputs_shape = outputs.get_shape().as_list()
    if config.label_type == 0:
        labels_response, weights = generate_labels_overlap(np.array(outputs_shape[1:3]), labels)
    else:
        labels_response, weights = generate_labels_dist(outputs_shape[0], np.array(outputs_shape[1:3]))
    if config.use_focal_loss:
        loss = tf.reduce_sum(weights * focal_loss(labels=labels_response, predictions=tf.nn.sigmoid(outputs))) / outputs_shape[0]
    else:
        loss = tf.reduce_sum(weights*tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_response, logits=outputs))/outputs_shape[0]
    tf.summary.scalar('loss', loss)
    return loss

def get_dist_error(outputs, mode):

    if mode == tf.estimator.ModeKeys.PREDICT:
        return None
    outputs_shape = outputs.get_shape().as_list()
    outputs = tf.reshape(outputs, [outputs_shape[0], -1])
    pred_loc_idx = tf.argmax(outputs, 1)
    loc_x = pred_loc_idx%outputs_shape[1]
    loc_y = pred_loc_idx//outputs_shape[1]
    pred_loc = tf.stack([loc_x, loc_y], 1)
    gt_loc = tf.tile(tf.expand_dims([outputs_shape[1]/2, outputs_shape[1]/2], 0), [outputs_shape[0], 1])
    dist_error = tf.losses.mean_squared_error(predictions=pred_loc, labels=gt_loc)
    tf.summary.scalar('dist_error', dist_error)
    return dist_error

def get_train_op(loss, mode):

    if mode != ModeKeys.TRAIN:
        return None

    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.exponential_decay(config.learning_rate, global_step, config.decay_circles, config.lr_decay, staircase=True)
    tf.summary.scalar('learning_rate', learning_rate)

    tvars = tf.trainable_variables()
    regularizer = tf.contrib.layers.l2_regularizer(config.weight_decay)
    regularizer_loss = tf.contrib.layers.apply_regularization(regularizer, tvars)
    loss += regularizer_loss
    grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), config.clip_gradients)
    # optimizer = tf.train.GradientDescentOptimizer(self.lr)
    optimizer = tf.train.AdamOptimizer(learning_rate)

    batchnorm_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(batchnorm_update_ops):
        train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)

    return train_op

def get_summary(mode):

    if mode == ModeKeys.PREDICT:
        return None
    return tf.summary.merge_all()

def get_saver():

    return tf.train.Saver(tf.global_variables(), max_to_keep=15)

def model_fn(features, labels, mode):
    # get cnn feature for query and search
    query_feature = get_cnn_feature(features['query'], None, mode)
    search_feature = get_cnn_feature(features['search'], True, mode)

    predictions = get_predictions(query_feature, search_feature, mode)
    loss = get_loss(predictions, labels, mode)
    dist_error = get_dist_error(predictions, mode)
    train_op = get_train_op(loss, mode)
    summary = get_summary(mode)
    saver = get_saver()

    return EstimatorSpec(predictions, loss, dist_error, train_op, summary, saver)

def build_initial_state(init_query, mem_cell, mode):

    query_feature = get_cnn_feature(init_query, None, mode)
    return mem_cell.initial_state(query_feature[:,0])

def build_model(query, search, mem_cell, initial_state, mode):
    # get cnn feature for query and search
    query_feature = get_cnn_feature(query, True, mode)
    search_feature = get_cnn_feature(search, True, mode)

    inputs = (search_feature, query_feature)
    outputs, final_state = rnn(cell=mem_cell, inputs=inputs, initial_state=initial_state)

    response = batch_conv(search_feature, outputs, mode)
    saver = get_saver()

    return response, saver, final_state


if __name__=='__main__':
    query_patch = tf.placeholder(tf.float32, [10, 5, config.z_exemplar_size, config.z_exemplar_size, 3])
    search_patch = tf.placeholder(tf.float32, [10, 5, config.x_instance_size, config.x_instance_size, 3])
    features = {
        'query': query_patch,
        'search': search_patch
    }
    labels = tf.placeholder(tf.float32, [10, 5, 4])
    mode = ModeKeys.TRAIN

    esti_spec = model_fn(features, labels, mode)
    pass