package common;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;

import com.google.common.util.concurrent.AtomicDoubleArray;

public class MLSparseMatrixFlat implements MLSparseMatrix {

	private static final long serialVersionUID = -7708714593085005498L;
	public static final int MISSING_ROW = -1;

	private int[] indexes;
	private float[] values;
	private int nCols;

	public MLSparseMatrixFlat(final int nRowsP, final int nColsP) {
		this.indexes = new int[nRowsP];
		Arrays.fill(this.indexes, MISSING_ROW);
		this.values = new float[nRowsP];
		this.nCols = nColsP;
	}

	public MLSparseMatrixFlat(final int[] indexesP, final float[] valuesP,
			final int nColsP) {
		this.indexes = indexesP;
		this.values = valuesP;
		this.nCols = nColsP;
	}

	@Override
	public void applyColNorm(final MLDenseVector colNorm) {
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			if (this.indexes[rowIndex] == MISSING_ROW) {
				return;
			}

			float norm = colNorm.getValue(this.indexes[rowIndex]);
			if (norm > 1e-10f) {
				this.values[rowIndex] /= norm;
			}
		});
	}

	@Override
	public void applyColSelector(final Map<Integer, Integer> selectedColMap,
			final int nColsSelected) {
		if (this.nCols == nColsSelected) {
			boolean noChanges = true;
			for (Map.Entry<Integer, Integer> entry : selectedColMap
					.entrySet()) {
				if (entry.getValue() != entry.getKey()) {
					noChanges = false;
					break;
				}
			}

			if (noChanges == true) {
				// nothing to do
				return;
			}
		}

		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			Integer index = this.indexes[rowIndex];
			if (index == MISSING_ROW) {
				return;
			}

			index = selectedColMap.get(index);
			if (index == null) {
				// not in the map so remove this row
				this.removeRow(rowIndex);

			} else {
				this.indexes[rowIndex] = index;
			}
		});

		this.setNCols(nColsSelected);
	}

	@Override
	public void applyRowNorm(final MLDenseVector rowNorm) {
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			if (this.indexes[rowIndex] == MISSING_ROW) {
				return;
			}

			float norm = rowNorm.getValue(rowIndex);
			if (norm > 1e-5f) {
				this.values[rowIndex] /= norm;
			}
		});

	}

	@Override
	public void binarizeValues() {
		Arrays.fill(this.values, 1f);
	}

	@Override
	public MLSparseMatrix deepCopy() {
		return new MLSparseMatrixFlat(this.indexes.clone(), this.values.clone(),
				this.nCols);
	}

	@Override
	public MLDenseVector getColNNZ() {
		float[] colNNZ = new float[this.getNCols()];
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			synchronized (colNNZ) {
				colNNZ[colIndex] += 1;
			}
		});
		return new MLDenseVector(colNNZ);
	}

	@Override
	public MLDenseVector getColNorm(final int p) {
		// compute L^p norm
		final float[] colNorm = new float[this.getNCols()];
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			synchronized (colNorm) {
				if (p == 1) {
					colNorm[colIndex] += Math.abs(this.values[rowIndex]);
				} else {
					colNorm[colIndex] += Math.pow(this.values[rowIndex], p);
				}
			}
		});

		if (p != 1) {
			for (int i = 0; i < this.getNCols(); i++) {
				// take p'th root
				colNorm[i] = (float) Math.pow(colNorm[i], 1.0 / p);
			}
		}
		return new MLDenseVector(colNorm);
	}

	@Override
	public MLDenseVector getColSum() {
		float[] colSum = new float[this.getNCols()];
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			synchronized (colSum) {
				colSum[colIndex] += this.values[rowIndex];
			}
		});
		return new MLDenseVector(colSum);
	}

	@Override
	public MLDenseVector getRowSum() {
		return new MLDenseVector(this.values);
	}

	@Override
	public int getNCols() {
		return this.nCols;
	}

	@Override
	public long getNNZ() {
		long nnz = 0;
		for (int i = 0; i < this.indexes.length; i++) {
			if (this.indexes[i] != MISSING_ROW) {
				nnz++;
			}
		}
		return nnz;
	}

	@Override
	public int getNRows() {
		return this.indexes.length;
	}

	@Override
	public MLSparseVector getRow(final int rowIndex) {
		int colIndex = this.indexes[rowIndex];
		if (colIndex == MISSING_ROW) {
			return null;
		}

		return new MLSparseVector(new int[] { colIndex },
				new float[] { this.values[rowIndex] }, null, this.nCols);
	}

	@Override
	public MLSparseVector getRow(final int rowIndex, boolean returnEmpty) {
		MLSparseVector row = this.getRow(rowIndex);
		if (row == null && returnEmpty == true) {
			// return empty row instead of null
			row = new MLSparseVector(new int[] {}, new float[] {}, null,
					this.getNCols());
		}
		return row;
	}

	@Override
	public MLDenseVector getRowNNZ() {
		float[] rowNNZ = new float[this.getNRows()];
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			rowNNZ[rowIndex] = 1;
		});
		return new MLDenseVector(rowNNZ);
	}

	@Override
	public MLDenseVector getRowNorm(final int p) {
		final float[] rowNorm = new float[this.getNRows()];
		System.arraycopy(values, 0, rowNorm, 0, rowNorm.length);
		return new MLDenseVector(rowNorm);
	}

	@Override
	public boolean hasDates() {
		return false;
	}

	@Override
	public void inferAndSetNCols() {
		// infer number of columns if it wasn't known during constructor
		int nColsNew = 0;

		for (int i = 0; i < this.indexes.length; i++) {
			int colIndex = this.indexes[i];
			if (colIndex == MISSING_ROW) {
				continue;
			}

			if (colIndex + 1 > nColsNew) {
				// nCols is 1 + largest col index
				nColsNew = colIndex + 1;
			}
		}

		this.setNCols(nColsNew);
	}

	@Override
	public MLSparseMatrix mult(final MLSparseMatrix another) {
		if (this.getNCols() != another.getNRows()) {
			throw new IllegalArgumentException(
					"this.getNCols() != another.getNRows()");
		}
		MLSparseVector[] resultRows = new MLSparseVector[this.getNRows()];
		IntStream.range(0, this.getNRows()).parallel().forEach(i -> {
			int index = this.indexes[i];
			if (index == MISSING_ROW) {
				return;
			}
			float value = this.values[i];
			float[] resultRow = new float[another.getNCols()];
			MLSparseVector rowAnother = another.getRow(index);
			if (rowAnother != null) {

				int[] indexesAnother = rowAnother.getIndexes();
				float[] valuesAnother = rowAnother.getValues();
				for (int k = 0; k < indexesAnother.length; k++) {
					resultRow[indexesAnother[k]] += value * valuesAnother[k];
				}
			}
			MLSparseVector resultRowSparse = MLSparseVector
					.fromDense(new MLDenseVector(resultRow));
			if (resultRowSparse.getIndexes() != null) {
				resultRows[i] = resultRowSparse;
			}
		});

		return new MLSparseMatrixAOO(resultRows, another.getNCols());
	}

	@Override
	public MLDenseVector multCol(final MLDenseVector vector) {

		// multiply 1 x nRows dense vector with this matrix
		if (this.getNRows() != vector.getLength()) {
			throw new IllegalArgumentException(
					"this.getNRows() != vector.getLength()");
		}

		AtomicDoubleArray result = new AtomicDoubleArray(this.nCols);
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			float val = vector.getValue(rowIndex);
			if (val == 0) {
				return;
			}

			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			float colValue = this.values[rowIndex];
			result.addAndGet(colIndex, val * colValue);
		});
		float[] temp = new float[this.nCols];
		for (int i = 0; i < temp.length; i++) {
			temp[i] = (float) result.get(i);
		}

		return new MLDenseVector(temp);
	}

	@Override
	public MLDenseVector multCol(final MLSparseVector vector) {

		// multiply 1 x nRows sparse vector with this matrix
		if (this.getNRows() != vector.getLength()) {
			throw new IllegalArgumentException(
					"this.getNRows() != vector.getLength()");
		}

		AtomicDoubleArray result = new AtomicDoubleArray(this.nCols);
		int[] vectorIndexes = vector.getIndexes();
		float[] vectorValues = vector.getValues();
		IntStream.range(0, vectorIndexes.length).parallel()
				.forEach(rowIndex -> {
					int ind = vectorIndexes[rowIndex];
					float val = vectorValues[rowIndex];

					int colIndex = this.indexes[ind];
					if (colIndex == MISSING_ROW) {
						return;
					}
					float colValue = this.values[ind];

					result.addAndGet(colIndex, val * colValue);
				});

		float[] temp = new float[this.nCols];
		for (int i = 0; i < temp.length; i++) {
			temp[i] = (float) result.get(i);
		}

		return new MLDenseVector(temp);
	}

	@Override
	public MLDenseVector multRow(final MLDenseVector vector) {

		// multiply this matrix with nCols x 1 dense vector
		if (this.getNCols() != vector.getLength()) {
			throw new IllegalArgumentException(
					"this.getNCols() != vector.getLength()");
		}

		float[] result = new float[this.getNRows()];
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			float colValue = this.values[rowIndex];

			result[rowIndex] = vector.getValue(colIndex) * colValue;
		});

		return new MLDenseVector(result);
	}

	@Override
	public MLDenseVector multRow(final MLSparseVector vector) {

		// multiply this matrix with nCols x 1 sparse vector
		if (this.getNCols() != vector.getLength()) {
			throw new IllegalArgumentException(
					"this.getNCols() != vector.getLength()");
		}

		int[] vecIndexes = vector.getIndexes();
		float[] vecValues = vector.getValues();
		float[] result = new float[this.getNRows()];
		IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> {
			int colIndex = this.indexes[rowIndex];
			if (colIndex == MISSING_ROW) {
				return;
			}
			float colValue = this.values[rowIndex];
			int matchIndex = Arrays.binarySearch(vecIndexes, colIndex);
			if (matchIndex >= 0) {
				result[rowIndex] = vecValues[matchIndex] * colValue;
			}
		});

		return new MLDenseVector(result);
	}

	private void removeRow(final int rowIndex) {
		this.indexes[rowIndex] = MISSING_ROW;
	}

	@Override
	public Map<Integer, Integer> selectCols(final int nnzCutOff) {
		Map<Integer, Integer> selectedColMap = new HashMap<Integer, Integer>(
				this.nCols);

		MLDenseVector colNNZ = this.getColNNZ();
		int newIndex = 0;
		for (int colIndex = 0; colIndex < this.nCols; colIndex++) {
			if (colNNZ.getValue(colIndex) > nnzCutOff) {
				selectedColMap.put(colIndex, newIndex);
				newIndex++;
			}
		}

		return selectedColMap;
	}

	@Override
	public void setNCols(final int nColsP) {
		this.nCols = nColsP;

	}

	public void setRow(final int index, final float value, final int rowIndex) {
		this.indexes[rowIndex] = index;
		this.values[rowIndex] = value;
	}

	@Override
	public void setRow(final MLSparseVector row, final int rowIndex) {
		if (row == null || row.getIndexes().length == 0) {
			this.removeRow(rowIndex);
			return;
		}

		int[] rowIndexes = row.getIndexes();
		if (rowIndexes.length != 1) {
			throw new IllegalArgumentException(
					"can't add row with != 1 element");
		}
		float[] rowValues = row.getValues();

		this.indexes[rowIndex] = rowIndexes[0];
		this.values[rowIndex] = rowValues[0];
	}

	@Override
	public void toBinFile(final String outFile) throws Exception {
		throw new UnsupportedOperationException("unsupported function");
	}

	@Override
	public MLSparseMatrix transpose() {
		/**
		 * convert to csr
		 */
		final int nnz = (int) this.getNNZ();
		final int nRows = this.getNRows();
		final int[] jaP = new int[nnz];
		final float[] aP = new float[nnz];
		for (int i = 0, inz = 0; i < nRows; i++) {
			int jaPi = this.indexes[i];
			if (jaPi != MISSING_ROW) {
				jaP[inz] = jaPi;
				aP[inz] = this.values[i];
				inz++;
			}
		}
		/**
		 * perform transpose
		 */
		final int nnzT = nnz;
		final int nRowsT = this.getNCols();
		final int nColsT = this.getNRows();
		final int[] rowIndexT = new int[nRowsT + 1];
		final int[] jaPT = new int[nnzT];
		final float[] aPT = new float[nnzT];

		// count nnz in each row
		for (int i = 0; i < nnzT; i++) {
			int jaPi = jaP[i];
			if (jaPi != MISSING_ROW) {
				rowIndexT[jaPi]++;
			}
		}

		// Fill starting point of the previous row to begin tally
		int r, j;
		rowIndexT[nRowsT] = nnzT - rowIndexT[nRowsT - 1];
		for (r = nRowsT - 1; r > 0; r--) {
			rowIndexT[r] = rowIndexT[r + 1] - rowIndexT[r - 1];
		}
		rowIndexT[0] = 0;

		// assign the new columns and values
		// synchronously tally
		// this is the place to insert extra values like dates
		for (int c = 0, i = 0; c < nColsT; c++) {
			// don't need to walk through row, there's only 0/1 vals per row
			if (this.indexes[c] == MISSING_ROW) {
				continue;
			}
			r = jaP[i];
			j = rowIndexT[r + 1]++;
			jaPT[j] = c;
			aPT[j] = aP[i];
			i++;
		}
		// TODO: CSR3 -> NIST and parsing can probably be more clever
		int[] pntrBT = new int[nRowsT];
		int[] pntrET = new int[nRowsT];
		int lastActivePtrE = 0;
		for (int i = 0; i < nRowsT; i++) {
			if (lastActivePtrE == rowIndexT[i + 1]) {
				continue;
			}
			pntrET[i] = rowIndexT[i + 1];
			pntrBT[i] = lastActivePtrE;
			lastActivePtrE = rowIndexT[i + 1];
		}

		/**
		 * consolidate csr (NIST) back to mlsparse
		 */
		final MLSparseVector[] rows = new MLSparseVector[nRowsT];
		IntStream.range(0, nRowsT).parallel().forEach(i -> {
			int rownnz = pntrET[i] - pntrBT[i];
			if (rownnz == 0) {
				return;
			}
			int[] rowColInds = new int[rownnz];
			float[] rowVals = new float[rownnz];
			for (int jj = 0, k = pntrBT[i]; jj < rownnz; jj++, k++) {
				rowColInds[jj] = jaPT[k];
				rowVals[jj] = aPT[k];
			}
			MLSparseVector rowVec = new MLSparseVector(rowColInds, rowVals,
					null, nColsT);
			rows[i] = rowVec;
		});
		return new MLSparseMatrixAOO(rows, nColsT);
	}
}