com.jstarcraft.ai.environment.EnvironmentContext Java Examples

The following examples show how to use com.jstarcraft.ai.environment.EnvironmentContext. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: LearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testGradient() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        long[] shape = { 5L, 2L };
        INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(shape);
        GradientUpdater<?> oldFunction = getOldFunction(shape);
        DenseMatrix gradient = getMatrix(array);
        Map<String, MathMatrix> gradients = new HashMap<>();
        gradients.put("gradients", gradient);
        Learner newFuction = getNewFunction(shape);
        newFuction.doCache(gradients);

        for (int iteration = 0; iteration < 10; iteration++) {
            oldFunction.applyUpdater(array, iteration, 0);
            newFuction.learn(gradients, iteration, 0);

            System.out.println(array);
            System.out.println(gradients);

            Assert.assertTrue(equalMatrix(gradient, array));
        }
    });
    task.get();
}
 
Example #2
Source File: HMMModel.java    From jstarcraft-rns with Apache License 2.0 6 votes vote down vote up
@Override
protected void eStep() {
    EnvironmentContext context = EnvironmentContext.getContext();
    // 并发计算
    CountDownLatch latch = new CountDownLatch(userSize);
    for (int userIndex = 0; userIndex < userSize; userIndex++) {
        int user = userIndex;
        context.doAlgorithmByAny(userIndex, () -> {
            calculateGammaRho(user, dataMatrixes[user]);
            latch.countDown();
        });
    }
    try {
        latch.await();
    } catch (Exception exception) {
        throw new ModelException(exception);
    }
}
 
Example #3
Source File: ActivationFunctionTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testForward() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2);
        IActivation oldFunction = getOldFunction();
        INDArray value = oldFunction.getActivation(array.dup(), true);

        DenseMatrix input = getMatrix(array);
        DenseMatrix output = DenseMatrix.valueOf(input.getRowSize(), input.getColumnSize());
        ActivationFunction newFuction = getNewFunction();
        newFuction.forward(input, output);

        System.out.println(value);
        System.out.println(output);
        Assert.assertTrue(equalMatrix(output, value));

        DenseVector vector = DenseVector.valueOf(input.getColumnSize());
        for (int index = 0, size = input.getRowSize(); index < size; index++) {
            newFuction.forward(input.getRowVector(index), vector);
            Assert.assertTrue(equalVector(vector, output.getRowVector(index)));
        }
    });
    task.get();
}
 
Example #4
Source File: MatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testSize() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        int dimension = 10;
        MathMatrix dataMatrix = getRandomMatrix(dimension);

        Assert.assertThat(dataMatrix.getKnownSize() + dataMatrix.getUnknownSize(), CoreMatchers.equalTo(dataMatrix.getRowSize() * dataMatrix.getColumnSize()));

        int elementSize = 0;
        float sumValue = 0F;
        for (MatrixScalar term : dataMatrix) {
            elementSize++;
            sumValue += term.getValue();
        }
        Assert.assertThat(elementSize, CoreMatchers.equalTo(dataMatrix.getElementSize()));
        Assert.assertThat(sumValue, CoreMatchers.equalTo(dataMatrix.getSum(false)));
    });
    task.get();
}
 
Example #5
Source File: MatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testSum() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        int dimension = 10;
        MathMatrix dataMatrix = getRandomMatrix(dimension);

        float oldSum = dataMatrix.getSum(false);
        dataMatrix.scaleValues(2F);
        float newSum = dataMatrix.getSum(false);
        Assert.assertThat(newSum, CoreMatchers.equalTo(oldSum * 2F));

        oldSum = newSum;
        dataMatrix.shiftValues(1F);
        newSum = dataMatrix.getSum(false);
        Assert.assertThat(newSum, CoreMatchers.equalTo(oldSum + dataMatrix.getElementSize()));

        dataMatrix.setValues(0F);
        newSum = dataMatrix.getSum(false);
        Assert.assertThat(newSum, CoreMatchers.equalTo(0F));
    });
    task.get();
}
 
Example #6
Source File: MatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testCodec() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        // 维度设置为100,可以测试编解码的效率.
        int dimension = 100;
        MathMatrix oldMatrix = getRandomMatrix(dimension);

        for (ModemCodec codec : ModemCodec.values()) {
            long encodeInstant = System.currentTimeMillis();
            byte[] data = codec.encodeModel(oldMatrix);
            String encodeMessage = StringUtility.format("编码{}数据的时间:{}毫秒", codec, System.currentTimeMillis() - encodeInstant);
            logger.info(encodeMessage);
            long decodeInstant = System.currentTimeMillis();
            MathMatrix newMatrix = (MathMatrix) codec.decodeModel(data);
            String decodeMessage = StringUtility.format("解码{}数据的时间:{}毫秒", codec, System.currentTimeMillis() - decodeInstant);
            logger.info(decodeMessage);
            Assert.assertThat(newMatrix, CoreMatchers.equalTo(oldMatrix));
        }
    });
    task.get();
}
 
Example #7
Source File: MatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void testPerformance() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {

    });
    task.get();
    // 性能测试
}
 
Example #8
Source File: DiversityEvaluatorTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected Evaluator<IntSet, IntList> getEvaluator(SparseMatrix featureMatrix) {
    MathCorrelation correlation = new CosineSimilarity();
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<SymmetryMatrix> task = context.doTask(() -> {
        SymmetryMatrix symmetryMatrix = new SymmetryMatrix(featureMatrix.getColumnSize());
        correlation.calculateCoefficients(featureMatrix, true, symmetryMatrix::setValue);
        return symmetryMatrix;
    });
    try {
        return new DiversityEvaluator(10, task.get());
    } catch (Exception exception) {
        throw new RuntimeException(exception);
    }
}
 
