package org.apache.solr.search.xjoin; /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import java.util.Iterator; import org.apache.commons.collections.IteratorUtils; import org.apache.commons.collections.Transformer; import org.apache.commons.collections.iterators.TransformIterator; import org.apache.lucene.index.Term; import org.apache.lucene.queries.TermsFilter; import org.apache.lucene.search.AutomatonQuery; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldCacheTermsFilter; import org.apache.lucene.search.Filter; import org.apache.lucene.search.MultiTermQueryWrapperFilter; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryWrapperFilter; import org.apache.lucene.search.TermQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.Automata; import org.apache.lucene.util.automaton.Automaton; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.schema.FieldType; import org.apache.solr.search.QParser; import org.apache.solr.search.QParserPlugin; import org.apache.solr.search.QueryParsing; import org.apache.solr.search.SolrConstantScoreQuery; import org.apache.solr.search.SyntaxError; /** * QParserPlugin for extracting join ids from the results stored in XJoin search * components. */ public class XJoinQParserPlugin extends QParserPlugin { public static final String NAME = "xjoin"; /** For choosing the internal algorithm */ public static final String METHOD = "method"; @Override @SuppressWarnings("rawtypes") public void init(NamedList args) { // nothing to do } // this code is modified from TermsQParserPlugin private static enum Method { termsFilter { @Override @SuppressWarnings("unchecked") Filter makeFilter(String fname, Iterator<BytesRef> it) { return new TermsFilter(fname, IteratorUtils.toList(it)); } }, booleanQuery { @Override Filter makeFilter(String fname, Iterator<BytesRef> it) { BooleanQuery bq = new BooleanQuery(true); while (it.hasNext()) { bq.add(new TermQuery(new Term(fname, it.next())), BooleanClause.Occur.SHOULD); } return new QueryWrapperFilter(bq); } }, automaton { @Override @SuppressWarnings("unchecked") Filter makeFilter(String fname, Iterator<BytesRef> it) { Automaton union = Automata.makeStringUnion(IteratorUtils.toList(it)); return new MultiTermQueryWrapperFilter<AutomatonQuery>(new AutomatonQuery(new Term(fname), union)) { }; } }, docValuesTermsFilter {//on 4x this is FieldCacheTermsFilter but we use the 5x name any way //note: limited to one val per doc @Override Filter makeFilter(String fname, Iterator<BytesRef> it) { return new FieldCacheTermsFilter(fname, (BytesRef[])IteratorUtils.toArray(it, BytesRef.class)); } }; //abstract Filter makeFilter(String fname, BytesRef... byteRefs); abstract Filter makeFilter(String fname, Iterator<BytesRef> it); } // transformer from Object to BytesRef (using the given FieldType) static private Transformer transformer(final FieldType ft) { return new Transformer() { BytesRef term = new BytesRef(); @Override public BytesRef transform(Object joinId) { if (joinId == null) { throw new RuntimeException("joinId is null! (weird)"); } String joinStr = joinId.toString(); // logic same as TermQParserPlugin if (ft != null) { ft.readableToIndexed(joinStr, term); } else { term.copyChars(joinStr); } return BytesRef.deepCopyOf(term); } }; } /** * Like fq={!xjoin}xjoin_component_name OR xjoin_component_name2 */ @Override @SuppressWarnings("rawtypes") public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { return new XJoinQParser(qstr, localParams, params, req); } static class XJoinQParser<T extends Comparable<T>> extends QParser implements JoinSpec.Iterable { // record the join field when retrieving external results // must be the same for all external sources referenced in our query private String joinField; public XJoinQParser(String qstr, SolrParams localParams, SolrParams params,SolrQueryRequest req) { super(qstr, localParams, params, req); joinField = null; } @Override @SuppressWarnings("unchecked") public Query parse() throws SyntaxError { Method method = Method.valueOf(localParams.get(METHOD, Method.termsFilter.name())); JoinSpec<T> js = JoinSpec.parse(localParams.get(QueryParsing.V)); Iterator<T> it = js.iterator(this); if (joinField == null) { throw new Exception("No XJoin component referenced by query"); } FieldType ft = req.getSchema().getFieldTypeNoEx(joinField); Iterator<BytesRef> bytesRefs = new TransformIterator(it, transformer(ft)); if (! bytesRefs.hasNext()) { return new BooleanQuery(); // matches nothing } return new SolrConstantScoreQuery(method.makeFilter(joinField, bytesRefs)); } @Override @SuppressWarnings("unchecked") public Iterator<T> iterator(String componentName) { XJoinSearchComponent xJoin = (XJoinSearchComponent)req.getCore().getSearchComponent(componentName); if (joinField == null) { joinField = xJoin.getJoinField(); } else if (! xJoin.getJoinField().equals(joinField)) { throw new Exception("XJoin components used in the same query must have same join field"); } XJoinResults<T> results = (XJoinResults<T>)req.getContext().get(xJoin.getResultsTag()); if (results == null) { throw new Exception("No xjoin results in request context"); } return results.getJoinIds().iterator(); } } @SuppressWarnings("serial") static class Exception extends RuntimeException { public Exception(String message) { super(message); } } }