/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification.utils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Objects;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.FuzzyTermsEnum;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.automaton.LevenshteinAutomata;

/**
 * Simplification of FuzzyLikeThisQuery, to be used in the context of KNN classification.
 */
public class NearestFuzzyQuery extends Query {

  private final ArrayList<FieldVals> fieldVals = new ArrayList<>();
  private final Analyzer analyzer;

  // fixed parameters
  private static final int MAX_VARIANTS_PER_TERM = 50;
  private static final float MIN_SIMILARITY = 1f;
  private static final int PREFIX_LENGTH = 2;
  private static final int MAX_NUM_TERMS = 300;

  /**
   * Default constructor
   *
   * @param analyzer the analyzer used to process the query text
   */
  public NearestFuzzyQuery(Analyzer analyzer) {
    this.analyzer = analyzer;
  }

  static class FieldVals {
    final String queryString;
    final String fieldName;
    final int maxEdits;
    final int prefixLength;

    FieldVals(String name, int maxEdits, String queryString) {
      this.fieldName = name;
      this.maxEdits = maxEdits;
      this.queryString = queryString;
      this.prefixLength = NearestFuzzyQuery.PREFIX_LENGTH;
    }

    @Override
    public int hashCode() {
      final int prime = 31;
      int result = 1;
      result = prime * result
          + ((fieldName == null) ? 0 : fieldName.hashCode());
      result = prime * result + maxEdits;
      result = prime * result + prefixLength;
      result = prime * result
          + ((queryString == null) ? 0 : queryString.hashCode());
      return result;
    }

    @Override
    public boolean equals(Object obj) {
      if (this == obj)
        return true;
      if (obj == null)
        return false;
      if (getClass() != obj.getClass())
        return false;
      FieldVals other = (FieldVals) obj;
      if (fieldName == null) {
        if (other.fieldName != null)
          return false;
      } else if (!fieldName.equals(other.fieldName))
        return false;
      if (maxEdits != other.maxEdits) {
        return false;
      }
      if (prefixLength != other.prefixLength)
        return false;
      if (queryString == null) {
        return other.queryString == null;
      } else return queryString.equals(other.queryString);
    }


  }

  /**
   * Adds user input for "fuzzification"
   *
   * @param queryString The string which will be parsed by the analyzer and for which fuzzy variants will be parsed
   */
  public void addTerms(String queryString, String fieldName) {
    int maxEdits = (int) MIN_SIMILARITY;
    if (maxEdits != MIN_SIMILARITY) {
      throw new IllegalArgumentException("MIN_SIMILARITY must integer value between 0 and " + LevenshteinAutomata.MAXIMUM_SUPPORTED_DISTANCE + ", inclusive; got " + MIN_SIMILARITY);
    }
    fieldVals.add(new FieldVals(fieldName, maxEdits, queryString));
  }