Example #9
Source File: ActivationFunctionTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void testBackward() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2);
        IActivation oldFunction = getOldFunction();
        INDArray epsilon = Nd4j.linspace(-2.5D, 2.5D, 10).reshape(5, 2);
        Pair<INDArray, INDArray> keyValue = oldFunction.backprop(array.dup(), epsilon);

        DenseMatrix input = getMatrix(array);
        DenseMatrix output = DenseMatrix.valueOf(input.getRowSize(), input.getColumnSize());
        DenseMatrix error = getMatrix(epsilon);
        ActivationFunction newFuction = getNewFunction();
        newFuction.backward(input, error, output);

        System.out.println(keyValue.getKey());
        System.out.println(output);
        Assert.assertTrue(equalMatrix(output, keyValue.getKey()));

        DenseVector vector = DenseVector.valueOf(input.getColumnSize());
        for (int index = 0, size = input.getRowSize(); index < size; index++) {
            newFuction.backward(input.getRowVector(index), error.getRowVector(index), vector);
            Assert.assertTrue(equalVector(vector, output.getRowVector(index)));
        }
    });
    task.get();
}
 
Example #10
Source File: LossFunctionTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void testGradient() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        LinkedList<KeyValue<IActivation, ActivationFunction>> activetionList = new LinkedList<>();
        activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction()));
        activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction()));
        for (KeyValue<IActivation, ActivationFunction> keyValue : activetionList) {
            INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2);
            INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2);
            ILossFunction oldFunction = getOldFunction();
            INDArray value = oldFunction.computeGradient(marks, array.dup(), keyValue.getKey(), null);

            DenseMatrix input = getMatrix(array);
            DenseMatrix output = DenseMatrix.valueOf(input.getRowSize(), input.getColumnSize());
            ActivationFunction function = keyValue.getValue();
            function.forward(input, output);
            DenseMatrix gradient = DenseMatrix.valueOf(input.getRowSize(), input.getColumnSize());
            LossFunction newFunction = getNewFunction(function);
            newFunction.doCache(getMatrix(marks), output);
            newFunction.computeGradient(getMatrix(marks), output, null, gradient);
            function.backward(input, gradient, output);
            System.out.println(value);
            System.out.println(output);
            Assert.assertTrue(equalMatrix(output, value));
        }
    });
    task.get();
}
 
Example #11
Source File: LossFunctionTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void testScore() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        LinkedList<KeyValue<IActivation, ActivationFunction>> activetionList = new LinkedList<>();
        activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction()));
        activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction()));
        for (KeyValue<IActivation, ActivationFunction> keyValue : activetionList) {
            INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2);
            INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2);
            ILossFunction oldFunction = getOldFunction();
            double value = oldFunction.computeScore(marks, array.dup(), keyValue.getKey(), null, false);

            DenseMatrix input = getMatrix(array);
            DenseMatrix output = DenseMatrix.valueOf(input.getRowSize(), input.getColumnSize());
            ActivationFunction function = keyValue.getValue();
            function.forward(input, output);
            LossFunction newFunction = getNewFunction(function);
            newFunction.doCache(getMatrix(marks), output);
            double score = newFunction.computeScore(getMatrix(marks), output, null);

            System.out.println(value);
            System.out.println(score);

            if (Math.abs(value - score) > MathUtility.EPSILON) {
                Assert.fail();
            }
        }
    });
    task.get();
}
 
Example #12
Source File: MixtureDensityLossFunctionTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
@Override
public void testGradient() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        LinkedList<KeyValue<IActivation, ActivationFunction>> activetionList = new LinkedList<>();
        activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction()));
        activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction()));
        for (KeyValue<IActivation, ActivationFunction> keyValue : activetionList) {
            INDArray array = Nd4j.linspace(-2.5D, 2.0D, 20).reshape(5, 4);
            INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2);
            ILossFunction oldFunction = getOldFunction();
            INDArray value = oldFunction.computeGradient(marks, array.dup(), keyValue.getKey(), null);

            MathMatrix input = getMatrix(array.rows(), array.columns()).copyMatrix(getMatrix(array), false);
            MathMatrix output = getMatrix(input.getRowSize(), input.getColumnSize());
            ActivationFunction function = keyValue.getValue();
            function.forward(input, output);
            MathMatrix gradient = getMatrix(input.getRowSize(), input.getColumnSize());
            LossFunction newFunction = getNewFunction(function);
            newFunction.doCache(getMatrix(marks.rows(), marks.columns()).copyMatrix(getMatrix(marks), false), output);
            newFunction.computeGradient(getMatrix(marks.rows(), marks.columns()).copyMatrix(getMatrix(marks), false), output, null, gradient);
            function.backward(input, gradient, output);
            System.out.println(value);
            System.out.println(output);
            Assert.assertTrue(equalMatrix(output, value));
        }
    });
    task.get();
}
 
Example #13
Source File: MixtureDensityLossFunctionTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
@Override
public void testScore() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        LinkedList<KeyValue<IActivation, ActivationFunction>> activetionList = new LinkedList<>();
        activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction()));
        activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction()));
        for (KeyValue<IActivation, ActivationFunction> keyValue : activetionList) {
            INDArray array = Nd4j.linspace(-2.5D, 2.0D, 20).reshape(5, 4);
            INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2);
            ILossFunction oldFunction = getOldFunction();
            float value = (float) oldFunction.computeScore(marks, array.dup(), keyValue.getKey(), null, false);

            MathMatrix input = getMatrix(array.rows(), array.columns()).copyMatrix(getMatrix(array), false);
            MathMatrix output = getMatrix(input.getRowSize(), input.getColumnSize());
            ActivationFunction function = keyValue.getValue();
            function.forward(input, output);
            LossFunction newFunction = getNewFunction(function);
            newFunction.doCache(getMatrix(marks.rows(), marks.columns()).copyMatrix(getMatrix(marks), false), output);
            float score = newFunction.computeScore(getMatrix(marks.rows(), marks.columns()).copyMatrix(getMatrix(marks), false), output, null);

            System.out.println(value);
            System.out.println(score);
            if (!MathUtility.equal(value, score)) {
                Assert.fail();
            }
        }
    });
    task.get();
}
 
