/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package umcg.genetica.math.stats.concurrent;

import cern.colt.function.tdouble.DoubleProcedure;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.tdouble.impl.DenseLargeDoubleMatrix2D;

import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.IntStream;

import umcg.genetica.console.ProgressBar;
import umcg.genetica.containers.Pair;
import umcg.genetica.math.matrix.SymmetricFloatDistanceMatrix;
import umcg.genetica.math.matrix2.DoubleMatrixDataset;
import umcg.genetica.math.stats.Correlation;
import umcg.genetica.math.stats.Descriptives;

/**
 * @author harmjan
 */
public class ConcurrentCorrelation {

	private int nrThreads = Runtime.getRuntime().availableProcessors();

	public ConcurrentCorrelation() {
	}

	public ConcurrentCorrelation(int nrProcs) {
		nrThreads = nrProcs;
	}

	public DoubleMatrixDataset<String, String> pairwiseCorrelation(DoubleMatrixDataset<String, String> in) throws Exception {

		DoubleMatrixDataset<String, String> output = new DoubleMatrixDataset<>(in.rows(), in.rows());
		output.setRowObjects(in.getRowObjects());
		output.setColObjects(in.getRowObjects());
		output.getMatrix().assign(Double.NaN);

		ProgressBar pb = new ProgressBar(in.rows(), "Calculating correlation matrix");

		double[][] data = in.getMatrix().toArray();
		DoubleMatrix2D matrix = output.getMatrix();

		IntStream.range(0, in.rows()).parallel().forEach(row -> {
			double[] xarr = data[row];

			double[] correl = new double[data.length];
			correl[row] = 1d;
			for (int j = row + 1; j < in.rows(); j++) {
				double[] yarr = data[j];
				double r = Correlation.correlate(xarr, yarr);
				correl[j] = r;
			}
			matrix.viewRow(row).assign(correl);
			pb.iterateSynched();
		});
		pb.close();

		for (int row = 0; row < matrix.rows(); row++) {
			double[] rowdata = matrix.viewRow(row).toArray();
			for (int col = row; col < matrix.columns(); col++) {
				matrix.setQuick(col, row, rowdata[col]);

			}
		}

		return output;
	}

	public DoubleMatrix2D pairwiseCorrelationDoubleMatrix(double[][] in) {
		ExecutorService threadPool = Executors.newFixedThreadPool(nrThreads);
		CompletionService<Pair<Integer, double[]>> pool = new ExecutorCompletionService<Pair<Integer, double[]>>(threadPool);
		double meanOfSamples[] = new double[in.length];

		for (int i = 0; i < meanOfSamples.length; ++i) {
			meanOfSamples[i] = Descriptives.mean(in[i]);
		}

		for (int i = 0; i < in.length; i++) {
			ConcurrentCorrelationTask task = new ConcurrentCorrelationTask(in, meanOfSamples, i);
			pool.submit(task);
		}

		int returned = 0;

		DoubleMatrix2D correlationMatrix;
		if ((in.length * (long) in.length) > (Integer.MAX_VALUE - 2)) {
			correlationMatrix = new DenseLargeDoubleMatrix2D(in.length, in.length);
		} else {
			correlationMatrix = new DenseDoubleMatrix2D(in.length, in.length);
		}


		ProgressBar pb = new ProgressBar(in.length, "Calculation of correlation matrix: " + in.length + " x " + in.length);
		while (returned < in.length) {
			try {
				Pair<Integer, double[]> result = pool.take().get();
				if (result != null) {
					int rownr = result.getLeft(); //  < 0 when row is not to be included because of hashProbesToInclude.
					if (rownr >= 0) {
						double[] doubles = result.getRight();
						for (int i = 0; i < doubles.length; ++i) {
							correlationMatrix.setQuick(rownr, i, doubles[i]);
						}
					}
					result = null;
					returned++;
					pb.iterate();
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}

		for (int r = 1; r < correlationMatrix.rows(); r++) {
			for (int c = 0; c < r; c++) {
				correlationMatrix.setQuick(r, c, correlationMatrix.getQuick(c, r));
			}
		}

		threadPool.shutdown();
		pb.close();
		return correlationMatrix;
	}
}