package weka.dl4j.scripts.keras_downloading;

import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import weka.dl4j.layers.lambda.CustomBroadcast;

import java.io.*;
import java.lang.reflect.Method;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * This class loads in a folder of Keras files, and one by one converts them
 * into the native DL4J format (.zip). This is safer to work with in DL4J than
 * importing from Keras files every time, and is fine to do in this case because
 * WDL4J defines a fixed set of models - this process only needs to be done once.
 */
public class KerasModelConverter {

    private static String modelSummariesPath = "";

    private static final String broadcastLayerRegex = "^broadcast_w(\\d+).*";

    private static void saveH5File(File modelFile, File outputFolder) {
        try {
            INDArray testShape = Nd4j.zeros(1, 3, 224, 224);
            String modelName = modelFile.getName();
            Method method = null;
            try {
                method = InputType.class.getMethod("setDefaultCNN2DFormat", CNN2DFormat.class);
                method.invoke(null, CNN2DFormat.NCHW);
            } catch (NoSuchMethodException ex) {
                System.err.println("setDefaultCNN2DFormat() not found on InputType class... " +
                        "Are you using the custom built deeplearning4j-nn.jar?");
                System.exit(1);
            }

            if (modelName.contains("EfficientNet")) {
                // Fixes for EfficientNet family of models
                testShape = Nd4j.zeros(1, 224, 224, 3);
                method.invoke(null, CNN2DFormat.NHWC);
                // We don't want the resulting .zip files to have 'Fixed' in the name, so we'll strip it off here
                modelName = modelName.replace("Fixed", "");
            }
            ComputationGraph kerasModel = KerasModelImport.importKerasModelAndWeights(modelFile.getAbsolutePath());
            kerasModel.feedForward(testShape, false);
            // e.g. ResNet50.h5 -> KerasResNet50.zip
            modelName = "Keras" + modelName.replace(".h5", ".zip");
            String newZip = Paths.get(outputFolder.getPath(), modelName).toString();
            kerasModel.save(new File(newZip));
            System.out.println("Saved file " + newZip);
        } catch (Exception e) {
            System.err.println("\n\nCouldn't save " + modelFile.getName());
            e.printStackTrace();
        }
    }

    public static void main(String[] args) throws Exception {
        if (args.length != 2) {
            System.err.println("Usage: KerasModelConverter <h5 folder path> <model summary folder path>");
            System.exit(1);
        }

        // Default location where Keras models are saved
        String modelFolderPath = args[0];
        modelSummariesPath = args[1];

        File modelFolder = new File(modelFolderPath);
        File outputFolder = new File(Paths.get(modelFolder.getParent(), "dl4j_format").toString());
        if (outputFolder.mkdir())
            System.out.println("Created DL4J format folder at " + outputFolder.getPath());

        File[] modelFiles = modelFolder.listFiles();
        if (modelFiles == null) {
            throw new Exception("Invalid folder name: " + modelFolderPath);
        }
        Arrays.sort(modelFiles);

        loadLambdaLayers();

        for (File fileEntry : modelFiles) {
            if (fileEntry.getPath().endsWith(".h5")) {
                saveH5File(fileEntry, outputFolder);
            }
        }
    }

    private static boolean isBroadcastLayer(String line) {
        Pattern p = Pattern.compile(broadcastLayerRegex);
        Matcher m = p.matcher(line);

        return m.matches();
    }

    private static int getWidth(String layerName) throws Exception {
        Pattern p = Pattern.compile(broadcastLayerRegex);
        Matcher m = p.matcher(layerName);

        if (m.find()) {
            String width = m.group(1);
            return Integer.parseInt(width);
        }
        throw new Exception("Couldn't find width in layerName " + layerName);
    }

    private static void loadLambdaLayers() throws Exception {
        File[] modelSummaries = new File(modelSummariesPath).listFiles();
        assert modelSummaries != null;
        Arrays.sort(modelSummaries);

        for (File f : modelSummaries) {
            BufferedReader br = new BufferedReader(new FileReader(f.getAbsoluteFile()));
            String modelName = f.getName();
            String line;
            while ((line = br.readLine()) != null) {
                //          __________________________________________________________________________________________________
                // Line is~ block2c_se_expand (Conv2D)      (None, 1, 1, 144)    1008        block2c_se_reduce[0][0]
                //          __________________________________________________________________________________________________
                if (isBroadcastLayer(line)) {
                    String[] lineParts = line.split(" ");
                    String layerName = lineParts[0]; // -> broadcast_w65_d144_2
                    int width = getWidth(layerName);
                    KerasLayer.registerLambdaLayer(layerName, new CustomBroadcast(width));
                    System.out.println(String.format("Registered %s layer %s with width %d", modelName, layerName, width));
                }
            }
        }
    }
}