package com.jstarcraft.ai.math.structure.matrix;

import java.util.concurrent.Future;

import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test;

import com.jstarcraft.ai.environment.EnvironmentContext;
import com.jstarcraft.ai.environment.EnvironmentFactory;
import com.jstarcraft.ai.math.structure.MathCalculator;
import com.jstarcraft.ai.math.structure.message.SumMessage;
import com.jstarcraft.ai.math.structure.vector.ArrayVector;
import com.jstarcraft.ai.math.structure.vector.MathVector;
import com.jstarcraft.core.utility.RandomUtility;

import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap;

public class RowArrayMatrixTestCase extends MatrixTestCase {

    @Override
    protected RowArrayMatrix getRandomMatrix(int dimension) {
        HashMatrix table = new HashMatrix(true, dimension, dimension, new Long2FloatRBTreeMap());
        for (int rowIndex = 0; rowIndex < dimension; rowIndex++) {
            for (int columnIndex = 0; columnIndex < dimension; columnIndex++) {
                if (RandomUtility.randomBoolean()) {
                    table.setValue(rowIndex, columnIndex, 0F);
                }
            }
        }
        SparseMatrix data = SparseMatrix.valueOf(dimension, dimension, table);
        ArrayVector[] vectors = new ArrayVector[dimension];
        for (int rowIndex = 0; rowIndex < dimension; rowIndex++) {
            vectors[rowIndex] = new ArrayVector(data.getRowVector(rowIndex));
        }
        RowArrayMatrix matrix = RowArrayMatrix.valueOf(dimension, vectors);
        matrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(RandomUtility.randomInteger(dimension));
        });
        return matrix;
    }

    @Override
    protected RowArrayMatrix getZeroMatrix(int dimension) {
        HashMatrix table = new HashMatrix(true, dimension, dimension, new Long2FloatRBTreeMap());
        for (int rowIndex = 0; rowIndex < dimension; rowIndex++) {
            for (int columnIndex = 0; columnIndex < dimension; columnIndex++) {
                table.setValue(rowIndex, columnIndex, 0F);
            }
        }
        SparseMatrix data = SparseMatrix.valueOf(dimension, dimension, table);
        ArrayVector[] vectors = new ArrayVector[dimension];
        for (int rowIndex = 0; rowIndex < dimension; rowIndex++) {
            vectors[rowIndex] = new ArrayVector(data.getRowVector(rowIndex));
        }
        RowArrayMatrix matrix = RowArrayMatrix.valueOf(dimension, vectors);
        return matrix;
    }

    @Override
    public void testProduct() throws Exception {
        EnvironmentContext context = EnvironmentFactory.getContext();
        Future<?> task = context.doTask(() -> {
            int dimension = 10;
            MathMatrix leftMatrix = getRandomMatrix(dimension);
            MathMatrix rightMatrix = getRandomMatrix(dimension);
            MathMatrix dataMatrix = getZeroMatrix(dimension);
            MathMatrix markMatrix = DenseMatrix.valueOf(dimension, dimension);
            MathVector dataVector = dataMatrix.getRowVector(0);
            MathVector markVector = markMatrix.getRowVector(0);

            // 相当于transposeProductThis
            markMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
            dataMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
            dataMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.PARALLEL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

            // 相当于transposeProductThat
            markMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.SERIAL);
            dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.SERIAL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
            dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

            MathVector leftVector = leftMatrix.getRowVector(RandomUtility.randomInteger(dimension));
            MathVector rightVector = rightMatrix.getRowVector(RandomUtility.randomInteger(dimension));
            markMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
            dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
            dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.PARALLEL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

            markVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.SERIAL);
            dataVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.SERIAL);
            Assert.assertTrue(equalVector(dataVector, markVector));
            dataVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.PARALLEL);
            Assert.assertTrue(equalVector(dataVector, markVector));

            markVector.dotProduct(leftVector, rightMatrix, true, MathCalculator.SERIAL);
            dataVector.dotProduct(leftVector, rightMatrix, true, MathCalculator.SERIAL);
            Assert.assertTrue(equalVector(dataVector, markVector));
            dataVector.dotProduct(leftVector, rightMatrix, true, MathCalculator.PARALLEL);
            Assert.assertTrue(equalVector(dataVector, markVector));

            // 利用转置乘运算的对称性
            dataMatrix = new SymmetryMatrix(dimension);
            markMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
            dataMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
            Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        });
    }

    @Test
    public void testNotify() {
        int dimension = 10;
        RowArrayMatrix matrix = getRandomMatrix(dimension);
        matrix.setValues(1F);

        try {
            matrix.getColumnVector(RandomUtility.randomInteger(dimension));
            Assert.fail();
        } catch (UnsupportedOperationException exception) {
        }

        ArrayVector vector = matrix.getRowVector(RandomUtility.randomInteger(dimension));
        int oldSize = vector.getElementSize();
        int newSize = RandomUtility.randomInteger(oldSize);
        int[] indexes = new int[newSize];
        for (int index = 0; index < newSize; index++) {
            indexes[index] = index;
        }
        SumMessage message = new SumMessage(false);
        matrix.attachMonitor((iterator, oldElementSize, newElementSize, oldKnownSize, newKnownSize, oldUnknownSize, newUnknownSize) -> {
            Assert.assertThat(newElementSize - oldElementSize, CoreMatchers.equalTo(newSize - oldSize));
            message.accumulateValue(oldSize + newSize);
        });
        vector.modifyIndexes(indexes);
        vector.setValues(1F);
        Assert.assertThat(message.getValue(), CoreMatchers.equalTo(oldSize + newSize + 0F));
        Assert.assertThat(matrix.getSum(false), CoreMatchers.equalTo(matrix.getElementSize() + 0F));

        message.accumulateValue(-message.getValue());
        matrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            message.accumulateValue(scalar.getValue());
        });
        Assert.assertThat(message.getValue(), CoreMatchers.equalTo(matrix.getSum(false)));

        message.accumulateValue(-message.getValue());
        for (MatrixScalar term : matrix) {
            message.accumulateValue(term.getValue());
        }
        Assert.assertThat(message.getValue(), CoreMatchers.equalTo(matrix.getSum(false)));
    }

}