Example #14
Source File: MatrixUtilityTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        DenseMatrix matrix = DenseMatrix.valueOf(4, 4);

        matrix.setValues(0F);
        matrix.getRowVector(1).iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(probability.sample().floatValue());
        });
        System.out.println(matrix);
        System.out.println(DenseMatrix.valueOf(4, 4).copyMatrix(matrix, true));
        System.out.println(DenseMatrix.valueOf(4, 4).dotProduct(matrix, true, matrix, false, MathCalculator.SERIAL));

        matrix.setValues(0F);
        matrix.getColumnVector(1).iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(probability.sample().floatValue());
        });
        System.out.println(matrix);
        System.out.println(DenseMatrix.valueOf(4, 4).copyMatrix(matrix, true));
        System.out.println(DenseMatrix.valueOf(4, 4).dotProduct(matrix, true, matrix, false, MathCalculator.SERIAL));

        matrix.iterateElement(MathCalculator.PARALLEL, (scalar) -> {
            if (scalar.getRow() == scalar.getColumn()) {
                scalar.setValue(probability.sample().floatValue());
            } else {
                scalar.setValue(0F);
            }
        });
        System.out.println(matrix);
        System.out.println(DenseMatrix.valueOf(4, 4).copyMatrix(matrix, true));
        System.out.println(DenseMatrix.valueOf(4, 4).dotProduct(matrix, true, matrix, false, MathCalculator.SERIAL));
    });
    task.get();
}
 
Example #15
Source File: MovieModelConfigurer.java    From jstarcraft-example with Apache License 2.0 5 votes vote down vote up
private Model getModel(Class<? extends Model> clazz, DataSpace dataSpace, DataModule dataModule) throws Exception {
    Model model = ReflectionUtility.getInstance(clazz);
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        model.prepare(configuration, dataModule, dataSpace);
        model.practice();
    });
    task.get();
    return model;
}
 
Example #16
Source File: AbstractModel.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
@Override
public final void practice() {
    EnvironmentContext context = EnvironmentContext.getContext();
    context.doAlgorithmByEvery(this::constructEnvironment);
    doPractice();
    context.doAlgorithmByEvery(this::destructEnvironment);
}
 
Example #17
Source File: LossFunctionTestCase.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
@Test
    public void testScore() throws Exception {
        EnvironmentContext context = EnvironmentFactory.getContext();
        Future<?> task = context.doTask(() -> {
            LinkedList<KeyValue<IActivation, ActivationFunction>> activetionList = new LinkedList<>();
            activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction()));
//            activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction()));
            for (KeyValue<IActivation, ActivationFunction> keyValue : activetionList) {
                INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2);
                INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2);
                ILossFunction oldFunction = getOldFunction(marks);
                double value = oldFunction.computeScore(marks, array.dup(), keyValue.getKey(), null, false);

                Nd4jMatrix input = getMatrix(array.dup());
                Nd4jMatrix output = new Nd4jMatrix(Nd4j.zeros(input.getRowSize(), input.getColumnSize()));
                ActivationFunction function = keyValue.getValue();
                function.forward(input, output);
                LossFunction newFunction = getNewFunction(marks, function);
                newFunction.doCache(getMatrix(marks), output);
                double score = newFunction.computeScore(getMatrix(marks), output, null);

                System.out.println(value);
                System.out.println(score);

                if (Math.abs(value - score) > MathUtility.EPSILON) {
                    Assert.fail();
                }
            }
        });
        task.get();
    }
 
Example #18
Source File: LossFunctionTestCase.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
@Test
    public void testGradient() throws Exception {
        EnvironmentContext context = EnvironmentFactory.getContext();
        Future<?> task = context.doTask(() -> {
            LinkedList<KeyValue<IActivation, ActivationFunction>> activetionList = new LinkedList<>();
            activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction()));
//            activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction()));
            for (KeyValue<IActivation, ActivationFunction> keyValue : activetionList) {
                INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2);
                INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2);
                ILossFunction oldFunction = getOldFunction(marks);
                INDArray value = oldFunction.computeGradient(marks, array.dup(), keyValue.getKey(), null);

                Nd4jMatrix input = getMatrix(array.dup());
                Nd4jMatrix output = new Nd4jMatrix(Nd4j.zeros(input.getRowSize(), input.getColumnSize()));
                ActivationFunction function = keyValue.getValue();
                function.forward(input, output);
                Nd4jMatrix gradient = new Nd4jMatrix(Nd4j.zeros(input.getRowSize(), input.getColumnSize()));
                LossFunction newFunction = getNewFunction(marks, function);
                newFunction.doCache(getMatrix(marks), output);
                newFunction.computeGradient(getMatrix(marks), output, null, gradient);
                function.backward(input, gradient, output);
                System.out.println(value);
                System.out.println(output);
                Assert.assertTrue(equalMatrix(output, value));
            }
        });
        task.get();
    }
 
Example #19
Source File: MathCorrelation.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
/**
 * 根据分数矩阵计算相关度
 * 
 * @param scoreMatrix
 * @param transpose
 * @param monitor
 */
