# coding: utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from utils import dtype


def _zero_variables(variables, name=None):
    ops = []

    for var in variables:
        with tf.device(var.device):
            op = var.assign(tf.zeros_like(var))
        ops.append(op)

    return tf.group(*ops, name=name or "zero_variables")


def _replicate_variables(variables, device=None, suffix="Replica"):
    new_vars = []

    for var in variables:
        device = device or var.device
        with tf.device(device):
            name = var.op.name + "/{}".format(suffix)
            new_vars.append(tf.Variable(tf.zeros_like(var),
                                        name=name, trainable=False))

    return new_vars


def _collect_gradients(gradients, variables):
    ops = []

    for grad, var in zip(gradients, variables):
        if isinstance(grad, tf.Tensor):
            ops.append(tf.assign_add(var, grad))
        else:
            ops.append(tf.scatter_add(var, grad.indices, grad.values))

    return tf.group(*ops, name="collect_gradients")


def create_train_op(named_scalars, grads_and_vars, optimizer, global_step, params):
    tf.get_variable_scope().set_dtype(tf.as_dtype(dtype.floatx()))

    gradients = [item[0] for item in grads_and_vars]
    variables = [item[1] for item in grads_and_vars]

    if params.update_cycle == 1:
        zero_variables_op = tf.no_op("zero_variables")
        collect_op = tf.no_op("collect_op")
    else:
        named_vars = {}
        for name in named_scalars:
            named_var = tf.Variable(tf.zeros([], dtype=tf.float32),
                                    name="{}/CTrainOpReplica".format(name),
                                    trainable=False)
            named_vars[name] = named_var
        count_var = tf.Variable(tf.zeros([], dtype=tf.as_dtype(dtype.floatx())),
                                name="count/CTrainOpReplica",
                                trainable=False)
        slot_variables = _replicate_variables(variables, suffix='CTrainOpReplica')
        zero_variables_op = _zero_variables(
            slot_variables + [count_var] + list(named_vars.values()))

        collect_ops = []
        # collect gradients
        collect_grads_op = _collect_gradients(gradients, slot_variables)
        collect_ops.append(collect_grads_op)

        # collect other scalars
        for name in named_scalars:
            scalar = named_scalars[name]
            named_var = named_vars[name]
            collect_op = tf.assign_add(named_var, scalar)
            collect_ops.append(collect_op)
        # collect counting variable
        collect_count_op = tf.assign_add(count_var, 1.0)
        collect_ops.append(collect_count_op)

        collect_op = tf.group(*collect_ops, name="collect_op")
        scale = 1.0 / (tf.cast(count_var, tf.float32) + 1.0)
        gradients = [scale * (g + s)
                     for (g, s) in zip(gradients, slot_variables)]

        for name in named_scalars:
            named_scalars[name] = scale * (
                    named_scalars[name] + named_vars[name])

    grand_norm = tf.global_norm(gradients)
    param_norm = tf.global_norm(variables)

    # Gradient clipping
    if isinstance(params.clip_grad_norm or None, float):
        gradients, _ = tf.clip_by_global_norm(gradients,
                                              params.clip_grad_norm,
                                              use_norm=grand_norm)

    # Update variables
    grads_and_vars = list(zip(gradients, variables))
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ops = {
        "zero_op": zero_variables_op,
        "collect_op": collect_op,
        "train_op": train_op
    }

    # apply ema
    if params.ema_decay > 0.:
        tf.logging.info('Using Exp Moving Average to train the model with decay {}.'.format(params.ema_decay))
        ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay, num_updates=global_step)
        ema_op = ema.apply(variables)
        with tf.control_dependencies([ops['train_op']]):
            ops['train_op'] = tf.group(ema_op)
        bck_vars = _replicate_variables(variables, suffix="CTrainOpBackUpReplica")

        ops['ema_backup_op'] = tf.group(*(tf.assign(bck, var.read_value())
                                        for bck, var in zip(bck_vars, variables)))
        ops['ema_restore_op'] = tf.group(*(tf.assign(var, bck.read_value())
                                         for bck, var in zip(bck_vars, variables)))
        ops['ema_assign_op'] = tf.group(*(tf.assign(var, ema.average(var).read_value())
                                        for var in variables))

    ret = named_scalars
    ret.update({
        "gradient_norm": grand_norm,
        "parameter_norm": param_norm,
    })

    return ret, ops