package com.sematext.lucene.query.extractor; import static com.sematext.lucene.query.extractor.TestQueryExtractor.DEFAULT_EXTRACTORS; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import org.apache.lucene.index.Term; import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import org.junit.Test; import static org.mockito.Mockito.mock; public class TestDisjunctionQueryExtracotr extends TestQueryExtractor { @Test public void testExtractTwoSubqueries() { Query q1 = mock(Query.class); Query q2 = mock(Query.class); DisjunctionQueryExtractor disjunctionQueryExtracotr = new DisjunctionQueryExtractor(); List<Query> disjunctQueries = new ArrayList<>(); disjunctQueries.add(q1); disjunctQueries.add(q2); DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(disjunctQueries, 0.0f); List<Query> extractedQueries = new ArrayList<>(); disjunctionQueryExtracotr.extract(disjunctionMaxQuery, DEFAULT_EXTRACTORS, extractedQueries); assertEquals(2, extractedQueries.size()); assertEquals(q1, extractedQueries.get(0)); assertEquals(q2, extractedQueries.get(1)); } @Test public void testExtractSubqueryField() { Query q1 = new TermQuery(new Term("field1", "value1")); Query q2 = new TermQuery(new Term("field2", "value2")); DisjunctionQueryExtractor disjunctionQueryExtracotr = new DisjunctionQueryExtractor(); List<Query> disjunctQueries = new ArrayList<>(); disjunctQueries.add(q1); disjunctQueries.add(q2); DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(disjunctQueries, 0.0f); Set<String> extractedFieldNames = new HashSet<>(); disjunctionQueryExtracotr.extractSubQueriesFields(disjunctionMaxQuery, DEFAULT_EXTRACTORS, extractedFieldNames); assertEquals(2, extractedFieldNames.size()); assertTrue(extractedFieldNames.contains("field1")); assertTrue(extractedFieldNames.contains("field2")); } }