package com.jstarcraft.rns.model.collaborative.rating; import com.jstarcraft.ai.data.DataModule; import com.jstarcraft.ai.data.DataSpace; import com.jstarcraft.ai.math.MathUtility; import com.jstarcraft.ai.math.structure.DefaultScalar; import com.jstarcraft.ai.math.structure.MathCalculator; import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; import com.jstarcraft.ai.math.structure.vector.ArrayVector; import com.jstarcraft.ai.math.structure.vector.DenseVector; import com.jstarcraft.ai.math.structure.vector.SparseVector; import com.jstarcraft.core.common.configuration.Configurator; import com.jstarcraft.core.utility.RandomUtility; import com.jstarcraft.rns.model.MatrixFactorizationModel; /** * * NMF推荐器 * * <pre> * Algorithms for Non-negative Matrix Factorization * 参考LibRec团队 * </pre> * * @author Birdy * */ public class NMFModel extends MatrixFactorizationModel { @Override public void prepare(Configurator configuration, DataModule model, DataSpace space) { super.prepare(configuration, model, space); userFactors = DenseMatrix.valueOf(userSize, factorSize); userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { scalar.setValue(RandomUtility.randomFloat(0.01F)); }); itemFactors = DenseMatrix.valueOf(itemSize, factorSize); itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { scalar.setValue(RandomUtility.randomFloat(0.01F)); }); } @Override protected void doPractice() { DefaultScalar scalar = DefaultScalar.getInstance(); for (int epocheIndex = 0; epocheIndex < epocheSize; ++epocheIndex) { // update userFactors by fixing itemFactors for (int userIndex = 0; userIndex < userSize; userIndex++) { SparseVector userVector = scoreMatrix.getRowVector(userIndex); if (userVector.getElementSize() == 0) { continue; } int user = userIndex; ArrayVector predictVector = new ArrayVector(userVector); predictVector.iterateElement(MathCalculator.SERIAL, (element) -> { element.setValue(predict(user, element.getIndex())); }); for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { DenseVector factorVector = itemFactors.getColumnVector(factorIndex); float score = scalar.dotProduct(factorVector, userVector).getValue(); float predict = scalar.dotProduct(factorVector, predictVector).getValue() + MathUtility.EPSILON; userFactors.setValue(userIndex, factorIndex, userFactors.getValue(userIndex, factorIndex) * (score / predict)); } } // update itemFactors by fixing userFactors for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); if (itemVector.getElementSize() == 0) { continue; } int item = itemIndex; ArrayVector predictVector = new ArrayVector(itemVector); predictVector.iterateElement(MathCalculator.SERIAL, (element) -> { element.setValue(predict(element.getIndex(), item)); }); for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { DenseVector factorVector = userFactors.getColumnVector(factorIndex); float score = scalar.dotProduct(factorVector, itemVector).getValue(); float predict = scalar.dotProduct(factorVector, predictVector).getValue() + MathUtility.EPSILON; itemFactors.setValue(itemIndex, factorIndex, itemFactors.getValue(itemIndex, factorIndex) * (score / predict)); } } // compute errors totalError = 0F; for (MatrixScalar term : scoreMatrix) { int userIndex = term.getRow(); int itemIndex = term.getColumn(); float score = term.getValue(); if (score > 0) { float error = predict(userIndex, itemIndex) - score; totalError += error * error; } } totalError *= 0.5F; if (isConverged(epocheIndex) && isConverged) { break; } currentError = totalError; } } }