package de.mpg.mpi_inf.ambiversenlu.nlu.entitylinking.uima.custom.entitysalience; import de.mpg.mpi_inf.ambiversenlu.nlu.entitylinking.uima.type.SalientEntity; import de.mpg.mpi_inf.ambiversenlu.nlu.entitysalience.featureextraction.extractor.FeatureExtractor; import de.mpg.mpi_inf.ambiversenlu.nlu.entitysalience.featureextraction.extractor.NYTEntitySalienceFeatureExtractor; import de.mpg.mpi_inf.ambiversenlu.nlu.entitysalience.featureextraction.featureset.FeatureSetFactory; import de.mpg.mpi_inf.ambiversenlu.nlu.entitysalience.featureextraction.util.EntityInstance; import de.mpg.mpi_inf.ambiversenlu.nlu.entitysalience.featureextraction.util.FeatureValueInstanceUtils; import de.mpg.mpi_inf.ambiversenlu.nlu.entitysalience.settings.TrainingSettings; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.classification.RandomForestClassificationModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.SQLContext; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.component.JCasAnnotator_ImplBase; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; /** * Entity Salience Analysis engine that loads a already trained data model from a file in libsvm format. * It makes prediction for a single document and writes the predictions back to jCas. * <p> * See example from here: * https://github.com/apache/spark/blob/v1.6.3/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java */ public class EntitySalienceSpark extends JCasAnnotator_ImplBase { private Logger logger = LoggerFactory.getLogger(EntitySalienceSpark.class); protected JavaSparkContext jsc; protected SQLContext sqlContext; private PipelineModel trainingModel; public static final String PARAM_MODEL_PATH = "modelPath"; @ConfigurationParameter( name = "modelPath", mandatory = true ) private String modelPath; @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); synchronized (EntitySalienceSpark.class) { SparkConf conf = new SparkConf() .setAppName("EntitySalienceTagger") .set("spark.driver.allowMultipleContexts","true") .setMaster("local"); jsc = new JavaSparkContext(conf); //Load the training model //trainingModel = PipelineModel.load(modelPath); trainingModel = (PipelineModel) jsc.objectFile(modelPath).first(); jsc.close(); jsc.stop(); } } @Override public void process(JCas jCas) throws AnalysisEngineProcessException { long startTime = System.currentTimeMillis(); FeatureExtractor fe = new NYTEntitySalienceFeatureExtractor(); List<EntityInstance> entityInstances; try { entityInstances = fe.getEntityInstances(jCas, TrainingSettings.FeatureExtractor.ENTITY_SALIENCE); final int featureVectorSize = FeatureSetFactory.createFeatureSet(TrainingSettings.FeatureExtractor.ENTITY_SALIENCE).getFeatureVectorSize(); //TODO: For each model create separate implementation. RandomForestClassificationModel rfm = (RandomForestClassificationModel)trainingModel.stages()[2]; for(EntityInstance ei : entityInstances) { Vector vei = FeatureValueInstanceUtils.convertToSparkMLVector(ei, featureVectorSize); double label = rfm.predict(vei); Vector probabilities = rfm.predictProbability(vei); double salience = probabilities.toArray()[1]; SalientEntity salientEntity = new SalientEntity(jCas, 0, 0); salientEntity.setLabel(label); salientEntity.setID(ei.getEntityId()); salientEntity.setSalience(salience); salientEntity.addToIndexes(); } long endTime = System.currentTimeMillis() - startTime; logger.debug("Annotating salient entities finished in {}ms.", endTime); } catch (Exception e) { throw new AnalysisEngineProcessException(e); } } @Override public void destroy() { synchronized (EntitySalienceSpark.class) { jsc.stop(); } } }