/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.search; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; import java.util.Map; import java.util.Set; import com.carrotsearch.hppc.IntFloatHashMap; import com.carrotsearch.hppc.IntIntHashMap; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.Rescorer; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrException; import org.apache.solr.handler.component.QueryElevationComponent; import org.apache.solr.request.SolrRequestInfo; /* A TopDocsCollector used by reranking queries. */ @SuppressWarnings({"rawtypes"}) public class ReRankCollector extends TopDocsCollector { final private TopDocsCollector<?> mainCollector; final private IndexSearcher searcher; final private int reRankDocs; final private int length; final private Set<BytesRef> boostedPriority; // order is the "priority" final private Rescorer reRankQueryRescorer; final private Sort sort; final private Query query; @SuppressWarnings({"unchecked"}) public ReRankCollector(int reRankDocs, int length, Rescorer reRankQueryRescorer, QueryCommand cmd, IndexSearcher searcher, Set<BytesRef> boostedPriority) throws IOException { super(null); this.reRankDocs = reRankDocs; this.length = length; this.boostedPriority = boostedPriority; this.query = cmd.getQuery(); Sort sort = cmd.getSort(); if(sort == null) { this.sort = null; this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length), cmd.getMinExactCount()); } else { this.sort = sort = sort.rewrite(searcher); //scores are needed for Rescorer (regardless of whether sort needs it) this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), cmd.getMinExactCount()); } this.searcher = searcher; this.reRankQueryRescorer = reRankQueryRescorer; } public int getTotalHits() { return mainCollector.getTotalHits(); } @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { return mainCollector.getLeafCollector(context); } @Override public ScoreMode scoreMode() { return this.mainCollector.scoreMode(); } @SuppressWarnings({"unchecked"}) public TopDocs topDocs(int start, int howMany) { try { TopDocs mainDocs = mainCollector.topDocs(0, Math.max(reRankDocs, length)); if(mainDocs.totalHits.value == 0 || mainDocs.scoreDocs.length == 0) { return mainDocs; } if (sort != null) { TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, query); } ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs; ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)]; System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length); mainDocs.scoreDocs = reRankScoreDocs; TopDocs rescoredDocs = reRankQueryRescorer .rescore(searcher, mainDocs, mainDocs.scoreDocs.length); //Lower howMany to return if we've collected fewer documents. howMany = Math.min(howMany, mainScoreDocs.length); if(boostedPriority != null) { SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); Map requestContext = null; if(info != null) { requestContext = info.getReq().getContext(); } IntIntHashMap boostedDocs = QueryElevationComponent.getBoostDocs((SolrIndexSearcher)searcher, boostedPriority, requestContext); float maxScore = rescoredDocs.scoreDocs.length == 0 ? Float.NaN : rescoredDocs.scoreDocs[0].score; Arrays.sort(rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, maxScore)); } if(howMany == rescoredDocs.scoreDocs.length) { return rescoredDocs; // Just return the rescoredDocs } else if(howMany > rescoredDocs.scoreDocs.length) { //We need to return more then we've reRanked, so create the combined page. ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; System.arraycopy(mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length); //lay down the initial docs System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, rescoredDocs.scoreDocs.length);//overlay the re-ranked docs. rescoredDocs.scoreDocs = scoreDocs; return rescoredDocs; } else { //We've rescored more then we need to return. ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, howMany); rescoredDocs.scoreDocs = scoreDocs; return rescoredDocs; } } catch (Exception e) { throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); } } @SuppressWarnings({"rawtypes"}) public static class BoostedComp implements Comparator { IntFloatHashMap boostedMap; public BoostedComp(IntIntHashMap boostedDocs, ScoreDoc[] scoreDocs, float maxScore) { this.boostedMap = new IntFloatHashMap(boostedDocs.size()*2); for(int i=0; i<scoreDocs.length; i++) { final int idx; if((idx = boostedDocs.indexOf(scoreDocs[i].doc)) >= 0) { boostedMap.put(scoreDocs[i].doc, maxScore+boostedDocs.indexGet(idx)); } else { break; } } } public int compare(Object o1, Object o2) { ScoreDoc doc1 = (ScoreDoc) o1; ScoreDoc doc2 = (ScoreDoc) o2; float score1 = doc1.score; float score2 = doc2.score; int idx; if((idx = boostedMap.indexOf(doc1.doc)) >= 0) { score1 = boostedMap.indexGet(idx); } if((idx = boostedMap.indexOf(doc2.doc)) >= 0) { score2 = boostedMap.indexGet(idx); } return -Float.compare(score1, score2); } } }