/*
 * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations
 * under the License.
 */

package edu.cmu.lti.oaqa.baseqa.passage.rerank.scorers;

import com.google.common.collect.*;
import com.google.common.io.Resources;
import edu.cmu.lti.oaqa.baseqa.learning_base.AbstractScorer;
import edu.cmu.lti.oaqa.type.kb.Concept;
import edu.cmu.lti.oaqa.type.kb.ConceptMention;
import edu.cmu.lti.oaqa.type.kb.ConceptType;
import edu.cmu.lti.oaqa.type.nlp.Token;
import edu.cmu.lti.oaqa.type.retrieval.Passage;
import edu.cmu.lti.oaqa.util.TypeUtil;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.store.RAMDirectory;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.*;

/**
 * An instance of an {@link AbstractScorer} for {@link Passage}s that scores candidate
 * {@link Passage}s, by using various different ways of constructing the queries, from only the
 * tokens to a full combination of tokens and concepts with all the synonyms, to query an
 * in-memory Lucene index that is created on the fly.
 * Each query construction method contributes to a dimension of the feature, and the retrieval
 * scores are used as the values.
 *
 * @see edu.cmu.lti.oaqa.baseqa.passage.retrieval.LuceneInMemorySentenceRetrievalExecutor
 *
 * @author <a href="mailto:[email protected]">Zi Yang</a> created on 4/5/16
 */
public class LuceneInMemoryPassageScorer extends AbstractScorer<Passage> {

  // TODO: Moved to separate files
  private final static Set<String> FORBIDDEN_CTYPES = ImmutableSet
          .of("lingpipe:other_name", "umls:qlco", "umls:qnco", "umls:ftcn", "umls:geoa");

  private int hits;

  private Set<String> stoplist;

  private Table<String, String, Double> uri2conf2score;

  private Table<String, String, Integer> uri2conf2rank;

  private Set<String> confs;

  private StandardAnalyzer analyzer;

  private QueryParser parser;

  private IndexReader reader;

  private IndexSearcher searcher;

  private static final Logger LOG = LoggerFactory.getLogger(LuceneInMemoryPassageScorer.class);

