/* * Copyright 2018 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.inception.conceptlinking; import static de.tudarmstadt.ukp.clarin.webanno.api.casstorage.CasAccessMode.EXCLUSIVE_WRITE_ACCESS; import static de.tudarmstadt.ukp.inception.support.test.recommendation.RecommenderTestHelper.getPredictions; import static java.util.Arrays.asList; import static org.apache.uima.fit.factory.CollectionReaderFactory.createReader; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assumptions.assumeThat; import static org.dkpro.core.api.datasets.DatasetValidationPolicy.CONTINUE; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; import org.apache.uima.UIMAException; import org.apache.uima.cas.CAS; import org.apache.uima.collection.CollectionReader; import org.apache.uima.fit.factory.JCasFactory; import org.apache.uima.jcas.JCas; import org.dkpro.core.api.datasets.Dataset; import org.dkpro.core.api.datasets.DatasetFactory; import org.dkpro.core.io.conll.Conll2002Reader; import org.dkpro.core.testing.DkproTestContext; import org.junit.After; import org.junit.Before; import org.junit.Test; import de.tudarmstadt.ukp.clarin.webanno.api.annotation.feature.FeatureSupport; import de.tudarmstadt.ukp.clarin.webanno.api.annotation.feature.FeatureSupportRegistry; import de.tudarmstadt.ukp.clarin.webanno.api.dao.casstorage.CasStorageSession; import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationFeature; import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer; import de.tudarmstadt.ukp.dkpro.core.api.ner.type.NamedEntity; import de.tudarmstadt.ukp.inception.conceptlinking.recommender.NamedEntityLinker; import de.tudarmstadt.ukp.inception.conceptlinking.recommender.NamedEntityLinkerTraits; import de.tudarmstadt.ukp.inception.conceptlinking.service.ConceptLinkingServiceImpl; import de.tudarmstadt.ukp.inception.kb.ConceptFeatureTraits; import de.tudarmstadt.ukp.inception.kb.ConceptFeatureValueType; import de.tudarmstadt.ukp.inception.kb.IriConstants; import de.tudarmstadt.ukp.inception.kb.KnowledgeBaseService; import de.tudarmstadt.ukp.inception.kb.graph.KBHandle; import de.tudarmstadt.ukp.inception.kb.model.KnowledgeBase; import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender; import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext; import de.tudarmstadt.ukp.inception.support.test.recommendation.RecommenderTestHelper; public class NamedEntityLinkerTest { private static File cache = DkproTestContext.getCacheFolder(); private static DatasetFactory loader = new DatasetFactory(cache); private RecommenderContext context; private Recommender recommender; private CasStorageSession casStorageSession; @Before public void setUp() { casStorageSession = CasStorageSession.open(); context = new RecommenderContext(); recommender = buildRecommender(); } @After public void tearDown() { CasStorageSession.get().close(); } @Test public void thatTrainingWorks() throws Exception { NamedEntityLinker sut = new NamedEntityLinker(recommender, new NamedEntityLinkerTraits(), mock(KnowledgeBaseService.class), mock(ConceptLinkingServiceImpl.class), mock(FeatureSupportRegistry.class), new ConceptFeatureTraits()); List<CAS> casList = loadDevelopmentData(); sut.train(context, casList); assertThat(context.get(NamedEntityLinker.KEY_MODEL)) .as("Model has been set") .isNotNull(); } @Test public void thatPredictionWorks() throws Exception { List<KBHandle> mockResult = asList( new KBHandle("https://www.wikidata.org/wiki/Q76", "Barack Obama", "44th President of the United States of America"), new KBHandle("https://www.wikidata.org/wiki/Q26446735", "Obama", "Japanese Family Name"), new KBHandle("https://www.wikidata.org/wiki/Q18355807", "Obama", "genus of worms"), new KBHandle("https://www.wikidata.org/wiki/Q41773", "Obama", "city in Fukui prefecture, Japan")); KnowledgeBaseService kbService = mock(KnowledgeBaseService.class); KnowledgeBase kb = new KnowledgeBase(); kb.setFullTextSearchIri(IriConstants.FTS_VIRTUOSO); when(kbService.getKnowledgeBaseById(any(), anyString())).thenReturn(Optional.of(kb)); when(kbService.getEnabledKnowledgeBases(any())).thenReturn(Collections.singletonList(kb)); when(kbService.read(any(), any())).thenReturn(mockResult); ConceptLinkingServiceImpl clService = mock(ConceptLinkingServiceImpl.class); when(clService.disambiguate(any(), anyString(), any(ConceptFeatureValueType.class), anyString(), anyString(), anyInt(), any())).thenReturn(mockResult); FeatureSupportRegistry fsRegistry = mock(FeatureSupportRegistry.class); FeatureSupport fs = mock(FeatureSupport.class); when(fsRegistry.getFeatureSupport(recommender.getFeature())).thenReturn(fs); when(fs.readTraits(recommender.getFeature())).thenReturn(new ConceptFeatureTraits()); NamedEntityLinker sut = new NamedEntityLinker(recommender, new NamedEntityLinkerTraits(), kbService, clService, fsRegistry, new ConceptFeatureTraits()); List<CAS> casList = loadDevelopmentData(); CAS cas = casList.get(0); casStorageSession.add("cas", EXCLUSIVE_WRITE_ACCESS, cas); sut.train(context, Collections.singletonList(cas)); RecommenderTestHelper.addScoreFeature(cas, NamedEntity.class, "value"); sut.predict(context, cas); List<NamedEntity> predictions = getPredictions(cas, NamedEntity.class); assertThat(predictions).as("Predictions have been written to CAS") .isNotEmpty(); } private List<CAS> loadDevelopmentData() throws IOException, UIMAException { Dataset ds = null; try { ds = loader.load("germeval2014-de", CONTINUE); } catch (Exception e) { // Workaround for https://github.com/dkpro/dkpro-core/issues/1469 assumeThat(e).isNotInstanceOf(FileNotFoundException.class); throw e; } return loadData(ds, ds.getDefaultSplit().getDevelopmentFiles()); } private List<CAS> loadData(Dataset ds, File ... files) throws UIMAException, IOException { CollectionReader reader = createReader( Conll2002Reader.class, Conll2002Reader.PARAM_PATTERNS, files, Conll2002Reader.PARAM_LANGUAGE, ds.getLanguage(), Conll2002Reader.PARAM_COLUMN_SEPARATOR, Conll2002Reader.ColumnSeparators.TAB.getName(), Conll2002Reader.PARAM_HAS_TOKEN_NUMBER, true, Conll2002Reader.PARAM_HAS_HEADER, true, Conll2002Reader.PARAM_HAS_EMBEDDED_NAMED_ENTITY, true); List<CAS> casList = new ArrayList<>(); while (reader.hasNext()) { JCas cas = JCasFactory.createJCas(); reader.getNext(cas.getCas()); casList.add(cas.getCas()); } return casList; } private static Recommender buildRecommender() { AnnotationLayer layer = new AnnotationLayer(); layer.setName(NamedEntity.class.getName()); AnnotationFeature feature = new AnnotationFeature(); feature.setName("identifier"); Recommender recommender = new Recommender(); recommender.setLayer(layer); recommender.setFeature(feature); recommender.setMaxRecommendations(3); return recommender; } }