/*
 * Copyright (C) 2016 RankSys http://ranksys.org
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */
package org.ranksys.javafm.learner.gd;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap.Entry;
import static java.lang.Math.log;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.ranksys.javafm.FM;
import org.ranksys.javafm.FMInstance;
import org.ranksys.javafm.learner.FMLearner;
import org.ranksys.javafm.data.ListWiseFMData;

/**
 *
 * @author Saúl Vargas ([email protected])
 */
public class ListRank implements FMLearner<ListWiseFMData> {

    private static final Logger LOG = Logger.getLogger(ListRank.class.getName());

    private final double learnRate;
    private final int numIter;
    private final double regB;
    private final double[] regW;
    private final double[] regM;

    public ListRank(double learnRate, int numIter, double regB, double[] regW, double[] regM) {
        this.learnRate = learnRate;
        this.numIter = numIter;
        this.regB = regB;
        this.regW = regW;
        this.regM = regM;
    }

    private double[] getP(List<? extends FMInstance> group) {
        double[] p = group.stream()
                .mapToDouble(FMInstance::getTarget)
                .map(Math::exp)
                .toArray();

        double pNorm = DoubleStream.of(p).sum();
        for (int i = 0; i < p.length; i++) {
            p[i] /= pNorm;
        }

        return p;
    }

    private double[] getQ(FM fm, List<? extends FMInstance> group) {
        double[] q = group.stream()
                .mapToDouble(fm::predict)
                .map(Math::exp)
                .toArray();

        double qNorm = DoubleStream.of(q).sum();
        for (int i = 0; i < q.length; i++) {
            q[i] /= qNorm;
        }

        return q;
    }

    @Override
    public double error(FM fm, ListWiseFMData test) {
        return test.streamByGroup().map(Entry::getValue)
                .mapToDouble((List<? extends FMInstance> group) -> {
                    double[] p = getP(group);
                    double[] q = getQ(fm, group);

                    return IntStream.range(0, group.size())
                            .mapToDouble(i -> -p[i] * log(q[i]))
                            .sum();
                })
                .average().getAsDouble();
    }

    @Override
    public void learn(FM fm, ListWiseFMData train, ListWiseFMData test) {
        LOG.fine(() -> String.format("iteration n = %3d e = %.6f e = %.6f", 0, error(fm, train), error(fm, test)));

        for (int t = 1; t <= numIter; t++) {
            long time0 = System.nanoTime();

            train.shuffle();

            train.streamByGroup().map(Entry::getValue).forEach(group -> {
                double b = fm.getB();
                double[] w = fm.getW();
                double[][] m = fm.getM();

                double[] p = getP(group);
                double[] q = getQ(fm, group);

                for (int k = 0; k < group.size(); k++) {
                    FMInstance x = group.get(k);

                    double lambda = -p[k] + q[k];

                    fm.setB(b - learnRate * (lambda + regB * b));

                    double[] xm = new double[m[0].length];
                    x.consume((i, xi) -> {
                        for (int j = 0; j < xm.length; j++) {
                            xm[j] += xi * m[i][j];
                        }

                        w[i] -= learnRate * (lambda * xi + regW[i] * w[i]);
                    });

                    x.consume((i, xi) -> {
                        for (int j = 0; j < m[i].length; j++) {
                            m[i][j] -= learnRate * (lambda * xi * xm[j]
                                    - lambda * xi * xi * m[i][j]
                                    + regM[i] * m[i][j]);
                        }
                    });

                }
            });

            int iter = t;
            long time1 = System.nanoTime() - time0;

            LOG.info(String.format("iteration n = %3d t = %.2fs", iter, time1 / 1_000_000_000.0));
            LOG.fine(() -> String.format("iteration n = %3d e = %.6f e = %.6f", iter, error(fm, train), error(fm, test)));
        }
    }

}