  private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws IOException {
    if (f.queryString == null) return;
    final Terms terms = MultiTerms.getTerms(reader, f.fieldName);
    if (terms == null) {
      return;
    }
    try (TokenStream ts = analyzer.tokenStream(f.fieldName, f.queryString)) {
      CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);

      int corpusNumDocs = reader.numDocs();
      HashSet<String> processedTerms = new HashSet<>();
      ts.reset();
      while (ts.incrementToken()) {
        String term = termAtt.toString();
        if (!processedTerms.contains(term)) {
          processedTerms.add(term);
          ScoreTermQueue variantsQ = new ScoreTermQueue(MAX_VARIANTS_PER_TERM); //maxNum variants considered for any one term
          float minScore = 0;
          Term startTerm = new Term(f.fieldName, term);
          FuzzyTermsEnum fe = new FuzzyTermsEnum(terms, startTerm, f.maxEdits, f.prefixLength, true);
          //store the df so all variants use same idf
          int df = reader.docFreq(startTerm);
          int numVariants = 0;
          int totalVariantDocFreqs = 0;
          BytesRef possibleMatch;
          while ((possibleMatch = fe.next()) != null) {
            numVariants++;
            totalVariantDocFreqs += fe.docFreq();
            float score = fe.getBoost();
            if (variantsQ.size() < MAX_VARIANTS_PER_TERM || score > minScore) {
              ScoreTerm st = new ScoreTerm(new Term(startTerm.field(), BytesRef.deepCopyOf(possibleMatch)), score, startTerm);
              variantsQ.insertWithOverflow(st);
              minScore = variantsQ.top().score; // maintain minScore
            }
            fe.setMaxNonCompetitiveBoost(variantsQ.size() >= MAX_VARIANTS_PER_TERM ? minScore : Float.NEGATIVE_INFINITY);
          }

          if (numVariants > 0) {
            int avgDf = totalVariantDocFreqs / numVariants;
            if (df == 0)//no direct match we can use as df for all variants
            {
              df = avgDf; //use avg df of all variants
            }

            // take the top variants (scored by edit distance) and reset the score
            // to include an IDF factor then add to the global queue for ranking
            // overall top query terms
            int size = variantsQ.size();
            for (int i = 0; i < size; i++) {
              ScoreTerm st = variantsQ.pop();
              if (st != null) {
                st.score = (st.score * st.score) * idf(df, corpusNumDocs);
                q.insertWithOverflow(st);
              }
            }
          }
        }
      }
      ts.end();
    }
  }

  private float idf(int docFreq, int docCount) {
    return (float)(Math.log((docCount+1)/(double)(docFreq+1)) + 1.0);
  }

  private Query newTermQuery(IndexReader reader, Term term) throws IOException {
    // we build an artificial TermStates that will give an overall df and ttf
    // equal to 1
    TermStates termStates = new TermStates(reader.getContext());
    for (LeafReaderContext leafContext : reader.leaves()) {
      Terms terms = leafContext.reader().terms(term.field());
      if (terms != null) {
        TermsEnum termsEnum = terms.iterator();
        if (termsEnum.seekExact(term.bytes())) {
          int freq = 1 - termStates.docFreq(); // we want the total df and ttf to be 1
          termStates.register(termsEnum.termState(), leafContext.ord, freq, freq);
        }
      }
    }
    return new TermQuery(term, termStates);
  }

  @Override
  public Query rewrite(IndexReader reader) throws IOException {
    ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS);
    //load up the list of possible terms
    for (FieldVals f : fieldVals) {
      addTerms(reader, f, q);
    }

    BooleanQuery.Builder bq = new BooleanQuery.Builder();

    //create BooleanQueries to hold the variants for each token/field pair and ensure it
    // has no coord factor
    //Step 1: sort the termqueries by term/field
    HashMap<Term, ArrayList<ScoreTerm>> variantQueries = new HashMap<>();
    int size = q.size();
    for (int i = 0; i < size; i++) {
      ScoreTerm st = q.pop();
      if (st != null) {
        ArrayList<ScoreTerm> l = variantQueries.computeIfAbsent(st.fuzziedSourceTerm, k -> new ArrayList<>());
        l.add(st);
      }
    }
    //Step 2: Organize the sorted termqueries into zero-coord scoring boolean queries
    for (ArrayList<ScoreTerm> variants : variantQueries.values()) {
      if (variants.size() == 1) {
        //optimize where only one selected variant
        ScoreTerm st = variants.get(0);
        Query tq = newTermQuery(reader, st.term);
        // set the boost to a mix of IDF and score
        bq.add(new BoostQuery(tq, st.score), BooleanClause.Occur.SHOULD);
      } else {
        BooleanQuery.Builder termVariants = new BooleanQuery.Builder();
        for (ScoreTerm st : variants) {
          // found a match
          Query tq = newTermQuery(reader, st.term);
          // set the boost using the ScoreTerm's score
          termVariants.add(new BoostQuery(tq, st.score), BooleanClause.Occur.SHOULD);          // add to query
        }
        bq.add(termVariants.build(), BooleanClause.Occur.SHOULD);          // add to query
      }
    }
    //TODO possible alternative step 3 - organize above booleans into a new layer of field-based
    // booleans with a minimum-should-match of NumFields-1?
    return bq.build();
  }

  //Holds info for a fuzzy term variant - initially score is set to edit distance (for ranking best
  // term variants) then is reset with IDF for use in ranking against all other
  // terms/fields
  private static class ScoreTerm {
    public final Term term;
    public float score;
    final Term fuzziedSourceTerm;

    ScoreTerm(Term term, float score, Term fuzziedSourceTerm) {
      this.term = term;
      this.score = score;
      this.fuzziedSourceTerm = fuzziedSourceTerm;
    }
  }

  private static class ScoreTermQueue extends PriorityQueue<ScoreTerm> {
    ScoreTermQueue(int size) {
      super(size);
    }

    /* (non-Javadoc)
     * @see org.apache.lucene.util.PriorityQueue#lessThan(java.lang.Object, java.lang.Object)
     */
    @Override
    protected boolean lessThan(ScoreTerm termA, ScoreTerm termB) {
      if (termA.score == termB.score)
        return termA.term.compareTo(termB.term) > 0;
      else
        return termA.score < termB.score;
    }

  }

  @Override
  public String toString(String field) {
    return null;
  }

  @Override
  public int hashCode() {
    int prime = 31;
    int result = classHash();
    result = prime * result + Objects.hashCode(analyzer);
    result = prime * result + Objects.hashCode(fieldVals);
    return result;
  }

  @Override
  public boolean equals(Object other) {
    return sameClassAs(other) &&
        equalsTo(getClass().cast(other));
  }

  private boolean equalsTo(NearestFuzzyQuery other) {
    return Objects.equals(analyzer, other.analyzer) &&
        Objects.equals(fieldVals, other.fieldVals);
  }

  @Override
  public void visit(QueryVisitor visitor) {
    visitor.visitLeaf(this);
  }

}