package com.jstarcraft.rns.model.context.ranking;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import com.jstarcraft.ai.data.DataInstance;
import com.jstarcraft.ai.data.DataModule;
import com.jstarcraft.ai.data.DataSpace;
import com.jstarcraft.ai.math.structure.DefaultScalar;
import com.jstarcraft.ai.math.structure.MathCalculator;
import com.jstarcraft.ai.math.structure.vector.DenseVector;
import com.jstarcraft.ai.math.structure.vector.SparseVector;
import com.jstarcraft.ai.math.structure.vector.VectorScalar;
import com.jstarcraft.core.common.configuration.Configurator;
import com.jstarcraft.core.utility.RandomUtility;
import com.jstarcraft.rns.model.SocialModel;
import com.jstarcraft.rns.utility.LogisticUtility;

import it.unimi.dsi.fastutil.ints.IntSet;

/**
 * 
 * SBPR推荐器
 * 
 * <pre>
 * Social Bayesian Personalized Ranking (SBPR)
 * Leveraging Social Connections to Improve Personalized Ranking for Collaborative Filtering
 * 参考LibRec团队
 * </pre>
 * 
 * @author Birdy
 *
 */
// TODO 仍需重构
public class SBPRModel extends SocialModel {
    /**
     * items biases vector
     */
    private DenseVector itemBiases;

    /**
     * bias regularization
     */
    protected float regBias;

    /**
     * find items rated by trusted neighbors only
     */
    // TODO 考虑重构为List<IntSet>
    private List<List<Integer>> socialItemList;

    private List<IntSet> userItemSet;

