/*
 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.tensorflow.framework.optimizers;

import java.util.List;
import java.util.Optional;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Variable;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TType;

/**
 * Optimizer that implements the Adagrad Dual-Averaging algorithm.
 * <p>
 * See the <a href="http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf">paper</a>.
 */
public class AdaGradDA extends Optimizer {

  public static final String ACCUMULATOR = "gradient_accumulator";
  public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator";
  private final float learningRate;
  private final float initialAccumulatorValue;
  private final float l1Strength;
  private final float l2Strength;
  private Variable<TInt64> globalStep;

  public AdaGradDA(Graph graph, float learningRate) {
    this(graph, learningRate, 0.1f, 0.0f, 0.0f);
  }

  public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l1Strength,
      float l2Strength) {
    super(graph);
    this.learningRate = learningRate;
    this.initialAccumulatorValue = initialAccumulatorValue;
    this.l1Strength = l1Strength;
    this.l2Strength = l2Strength;
  }

  public AdaGradDA(Graph graph, String name, float learningRate) {
    this(graph, name, learningRate, 0.1f, 0.0f, 0.0f);
  }

  public AdaGradDA(Graph graph, String name, float learningRate, float initialAccumulatorValue, float l1Strength,
      float l2Strength) {
    super(graph, name);
    this.learningRate = learningRate;
    this.initialAccumulatorValue = initialAccumulatorValue;
    this.l1Strength = l1Strength;
    this.l2Strength = l2Strength;
  }

  @Override
  protected Optional<Op> prepare(String name) {
    return Optional.of(tf.assignAdd(globalStep, tf.constant(1L)));
  }

  @Override
  protected void createSlots(List<Output<? extends TType>> variables) {
    for (Output<? extends TType> v : variables) {
      createAdaGradDASlot(v);
    }
    globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.DTYPE);
    Assign<TInt64> globalStepInitializer = tf.assign(globalStep, tf.constant(0L));
    graph.addInitializer(globalStepInitializer);
  }

  private <T extends TType> void createAdaGradDASlot(Output<T> v) {
    Operand<T> initializer = tf
        .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType()));
    createSlot(v.asOutput(), ACCUMULATOR, initializer);
    Operand<T> sqInitializer = tf.fill(tf.shape(v),
        tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType()));
    createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer);
  }

  @Override
  protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
    Variable<T> gradSlot = getSlot(variable, ACCUMULATOR).get();
    Variable<T> gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get();
    return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient,
        tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()),
        tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()),
        tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()),
        globalStep);
  }

  /**
   * Gathers up the update operations into a single op that can be used as a run target.
   * <p>
   * Adds the global step update to the end of the updates list.
   *
   * @param updateOperations The update operations.
   * @param name             The name of the run target.
   * @return A NoOp with a control dependency on each update operation.
   */
  @Override
  protected Op finish(List<Op> updateOperations, String name) {
    updateOperations.add(tf.assignAdd(globalStep, tf.constant(1L)));
    return super.finish(updateOperations, name);
  }

  @Override
  public String toString() {
    return "AdaGradDA{" +
        "globalStep=" + globalStep +
        ", learningRate=" + learningRate +
        ", initialAccumulatorValue=" + initialAccumulatorValue +
        ", l1Strength=" + l1Strength +
        ", l2Strength=" + l2Strength +
        '}';
  }

  @Override
  public String getOptimizerName() {
    return "adagrad-da";
  }
}