package org.broadinstitute.hellbender.utils;

import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math3.stat.descriptive.rank.Median;

import java.util.stream.IntStream;

/**
 * Static class for implementing some matrix summary stats that are not in Apache, Spark, etc
 *
 */
public class MatrixSummaryUtils {

    private MatrixSummaryUtils() {}

    /**
     * Return an array containing the median for each column in the given matrix.
     * @param m Not {@code null}.  Size MxN, where neither dimension is zero.  If any entry is NaN, it is disregarded
     *          in the calculation.
     * @return array of size N.  Never {@code null}
     */
    public static double[] getColumnMedians(final RealMatrix m) {
        Utils.nonNull(m, "Cannot calculate medians on a null matrix.");
        final Median medianCalculator = new Median();
        return IntStream.range(0, m.getColumnDimension()).boxed()
                .mapToDouble(i -> medianCalculator.evaluate(m.getColumn(i))).toArray();
    }

    /**
     * Return an array containing the median for each row in the given matrix.
     * @param m Not {@code null}.  Size MxN.    If any entry is NaN, it is disregarded
     *          in the calculation.
     * @return array of size M.  Never {@code null}
     */
    public static double[] getRowMedians(final RealMatrix m) {
        Utils.nonNull(m, "Cannot calculate medians on a null matrix.");
        final Median medianCalculator = new Median();
        return IntStream.range(0, m.getRowDimension()).boxed()
                .mapToDouble(i -> medianCalculator.evaluate(m.getRow(i))).toArray();
    }

    /**
     * Return an array containing the variance for each row in the given matrix.
     * @param m Not {@code null}.  Size MxN.    If any entry is NaN, the corresponding rows will have a
     *          variance of NaN.
     * @return array of size M.  Never {@code null}  IF there is only one column (or only one entry
     */
    public static double[] getRowVariances(final RealMatrix m) {
        Utils.nonNull(m, "Cannot calculate medians on a null matrix.");
        final StandardDeviation std = new StandardDeviation();
        return IntStream.range(0, m.getRowDimension()).boxed()
                .mapToDouble(i -> Math.pow(std.evaluate(m.getRow(i)), 2)).toArray();
    }
}