package com.jstarcraft.rns.task; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import com.jstarcraft.ai.data.DataInstance; import com.jstarcraft.ai.data.module.ArrayInstance; import com.jstarcraft.ai.data.module.ReferenceModule; import com.jstarcraft.ai.evaluate.Evaluator; import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; import com.jstarcraft.core.common.configuration.Configurator; import com.jstarcraft.core.utility.Integer2FloatKeyValue; import com.jstarcraft.rns.model.Model; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; /** * 排序任务 * * @author Birdy * */ public class RankingTask extends AbstractTask<IntSet, IntList> { public RankingTask(Model recommender, Configurator configuration) { super(recommender, configuration); } public RankingTask(Class<? extends Model> clazz, Configurator configuration) { super(clazz, configuration); } @Override protected Collection<Evaluator> getEvaluators(SparseMatrix featureMatrix) { Collection<Evaluator> evaluators = new LinkedList<>(); int size = configurator.getInteger("recommender.recommender.ranking.topn", 10); evaluators.add(new AUCEvaluator(size)); evaluators.add(new MAPEvaluator(size)); evaluators.add(new MRREvaluator(size)); evaluators.add(new NDCGEvaluator(size)); evaluators.add(new NoveltyEvaluator(size, featureMatrix)); evaluators.add(new PrecisionEvaluator(size)); evaluators.add(new RecallEvaluator(size)); return evaluators; } @Override protected IntSet check(int userIndex) { ReferenceModule testModule = testModules[userIndex]; IntSet itemSet = new IntOpenHashSet(); for (DataInstance instance : testModule) { itemSet.add(instance.getQualityFeature(itemDimension)); } return itemSet; } @Override protected IntList recommend(Model recommender, int userIndex) { ReferenceModule trainModule = trainModules[userIndex]; ReferenceModule testModule = testModules[userIndex]; IntSet itemSet = new IntOpenHashSet(); for (DataInstance instance : trainModule) { itemSet.add(instance.getQualityFeature(itemDimension)); } // TODO 此处代码需要重构 ArrayInstance copy = new ArrayInstance(trainMarker.getQualityOrder(), trainMarker.getQuantityOrder()); copy.copyInstance(testModule.getInstance(0)); copy.setQualityFeature(userDimension, userIndex); List<Integer2FloatKeyValue> rankList = new ArrayList<>(itemSize - itemSet.size()); for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { if (itemSet.contains(itemIndex)) { continue; } copy.setQualityFeature(itemDimension, itemIndex); recommender.predict(copy); rankList.add(new Integer2FloatKeyValue(itemIndex, copy.getQuantityMark())); } Collections.sort(rankList, (left, right) -> { return Float.compare(right.getValue(), left.getValue()); }); IntList recommendList = new IntArrayList(rankList.size()); for (Integer2FloatKeyValue keyValue : rankList) { recommendList.add(keyValue.getKey()); } return recommendList; } }