/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.queries.payloads;

import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.spans.SpanCollector;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanNotQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.SpanWeight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;

public class TestPayloadSpans extends LuceneTestCase {
  private IndexSearcher searcher;
  private Similarity similarity = new ClassicSimilarity();
  protected IndexReader indexReader;
  private IndexReader closeIndexReader;
  private Directory directory;

  @Override
  public void setUp() throws Exception {
    super.setUp();
    PayloadHelper helper = new PayloadHelper();
    searcher = helper.setUp(random(), similarity, 1000);
    indexReader = searcher.getIndexReader();
  }

  public void testSpanTermQuery() throws Exception {
    SpanTermQuery stq;
    Spans spans;
    stq = new SpanTermQuery(new Term(PayloadHelper.FIELD, "seventy"));

    spans = stq.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 100, 1, 1, 1);

    stq = new SpanTermQuery(new Term(PayloadHelper.NO_PAYLOAD_FIELD, "seventy"));  
    spans = stq.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 100, 0, 0, 0);
  }

  public void testSpanFirst() throws IOException {

    SpanQuery match;
    SpanFirstQuery sfq;
    match = new SpanTermQuery(new Term(PayloadHelper.FIELD, "one"));
    sfq = new SpanFirstQuery(match, 2);
    Spans spans = sfq.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    checkSpans(spans, 109, 1, 1, 1);
    //Test more complicated subclause
    SpanQuery[] clauses = new SpanQuery[2];
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "one"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "hundred"));
    match = new SpanNearQuery(clauses, 0, true);
    sfq = new SpanFirstQuery(match, 2);
    checkSpans(sfq.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS), 100, 2, 1, 1);

    match = new SpanNearQuery(clauses, 0, false);
    sfq = new SpanFirstQuery(match, 2);
    checkSpans(sfq.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS), 100, 2, 1, 1);
    
  }
  
  public void testSpanNot() throws Exception {
    SpanQuery[] clauses = new SpanQuery[2];
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "one"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "three"));
    SpanQuery spq = new SpanNearQuery(clauses, 5, true);
    SpanNotQuery snq = new SpanNotQuery(spq, new SpanTermQuery(new Term(PayloadHelper.FIELD, "two")));



    Directory directory = newDirectory();
    RandomIndexWriter writer = new RandomIndexWriter(random(), directory,
                                                     newIndexWriterConfig(new PayloadAnalyzer()).setSimilarity(similarity));

    Document doc = new Document();
    doc.add(newTextField(PayloadHelper.FIELD, "one two three one four three", Field.Store.YES));
    writer.addDocument(doc);
    IndexReader reader = getOnlyLeafReader(writer.getReader());
    writer.close();

    checkSpans(snq.createWeight(newSearcher(reader, false), ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(reader.leaves().get(0), SpanWeight.Postings.PAYLOADS), 1, new int[]{2});
    reader.close();
    directory.close();
  }
  
  public void testNestedSpans() throws Exception {
    SpanTermQuery stq;
    Spans spans;
    IndexSearcher searcher = getSearcher();

    stq = new SpanTermQuery(new Term(PayloadHelper.FIELD, "mark"));
    spans = stq.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    assertNull(spans);

    SpanQuery[] clauses = new SpanQuery[3];
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "rr"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "yy"));
    clauses[2] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "xx"));
    SpanNearQuery spanNearQuery = new SpanNearQuery(clauses, 12, false);

    spans = spanNearQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 2, new int[]{3,3});

     
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "xx"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "rr"));
    clauses[2] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "yy"));

    spanNearQuery = new SpanNearQuery(clauses, 6, true);
   
    spans = spanNearQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);

    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 1, new int[]{3});
     
    clauses = new SpanQuery[2];
     
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "xx"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "rr"));

    spanNearQuery = new SpanNearQuery(clauses, 6, true);
     
    // xx within 6 of rr
    
    SpanQuery[] clauses2 = new SpanQuery[2];
     
    clauses2[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "yy"));
    clauses2[1] = spanNearQuery;
     
    SpanNearQuery nestedSpanNearQuery = new SpanNearQuery(clauses2, 6, false);
    
    // yy within 6 of xx within 6 of rr
    spans = nestedSpanNearQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 2, new int[]{3,3});
    closeIndexReader.close();
    directory.close();
  }
  
  public void testFirstClauseWithoutPayload() throws Exception {
    Spans spans;
    IndexSearcher searcher = getSearcher();

    SpanQuery[] clauses = new SpanQuery[3];
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "nopayload"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "qq"));
    clauses[2] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "ss"));

    SpanNearQuery spanNearQuery = new SpanNearQuery(clauses, 6, true);
    
    SpanQuery[] clauses2 = new SpanQuery[2];
     
    clauses2[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "pp"));
    clauses2[1] = spanNearQuery;

    SpanNearQuery snq = new SpanNearQuery(clauses2, 6, false);
    
    SpanQuery[] clauses3 = new SpanQuery[2];
     
    clauses3[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "np"));
    clauses3[1] = snq;

    SpanNearQuery nestedSpanNearQuery = new SpanNearQuery(clauses3, 6, false);
    spans = nestedSpanNearQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);

    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 1, new int[]{3});
    closeIndexReader.close();
    directory.close();
  }
  
  public void testHeavilyNestedSpanQuery() throws Exception {
    Spans spans;
    IndexSearcher searcher = getSearcher();

    SpanQuery[] clauses = new SpanQuery[3];
    clauses[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "one"));
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "two"));
    clauses[2] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "three"));

    SpanNearQuery spanNearQuery = new SpanNearQuery(clauses, 5, true);
   
    clauses = new SpanQuery[3];
    clauses[0] = spanNearQuery; 
    clauses[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "five"));
    clauses[2] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "six"));

    SpanNearQuery spanNearQuery2 = new SpanNearQuery(clauses, 6, true);
     
    SpanQuery[] clauses2 = new SpanQuery[2];
    clauses2[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "eleven"));
    clauses2[1] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "ten"));
    SpanNearQuery spanNearQuery3 = new SpanNearQuery(clauses2, 2, false);
    
    SpanQuery[] clauses3 = new SpanQuery[3];
    clauses3[0] = new SpanTermQuery(new Term(PayloadHelper.FIELD, "nine"));
    clauses3[1] = spanNearQuery2;
    clauses3[2] = spanNearQuery3;
     
    SpanNearQuery nestedSpanNearQuery = new SpanNearQuery(clauses3, 6, false);

    spans = nestedSpanNearQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(searcher.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);
    assertTrue("spans is null and it shouldn't be", spans != null);
    checkSpans(spans, 2, new int[]{8, 8});
    closeIndexReader.close();
    directory.close();
  }
  
  public void testShrinkToAfterShortestMatch() throws IOException {
    Directory directory = newDirectory();
    RandomIndexWriter writer = new RandomIndexWriter(random(), directory,
                                                     newIndexWriterConfig(new TestPayloadAnalyzer()));

    Document doc = new Document();
    doc.add(new TextField("content", new StringReader("a b c d e f g h i j a k")));
    writer.addDocument(doc);

    IndexReader reader = writer.getReader();
    IndexSearcher is = newSearcher(getOnlyLeafReader(reader), false);
    writer.close();

    SpanTermQuery stq1 = new SpanTermQuery(new Term("content", "a"));
    SpanTermQuery stq2 = new SpanTermQuery(new Term("content", "k"));
    SpanQuery[] sqs = { stq1, stq2 };
    SpanNearQuery snq = new SpanNearQuery(sqs, 1, true);
    VerifyingCollector collector = new VerifyingCollector();
    Spans spans = snq.createWeight(is, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(is.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);

    TopDocs topDocs = is.search(snq, 1);
    Set<String> payloadSet = new HashSet<>();
    for (int i = 0; i < topDocs.scoreDocs.length; i++) {
      while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
        while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) {
          collector.reset();
          spans.collect(collector);
          for (final BytesRef payload : collector.payloads) {
            payloadSet.add(Term.toString(payload));
          }
        }
      }
    }
    assertEquals(2, payloadSet.size());
    assertTrue(payloadSet.contains("a:Noise:10"));
    assertTrue(payloadSet.contains("k:Noise:11"));
    reader.close();
    directory.close();
  }
  
  public void testShrinkToAfterShortestMatch2() throws IOException {
    Directory directory = newDirectory();
    RandomIndexWriter writer = new RandomIndexWriter(random(), directory,
                                                     newIndexWriterConfig(new TestPayloadAnalyzer()));

    Document doc = new Document();
    doc.add(new TextField("content", new StringReader("a b a d k f a h i k a k")));
    writer.addDocument(doc);
    IndexReader reader = writer.getReader();
    IndexSearcher is = newSearcher(getOnlyLeafReader(reader), false);
    writer.close();

    SpanTermQuery stq1 = new SpanTermQuery(new Term("content", "a"));
    SpanTermQuery stq2 = new SpanTermQuery(new Term("content", "k"));
    SpanQuery[] sqs = { stq1, stq2 };
    SpanNearQuery snq = new SpanNearQuery(sqs, 0, true);
    VerifyingCollector collector = new VerifyingCollector();
    Spans spans = snq.createWeight(is, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(is.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);

    TopDocs topDocs = is.search(snq, 1);
    Set<String> payloadSet = new HashSet<>();
    for (int i = 0; i < topDocs.scoreDocs.length; i++) {
      while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
        while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) {
          collector.reset();
          spans.collect(collector);
          for (final BytesRef payload: collector.payloads) {
            payloadSet.add(Term.toString(payload));
          }
        }
      }
    }
    assertEquals(2, payloadSet.size());
    assertTrue(payloadSet.contains("a:Noise:10"));
    assertTrue(payloadSet.contains("k:Noise:11"));
    reader.close();
    directory.close();
  }
  
  public void testShrinkToAfterShortestMatch3() throws IOException {
    Directory directory = newDirectory();
    RandomIndexWriter writer = new RandomIndexWriter(random(), directory,
                                                     newIndexWriterConfig(new TestPayloadAnalyzer()));

    Document doc = new Document();
    doc.add(new TextField("content", new StringReader("j k a l f k k p a t a k l k t a")));
    writer.addDocument(doc);
    IndexReader reader = writer.getReader();
    IndexSearcher is = newSearcher(getOnlyLeafReader(reader), false);
    writer.close();

    SpanTermQuery stq1 = new SpanTermQuery(new Term("content", "a"));
    SpanTermQuery stq2 = new SpanTermQuery(new Term("content", "k"));
    SpanQuery[] sqs = { stq1, stq2 };
    SpanNearQuery snq = new SpanNearQuery(sqs, 0, true);
    Spans spans = snq.createWeight(is, ScoreMode.COMPLETE_NO_SCORES, 1f).getSpans(is.getIndexReader().leaves().get(0), SpanWeight.Postings.PAYLOADS);

    TopDocs topDocs = is.search(snq, 1);
    Set<String> payloadSet = new HashSet<>();
    VerifyingCollector collector = new VerifyingCollector();
    for (int i = 0; i < topDocs.scoreDocs.length; i++) {
      while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
        while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) {
          collector.reset();
          spans.collect(collector);
          for (final BytesRef payload : collector.payloads) {
            payloadSet.add(Term.toString(payload));
          }
        }
      }
    }
    assertEquals(2, payloadSet.size());
    if(VERBOSE) {
      for (final String payload : payloadSet)
        System.out.println("match:" +  payload);
      
    }
    assertTrue(payloadSet.contains("a:Noise:10"));
    assertTrue(payloadSet.contains("k:Noise:11"));
    reader.close();
    directory.close();
  }

  static class VerifyingCollector implements SpanCollector {

    List<BytesRef> payloads = new ArrayList<>();

    @Override
    public void collectLeaf(PostingsEnum postings, int position, Term term) throws IOException {
      if (postings.getPayload() != null) {
        payloads.add(BytesRef.deepCopyOf(postings.getPayload()));
      }
    }

    @Override
    public void reset() {
      payloads.clear();
    }

    public void verify(int expectedLength, int expectedFirstByte) {
      for (BytesRef payload : payloads) {
        assertEquals("Incorrect payload length", expectedLength, payload.length);
        assertEquals("Incorrect first byte", expectedFirstByte, payload.bytes[0]);
      }
    }
  }

  private void checkSpans(Spans spans, int expectedNumSpans, int expectedNumPayloads,
                          int expectedPayloadLength, int expectedFirstByte) throws IOException {
    assertTrue("spans is null and it shouldn't be", spans != null);
    //each position match should have a span associated with it, since there is just one underlying term query, there should
    //only be one entry in the span
    VerifyingCollector collector = new VerifyingCollector();
    int seen = 0;
    while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
      while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) {
        collector.reset();
        spans.collect(collector);
        collector.verify(expectedPayloadLength, expectedFirstByte);
        assertEquals("expectedNumPayloads", expectedNumPayloads, collector.payloads.size());
        seen++;
      }
    }
    assertEquals("expectedNumSpans", expectedNumSpans, seen);
  }


  private IndexSearcher getSearcher() throws Exception {
    directory = newDirectory();
    String[] docs = new String[]{"xx rr yy mm  pp","xx yy mm rr pp", "nopayload qq ss pp np", "one two three four five six seven eight nine ten eleven", "nine one two three four five six seven eight eleven ten"};
    RandomIndexWriter writer = new RandomIndexWriter(random(), directory,
                                                     newIndexWriterConfig(new PayloadAnalyzer()).setSimilarity(similarity));

    Document doc = null;
    for(int i = 0; i < docs.length; i++) {
      doc = new Document();
      String docText = docs[i];
      doc.add(newTextField(PayloadHelper.FIELD, docText, Field.Store.YES));
      writer.addDocument(doc);
    }

    writer.forceMerge(1);
    closeIndexReader = writer.getReader();
    writer.close();

    IndexSearcher searcher = newSearcher(closeIndexReader, false);
    return searcher;
  }
  
  private void checkSpans(Spans spans, int numSpans, int[] numPayloads) throws IOException {
    int cnt = 0;
    VerifyingCollector collector = new VerifyingCollector();
    while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
      while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) {
        if(VERBOSE)
          System.out.println("\nSpans Dump --");
        collector.reset();
        spans.collect(collector);
        assertEquals("payload size", numPayloads[cnt], collector.payloads.size());

        cnt++;
      }
    }

    assertEquals("expected numSpans", numSpans, cnt);
  }

  static final class PayloadAnalyzer extends Analyzer {

    @Override
    public TokenStreamComponents createComponents(String fieldName) {
      Tokenizer result = new MockTokenizer(MockTokenizer.SIMPLE, true);
      return new TokenStreamComponents(result, new PayloadFilter(result));
    }
  }

  static final class PayloadFilter extends TokenFilter {
    Set<String> entities = new HashSet<>();
    Set<String> nopayload = new HashSet<>();
    int pos;
    PayloadAttribute payloadAtt;
    CharTermAttribute termAtt;
    PositionIncrementAttribute posIncrAtt;

    public PayloadFilter(TokenStream input) {
      super(input);
      pos = 0;
      entities.add("xx");
      entities.add("one");
      nopayload.add("nopayload");
      nopayload.add("np");
      termAtt = addAttribute(CharTermAttribute.class);
      posIncrAtt = addAttribute(PositionIncrementAttribute.class);
      payloadAtt = addAttribute(PayloadAttribute.class);
    }

    @Override
    public boolean incrementToken() throws IOException {
      if (input.incrementToken()) {
        String token = termAtt.toString();

        if (!nopayload.contains(token)) {
          if (entities.contains(token)) {
            payloadAtt.setPayload(new BytesRef(token + ":Entity:"+ pos ));
          } else {
            payloadAtt.setPayload(new BytesRef(token + ":Noise:" + pos ));
          }
        }
        pos += posIncrAtt.getPositionIncrement();
        return true;
      }
      return false;
    }

    @Override
    public void reset() throws IOException {
      super.reset();
      this.pos = 0;
    }
  }
  
  public static final class TestPayloadAnalyzer extends Analyzer {

    @Override
    public TokenStreamComponents createComponents(String fieldName) {
      Tokenizer result = new MockTokenizer(MockTokenizer.SIMPLE, true);
      return new TokenStreamComponents(result, new PayloadFilter(result));
    }
  }
}