package weka.dl4j.zoo;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.NotImplementedException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import weka.classifiers.functions.Dl4jMlpClassifier;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.OptionMetadata;
import weka.core.WekaException;
import weka.dl4j.PretrainedType;
import weka.dl4j.zoo.keras.EfficientNet;
import weka.gui.ProgrammaticProperty;

import java.io.File;
import java.io.Serializable;
import java.util.*;

/**
 * This class contains the logic necessary to load the pretrained weights for a given zoo model
 *
 * It also handles the addition/removal of output layers to enable training the model in DL4J
 * @author Rhys Compton
 */
@Log4j2
public abstract class AbstractZooModel implements OptionHandler, Serializable {

    private static final long serialVersionUID = -4598529061609767660L;

    /**
     * Dataset that the pretrained weights are trained on
     */
    protected weka.dl4j.PretrainedType m_pretrainedType = PretrainedType.NONE;
    /**
     * Output layer of the model to be taken off (and replace by our custom output layer)
     */
    protected String m_outputLayer;
    /**
     * Feature extraction layer of the model (which we'll attach our output layer to) and use for feature extraction
     */
    protected String m_featureExtractionLayer;
    /**
     * Name of the output layer we attach to the end of the model
     */
    protected String m_predictionLayerName = "weka_predictions";
    /**
     * Names of extraneous layers to completely remove (doesn't keep the weights of them)
     */
    protected String[] m_extraLayersToRemove = new String[0];
    /**
     * Number of dimensions of the feature extraction layer
     */
    protected int m_numFExtractOutputs;
    /**
     * Random seed
     */
    private long seed;
    /**
     * Number of output labels we want to classify on
     */
    private long numLabels;
    /**
     * Are we loading this model for use as feature extractor (no need to add output layer)
     */
    protected boolean filterMode;
    /**
     * Do the activations from featureLayer require pooling (too high dimensionality)
     */
    protected boolean requiresPooling = false;
    /**
     * Is the model in channels-last order? (i.e. data must come in [minibatch, height, width, channels]
     * instead of the default [minibatch, channels, height, width]
     */
    protected boolean channelsLast = false;

    /**
     * Initialize the ZooModel as MLP
     *
     * @param numLabels Number of labels to adjust the output
     * @param seed Seed
     * @param shape shape
     * @param filterMode True if creating for feature extraction
     * @return MultiLayerNetwork of the specified ZooModel
     * @throws UnsupportedOperationException Init(...) was not supported (only CustomNet)
     */
    public abstract ComputationGraph init(int numLabels, long seed, int[] shape, boolean filterMode)
            throws UnsupportedOperationException;

    /**
     * Get the input shape of this zoomodel
     *
     * @return Input shape of this zoomodel
     */
    public abstract int[][] getShape();

    public Enum getVariation() {
        return null;
    }

    @OptionMetadata(
            displayName = "Image channels last",
            description = "Set to true to supply image channels last. " +
                    "The default value will usually be correct, so as an end user you shouldn't need to change this setting. " +
                    "If you do be aware that it may break the model.",
            commandLineParamName = "channelsLast",
            commandLineParamSynopsis = "-channelsLast <boolean>"
    )
    public boolean getChannelsLast() {
        return channelsLast;
    }

    public void setChannelsLast(boolean channelsLast) {
        this.channelsLast = channelsLast;
    }

    @ProgrammaticProperty
    public boolean isRequiresPooling() {
        return requiresPooling;
    }

    public void setRequiresPooling(boolean requiresPooling) {
        this.requiresPooling = requiresPooling;
    }


