/* * #%L * Alfresco Search Services * %% * Copyright (C) 2005 - 2020 Alfresco Software Limited * %% * This file is part of the Alfresco software. * If the software was purchased under a paid Alfresco license, the terms of * the paid license agreement will prevail. Otherwise, the software is * provided under the following open source license terms: * * Alfresco is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Alfresco is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with Alfresco. If not, see <http://www.gnu.org/licenses/>. * #L% */ package org.alfresco.solr.query; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.Map; import java.util.Set; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryRescorer; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrException; import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.handler.component.MergeStrategy; import org.apache.solr.handler.component.QueryElevationComponent; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.request.SolrRequestInfo; import org.apache.solr.search.QParser; import org.apache.solr.search.QParserPlugin; import org.apache.solr.search.RankQuery; import org.apache.solr.search.QueryCommand; import org.apache.solr.search.SolrIndexSearcher; import org.apache.solr.search.SyntaxError; import com.carrotsearch.hppc.IntFloatHashMap; import com.carrotsearch.hppc.IntIntHashMap; /* * * Syntax: q=*:*&rq={!rerank reRankQuery=$rqq reRankDocs=300 reRankWeight=3} * */ public class AlfrescoReRankQParserPlugin extends QParserPlugin { public static final String NAME = "rerank"; private static Query defaultQuery = new MatchAllDocsQuery(); public void init(NamedList args) { } public QParser createParser(String query, SolrParams localParams, SolrParams params, SolrQueryRequest req) { return new ReRankQParser(query, localParams, params, req); } private class ReRankQParser extends QParser { public ReRankQParser(String query, SolrParams localParams, SolrParams params, SolrQueryRequest req) { super(query, localParams, params, req); } public Query parse() throws SyntaxError { String reRankQueryString = localParams.get("reRankQuery"); boolean scale = localParams.getBool("scale", false); QParser reRankParser = QParser.getParser(reRankQueryString, null, req); Query reRankQuery = reRankParser.parse(); int reRankDocs = localParams.getInt("reRankDocs", 200); reRankDocs = Math.max(1, reRankDocs); // double reRankWeight = localParams.getDouble("reRankWeight",2.0d); int start = params.getInt(CommonParams.START,0); int rows = params.getInt(CommonParams.ROWS,10); int length = start+rows; return new ReRankQuery(reRankQuery, reRankDocs, reRankWeight, length, scale); } } private class ReRankQuery extends RankQuery { private Query mainQuery = defaultQuery; private Query reRankQuery; private int reRankDocs; private int length; private double reRankWeight; private boolean scale; private Map<BytesRef, Integer> boostedPriority; public int hashCode() { return mainQuery.hashCode()+reRankQuery.hashCode()+(int)reRankWeight+reRankDocs+(scale ? 1 : 0); } public boolean equals(Object o) { if(o instanceof ReRankQuery) { ReRankQuery rrq = (ReRankQuery)o; return (mainQuery.equals(rrq.mainQuery) && reRankQuery.equals(rrq.reRankQuery) && reRankWeight == rrq.reRankWeight && reRankDocs == rrq.reRankDocs && scale == rrq.scale); } return false; } public ReRankQuery(Query reRankQuery, int reRankDocs, double reRankWeight, int length, boolean scale) { this.reRankQuery = reRankQuery; this.reRankDocs = reRankDocs; this.reRankWeight = reRankWeight; this.length = length; this.scale = scale; } public RankQuery wrap(Query _mainQuery) { if(_mainQuery != null){ this.mainQuery = _mainQuery; } return this; } public MergeStrategy getMergeStrategy() { return null; } public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { if(this.boostedPriority == null) { SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); if(info != null) { Map context = info.getReq().getContext(); this.boostedPriority = (Map<BytesRef, Integer>)context.get(QueryElevationComponent.BOOSTED_PRIORITY); } } return new ReRankCollector(reRankDocs, length, reRankQuery, reRankWeight, cmd, searcher, boostedPriority, scale); } public String toString(String s) { return "{!rerank mainQuery='"+mainQuery.toString()+ "' reRankQuery='"+reRankQuery.toString()+ "' reRankDocs="+reRankDocs+ " reRankWeigh="+reRankWeight+"}"; } public Query rewrite(IndexReader reader) throws IOException { Query q = mainQuery.rewrite(reader); if(q == mainQuery) { return this; } else { return clone().wrap(q); } } public ReRankQuery clone() { ReRankQuery clonedQuery = new ReRankQuery(reRankQuery, reRankDocs, reRankWeight, length, scale); return clonedQuery; } public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException{ return new ReRankWeight(mainQuery, reRankQuery, reRankWeight, searcher); } } private class ReRankWeight extends Weight{ private Query reRankQuery; private IndexSearcher searcher; private Weight mainWeight; private double reRankWeight; public ReRankWeight(Query mainQuery, Query reRankQuery, double reRankWeight, IndexSearcher searcher) throws IOException { super(reRankQuery); this.reRankQuery = reRankQuery; this.searcher = searcher; this.reRankWeight = reRankWeight; this.mainWeight = mainQuery.createWeight(searcher, true); } @Override public void extractTerms(Set<Term> terms) { this.mainWeight.extractTerms(terms); } public float getValueForNormalization() throws IOException { return mainWeight.getValueForNormalization(); } public Scorer scorer(LeafReaderContext context) throws IOException { return mainWeight.scorer(context); } public void normalize(float norm, float topLevelBoost) { mainWeight.normalize(norm, topLevelBoost); } public Explanation explain(LeafReaderContext context, int doc) throws IOException { Explanation mainExplain = mainWeight.explain(context, doc); return new QueryRescorer(getQuery()) { @Override protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) { float score = firstPassScore; if (secondPassMatches) { score += reRankWeight * secondPassScore; } return score; } }.explain(searcher, mainExplain, context.docBase+doc); } } private class ReRankCollector extends TopDocsCollector { private Query reRankQuery; private TopDocsCollector mainCollector; private IndexSearcher searcher; private int reRankDocs; private int length; private double reRankWeight; private Map<BytesRef, Integer> boostedPriority; private float minScore = Float.MAX_VALUE; private float maxScore = -Float.MAX_VALUE; private Scorer localScorer; private boolean scale; public ReRankCollector(int reRankDocs, int length, Query reRankQuery, double reRankWeight, QueryCommand cmd, IndexSearcher searcher, Map<BytesRef, Integer> boostedPriority, boolean scale) throws IOException { super(null); this.reRankQuery = reRankQuery; this.reRankDocs = reRankDocs; this.length = length; this.boostedPriority = boostedPriority; this.scale = scale; Sort sort = cmd.getSort(); if(sort == null) { this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length), null); } else { sort = sort.rewrite(searcher); this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), null, false, true, true); } this.searcher = searcher; this.reRankWeight = reRankWeight; } public int getTotalHits() { return mainCollector.getTotalHits(); } public TopDocs topDocs(int start, int howMany) { try { TopDocs mainDocs = mainCollector.topDocs(0, Math.max(reRankDocs, length)); if(mainDocs.totalHits == 0 || mainDocs.scoreDocs.length == 0) { return mainDocs; } if(reRankDocs == 0) { scaleScores(mainDocs, new HashMap()); return mainDocs; } if(boostedPriority != null) { SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); Map requestContext = null; if(info != null) { requestContext = info.getReq().getContext(); } IntIntHashMap boostedDocs = QueryElevationComponent.getBoostDocs((SolrIndexSearcher)searcher, boostedPriority, requestContext); ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs; ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)]; System.arraycopy(mainScoreDocs,0,reRankScoreDocs,0,reRankScoreDocs.length); mainDocs.scoreDocs = reRankScoreDocs; Map<Integer, Float> scoreMap = getScoreMap(mainDocs.scoreDocs, mainDocs.scoreDocs.length); TopDocs rescoredDocs = new QueryRescorer(reRankQuery) { @Override protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) { float score = firstPassScore; if (secondPassMatches) { score += reRankWeight * secondPassScore; } return score; } }.rescore(searcher, mainDocs, mainDocs.scoreDocs.length); Arrays.sort(rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, rescoredDocs.getMaxScore())); //Lower howMany if we've collected fewer documents. howMany = Math.min(howMany, mainScoreDocs.length); if(howMany == rescoredDocs.scoreDocs.length) { if(scale) { scaleScores(rescoredDocs, scoreMap); } return rescoredDocs; // Just return the rescoredDocs } else if(howMany > rescoredDocs.scoreDocs.length) { //We need to return more then we've reRanked, so create the combined page. ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; System.arraycopy(mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length); //lay down the initial docs System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, rescoredDocs.scoreDocs.length);//overlay the re-ranked docs. rescoredDocs.scoreDocs = scoreDocs; if(scale) { scaleScores(rescoredDocs, scoreMap); } return rescoredDocs; } else { //We've rescored more then we need to return. ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, howMany); rescoredDocs.scoreDocs = scoreDocs; if(scale) { scaleScores(rescoredDocs, scoreMap); } return rescoredDocs; } } else { ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs; /* * Create the array for the reRankScoreDocs. */ ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)]; /* * Copy the initial results into the reRankScoreDocs array. */ System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length); mainDocs.scoreDocs = reRankScoreDocs; Map<Integer, Float> scoreMap = getScoreMap(mainDocs.scoreDocs, mainDocs.scoreDocs.length); TopDocs rescoredDocs = new QueryRescorer(reRankQuery) { @Override protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) { float score = firstPassScore; if (secondPassMatches) { score += reRankWeight * secondPassScore; } return score; } }.rescore(searcher, mainDocs, mainDocs.scoreDocs.length); //Lower howMany to return if we've collected fewer documents. howMany = Math.min(howMany, mainScoreDocs.length); if(howMany == rescoredDocs.scoreDocs.length) { if(scale) { scaleScores(rescoredDocs, scoreMap); } return rescoredDocs; // Just return the rescoredDocs } else if(howMany > rescoredDocs.scoreDocs.length) { //We need to return more then we've reRanked, so create the combined page. ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; //lay down the initial docs System.arraycopy(mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length); //overlay the rescoreds docs System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, rescoredDocs.scoreDocs.length); rescoredDocs.scoreDocs = scoreDocs; if(scale) { assert(scoreMap != null); scaleScores(rescoredDocs, scoreMap); } return rescoredDocs; } else { //We've rescored more then we need to return. ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, howMany); rescoredDocs.scoreDocs = scoreDocs; if(scale) { scaleScores(rescoredDocs, scoreMap); } return rescoredDocs; } } } catch (Exception e) { throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); } } @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { return mainCollector.getLeafCollector(context); } @Override public boolean needsScores() { return true; } } private void scaleScores(TopDocs topDocs, Map<Integer, Float> scoreMap) { float maxScore = topDocs.getMaxScore(); float newMax = -Float.MAX_VALUE; for(ScoreDoc scoreDoc : topDocs.scoreDocs) { float score = scoreDoc.score; Float oldScore = scoreMap.get(scoreDoc.doc); // check if the score has been changed after rescoring boolean rescored = oldScore != null && score != oldScore; // If maxScore is different from 0, the score is divided by maxscore scoreDoc.score = score / (maxScore != 0? maxScore : 1); // If the document has been rescored, the score is increased by 1. // This results in having all the rescored element scores in (1,2] range. if (rescored) { scoreDoc.score += 1; } if(scoreDoc.score > newMax) { newMax = scoreDoc.score; } } assert(newMax <= 2); topDocs.setMaxScore(newMax); } private Map<Integer, Float> getScoreMap(ScoreDoc[] scoreDocs, int num) { Map<Integer, Float> scoreMap = new HashMap(); for(int i=0; i<num; i++) { ScoreDoc doc = scoreDocs[i]; scoreMap.put(doc.doc, doc.score); } return scoreMap; } public class BoostedComp implements Comparator { IntFloatHashMap boostedMap; public BoostedComp(IntIntHashMap boostedDocs, ScoreDoc[] scoreDocs, float maxScore) { this.boostedMap = new IntFloatHashMap(boostedDocs.size()*2); for(int i=0; i<scoreDocs.length; i++) { if(boostedDocs.containsKey(scoreDocs[i].doc)) { boostedMap.put(scoreDocs[i].doc, maxScore+boostedDocs.get(scoreDocs[i].doc)); } else { break; } } } public int compare(Object o1, Object o2) { ScoreDoc doc1 = (ScoreDoc) o1; ScoreDoc doc2 = (ScoreDoc) o2; float score1 = doc1.score; float score2 = doc2.score; if(boostedMap.containsKey(doc1.doc)) { score1 = boostedMap.get(doc1.doc); } if(boostedMap.containsKey(doc2.doc)) { score2 = boostedMap.get(doc2.doc); } if(score1 > score2) { return -1; } else if(score1 < score2) { return 1; } else { return 0; } } } }