/* ******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 * Copyright (c) 2019 Konduit K.K.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.ui.module.train;

import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;

import java.util.*;

/**
 *
 * Utility methods for {@link TrainModule}
 *
 * @author Alex Black
 */
public class TrainModuleUtils {


    @AllArgsConstructor
    @Data
    public static class GraphInfo {

        private List<String> vertexNames;
        private List<String> vertexTypes;
        private List<List<Integer>> vertexInputs;
        private List<Map<String, String>> vertexInfo;

        @JsonIgnore
        private List<String> originalVertexName;
    }

    public static GraphInfo buildGraphInfo(MultiLayerConfiguration config) {
        List<String> vertexNames = new ArrayList<>();
        List<String> originalVertexName = new ArrayList<>();
        List<String> layerTypes = new ArrayList<>();
        List<List<Integer>> layerInputs = new ArrayList<>();
        List<Map<String, String>> layerInfo = new ArrayList<>();
        vertexNames.add("Input");
        originalVertexName.add(null);
        layerTypes.add("Input");
        layerInputs.add(Collections.emptyList());
        layerInfo.add(Collections.emptyMap());


        List<NeuralNetConfiguration> list = config.getConfs();
        int layerIdx = 1;
        for (NeuralNetConfiguration c : list) {
            Layer layer = c.getLayer();
            String layerName = layer.getLayerName();
            if (layerName == null)
                layerName = "layer" + layerIdx;
            vertexNames.add(layerName);
            originalVertexName.add(String.valueOf(layerIdx - 1));

            String layerType = c.getLayer().getClass().getSimpleName().replaceAll("Layer$", "");
            layerTypes.add(layerType);

            layerInputs.add(Collections.singletonList(layerIdx - 1));
            layerIdx++;

            //Extract layer info
            Map<String, String> map = getLayerInfo(c, layer);
            layerInfo.add(map);
        }

        return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName);
    }

    public static GraphInfo buildGraphInfo(ComputationGraphConfiguration config) {
        List<String> layerNames = new ArrayList<>();
        List<String> layerTypes = new ArrayList<>();
        List<List<Integer>> layerInputs = new ArrayList<>();
        List<Map<String, String>> layerInfo = new ArrayList<>();


        Map<String, GraphVertex> vertices = config.getVertices();
        Map<String, List<String>> vertexInputs = config.getVertexInputs();
        List<String> networkInputs = config.getNetworkInputs();

        List<String> originalVertexName = new ArrayList<>();

        Map<String, Integer> vertexToIndexMap = new HashMap<>();
        int vertexCount = 0;
        for (String s : networkInputs) {
            vertexToIndexMap.put(s, vertexCount++);
            layerNames.add(s);
            originalVertexName.add(s);
            layerTypes.add(s);
            layerInputs.add(Collections.emptyList());
            layerInfo.add(Collections.emptyMap());
        }

        for (String s : vertices.keySet()) {
            vertexToIndexMap.put(s, vertexCount++);
        }

        for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
            GraphVertex gv = entry.getValue();
            layerNames.add(entry.getKey());

            List<String> inputsThisVertex = vertexInputs.get(entry.getKey());
            List<Integer> inputIndexes = new ArrayList<>();
            for (String s : inputsThisVertex) {
                inputIndexes.add(vertexToIndexMap.get(s));
            }

            layerInputs.add(inputIndexes);

            if (gv instanceof LayerVertex) {
                NeuralNetConfiguration c = ((LayerVertex) gv).getLayerConf();
                Layer layer = c.getLayer();

                String layerType = layer.getClass().getSimpleName().replaceAll("Layer$", "");
                layerTypes.add(layerType);

                //Extract layer info
                Map<String, String> map = getLayerInfo(c, layer);
                layerInfo.add(map);
            } else {
                String layerType = gv.getClass().getSimpleName();
                layerTypes.add(layerType);
                Map<String, String> thisVertexInfo = Collections.emptyMap(); //TODO
                layerInfo.add(thisVertexInfo);
            }
            originalVertexName.add(entry.getKey());
        }

        return new GraphInfo(layerNames, layerTypes, layerInputs, layerInfo, originalVertexName);
    }

    public static GraphInfo buildGraphInfo(NeuralNetConfiguration config) {

        List<String> vertexNames = new ArrayList<>();
        List<String> originalVertexName = new ArrayList<>();
        List<String> layerTypes = new ArrayList<>();
        List<List<Integer>> layerInputs = new ArrayList<>();
        List<Map<String, String>> layerInfo = new ArrayList<>();
        vertexNames.add("Input");
        originalVertexName.add(null);
        layerTypes.add("Input");
        layerInputs.add(Collections.emptyList());
        layerInfo.add(Collections.emptyMap());

        if (config.getLayer() instanceof VariationalAutoencoder) {
            //Special case like this is a bit ugly - but it works
            VariationalAutoencoder va = (VariationalAutoencoder) config.getLayer();
            int[] encLayerSizes = va.getEncoderLayerSizes();
            int[] decLayerSizes = va.getDecoderLayerSizes();

            int layerIndex = 1;
            for (int i = 0; i < encLayerSizes.length; i++) {
                String name = "encoder_" + i;
                vertexNames.add(name);
                originalVertexName.add("e" + i);
                String layerType = "VAE-Encoder";
                layerTypes.add(layerType);
                layerInputs.add(Collections.singletonList(layerIndex - 1));
                layerIndex++;

                Map<String, String> encoderInfo = new LinkedHashMap<>();
                long inputSize = (i == 0 ? va.getNIn() : encLayerSizes[i - 1]);
                long outputSize = encLayerSizes[i];
                encoderInfo.put("Input Size", String.valueOf(inputSize));
                encoderInfo.put("Layer Size", String.valueOf(outputSize));
                encoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
                encoderInfo.put("Activation Function", va.getActivationFn().toString());
                layerInfo.add(encoderInfo);
            }

            vertexNames.add("z");
            originalVertexName.add(VariationalAutoencoderParamInitializer.PZX_PREFIX);
            layerTypes.add("VAE-LatentVariable");
            layerInputs.add(Collections.singletonList(layerIndex - 1));
            layerIndex++;
            Map<String, String> latentInfo = new LinkedHashMap<>();
            long inputSize = encLayerSizes[encLayerSizes.length - 1];
            long outputSize = va.getNOut();
            latentInfo.put("Input Size", String.valueOf(inputSize));
            latentInfo.put("Layer Size", String.valueOf(outputSize));
            latentInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize * 2));
            latentInfo.put("Activation Function", va.getPzxActivationFn().toString());
            layerInfo.add(latentInfo);


            for (int i = 0; i < decLayerSizes.length; i++) {
                String name = "decoder_" + i;
                vertexNames.add(name);
                originalVertexName.add("d" + i);
                String layerType = "VAE-Decoder";
                layerTypes.add(layerType);
                layerInputs.add(Collections.singletonList(layerIndex - 1));
                layerIndex++;

                Map<String, String> decoderInfo = new LinkedHashMap<>();
                inputSize = (i == 0 ? va.getNOut() : decLayerSizes[i - 1]);
                outputSize = decLayerSizes[i];
                decoderInfo.put("Input Size", String.valueOf(inputSize));
                decoderInfo.put("Layer Size", String.valueOf(outputSize));
                decoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
                decoderInfo.put("Activation Function", va.getActivationFn().toString());
                layerInfo.add(decoderInfo);
            }

            vertexNames.add("x");
            originalVertexName.add(VariationalAutoencoderParamInitializer.PXZ_PREFIX);
            layerTypes.add("VAE-Reconstruction");
            layerInputs.add(Collections.singletonList(layerIndex - 1));
            layerIndex++;
            Map<String, String> reconstructionInfo = new LinkedHashMap<>();
            inputSize = decLayerSizes[decLayerSizes.length - 1];
            outputSize = va.getNIn();
            reconstructionInfo.put("Input Size", String.valueOf(inputSize));
            reconstructionInfo.put("Layer Size", String.valueOf(outputSize));
            reconstructionInfo.put("Num Parameters", String
                    .valueOf((inputSize + 1) * va.getOutputDistribution().distributionInputSize((int) va.getNIn())));
            reconstructionInfo.put("Distribution", va.getOutputDistribution().toString());
            layerInfo.add(reconstructionInfo);


        } else {
            //VAE or similar...
            Layer layer = config.getLayer();
            String layerName = layer.getLayerName();
            if (layerName == null)
                layerName = "layer0";
            vertexNames.add(layerName);
            originalVertexName.add(String.valueOf("0"));

            String layerType = config.getLayer().getClass().getSimpleName().replaceAll("Layer$", "");
            layerTypes.add(layerType);

            layerInputs.add(Collections.singletonList(0));

            //Extract layer info
            Map<String, String> map = getLayerInfo(config, layer);
            layerInfo.add(map);
        }


        return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName);
    }


    private static Map<String, String> getLayerInfo(NeuralNetConfiguration c, Layer layer) {
        Map<String, String> map = new LinkedHashMap<>();

        if (layer instanceof FeedForwardLayer) {
            FeedForwardLayer layer1 = (FeedForwardLayer) layer;
            map.put("Input size", String.valueOf(layer1.getNIn()));
            map.put("Output size", String.valueOf(layer1.getNOut()));
            map.put("Num Parameters", String.valueOf(layer1.initializer().numParams(c)));
            map.put("Activation Function", layer1.getActivationFn().toString());
        }

        if (layer instanceof ConvolutionLayer) {
            ConvolutionLayer layer1 =
                    (ConvolutionLayer) layer;
            map.put("Kernel size", Arrays.toString(layer1.getKernelSize()));
            map.put("Stride", Arrays.toString(layer1.getStride()));
            map.put("Padding", Arrays.toString(layer1.getPadding()));
        } else if (layer instanceof SubsamplingLayer) {
            SubsamplingLayer layer1 = (SubsamplingLayer) layer;
            map.put("Kernel size", Arrays.toString(layer1.getKernelSize()));
            map.put("Stride", Arrays.toString(layer1.getStride()));
            map.put("Padding", Arrays.toString(layer1.getPadding()));
            map.put("Pooling Type", layer1.getPoolingType().toString());
        } else if (layer instanceof BaseOutputLayer) {
            BaseOutputLayer ol = (BaseOutputLayer) layer;
            if(ol.getLossFn() != null)
                map.put("Loss Function", ol.getLossFn().toString());
        }

        return map;
    }
}