default void calculateCoefficients(MathMatrix scoreMatrix, boolean transpose, CorrelationMonitor monitor) {
    EnvironmentContext context = EnvironmentContext.getContext();
    Semaphore semaphore = new Semaphore(0);
    int count = transpose ? scoreMatrix.getColumnSize() : scoreMatrix.getRowSize();
    for (int leftIndex = 0; leftIndex < count; leftIndex++) {
        MathVector thisVector = transpose ? scoreMatrix.getColumnVector(leftIndex) : scoreMatrix.getRowVector(leftIndex);
        if (thisVector.getElementSize() == 0) {
            continue;
        }
        monitor.notifyCoefficientCalculated(leftIndex, leftIndex, getIdentical());
        // user/item itself exclusive
        int permits = 0;
        for (int rightIndex = leftIndex + 1; rightIndex < count; rightIndex++) {
            MathVector thatVector = transpose ? scoreMatrix.getColumnVector(rightIndex) : scoreMatrix.getRowVector(rightIndex);
            if (thatVector.getElementSize() == 0) {
                continue;
            }
            int leftCursor = leftIndex;
            int rightCursor = rightIndex;
            context.doAlgorithmByAny(leftIndex * rightIndex, () -> {
                float coefficient = getCoefficient(thisVector, thatVector);
                if (!Float.isNaN(coefficient)) {
                    monitor.notifyCoefficientCalculated(leftCursor, rightCursor, coefficient);
                }
                semaphore.release();
            });
            permits++;
        }
        try {
            semaphore.acquire(permits);
        } catch (Exception exception) {
            throw new RuntimeException(exception);
        }
    }
}
 
Example #20
Source File: SymmetryMatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void testProduct() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        int dimension = 10;
        MathMatrix randomMatrix = getRandomMatrix(dimension);
        MathMatrix dataMatrix = getZeroMatrix(dimension);
        MathMatrix markMatrix = DenseMatrix.valueOf(dimension, dimension);
        MathVector dataVector = dataMatrix.getRowVector(0);
        MathVector markVector = markMatrix.getRowVector(0);

        markMatrix.dotProduct(randomMatrix, true, randomMatrix, false, MathCalculator.SERIAL);
        dataMatrix.dotProduct(randomMatrix, true, randomMatrix, false, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(randomMatrix, true, randomMatrix, false, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        MathVector leftVector = randomMatrix.getRowVector(RandomUtility.randomInteger(dimension));
        MathVector rightVector = randomMatrix.getRowVector(RandomUtility.randomInteger(dimension));
        markVector.dotProduct(randomMatrix, false, rightVector, MathCalculator.SERIAL);
        dataVector.dotProduct(randomMatrix, false, rightVector, MathCalculator.SERIAL);
        Assert.assertTrue(equalVector(dataVector, markVector));
        dataVector.dotProduct(randomMatrix, false, rightVector, MathCalculator.PARALLEL);
        Assert.assertTrue(equalVector(dataVector, markVector));

        markVector.dotProduct(leftVector, randomMatrix, false, MathCalculator.SERIAL);
        dataVector.dotProduct(leftVector, randomMatrix, false, MathCalculator.SERIAL);
        Assert.assertTrue(equalVector(dataVector, markVector));
        dataVector.dotProduct(leftVector, randomMatrix, false, MathCalculator.PARALLEL);
        Assert.assertTrue(equalVector(dataVector, markVector));
    });
    task.get();
}
 
Example #21
Source File: MatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Test
public void testFourArithmeticOperation() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        RandomUtility.setSeed(0L);
        int dimension = 10;
        MathMatrix dataMatrix = getZeroMatrix(dimension);
        dataMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(RandomUtility.randomFloat(10F));
        });
        MathMatrix copyMatrix = getZeroMatrix(dimension);
        float sum = dataMatrix.getSum(false);

        copyMatrix.copyMatrix(dataMatrix, false);
        Assert.assertThat(copyMatrix.getSum(false), CoreMatchers.equalTo(sum));

        dataMatrix.subtractMatrix(copyMatrix, false);
        Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(0F));

        dataMatrix.addMatrix(copyMatrix, false);
        Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(sum));

        dataMatrix.divideMatrix(copyMatrix, false);
        Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(dataMatrix.getElementSize() + 0F));

        dataMatrix.multiplyMatrix(copyMatrix, false);
        Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(sum));
    });
    task.get();
}
 
Example #22
Source File: RowHashMatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Override
public void testProduct() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        int dimension = 10;
        MathMatrix leftMatrix = getRandomMatrix(dimension);
        MathMatrix rightMatrix = getRandomMatrix(dimension);
        MathMatrix dataMatrix = getZeroMatrix(dimension);
        MathMatrix markMatrix = DenseMatrix.valueOf(dimension, dimension);
        MathVector dataVector = dataMatrix.getRowVector(0);
        MathVector markVector = markMatrix.getRowVector(0);

        // 相当于transposeProductThis
        markMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        // 相当于transposeProductThat
        markMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        MathVector leftVector = leftMatrix.getRowVector(RandomUtility.randomInteger(dimension));
        MathVector rightVector = rightMatrix.getRowVector(RandomUtility.randomInteger(dimension));
        markMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        markVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.SERIAL);
        dataVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.SERIAL);
        Assert.assertTrue(equalVector(dataVector, markVector));
        dataVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.PARALLEL);
        Assert.assertTrue(equalVector(dataVector, markVector));

        markVector.dotProduct(leftVector, rightMatrix, true, MathCalculator.SERIAL);
        dataVector.dotProduct(leftVector, rightMatrix, true, MathCalculator.SERIAL);
        Assert.assertTrue(equalVector(dataVector, markVector));
        dataVector.dotProduct(leftVector, rightMatrix, true, MathCalculator.PARALLEL);
        Assert.assertTrue(equalVector(dataVector, markVector));

        // 利用转置乘运算的对称性
        dataMatrix = new SymmetryMatrix(dimension);
        markMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftMatrix, false, leftMatrix, true, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
    });
}
 
