/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification.utils;


import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermRangeQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedThreadFactory;

/**
 * Utility class to generate the confusion matrix of a {@link Classifier}
 */
public class ConfusionMatrixGenerator {

  private ConfusionMatrixGenerator() {

  }

  /**
   * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
   * generated on the given {@link IndexReader}, class and text fields.
   *
   * @param reader              the {@link IndexReader} containing the index used for creating the {@link Classifier}
   * @param classifier          the {@link Classifier} whose confusion matrix has to be generated
   * @param classFieldName      the name of the Lucene field used as the classifier's output
   * @param textFieldName       the nome the Lucene field used as the classifier's input
   * @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix
   * @param <T>                 the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
   * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
   * @throws IOException if problems occurr while reading the index or using the classifier
   */
  public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName,
                                                       String textFieldName, long timeoutMilliseconds) throws IOException {

    ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));

    try {

      Map<String, Map<String, Long>> counts = new HashMap<>();
      IndexSearcher indexSearcher = new IndexSearcher(reader);
      TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
      double time = 0d;

      int counter = 0;
      for (ScoreDoc scoreDoc : topDocs.scoreDocs) {

        if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) {
          break;
        }

        Document doc = reader.document(scoreDoc.doc);
        String[] correctAnswers = doc.getValues(classFieldName);

        if (correctAnswers != null && correctAnswers.length > 0) {
          Arrays.sort(correctAnswers);
          ClassificationResult<T> result;
          String text = doc.get(textFieldName);
          if (text != null) {
            try {
              // fail if classification takes more than 5s
              long start = System.currentTimeMillis();
              result = executorService.submit(() -> classifier.assignClass(text)).get(5, TimeUnit.SECONDS);
              long end = System.currentTimeMillis();
              time += end - start;

              if (result != null) {
                T assignedClass = result.getAssignedClass();
                if (assignedClass != null) {
                  counter++;
                  String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();

                  String correctAnswer;
                  if (Arrays.binarySearch(correctAnswers, classified) >= 0) {
                    correctAnswer = classified;
                  } else {
                    correctAnswer = correctAnswers[0];
                  }

                  Map<String, Long> stringLongMap = counts.get(correctAnswer);
                  if (stringLongMap != null) {
                    Long aLong = stringLongMap.get(classified);
                    if (aLong != null) {
                      stringLongMap.put(classified, aLong + 1);
                    } else {
                      stringLongMap.put(classified, 1L);
                    }
                  } else {
                    stringLongMap = new HashMap<>();
                    stringLongMap.put(classified, 1L);
                    counts.put(correctAnswer, stringLongMap);
                  }

                }
              }
            } catch (TimeoutException timeoutException) {
              // add classification timeout
              time += 5000;
            } catch (ExecutionException | InterruptedException executionException) {
              throw new RuntimeException(executionException);
            }

          }
        }
      }
      return new ConfusionMatrix(counts, time / counter, counter);
    } finally {
      executorService.shutdown();
    }
  }

  /**
   * a confusion matrix, backed by a {@link Map} representing the linearized matrix
   */
  public static class ConfusionMatrix {

    private final Map<String, Map<String, Long>> linearizedMatrix;
    private final double avgClassificationTime;
    private final int numberOfEvaluatedDocs;
    private double accuracy = -1d;

    private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
      this.linearizedMatrix = linearizedMatrix;
      this.avgClassificationTime = avgClassificationTime;
      this.numberOfEvaluatedDocs = numberOfEvaluatedDocs;
    }

    /**
     * get the linearized confusion matrix as a {@link Map}
     *
     * @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers'
     * counts
     */
    public Map<String, Map<String, Long>> getLinearizedMatrix() {
      return Collections.unmodifiableMap(linearizedMatrix);
    }

    /**
     * calculate precision on the given class
     *
     * @param klass the class to calculate the precision for
     * @return the precision for the given class
     */
    public double getPrecision(String klass) {
      Map<String, Long> classifications = linearizedMatrix.get(klass);
      double tp = 0;
      double den = 0; // tp + fp
      if (classifications != null) {
        for (Map.Entry<String, Long> entry : classifications.entrySet()) {
          if (klass.equals(entry.getKey())) {
            tp += entry.getValue();
          }
        }
        for (Map<String, Long> values : linearizedMatrix.values()) {
          if (values.containsKey(klass)) {
            den += values.get(klass);
          }
        }
      }
      return tp > 0 ? tp / den : 0;
    }

    /**
     * calculate recall on the given class
     *
     * @param klass the class to calculate the recall for
     * @return the recall for the given class
     */
    public double getRecall(String klass) {
      Map<String, Long> classifications = linearizedMatrix.get(klass);
      double tp = 0;
      double fn = 0;
      if (classifications != null) {
        for (Map.Entry<String, Long> entry : classifications.entrySet()) {
          if (klass.equals(entry.getKey())) {
            tp += entry.getValue();
          } else {
            fn += entry.getValue();
          }
        }
      }
      return tp + fn > 0 ? tp / (tp + fn) : 0;
    }

    /**
     * get the F-1 measure of the given class
     *
     * @param klass the class to calculate the F-1 measure for
     * @return the F-1 measure for the given class
     */
    public double getF1Measure(String klass) {
      double recall = getRecall(klass);
      double precision = getPrecision(klass);
      return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
    }

    /**
     * get the F-1 measure on this confusion matrix
     *
     * @return the F-1 measure
     */
    public double getF1Measure() {
      double recall = getRecall();
      double precision = getPrecision();
      return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
    }

    /**
     * Calculate accuracy on this confusion matrix using the formula:
     * {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
     *
     * @return the accuracy
     */
    public double getAccuracy() {
      if (this.accuracy == -1) {
        double tp = 0d;
        double tn = 0d;
        double tfp = 0d; // tp + fp
        double fn = 0d;
        for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
          String klass = classification.getKey();
          for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
            if (klass.equals(entry.getKey())) {
              tp += entry.getValue();
            } else {
              fn += entry.getValue();
            }
          }
          for (Map<String, Long> values : linearizedMatrix.values()) {
            if (values.containsKey(klass)) {
              tfp += values.get(klass);
            } else {
              tn++;
            }
          }

        }
        this.accuracy = (tp + tn) / (tfp + fn + tn);
      }
      return this.accuracy;
    }

    /**
     * get the macro averaged precision (see {@link #getPrecision(String)}) over all the classes.
     *
     * @return the macro averaged precision as computed from the confusion matrix
     */
    public double getPrecision() {
      double p = 0;
      for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
        String klass = classification.getKey();
        p += getPrecision(klass);
      }

      return p / linearizedMatrix.size();
    }

    /**
     * get the macro averaged recall (see {@link #getRecall(String)}) over all the classes
     *
     * @return the recall as computed from the confusion matrix
     */
    public double getRecall() {
      double r = 0;
      for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
        String klass = classification.getKey();
        r += getRecall(klass);
      }

      return r / linearizedMatrix.size();
    }

    @Override
    public String toString() {
      return "ConfusionMatrix{" +
          "linearizedMatrix=" + linearizedMatrix +
          ", avgClassificationTime=" + avgClassificationTime +
          ", numberOfEvaluatedDocs=" + numberOfEvaluatedDocs +
          '}';
    }

    /**
     * get the average classification time in milliseconds
     *
     * @return the avg classification time
     */
    public double getAvgClassificationTime() {
      return avgClassificationTime;
    }

    /**
     * get the no. of documents evaluated while generating this confusion matrix
     *
     * @return the no. of documents evaluated
     */
    public int getNumberOfEvaluatedDocs() {
      return numberOfEvaluatedDocs;
    }
  }
}