    /**
     *
     * @param zooModel Zoo model family to use
     * @param defaultNet Default ComputationGraph to use if loading weights fails
     * @param seed Random seed to initialize with
     * @param numLabels Number of output labels
     * @param filterMode True if using this zoo model for a filter - output layers don't need to be setup
     * @return ComputationGraph - if all succeeds then will be initialized with pretrained weights
     */
    public ComputationGraph attemptToLoadWeights(org.deeplearning4j.zoo.ZooModel zooModel,
                                                 ComputationGraph defaultNet,
                                                 long seed,
                                                 int numLabels,
                                                 boolean filterMode) {

        this.seed = seed;
        this.numLabels = numLabels;
        this.filterMode = filterMode;

        // If no pretrained weights specified, simply return the standard model
        if (m_pretrainedType == PretrainedType.NONE)
            return finish(defaultNet);

        // If the specified pretrained weights aren't available, return the standard model
        if (!checkPretrained(zooModel)) {
            return null;
        }

        // If downloading the weights fails, return the standard model
        ComputationGraph pretrainedModel = downloadWeights(zooModel);
        if (pretrainedModel == null)
            return finish(defaultNet);

        // If all has gone well, we have the pretrained weights
        return finish(pretrainedModel);
    }

    /**
     * Final endpoint for ComputationGraph before returning
     * @param computationGraph Input ComputationGraph
     * @return Finalized ComputationGraph
     */
    private ComputationGraph finish(ComputationGraph computationGraph) {
        log.debug(computationGraph.summary());
        return addFinalOutputLayer(computationGraph);
    }

    /**
     * Checks if we need to add a final output layer - also applies pooling beforehand if necessary
     * @param computationGraph Input ComputationGraph
     * @return Finalized ComputationGraph
     */
    protected ComputationGraph addFinalOutputLayer(ComputationGraph computationGraph) {
        org.deeplearning4j.nn.conf.layers.Layer lastLayer = computationGraph.getLayers()[computationGraph.getNumLayers() - 1].conf().getLayer();
        if (!Dl4jMlpClassifier.noOutputLayer(filterMode, lastLayer)) {
            log.debug("No need to add output layer, ignoring");
            return computationGraph;
        }
        try {
            TransferLearning.GraphBuilder graphBuilder;

            if (requiresPooling)
                graphBuilder = new TransferLearning.GraphBuilder(computationGraph)
                    .fineTuneConfiguration(getFineTuneConfig())
                    .addLayer("intermediate_pooling", new GlobalPoolingLayer.Builder().build(), m_featureExtractionLayer)
                    .addLayer(m_predictionLayerName, createOutputLayer(), "intermediate_pooling")
                    .setOutputs(m_predictionLayerName);
            else
                graphBuilder = new TransferLearning.GraphBuilder(computationGraph)
                        .fineTuneConfiguration(getFineTuneConfig())
                        .addLayer(m_predictionLayerName, createOutputLayer(), m_featureExtractionLayer)
                        .setOutputs(m_predictionLayerName);

            // Remove the old output layer, but keep the connections
            graphBuilder.removeVertexKeepConnections(m_outputLayer);
            // Remove any other layers we don't want
            for (String layer : m_extraLayersToRemove) {
                graphBuilder.removeVertexAndConnections(layer);
            }

            log.debug("Finished adding output layer");
            return graphBuilder.build();
        } catch (Exception ex) {
            ex.printStackTrace();
            log.error(computationGraph.summary());
            return computationGraph;
        }

    }

    /**
     *
     * @return True if current model is pretrained
     */
    public boolean isPretrained() {
        return m_pretrainedType != PretrainedType.NONE;
    }