Example #23
Source File: AssociationRuleModel.java    From jstarcraft-rns with Apache License 2.0 4 votes vote down vote up
@Override
protected void doPractice() {
    EnvironmentContext context = EnvironmentContext.getContext();
    Semaphore semaphore = new Semaphore(0);
    // simple rule: X => Y, given that each user vector is regarded as a
    // transaction
    for (int leftItemIndex = 0; leftItemIndex < itemSize; leftItemIndex++) {
        // all transactions for item itemIdx
        SparseVector leftVector = scoreMatrix.getColumnVector(leftItemIndex);
        for (int rightItemIndex = leftItemIndex + 1; rightItemIndex < itemSize; rightItemIndex++) {
            SparseVector rightVector = scoreMatrix.getColumnVector(rightItemIndex);
            int leftIndex = leftItemIndex;
            int rightIndex = rightItemIndex;
            context.doAlgorithmByAny(leftItemIndex * rightItemIndex, () -> {
                int leftCursor = 0, rightCursor = 0, leftSize = leftVector.getElementSize(), rightSize = rightVector.getElementSize();
                if (leftSize != 0 && rightSize != 0) {
                    // compute confidence where containing item assoItemIdx
                    // among
                    // userRatingsVector
                    int count = 0;
                    Iterator<VectorScalar> leftIterator = leftVector.iterator();
                    Iterator<VectorScalar> rightIterator = rightVector.iterator();
                    VectorScalar leftTerm = leftIterator.next();
                    VectorScalar rightTerm = rightIterator.next();
                    // 判断两个有序数组中是否存在相同的数字
                    while (leftCursor < leftSize && rightCursor < rightSize) {
                        if (leftTerm.getIndex() == rightTerm.getIndex()) {
                            count++;
                            if (leftIterator.hasNext()) {
                                leftTerm = leftIterator.next();
                            }
                            if (rightIterator.hasNext()) {
                                rightTerm = rightIterator.next();
                            }
                            leftCursor++;
                            rightCursor++;
                        } else if (leftTerm.getIndex() > rightTerm.getIndex()) {
                            if (rightIterator.hasNext()) {
                                rightTerm = rightIterator.next();
                            }
                            rightCursor++;
                        } else if (leftTerm.getIndex() < rightTerm.getIndex()) {
                            if (leftIterator.hasNext()) {
                                leftTerm = leftIterator.next();
                            }
                            leftCursor++;
                        }
                    }
                    float leftValue = (count + 0F) / leftVector.getElementSize();
                    float rightValue = (count + 0F) / rightVector.getElementSize();
                    associationMatrix.setValue(leftIndex, rightIndex, leftValue);
                    associationMatrix.setValue(rightIndex, leftIndex, rightValue);
                }
                semaphore.release();
            });
        }
        try {
            semaphore.acquire(itemSize - leftItemIndex - 1);
        } catch (Exception exception) {
            throw new ModelException(exception);
        }
    }
}
 
Example #24
Source File: GraphTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Test
public void testPropagate() throws Exception {
    MathCache factory = new Nd4jCache();

    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        ComputationGraph oldGraph = getOldFunction();
        Graph graph = getNewFunction(factory, oldGraph);

        int size = 5;
        INDArray oldLeftInputs = Nd4j.zeros(size, 1);
        INDArray oldRightInputs = Nd4j.zeros(size, 1);
        INDArray oldMarks = Nd4j.zeros(size, 1).assign(5);
        for (int point = 0; point < 5; point++) {
            oldLeftInputs.put(point, 0, RandomUtility.randomInteger(5));
            oldRightInputs.put(point, 0, RandomUtility.randomInteger(5));
        }
        for (int index = 0; index < 50; index++) {
            oldGraph.setInputs(oldLeftInputs, oldRightInputs);
            oldGraph.setLabels(oldMarks);
            // 设置fit过程的迭代次数
            for (int iteration = 0; iteration < 2; iteration++) {
                oldGraph.fit();
                double oldScore = oldGraph.score();
                System.out.println(oldScore);
            }
        }
        INDArray oldOutputs = oldGraph.outputSingle(oldLeftInputs, oldRightInputs);
        System.out.println(oldOutputs);

        AffinityManager manager = Nd4j.getAffinityManager();
        manager.attachThreadToDevice(Thread.currentThread(), 0);
        MathMatrix newLeftInputs = getMatrix(factory, oldLeftInputs);
        MathMatrix newRightInputs = getMatrix(factory, oldRightInputs);
        MathMatrix newMarks = getMatrix(factory, oldMarks);
        MathMatrix newOutputs = getMatrix(factory, oldOutputs);

        for (int index = 0; index < 50; index++) {
            double newScore = graph.practice(2, new MathMatrix[] { newLeftInputs, newRightInputs }, new MathMatrix[] { newMarks });
            System.out.println(newScore);
        }

        graph.predict(new MathMatrix[] { newLeftInputs, newRightInputs }, new MathMatrix[] { newOutputs });
        System.out.println(newOutputs);
        Assert.assertTrue(equalMatrix(newOutputs, oldOutputs));
    });
    task.get();
}
 
