package org.vitrivr.cineast.core.features.neuralnet.tf.models.yolo;


import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.commons.math3.analysis.function.Sigmoid;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.vitrivr.cineast.core.color.RGBContainer;
import org.vitrivr.cineast.core.data.MultiImage;
import org.vitrivr.cineast.core.features.neuralnet.tf.GraphBuilder;
import org.vitrivr.cineast.core.features.neuralnet.tf.models.yolo.util.BoundingBox;
import org.vitrivr.cineast.core.features.neuralnet.tf.models.yolo.util.BoxPosition;
import org.vitrivr.cineast.core.features.neuralnet.tf.models.yolo.util.Recognition;
import org.vitrivr.cineast.core.util.LogHelper;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.MathHelper.ArgMaxResult;


/**
 * based on https://github.com/szaza/tensorflow-example-java
 */
public class YOLO implements AutoCloseable {

  private final static float OVERLAP_THRESHOLD = 0.5f;
  private final static double anchors[] = {1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62,
      10.52};
  private final static int SIZE = 13;
  private final static int MAX_RECOGNIZED_CLASSES = 24;
  private final static float THRESHOLD = 0.5f;
  private final static int MAX_RESULTS = 24;
  private final static int NUMBER_OF_BOUNDING_BOX = 5;
  private static final Logger LOGGER = LogManager.getLogger();
  private final String[] LABELS = {
      "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow",
      "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train",
      "tvmonitor"
  };
  private final Graph preprocessingGraph;
  private final Session preprocessingSession;
  private final String imageOutName;
  private final Graph yoloGraph;
  private final Session yoloSession;

  public YOLO() {
    byte[] GRAPH_DEF = new byte[0];
    try {
      GRAPH_DEF = Files
          .readAllBytes((Paths.get("resources/YOLO/yolo-voc.pb")));
    } catch (IOException e) {
      throw new RuntimeException(
          "could not load graph for YOLO: " + LogHelper.getStackTrace(e));
    }
    yoloGraph = new Graph();
    yoloGraph.importGraphDef(GRAPH_DEF);
    yoloSession = new Session(yoloGraph);

    preprocessingGraph = new Graph();

    GraphBuilder graphBuilder = new GraphBuilder(preprocessingGraph);

    Output<Float> imageFloat = graphBuilder.placeholder("T", Float.class);

    final int[] size = new int[]{416, 416};

    final Output<Float> output =

        graphBuilder.resizeBilinear( // Resize using bilinear interpolation
            graphBuilder.expandDims( // Increase the output tensors dimension
                imageFloat,
                graphBuilder.constant("make_batch", 0)),
            graphBuilder.constant("size", size)
        );

    imageOutName = output.op().name();

    preprocessingSession = new Session(preprocessingGraph);

  }

  private static Tensor<Float> readImage(MultiImage img) {

    float[][][] fimg = new float[img.getHeight()][img.getWidth()][3];

    int[] colors = img.getColors();

    for (int x = 0; x < img.getWidth(); ++x) {
      for (int y = 0; y < img.getHeight(); ++y) {
        int c = colors[x + img.getWidth() * y];
        fimg[y][x][0] = ((float) RGBContainer.getRed(c)) / 255f;
        fimg[y][x][1] = ((float) RGBContainer.getGreen(c)) / 255f;
        fimg[y][x][2] = ((float) RGBContainer.getBlue(c)) / 255f;
      }
    }

    return Tensor.create(fimg, Float.class);
  }

  /**
   * Gets the number of classes based on the tensor shape
   *
   * @param result - the tensorflow output
   * @return the number of classes
   */
  private int getOutputSizeByShape(Tensor<Float> result) {
    return (int) (result.shape()[3] * Math.pow(SIZE, 2));
  }

  /**
   * It classifies the object/objects on the image
   *
   * @param tensorFlowOutput output from the TensorFlow, it is a 13x13x((num_class +1) * 5) tensor
   * 125 = (numClass +  Tx, Ty, Tw, Th, To) * 5 - cause we have 5 boxes per each cell
   * @param labels a string vector with the labels
   * @return a list of recognition objects
   */
  private List<Recognition> classifyImage(final float[] tensorFlowOutput, final String[] labels) {
    int numClass = (int) (tensorFlowOutput.length / (Math.pow(SIZE, 2) * NUMBER_OF_BOUNDING_BOX)
        - 5);
    BoundingBox[][][] boundingBoxPerCell = new BoundingBox[SIZE][SIZE][NUMBER_OF_BOUNDING_BOX];
    PriorityQueue<Recognition> priorityQueue = new PriorityQueue<>(
        MAX_RECOGNIZED_CLASSES,
        new RecognitionComparator());

    int offset = 0;
    for (int cy = 0; cy < SIZE; cy++) {        // SIZE * SIZE cells
      for (int cx = 0; cx < SIZE; cx++) {
        for (int b = 0; b < NUMBER_OF_BOUNDING_BOX; b++) {   // 5 bounding boxes per each cell
          boundingBoxPerCell[cx][cy][b] = getModel(tensorFlowOutput, cx, cy, b, numClass, offset);
          calculateTopPredictions(boundingBoxPerCell[cx][cy][b], priorityQueue, labels);
          offset = offset + numClass + 5;
        }
      }
    }

    return getRecognition(priorityQueue);
  }

