/*
 * WekaDeeplearning4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * WekaDeeplearning4j is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with WekaDeeplearning4j.  If not, see <https://www.gnu.org/licenses/>.
 *
 * RnnOutputLayer.java
 * Copyright (C) 2017-2018 University of Waikato, Hamilton, New Zealand
 */

package weka.dl4j.layers;

import java.io.Serializable;
import java.util.Enumeration;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.OptionMetadata;
import weka.dl4j.activations.ActivationSoftmax;
import weka.dl4j.lossfunctions.LossFunction;
import weka.dl4j.lossfunctions.LossMCXENT;

/**
 * A version of DeepLearning4j's RnnOutputLayer that implements WEKA option handling.
 *
 * @author Steven Lang
 */
public class RnnOutputLayer extends
    FeedForwardLayer<org.deeplearning4j.nn.conf.layers.RnnOutputLayer>
    implements OptionHandler, Serializable {

  /**
   * SerialVersionUID
   */
  private static final long serialVersionUID = 6333123285853131216L;

  /**
   * Constructor for setting some defaults.
   */
  public RnnOutputLayer() {
    super();
    setLayerName("RnnOutput layer");
    setActivationFunction(new ActivationSoftmax());
    setLossFn(new LossMCXENT());
    setNOut(2);
  }

  @Override
  public void initializeBackend() {
    backend = new org.deeplearning4j.nn.conf.layers.RnnOutputLayer();
  }

  /**
   * Global info.
   *
   * @return string describing this class.
   */
  public String globalInfo() {
    return "A RnnOutput layer from DeepLearning4J.";
  }


  @OptionMetadata(
      displayName = "loss function",
      description = "The loss function to use (default = LossMCXENT).",
      commandLineParamName = "lossFn",
      commandLineParamSynopsis = "-lossFn <specification>",
      displayOrder = 1
  )
  public LossFunction<? extends ILossFunction> getLossFn() {
    return LossFunction.create(backend.getLossFn());
  }

  public void setLossFn(LossFunction<? extends ILossFunction> lossFn) {
    backend.setLossFn(lossFn.getBackend());
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  @Override
  public Enumeration<Option> listOptions() {
    return Option.listOptionsForClassHierarchy(this.getClass(), super.getClass()).elements();
  }

  /**
   * Gets the current settings of the Classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  @Override
  public String[] getOptions() {
    return Option.getOptionsForHierarchy(this, super.getClass());
  }

  /**
   * Parses a given list of options.
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    Option.setOptionsForHierarchy(options, this, super.getClass());
  }
}