    @Override
    public void prepare(Configurator configuration, DataModule model, DataSpace space) {
        super.prepare(configuration, model, space);
        regBias = configuration.getFloat("recommender.bias.regularization", 0.01F);
        // cacheSpec = conf.get("guava.cache.spec",
        // "maximumSize=5000,expireAfterAccess=50m");

        itemBiases = DenseVector.valueOf(itemSize);
        itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(RandomUtility.randomFloat(1F));
        });

        userItemSet = getUserItemSet(scoreMatrix);

        // TODO 考虑重构
        // find items rated by trusted neighbors only
        socialItemList = new ArrayList<>(userSize);

        for (int userIndex = 0; userIndex < userSize; userIndex++) {
            SparseVector userVector = scoreMatrix.getRowVector(userIndex);
            IntSet itemSet = userItemSet.get(userIndex);
            // find items rated by trusted neighbors only

            SparseVector socialVector = socialMatrix.getRowVector(userIndex);
            List<Integer> socialList = new LinkedList<>();
            for (VectorScalar term : socialVector) {
                int socialIndex = term.getIndex();
                userVector = scoreMatrix.getRowVector(socialIndex);
                for (VectorScalar enrty : userVector) {
                    int itemIndex = enrty.getIndex();
                    // v's rated items
                    if (!itemSet.contains(itemIndex) && !socialList.contains(itemIndex)) {
                        socialList.add(itemIndex);
                    }
                }
            }
            socialItemList.add(new ArrayList<>(socialList));
        }
    }

    @Override
    protected void doPractice() {
        for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) {
            totalError = 0F;
            for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) {
                // uniformly draw (userIdx, posItemIdx, k, negItemIdx)
                int userIndex, positiveItemIndex, negativeItemIndex;
                // userIdx
                SparseVector userVector;
                do {
                    userIndex = RandomUtility.randomInteger(userSize);
                    userVector = scoreMatrix.getRowVector(userIndex);
                } while (userVector.getElementSize() == 0);

                // positive item index
                positiveItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize()));
                float positiveScore = predict(userIndex, positiveItemIndex);

                // social Items List
                // TODO 应该修改为IntSet合适点.
                List<Integer> socialList = socialItemList.get(userIndex);
                IntSet itemSet = userItemSet.get(userIndex);
                do {
                    negativeItemIndex = RandomUtility.randomInteger(itemSize);
                } while (itemSet.contains(negativeItemIndex) || socialList.contains(negativeItemIndex));
                float negativeScore = predict(userIndex, negativeItemIndex);

                if (socialList.size() > 0) {
                    // if having social neighbors
                    int itemIndex = socialList.get(RandomUtility.randomInteger(socialList.size()));
                    float socialScore = predict(userIndex, itemIndex);
                    SparseVector socialVector = socialMatrix.getRowVector(userIndex);
                    float socialWeight = 0F;
                    for (VectorScalar term : socialVector) {
                        int socialIndex = term.getIndex();
                        itemSet = userItemSet.get(socialIndex);
                        if (itemSet.contains(itemIndex)) {
                            socialWeight += 1;
                        }
                    }
                    float positiveError = (positiveScore - socialScore) / (1 + socialWeight);
                    float negativeError = socialScore - negativeScore;
                    float positiveGradient = LogisticUtility.getValue(-positiveError), negativeGradient = LogisticUtility.getValue(-negativeError);
                    float error = (float) (-Math.log(1 - positiveGradient) - Math.log(1 - negativeGradient));
                    totalError += error;

                    // update bi, bk, bj
                    float positiveBias = itemBiases.getValue(positiveItemIndex);
                    itemBiases.shiftValue(positiveItemIndex, learnRatio * (positiveGradient / (1F + socialWeight) - regBias * positiveBias));
                    totalError += regBias * positiveBias * positiveBias;
                    float socialBias = itemBiases.getValue(itemIndex);
                    itemBiases.shiftValue(itemIndex, learnRatio * (-positiveGradient / (1F + socialWeight) + negativeGradient - regBias * socialBias));
                    totalError += regBias * socialBias * socialBias;
                    float negativeBias = itemBiases.getValue(negativeItemIndex);
                    itemBiases.shiftValue(negativeItemIndex, learnRatio * (-negativeGradient - regBias * negativeBias));
                    totalError += regBias * negativeBias * negativeBias;

                    // update P, Q
                    for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) {
                        float userFactor = userFactors.getValue(userIndex, factorIndex);
                        float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex);
                        float itemFactor = itemFactors.getValue(itemIndex, factorIndex);
                        float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex);
                        float delta = positiveGradient * (positiveFactor - itemFactor) / (1F + socialWeight) + negativeGradient * (itemFactor - negativeFactor);
                        userFactors.shiftValue(userIndex, factorIndex, learnRatio * (delta - userRegularization * userFactor));
                        itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (positiveGradient * userFactor / (1F + socialWeight) - itemRegularization * positiveFactor));
                        itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (negativeGradient * (-userFactor) - itemRegularization * negativeFactor));
                        delta = positiveGradient * (-userFactor / (1F + socialWeight)) + negativeGradient * userFactor;
                        itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (delta - itemRegularization * itemFactor));
                        totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negativeFactor * negativeFactor + itemRegularization * itemFactor * itemFactor;
                    }
                } else {
                    // if no social neighbors, the same as BPR
                    float error = positiveScore - negativeScore;
                    totalError += error;
                    float gradient = LogisticUtility.getValue(-error);

                    // update bi, bj
                    float positiveBias = itemBiases.getValue(positiveItemIndex);
                    itemBiases.shiftValue(positiveItemIndex, learnRatio * (gradient - regBias * positiveBias));
                    totalError += regBias * positiveBias * positiveBias;
                    float negativeBias = itemBiases.getValue(negativeItemIndex);
                    itemBiases.shiftValue(negativeItemIndex, learnRatio * (-gradient - regBias * negativeBias));
                    totalError += regBias * negativeBias * negativeBias;

                    // update user factors, item factors
                    for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) {
                        float userFactor = userFactors.getValue(userIndex, factorIndex);
                        float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex);
                        float negItemFactorValue = itemFactors.getValue(negativeItemIndex, factorIndex);
                        userFactors.shiftValue(userIndex, factorIndex, learnRatio * (gradient * (positiveFactor - negItemFactorValue) - userRegularization * userFactor));
                        itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (gradient * userFactor - itemRegularization * positiveFactor));
                        itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (gradient * (-userFactor) - itemRegularization * negItemFactorValue));
                        totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negItemFactorValue * negItemFactorValue;
                    }
                }
            }

            if (isConverged(epocheIndex) && isConverged) {
                break;
            }
            isLearned(epocheIndex);
            currentError = totalError;
        }
    }

    @Override
    protected float predict(int userIndex, int itemIndex) {
        DefaultScalar scalar = DefaultScalar.getInstance();
        DenseVector userVector = userFactors.getRowVector(userIndex);
        DenseVector itemVector = itemFactors.getRowVector(itemIndex);
        return itemBiases.getValue(itemIndex) + scalar.dotProduct(userVector, itemVector).getValue();
    }

    @Override
    public void predict(DataInstance instance) {
        int userIndex = instance.getQualityFeature(userDimension);
        int itemIndex = instance.getQualityFeature(itemDimension);
        instance.setQuantityMark(predict(userIndex, itemIndex));
    }

}