  private BoundingBox getModel(final float[] tensorFlowOutput, int cx, int cy, int b, int numClass,
      int offset) {
    BoundingBox model = new BoundingBox();
    Sigmoid sigmoid = new Sigmoid();
    model.setX((cx + sigmoid.value(tensorFlowOutput[offset])) * 32);
    model.setY((cy + sigmoid.value(tensorFlowOutput[offset + 1])) * 32);
    model.setWidth(Math.exp(tensorFlowOutput[offset + 2]) * anchors[2 * b] * 32);
    model.setHeight(Math.exp(tensorFlowOutput[offset + 3]) * anchors[2 * b + 1] * 32);
    model.setConfidence(sigmoid.value(tensorFlowOutput[offset + 4]));

    model.setClasses(new double[numClass]);

    for (int probIndex = 0; probIndex < numClass; probIndex++) {
      model.getClasses()[probIndex] = tensorFlowOutput[probIndex + offset + 5];
    }

    return model;
  }

  private void calculateTopPredictions(final BoundingBox boundingBox,
      final PriorityQueue<Recognition> predictionQueue,
      final String[] labels) {
    for (int i = 0; i < boundingBox.getClasses().length; i++) {

      ArgMaxResult argMax = MathHelper.argMax(MathHelper.softmax(boundingBox.getClasses()));
      double confidenceInClass = argMax.getMaxValue() * boundingBox.getConfidence();
      if (confidenceInClass > THRESHOLD) {
        predictionQueue.add(
            new Recognition(argMax.getIndex(), labels[argMax.getIndex()], (float) confidenceInClass,
                new BoxPosition((float) (boundingBox.getX() - boundingBox.getWidth() / 2),
                    (float) (boundingBox.getY() - boundingBox.getHeight() / 2),
                    (float) boundingBox.getWidth(),
                    (float) boundingBox.getHeight())));
      }
    }
  }

  private List<Recognition> getRecognition(final PriorityQueue<Recognition> priorityQueue) {
    ArrayList<Recognition> recognitions = new ArrayList<>();

    if (priorityQueue.size() > 0) {
      // Best recognition
      Recognition bestRecognition = priorityQueue.poll();
      recognitions.add(bestRecognition);

      for (int i = 0; i < Math.min(priorityQueue.size(), MAX_RESULTS); ++i) {
        Recognition recognition = priorityQueue.poll();
        boolean overlaps = false;
        for (Recognition previousRecognition : recognitions) {
          overlaps = overlaps || (getIntersectionProportion(previousRecognition.getLocation(),
              recognition.getLocation()) > OVERLAP_THRESHOLD);
        }

        if (!overlaps) {
          recognitions.add(recognition);
        }
      }
    }

    return recognitions;
  }

  private float getIntersectionProportion(BoxPosition primaryShape, BoxPosition secondaryShape) {
    if (BoxPosition.overlaps(primaryShape, secondaryShape)) {
      float intersectionSurface = Math.max(0,
          Math.min(primaryShape.getRight(), secondaryShape.getRight()) - Math
              .max(primaryShape.getLeft(), secondaryShape.getLeft())) *
          Math.max(0, Math.min(primaryShape.getBottom(), secondaryShape.getBottom()) - Math
              .max(primaryShape.getTop(), secondaryShape.getTop()));

      float surfacePrimary = Math.abs(primaryShape.getRight() - primaryShape.getLeft()) * Math
          .abs(primaryShape.getBottom() - primaryShape.getTop());

      return intersectionSurface / surfacePrimary;
    }

    return 0f;

  }

  @Override
  public void close() {
    yoloSession.close();
    yoloGraph.close();
    preprocessingSession.close();
    preprocessingGraph.close();
  }


  public List<Recognition> detect(MultiImage img) {

    try (Tensor<Float> normalizedImage = normalizeImage(img)) {
      return classifyImage(executeYOLOGraph(normalizedImage), LABELS);
    }
  }

  /**
   * Pre-process input. It resize the image and normalize its pixels
   *
   * @return Tensor<Float> with shape [1][416][416][3]
   */
  private Tensor<Float> normalizeImage(MultiImage img) {

    try (Tensor<Float> image = readImage(img)) {
      return preprocessingSession.runner().feed("T", image).fetch(imageOutName).run().get(0)
          .expect(Float.class);
    }
  }

  /**
   * Executes graph on the given preprocessed image
   *
   * @param image preprocessed image
   * @return output tensor returned by tensorFlow
   */
  private float[] executeYOLOGraph(final Tensor<Float> image) {

    Tensor<Float> result = yoloSession.runner().feed("input", image).fetch("output").run().get(0)
        .expect(Float.class);

    float[] outputTensor = new float[getOutputSizeByShape(result)];
    FloatBuffer floatBuffer = FloatBuffer.wrap(outputTensor);
    result.writeTo(floatBuffer);
    result.close();
    return outputTensor;

  }

  // Intentionally reversed to put high confidence at the head of the queue.
  private class RecognitionComparator implements Comparator<Recognition> {

    @Override
    public int compare(final Recognition recognition1, final Recognition recognition2) {
      return Float.compare(recognition2.getConfidence(), recognition1.getConfidence());
    }
  }
}