Example #25
Source File: ColumnArrayMatrixTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Override
public void testProduct() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        int dimension = 10;
        MathMatrix leftMatrix = getRandomMatrix(dimension);
        MathMatrix rightMatrix = getRandomMatrix(dimension);
        MathMatrix dataMatrix = getZeroMatrix(dimension);
        MathMatrix markMatrix = DenseMatrix.valueOf(dimension, dimension);
        MathVector dataVector = dataMatrix.getColumnVector(0);
        MathVector markVector = markMatrix.getColumnVector(0);

        // 相当于transposeProductThis
        markMatrix.dotProduct(leftMatrix, true, leftMatrix, false, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftMatrix, true, leftMatrix, false, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(leftMatrix, true, leftMatrix, false, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        // 相当于transposeProductThat
        markMatrix.dotProduct(leftMatrix, true, rightMatrix, false, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftMatrix, true, rightMatrix, false, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(leftMatrix, true, rightMatrix, false, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        MathVector leftVector = leftMatrix.getColumnVector(RandomUtility.randomInteger(dimension));
        MathVector rightVector = rightMatrix.getColumnVector(RandomUtility.randomInteger(dimension));
        markMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
        dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.PARALLEL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));

        markVector.dotProduct(leftMatrix, true, rightVector, MathCalculator.SERIAL);
        dataVector.dotProduct(leftMatrix, true, rightVector, MathCalculator.SERIAL);
        Assert.assertTrue(equalVector(dataVector, markVector));
        dataVector.dotProduct(leftMatrix, true, rightVector, MathCalculator.PARALLEL);
        Assert.assertTrue(equalVector(dataVector, markVector));

        markVector.dotProduct(leftVector, rightMatrix, false, MathCalculator.SERIAL);
        dataVector.dotProduct(leftVector, rightMatrix, false, MathCalculator.SERIAL);
        Assert.assertTrue(equalVector(dataVector, markVector));
        dataVector.dotProduct(leftVector, rightMatrix, false, MathCalculator.PARALLEL);
        Assert.assertTrue(equalVector(dataVector, markVector));

        // 利用转置乘运算的对称性
        dataMatrix = new SymmetryMatrix(dimension);
        markMatrix.dotProduct(leftMatrix, true, leftMatrix, false, MathCalculator.SERIAL);
        dataMatrix.dotProduct(leftMatrix, true, leftMatrix, false, MathCalculator.SERIAL);
        Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
    });
    task.get();
}
 
Example #26
Source File: RandomLayer.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Override
public void doForward() {
    MathMatrix weightParameters = parameters.get(WEIGHT_KEY);
    MathMatrix biasParameters = parameters.get(BIAS_KEY);

    MathMatrix inputData = inputKeyValue.getKey();
    MathMatrix middleData = middleKeyValue.getKey();
    middleData.setValues(0F);
    MathMatrix outputData = outputKeyValue.getKey();
    outputData.setValues(0F);
    outputData.getColumnVector(0).copyVector(inputData.getColumnVector(0));
    int numberOfRows = inputData.getRowSize();
    EnvironmentContext context = EnvironmentContext.getContext();
    CountDownLatch latch = new CountDownLatch(numberOfRows);
    for (int rowIndex = 0; rowIndex < numberOfRows; rowIndex++) {
        MathVector inputMajorData = inputData.getRowVector(rowIndex);
        MathVector middleMajorData = middleData.getRowVector(rowIndex);
        MathVector outputMajorData = outputData.getRowVector(rowIndex);
        int numberOfColumns = (int) inputMajorData.getValue(0);
        context.doStructureByAny(rowIndex, () -> {
            for (int columnIndex = 0; columnIndex < numberOfColumns; columnIndex++) {
                MathVector inputMinorData = GlobalVector.detachOf(GlobalVector.class.cast(inputMajorData), columnIndex * numberOfInputs + 1, (columnIndex + 1) * numberOfInputs + 1);
                MathVector middleMinorData = GlobalVector.detachOf(GlobalVector.class.cast(middleMajorData), columnIndex * numberOfOutputs, (columnIndex + 1) * numberOfOutputs);
                MathVector outputMinorData = GlobalVector.detachOf(GlobalVector.class.cast(outputMajorData), columnIndex * numberOfOutputs + 1, (columnIndex + 1) * numberOfOutputs + 1);
                middleMinorData.dotProduct(inputMinorData, weightParameters, false, MathCalculator.SERIAL);
                if (biasParameters != null) {
                    middleMinorData.iterateElement(MathCalculator.SERIAL, (scalar) -> {
                        int index = scalar.getIndex();
                        float value = scalar.getValue();
                        scalar.setValue(value + biasParameters.getValue(0, index));
                    });
                }
                function.forward(middleMajorData, outputMinorData);
            }
            latch.countDown();
        });
    }
    try {
        latch.await();
    } catch (Exception exception) {
        throw new RuntimeException(exception);
    }

    MathMatrix middleError = middleKeyValue.getValue();
    middleError.setValues(0F);

    MathMatrix innerError = outputKeyValue.getValue();
    innerError.setValues(0F);
}
 