    /**
     * We need to create and set the fine tuning config
     * @return Default fine tuning config
     */
    protected FineTuneConfiguration getFineTuneConfig() {
        return new FineTuneConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new Nesterovs(5e-5))
                .seed(seed)
                .build();
    }

    /**
     * Attempts to download weights for the given zoo model
     * @param zooModel Model to try download weights for
     * @return new ComputationGraph initialized with the given PretrainedType
     */
    protected ComputationGraph downloadWeights(org.deeplearning4j.zoo.ZooModel zooModel) {
        try {
            log.info(String.format("Downloading %s weights", m_pretrainedType));
            Object pretrained = zooModel.initPretrained(m_pretrainedType.getBackend());
            if (pretrained == null) {
                throw new Exception("Error while initialising model");
            }
            if (pretrained instanceof MultiLayerNetwork) {
                return ((MultiLayerNetwork) pretrained).toComputationGraph();
            } else {
                return (ComputationGraph) pretrained;
            }
        } catch (Exception ex) {
            ex.printStackTrace();
            return null;
        }
    }

    /**
     * We need a layer with the correct number of outputs
     * @return Default output layer
     */
    protected OutputLayer createOutputLayer() {
        return new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nIn(m_numFExtractOutputs).nOut(numLabels)
                .weightInit(new NormalDistribution(0, 0.2 * (2.0 / (4096 + numLabels)))) //This weight init dist gave better results than Xavier
                .activation(Activation.SOFTMAX).build();
    }

    /**
     * Checks if the zoo model has the specific pretrained type available
     * @param dl4jModelType ZooModel to check
     * @return True if model supports `m_pretrainedType` weights
     */
    protected boolean checkPretrained(org.deeplearning4j.zoo.ZooModel dl4jModelType) {
        Set<PretrainedType> availableTypes = getAvailablePretrainedWeights(dl4jModelType);
        if (availableTypes.isEmpty()) {
            log.error("Sorry, no pretrained weights are available for this model, " +
                    "please explicitly set pretrained type to NONE");
            return false;
        }
        if (!availableTypes.contains(m_pretrainedType) && m_pretrainedType != PretrainedType.NONE){
            log.error(String.format("%s weights are not available for this model, " +
                    "please try one of: %s", m_pretrainedType, availableTypes.toString()));
            return false;
        }
        return true;
    }

    /**
     * Get all Pretrained types this ZooModel supports
     * @param zooModel ZooModel to check
     * @return Set of pretrained types the model supports
     */
    private Set<PretrainedType> getAvailablePretrainedWeights(org.deeplearning4j.zoo.ZooModel zooModel) {
        Set<PretrainedType> availableTypes = new HashSet<>();
        for (PretrainedType pretrainedType : PretrainedType.values()) {
            if (pretrainedType == PretrainedType.NONE)
                continue;

            if (zooModel.pretrainedAvailable(pretrainedType.getBackend())) {
                availableTypes.add(pretrainedType);
            }
        }
        return availableTypes;
    }

    @OptionMetadata(
            description = "Pretrained Type (IMAGENET, VGGFACE, MNIST)",
            displayName = "Pretrained Type",
            commandLineParamName = "pretrained",
            commandLineParamSynopsis = "-pretrained <string>"
    )
    public PretrainedType getPretrainedType() {
        return m_pretrainedType;
    }

    public void setPretrainedType(PretrainedType pretrainedType) {
        this.m_pretrainedType = pretrainedType;
    }

    @ProgrammaticProperty
    public String getOutputlayer() {
        return m_outputLayer;
    }

    public void setOutputLayer(String m_outputLayer) {
        this.m_outputLayer = m_outputLayer;
    }

    @ProgrammaticProperty
    public String getFeatureExtractionLayer() {
        return m_featureExtractionLayer;
    }

    public void setFeatureExtractionLayer(String m_featureExtractionLayer) {
        this.m_featureExtractionLayer = m_featureExtractionLayer;
    }

    @ProgrammaticProperty
    public String[] getExtraLayersToRemove() {
        return m_extraLayersToRemove;
    }

    public void setExtraLayersToRemove(String[] m_extraLayersToRemove) {
        this.m_extraLayersToRemove = m_extraLayersToRemove;
    }

    @ProgrammaticProperty
    public int getNumFExtractOutputs() {
        return m_numFExtractOutputs;
    }

    public void setNumFExtractOutputs(int m_numFExtractOutputs) {
        this.m_numFExtractOutputs = m_numFExtractOutputs;
    }

    /**
     * Returns an enumeration describing the available options.
     *
     * @return an enumeration of all the available options.
     */
    @Override
    public Enumeration<Option> listOptions() {
        return Option.listOptionsForClass(this.getClass()).elements();
    }

    /**
     * Gets the current settings of the Classifier.
     *
     * @return an array of strings suitable for passing to setOptions
     */
    @Override
    public String[] getOptions() {
        return Option.getOptions(this, this.getClass());
    }

    /**
     * Parses a given list of options.
     *
     * @param options the list of options as an array of strings
     * @throws Exception if an option is not supported
     */
    public void setOptions(String[] options) throws Exception {
        Option.setOptions(options, this, this.getClass());
    }
}