  @Override
  public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> aAdditionalParams)
          throws ResourceInitializationException {
    super.initialize(aSpecifier, aAdditionalParams);
    hits = Integer.class.cast(getParameterValue("hits"));
    // query constructor
    String stoplistPath = String.class.cast(getParameterValue("stoplist-path"));
    try {
      stoplist = Resources.readLines(getClass().getResource(stoplistPath), UTF_8).stream()
              .map(String::trim).collect(toSet());
    } catch (IOException e) {
      throw new ResourceInitializationException(e);
    }
    analyzer = new StandardAnalyzer();
    parser = new QueryParser("text", analyzer);
    return true;
  }

  private static String normalizeQuoteName(String name) {
    return "\"" + QueryParser.escape(name) + "\"";
  }

  @Override
  public void prepare(JCas jcas) throws AnalysisEngineProcessException {
    uri2conf2score = HashBasedTable.create();
    uri2conf2rank = HashBasedTable.create();
    // index
    List<Passage> passages = TypeUtil.getRankedPassages(jcas);
    RAMDirectory index = new RAMDirectory();
    try (IndexWriter writer = new IndexWriter(index, new IndexWriterConfig(analyzer))) {
      for (Passage passage : passages) {
        Document doc = new Document();
        doc.add(new StringField("uri", TypeUtil.getUriOffsets(passage, ":"), Field.Store.YES));
        doc.add(new TextField("text", passage.getText(), Field.Store.NO));
        writer.addDocument(doc);
      }
      writer.close();
      reader = DirectoryReader.open(index);
      searcher = new IndexSearcher(reader);
    } catch (IOException e) {
      throw new AnalysisEngineProcessException(e);
    }
    // queries
    List<String> tokens = TypeUtil.getOrderedTokens(jcas).stream().map(Token::getCoveredText)
            .map(QueryParser::escape)
            .filter(name -> !name.isEmpty() && !stoplist.contains(name.toLowerCase()))
            .collect(toList());
    Multimap<String, String> ctype2names = HashMultimap.create();
    for (Concept concept : TypeUtil.getConcepts(jcas)) {
      Set<String> ctypes = TypeUtil.getConceptTypes(concept).stream()
              .map(ConceptType::getAbbreviation).collect(toSet());
      String cnames = TypeUtil.getConceptNames(concept).stream()
              .map(LuceneInMemoryPassageScorer::normalizeQuoteName).distinct()
              .collect(joining(" "));
      ctypes.stream().filter(t -> !FORBIDDEN_CTYPES.contains(t))
              .forEach(ctype -> ctype2names.put(ctype, cnames));
    }
    Multimap<String, String> ctypepre2names = HashMultimap.create();
    ctype2names.asMap().entrySet()
            .forEach(e -> ctypepre2names.putAll(e.getKey().split(":")[0], e.getValue()));
    Multimap<String, String> ctype2mentions = HashMultimap.create();
    for (Concept concept : TypeUtil.getConcepts(jcas)) {
      Set<String> ctypes = TypeUtil.getConceptTypes(concept).stream()
              .map(ConceptType::getAbbreviation).collect(toSet());
      String cmentions = TypeUtil.getConceptMentions(concept).stream()
              .map(ConceptMention::getMatchedName)
              .map(LuceneInMemoryPassageScorer::normalizeQuoteName).distinct()
              .collect(joining(" "));
      ctypes.stream().filter(t -> !FORBIDDEN_CTYPES.contains(t))
              .forEach(ctype -> ctype2mentions.put(ctype, cmentions));
    }
    Multimap<String, String> ctypepre2mentions = HashMultimap.create();
    ctypepre2mentions.asMap().entrySet()
            .forEach(e -> ctypepre2mentions.putAll(e.getKey().split(":")[0], e.getValue()));
    LOG.debug("Query strings");
    ExecutorService service = Executors.newCachedThreadPool();
    // execute against all tokens
    service.submit(() -> {
      String concatTokens = String.join(" ", tokens);
      LOG.debug(" - Concatenated tokens: {}", concatTokens);
      search(concatTokens, "[email protected]");
    });
    // execute against concatenated concept names
    service.submit(() -> {
      String concatCnames = String.join(" ", ctype2names.values());
      LOG.debug(" - Concatenated concept names: {}", concatCnames);
      search(concatCnames, "[email protected]");
    });
    // execute against concatenated concept mentions
    service.submit(() -> {
      String concatCmentions = String.join(" ", ctype2mentions.values());
      LOG.debug(" - Concatenated concept mentions: {}", concatCmentions);
      search(concatCmentions, "[email protected]");
    });
    // execute against concept names for each concept
    service.submit(() -> {
      for (String cnames : ImmutableSet.copyOf(ctype2names.values())) {
        LOG.debug(" - Concatenated concept names: {}", cnames);
        search(cnames, "[email protected]");
      }
    });
    // execute against concept names for each concept type
    service.submit(() -> {
      for (String ctype : ctype2names.keySet()) {
        String concatCnames = String.join(" ", ctype2names.get(ctype));
        LOG.debug(" - Concatenated concept names for {}: {}", ctype, concatCnames);
        search(concatCnames, "[email protected]" + ctype + "@all");
      }
    });
    // execute against concept names for each concept type prefix
    service.submit(() -> {
      for (String ctypepre : ctypepre2names.keySet()) {
        String concatCnames = String.join(" ", ctypepre2names.get(ctypepre));
        LOG.debug(" - Concatenated concept names for {}: {}", ctypepre, concatCnames);
        search(concatCnames, "[email protected]" + ctypepre + "@all");
      }
    });
    // execute against concept mentions for each concept
    service.submit(() -> {
      for (String cmentions : ImmutableSet.copyOf(ctype2mentions.values())) {
        LOG.debug(" - Concatenated concept mentions: {}", cmentions);
        search(cmentions, "[email protected]");
      }
    });
    // execute against concept mentions for each concept type
    service.submit(() -> {
      for (String ctype : ctype2mentions.keySet()) {
        String concatCmentions = String.join(" ", ctype2mentions.get(ctype));
        LOG.debug(" - Concatenated concept mentions for {}: {}", ctype, concatCmentions);
        search(concatCmentions, "[email protected]" + ctype + "@all");
      }
    });
    // execute against concept mentions for each concept type prefix
    service.submit(() -> {
      for (String ctypepre : ctypepre2mentions.keySet()) {
        String concatCmentions = String.join(" ", ctypepre2mentions.get(ctypepre));
        LOG.debug(" - Concatenated concept mentions for {}: {}", ctypepre, concatCmentions);
        search(concatCmentions, "[email protected]" + ctypepre + "@all");
      }
    });
    service.shutdown();
    try {
      service.awaitTermination(1, TimeUnit.MINUTES);
    } catch (InterruptedException e) {
      throw new AnalysisEngineProcessException(e);
    }
    confs = uri2conf2score.columnKeySet();
  }

  private void search(String queryString, String conf) throws RuntimeException {
    if (queryString.trim().isEmpty()) return;
    ScoreDoc[] results;
    try {
      Query query = parser.parse(queryString);
      results = searcher.search(query, hits).scoreDocs;
    } catch (ParseException | IOException e) {
      throw new RuntimeException(e);
    }
    for (int i = 0; i < results.length; i++) {
      try {
        int doc = results[i].doc;
        String uri = reader.document(doc).get("uri");
        if (!uri2conf2rank.contains(uri, conf) || uri2conf2rank.get(uri, conf) > i) {
          synchronizedPut(uri2conf2rank, uri, conf, i);
        }
        double score = results[i].score;
        if (!uri2conf2score.contains(uri, conf) || uri2conf2score.get(uri, conf) < score) {
          synchronizedPut(uri2conf2score, uri, conf, score);
        }
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }
  }

  private static synchronized <T extends Number> void synchronizedPut(
          Table<String, String, T> table, String uri, String conf, T value) {
    table.put(uri, conf, value);
  }

  @Override
  public Map<String, Double> score(JCas jcas, Passage result) {
    ImmutableMap.Builder<String, Double> ret = new ImmutableMap.Builder<>();
    String uri = TypeUtil.getUriOffsets(result, ":");
    for (String conf : confs) {
      double rank = uri2conf2rank.contains(uri, conf) ?
              1.0 / (uri2conf2rank.get(uri, conf) + 1.0) :
              0.0;
      ret.put(conf + "/rank", rank);
      double score = uri2conf2score.contains(uri, conf) ?
              uri2conf2score.get(uri, conf) :
              0.0;
      ret.put(conf + "/score", score);
    }
    return ret.build();
  }

}