Example #27
Source File: MatrixUtilityTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetryMatrix() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        SymmetryMatrix matrix = new SymmetryMatrix(5);
        Assert.assertThat(matrix.getElementSize(), CoreMatchers.equalTo(15));
        int times = 0;
        for (MatrixScalar term : matrix) {
            times++;
            term.setValue(times);
            Assert.assertTrue(term.getValue() == times);
            Assert.assertTrue(matrix.getValue(term.getRow(), term.getColumn()) == times);
        }
        Assert.assertThat(times, CoreMatchers.equalTo(15));

        // 对称矩阵无论转置多少次都不变.
        DenseMatrix symmetry = DenseMatrix.valueOf(3, 3);
        symmetry.setValue(0, 0, 25F);
        symmetry.setValue(0, 1, 15F);
        symmetry.setValue(1, 0, 15F);
        symmetry.setValue(1, 1, 18F);
        symmetry.setValue(0, 2, -5F);
        symmetry.setValue(2, 0, -5F);
        symmetry.setValue(2, 2, 11F);
        DenseMatrix transpose = DenseMatrix.valueOf(symmetry.getColumnSize(), symmetry.getRowSize());
        Assert.assertThat(transpose.copyMatrix(symmetry, true), CoreMatchers.equalTo(symmetry));

        DenseMatrix left = DenseMatrix.valueOf(symmetry.getRowSize(), symmetry.getColumnSize());
        DenseMatrix right = DenseMatrix.valueOf(symmetry.getRowSize(), symmetry.getColumnSize());

        // Cholesky分解:http://www.qiujiawei.com/linear-algebra-11/
        // (Cholesky分解的目标是把A变成:A=LLT,L是下三角矩阵.)
        DenseMatrix cholesky = DenseMatrix.valueOf(3, 3);
        MatrixUtility.cholesky(symmetry, cholesky);
        Assert.assertThat(left.dotProduct(cholesky, false, transpose.copyMatrix(cholesky, true), false, MathCalculator.SERIAL), CoreMatchers.equalTo(symmetry));

        // 协方差矩阵(本质是对称矩阵)
        DenseVector outerMeans = DenseVector.valueOf(3);
        DenseVector innerMeans = DenseVector.valueOf(3);
        DenseMatrix covariance = DenseMatrix.valueOf(3, 3);
        MatrixUtility.covariance(cholesky, outerMeans, innerMeans, covariance);
        Assert.assertThat(transpose.copyMatrix(covariance, true), CoreMatchers.equalTo(covariance));

        // 逆矩阵的目标是AB=BA=E(E是单位矩阵,对角都是1,其它都是0)
        // 由于精度问题,所以使用transformer将矩阵修改为单位矩阵.
        MathAccessor<MatrixScalar> accessor = (scalar) -> {
            int row = scalar.getRow();
            int column = scalar.getColumn();
            float value = scalar.getValue();
            if (row == column) {
                if (!MathUtility.equal(value, 1F)) {
                    System.err.println(value);
                    Assert.fail();
                }
                scalar.setValue(1F);
            } else {
                if (!MathUtility.equal(value, 0F)) {
                    System.err.println(value);
                    Assert.fail();
                }
                scalar.setValue(0F);
            }
        };
        DenseMatrix inverse = DenseMatrix.valueOf(3, 3);
        DenseMatrix copy = DenseMatrix.valueOf(3, 3);
        MatrixUtility.inverse(symmetry, copy, inverse);
        Assert.assertThat(left.dotProduct(symmetry, false, inverse, false, MathCalculator.SERIAL).iterateElement(MathCalculator.PARALLEL, accessor), CoreMatchers.equalTo(right.dotProduct(inverse, false, symmetry, false, MathCalculator.SERIAL).iterateElement(MathCalculator.PARALLEL, accessor)));

        inverse = MatrixUtility.inverse(cholesky, copy, inverse);
        Assert.assertThat(left.dotProduct(cholesky, false, inverse, false, MathCalculator.SERIAL).iterateElement(MathCalculator.PARALLEL, accessor), CoreMatchers.equalTo(right.dotProduct(inverse, false, cholesky, false, MathCalculator.SERIAL).iterateElement(MathCalculator.PARALLEL, accessor)));

        inverse = MatrixUtility.inverse(covariance, copy, inverse);
        Assert.assertThat(left.dotProduct(covariance, false, inverse, false, MathCalculator.SERIAL).iterateElement(MathCalculator.PARALLEL, accessor), CoreMatchers.equalTo(right.dotProduct(inverse, false, covariance, false, MathCalculator.SERIAL).iterateElement(MathCalculator.PARALLEL, accessor)));
    });
    task.get();
}
 
Example #28
Source File: MatrixUtilityTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
/**
 * 测试矩阵奇异值分解
 */
@Test
public void testSVD() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        // 注意:矩阵必须row大于等于column
        int rowSize = 5;
        int columnSize = 3;
        AtomicInteger random = new AtomicInteger();
        DenseMatrix oldMatrix = DenseMatrix.valueOf(rowSize, columnSize);
        oldMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(random.getAndIncrement());
        });

        SingularValueDecomposition svd = new SingularValueDecomposition(oldMatrix);
        System.out.println("print matrix U:");
        System.out.println(svd.getU().toString());

        System.out.println("print vector S:");
        System.out.println(svd.getS().toString());

        System.out.println("print matrix V:");
        System.out.println(svd.getV().toString());

        DenseVector singularVector = svd.getS();

        // TODO 修改为基于稀疏/哈希矩阵
        DenseMatrix S = DenseMatrix.valueOf(singularVector.getElementSize(), singularVector.getElementSize());
        for (int index = 0; index < singularVector.getElementSize(); index++) {
            S.setValue(index, index, singularVector.getValue(index));
        }

        DenseMatrix middleMatrix = DenseMatrix.valueOf(rowSize, columnSize);
        DenseMatrix transpose = DenseMatrix.valueOf(svd.getV().getColumnSize(), svd.getV().getRowSize());
        DenseMatrix newMatrix = DenseMatrix.valueOf(rowSize, columnSize);
        middleMatrix.dotProduct(svd.getU(), false, S, false, MathCalculator.SERIAL);
        newMatrix.dotProduct(middleMatrix, false, transpose.copyMatrix(svd.getV(), true), false, MathCalculator.SERIAL);

        System.out.println("print matrix OLD:");
        System.out.println(oldMatrix.toString());
        System.out.println("print matrix NEW:");
        System.out.println(newMatrix.toString());

        // 由于计算导致精度的损失,所以此处不直接使用DenseMatrix.equals比较.
        for (int row = 0; row < rowSize; row++) {
            for (int column = 0; column < columnSize; column++) {
                float oldValue = oldMatrix.getValue(row, column);
                float newValue = newMatrix.getValue(row, column);
                if (!MathUtility.equal(oldValue, newValue)) {
                    System.out.println(oldValue);
                    System.out.println(newValue);
                    Assert.fail();
                }
            }
        }
    });
    task.get();
}
 
