/******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.nn.modelexport.solr.handler; import java.io.File; import java.io.InputStream; import java.io.IOException; import java.util.List; import java.util.Map; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; import org.apache.solr.client.solrj.io.stream.StreamContext; import org.apache.solr.client.solrj.io.stream.TupleStream; import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType; import org.apache.solr.client.solrj.io.stream.expr.Explanation; import org.apache.solr.client.solrj.io.stream.expr.Expressible; import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.core.SolrResourceLoader; import org.apache.solr.handler.SolrDefaultStreamFactory; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.core.util.ModelGuesser; import org.deeplearning4j.util.NetworkUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; /** * A <a href="https://lucene.apache.org/solr/7_7_1/solr-solrj/org/apache/solr/client/solrj/io/stream/TupleStream.html"> * org.apache.solr.client.solrj.io.stream.TupleStream</a> that uses a {@link Model} to compute output scores. * <a href="https://lucene.apache.org/solr/7_7_1/solr-solrj/org/apache/solr/client/solrj/io/Tuple.html">Tuple</a> * fields are the model inputs and the model output(s) are added to the returned tuple. * <p> * Illustrative configuration snippet: * <pre> <expressible name="emailModel" class="org.deeplearning4j.nn.modelexport.solr.handler.ModelTupleStream"/> </pre> * <p> * Illustrative expression snippet: * <pre> emailModel(search(myCollection, q="*:*", fl="id,fieldX,fieldY,fieldZ", sort="id asc", qt="/export"), serializedModelFileName="mySerializedModel", inputKeys="fieldX,fieldY,fieldZ", outputKeys="modelScoreField1,modelScoreField2") </pre> * <p> * Apache Solr Reference Guide: * <ul> * <li> <a href="https://lucene.apache.org/solr/guide/7_7/streaming-expressions.html">Streaming Expressions</a> * </ul> */ public class ModelTupleStream extends TupleStream implements Expressible { final private static String SERIALIZED_MODEL_FILE_NAME_PARAM = "serializedModelFileName"; final private static String INPUT_KEYS_PARAM = "inputKeys"; final private static String OUTPUT_KEYS_PARAM = "outputKeys"; final private TupleStream tupleStream; final private String serializedModelFileName; final private String inputKeysParam; final private String outputKeysParam; final private String[] inputKeys; final private String[] outputKeys; final private SolrResourceLoader solrResourceLoader; final private Model model; public ModelTupleStream(StreamExpression streamExpression, StreamFactory streamFactory) throws IOException { final List<StreamExpression> streamExpressions = streamFactory.getExpressionOperandsRepresentingTypes(streamExpression, Expressible.class, TupleStream.class); if (streamExpressions.size() == 1) { this.tupleStream = streamFactory.constructStream(streamExpressions.get(0)); } else { throw new IOException("Expected exactly one stream in expression: "+streamExpression); } this.serializedModelFileName = getOperandValue(streamExpression, streamFactory, SERIALIZED_MODEL_FILE_NAME_PARAM); this.inputKeysParam = getOperandValue(streamExpression, streamFactory, INPUT_KEYS_PARAM); this.inputKeys = inputKeysParam.split(","); this.outputKeysParam = getOperandValue(streamExpression, streamFactory, OUTPUT_KEYS_PARAM); this.outputKeys = outputKeysParam.split(","); if (!(streamFactory instanceof SolrDefaultStreamFactory)) { throw new IOException(this.getClass().getName()+" requires a "+SolrDefaultStreamFactory.class.getName()+" StreamFactory"); } this.solrResourceLoader = ((SolrDefaultStreamFactory)streamFactory).getSolrResourceLoader(); this.model = restoreModel(openInputStream()); } private static String getOperandValue(StreamExpression streamExpression, StreamFactory streamFactory, String operandName) throws IOException { final StreamExpressionNamedParameter namedParameter = streamFactory.getNamedOperand(streamExpression, operandName); String operandValue = null; if (namedParameter != null && namedParameter.getParameter() instanceof StreamExpressionValue) { operandValue = ((StreamExpressionValue)namedParameter.getParameter()).getValue(); } if (operandValue == null) { throw new IOException("Expected '"+operandName+"' in expression: "+streamExpression); } else { return operandValue; } } public Map toMap(Map<String, Object> map) { // We (ModelTupleStream) extend TupleStream which implements MapWriter which extends MapSerializable. // MapSerializable says to have a toMap method. // org.apache.solr.common.MapWriter has a toMap method which has 'default' visibility. // So MapWriter.toMap here is not 'visible' but it is 'callable' it seems. return super.toMap(map); } public void setStreamContext(StreamContext streamContext) { tupleStream.setStreamContext(streamContext); } public List<TupleStream> children() { return tupleStream.children(); } public void open() throws IOException { tupleStream.open(); } public void close() throws IOException { tupleStream.close(); } public Tuple read() throws IOException { final Tuple tuple = tupleStream.read(); if (tuple.EOF) { return tuple; } else { final INDArray inputs = getInputsFromTuple(tuple); final INDArray outputs = NetworkUtils.output(model, inputs); return applyOutputsToTuple(tuple, outputs); } } public StreamComparator getStreamSort() { return tupleStream.getStreamSort(); } public Explanation toExplanation(StreamFactory streamFactory) throws IOException { return new StreamExplanation(getStreamNodeId().toString()) .withChildren(new Explanation[]{ tupleStream.toExplanation(streamFactory) }) .withExpressionType(ExpressionType.STREAM_DECORATOR) .withFunctionName(streamFactory.getFunctionName(this.getClass())) .withImplementingClass(this.getClass().getName()) .withExpression(toExpression(streamFactory, false).toString()); } public StreamExpressionParameter toExpression(StreamFactory streamFactory) throws IOException { return toExpression(streamFactory, true /* includeStreams */); } private StreamExpression toExpression(StreamFactory streamFactory, boolean includeStreams) throws IOException { final String functionName = streamFactory.getFunctionName(this.getClass()); final StreamExpression streamExpression = new StreamExpression(functionName); if (includeStreams) { if (this.tupleStream instanceof Expressible) { streamExpression.addParameter(((Expressible)this.tupleStream).toExpression(streamFactory)); } else { throw new IOException("This "+this.getClass().getName()+" contains a non-Expressible TupleStream "+this.tupleStream.getClass().getName()); } } else { streamExpression.addParameter("<stream>"); } streamExpression.addParameter(new StreamExpressionNamedParameter(SERIALIZED_MODEL_FILE_NAME_PARAM, this.serializedModelFileName)); streamExpression.addParameter(new StreamExpressionNamedParameter(INPUT_KEYS_PARAM, this.inputKeysParam)); streamExpression.addParameter(new StreamExpressionNamedParameter(OUTPUT_KEYS_PARAM, this.outputKeysParam)); return streamExpression; } protected InputStream openInputStream() throws IOException { return solrResourceLoader.openResource(serializedModelFileName); } /** * Uses the {@link ModelGuesser#loadModelGuess(InputStream)} method. */ protected Model restoreModel(InputStream inputStream) throws IOException { final File instanceDir = solrResourceLoader.getInstancePath().toFile(); try { return ModelGuesser.loadModelGuess(inputStream, instanceDir); } catch (Exception e) { throw new IOException("Failed to restore model from given file (" + serializedModelFileName + ")", e); } } protected INDArray getInputsFromTuple(Tuple tuple) { final double[] inputs = new double[inputKeys.length]; for (int ii=0; ii<inputKeys.length; ++ii) { inputs[ii] = tuple.getDouble(inputKeys[ii]).doubleValue(); } return Nd4j.create(new double[][]{ inputs }); } protected Tuple applyOutputsToTuple(Tuple tuple, INDArray output) { for (int ii=0; ii<outputKeys.length; ++ii) { tuple.put(outputKeys[ii], output.getFloat(ii)); } return tuple; } }