package com.jstarcraft.rns.model.extend.ranking; import java.util.List; import com.jstarcraft.ai.data.DataModule; import com.jstarcraft.ai.data.DataSpace; import com.jstarcraft.ai.math.algorithm.correlation.MathCorrelation; import com.jstarcraft.ai.math.structure.matrix.SymmetryMatrix; import com.jstarcraft.ai.math.structure.vector.DenseVector; import com.jstarcraft.ai.math.structure.vector.SparseVector; import com.jstarcraft.ai.math.structure.vector.VectorScalar; import com.jstarcraft.core.common.configuration.Configurator; import com.jstarcraft.core.common.reflection.ReflectionUtility; import com.jstarcraft.core.utility.RandomUtility; import com.jstarcraft.rns.model.collaborative.ranking.RankSGDModel; import com.jstarcraft.rns.model.exception.ModelException; import com.jstarcraft.rns.utility.SampleUtility; import it.unimi.dsi.fastutil.ints.IntSet; /** * * PRankD推荐器 * * <pre> * Personalised ranking with diversity * 参考LibRec团队 * </pre> * * @author Birdy * */ public class PRankDModel extends RankSGDModel { /** * item importance */ private DenseVector itemWeights; /** * item correlations */ private SymmetryMatrix itemCorrelations; /** * similarity filter */ private float similarityFilter; /** * initialization * * @throws ModelException if error occurs */ @Override public void prepare(Configurator configuration, DataModule model, DataSpace space) { super.prepare(configuration, model, space); similarityFilter = configuration.getFloat("recommender.sim.filter", 4F); float denominator = 0F; itemWeights = DenseVector.valueOf(itemSize); for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { float numerator = scoreMatrix.getColumnScope(itemIndex); denominator = denominator < numerator ? numerator : denominator; itemWeights.setValue(itemIndex, numerator); } // compute item relative importance for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { itemWeights.setValue(itemIndex, itemWeights.getValue(itemIndex) / denominator); } // compute item correlations by cosine similarity // TODO 修改为配置枚举 try { Class<MathCorrelation> correlationClass = (Class<MathCorrelation>) Class.forName(configuration.getString("recommender.correlation.class")); MathCorrelation correlation = ReflectionUtility.getInstance(correlationClass); itemCorrelations = new SymmetryMatrix(scoreMatrix.getColumnSize()); correlation.calculateCoefficients(scoreMatrix, true, itemCorrelations::setValue); } catch (Exception exception) { throw new RuntimeException(exception); } } /** * train model * * @throws ModelException if error occurs */ @Override protected void doPractice() { List<IntSet> userItemSet = getUserItemSet(scoreMatrix); for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { totalError = 0F; // for each rated user-item (u,i) pair for (int userIndex = 0; userIndex < userSize; userIndex++) { SparseVector userVector = scoreMatrix.getRowVector(userIndex); if (userVector.getElementSize() == 0) { continue; } IntSet itemSet = userItemSet.get(userIndex); for (VectorScalar term : userVector) { // each rated item i int positiveItemIndex = term.getIndex(); float positiveScore = term.getValue(); int negativeItemIndex = -1; do { // draw an item j with probability proportional to // popularity negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1))); // ensure that it is unrated by user u } while (itemSet.contains(negativeItemIndex)); float negativeScore = 0F; // compute predictions float positivePredict = predict(userIndex, positiveItemIndex), negativePredict = predict(userIndex, negativeItemIndex); float distance = (float) Math.sqrt(1 - Math.tanh(itemCorrelations.getValue(positiveItemIndex, negativeItemIndex) * similarityFilter)); float itemWeight = itemWeights.getValue(negativeItemIndex); float error = itemWeight * (positivePredict - negativePredict - distance * (positiveScore - negativeScore)); totalError += error * error; // update vectors float learnFactor = learnRatio * error; for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { float userFactor = userFactors.getValue(userIndex, factorIndex); float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); userFactors.shiftValue(userIndex, factorIndex, -learnFactor * (positiveItemFactor - negativeItemFactor)); itemFactors.shiftValue(positiveItemIndex, factorIndex, -learnFactor * userFactor); itemFactors.shiftValue(negativeItemIndex, factorIndex, learnFactor * userFactor); } } } totalError *= 0.5F; if (isConverged(epocheIndex) && isConverged) { break; } isLearned(epocheIndex); currentError = totalError; } } }