Example #29
Source File: MatrixUtilityTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
/**
 * 测试伪逆矩阵
 */
@Test
public void testPseudoinverse() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        // 伪逆矩阵
        DenseMatrix matrix = DenseMatrix.valueOf(4, 3);
        matrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(probability.sample().floatValue());
        });
        System.out.println("print matrix Matrix:");
        System.out.println(matrix.toString());
        DenseMatrix singular = DenseMatrix.valueOf(matrix.getColumnSize(), matrix.getColumnSize());
        DenseMatrix pseudoinverse = DenseMatrix.valueOf(matrix.getColumnSize(), matrix.getRowSize());
        MatrixUtility.pseudoinverse(matrix, singular, pseudoinverse);
        System.out.println("print pseudoinverse Matrix:");
        System.out.println(pseudoinverse.toString());

        Assert.assertThat(pseudoinverse.getRowSize(), CoreMatchers.equalTo(matrix.getColumnSize()));
        Assert.assertThat(pseudoinverse.getColumnSize(), CoreMatchers.equalTo(matrix.getRowSize()));

        // 伪逆矩阵的目标是AXA=A,XAX=X.且X与A.transpose行列相同.
        // 由于精度问题,所以使用transformer确定是否为0.
        MathAccessor<MatrixScalar> accessor = (scalar) -> {
            if (scalar.getValue() >= MathUtility.EPSILON) {
                Assert.fail();
            }
        };
        DenseMatrix left;
        DenseMatrix right;

        left = DenseMatrix.valueOf(matrix.getRowSize(), pseudoinverse.getColumnSize());
        right = DenseMatrix.valueOf(matrix.getRowSize(), matrix.getColumnSize());
        left.dotProduct(matrix, false, pseudoinverse, false, MathCalculator.SERIAL);
        right.dotProduct(left, false, matrix, false, MathCalculator.SERIAL);
        System.out.println(right);
        right.subtractMatrix(matrix, false).iterateElement(MathCalculator.PARALLEL, accessor);

        left = DenseMatrix.valueOf(pseudoinverse.getRowSize(), matrix.getColumnSize());
        right = DenseMatrix.valueOf(pseudoinverse.getRowSize(), pseudoinverse.getColumnSize());
        left.dotProduct(pseudoinverse, false, matrix, false, MathCalculator.SERIAL);
        right.dotProduct(left, false, pseudoinverse, false, MathCalculator.SERIAL);
        System.out.println(right);
        right.subtractMatrix(pseudoinverse, false).iterateElement(MathCalculator.PARALLEL, accessor);
    });
    task.get();
}
 
Example #30
Source File: MatrixUtilityTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
/**
 * 测试矩阵转置乘法
 */
@Test
public void testTransposeProduct() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        DenseMatrix matrix = DenseMatrix.valueOf(4, 3);
        matrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
            scalar.setValue(probability.sample().floatValue());
        });

        System.out.println("print matrix Matrix:");
        System.out.println(matrix.toString());

        Table<Integer, Integer, Float> dataTable = HashBasedTable.create();
        dataTable.put(0, 0, RandomUtility.randomFloat(1F));
        dataTable.put(0, 1, RandomUtility.randomFloat(1F));
        dataTable.put(0, 3, RandomUtility.randomFloat(1F));
        dataTable.put(1, 0, RandomUtility.randomFloat(1F));
        dataTable.put(2, 0, RandomUtility.randomFloat(1F));
        dataTable.put(2, 1, RandomUtility.randomFloat(1F));
        dataTable.put(3, 1, RandomUtility.randomFloat(1F));
        dataTable.put(3, 3, RandomUtility.randomFloat(1F));
        // 稀疏矩阵
        SparseMatrix sparse = SparseMatrix.valueOf(4, 4, dataTable);
        // 稠密矩阵
        DenseMatrix dense = DenseMatrix.valueOf(4, 4);
        for (MatrixScalar term : sparse) {
            int row = term.getRow();
            int column = term.getColumn();
            float value = term.getValue();
            dense.setValue(row, column, value);
        }
        DenseMatrix transpose = DenseMatrix.valueOf(matrix.getColumnSize(), dense.getRowSize());
        DenseMatrix left = DenseMatrix.valueOf(matrix.getColumnSize(), dense.getColumnSize());
        DenseMatrix right = DenseMatrix.valueOf(matrix.getColumnSize(), dense.getColumnSize());
        Assert.assertThat(left.dotProduct(matrix, true, dense, false, MathCalculator.SERIAL), CoreMatchers.equalTo(right.dotProduct(transpose.copyMatrix(matrix, true), false, dense, false, MathCalculator.SERIAL)));
        Assert.assertThat(left.dotProduct(matrix, true, sparse, false, MathCalculator.SERIAL), CoreMatchers.equalTo(right.dotProduct(transpose.copyMatrix(matrix, true), false, dense, false, MathCalculator.SERIAL)));
        left.dotProduct(matrix, true, sparse, false, MathCalculator.SERIAL);
        Assert.assertThat(left, CoreMatchers.equalTo(right));
        right.dotProduct(matrix, true, dense, false, MathCalculator.SERIAL);
        Assert.assertThat(right, CoreMatchers.equalTo(left));
    });
    task.get();
}