Java Code Examples for org.nd4j.linalg.checkutil.NDArrayCreationUtil#getAllTestMatricesWithShape()

The following examples show how to use org.nd4j.linalg.checkutil.NDArrayCreationUtil#getAllTestMatricesWithShape() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ElementWiseStrideTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testEWS1() throws Exception {
    List<Pair<INDArray,String>> list = NDArrayCreationUtil.getAllTestMatricesWithShape(4,5,12345);
    list.addAll(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345,4,5,6));
    list.addAll(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345,4,5,6,7));
    list.addAll(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345,4,5,6,7,8));
    list.addAll(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345,4,5,6,7,8,9));


    for(Pair<INDArray,String> p : list){
        int ewsBefore = Shape.elementWiseStride(p.getFirst().shapeInfo());
        INDArray reshapeAttempt = Shape.newShapeNoCopy(p.getFirst(),new int[]{1,p.getFirst().length()}, Nd4j.order() == 'f');

        if (reshapeAttempt != null && ewsBefore == -1 && reshapeAttempt.elementWiseStride() != -1 ) {
            System.out.println("NDArrayCreationUtil." + p.getSecond());
            System.out.println("ews before: " + ewsBefore);
            System.out.println(p.getFirst().shapeInfoToString());
            System.out.println("ews returned by elementWiseStride(): " + p.getFirst().elementWiseStride());
            System.out.println("ews returned by reshape(): " + reshapeAttempt.elementWiseStride());
            System.out.println();
      //      assertTrue(false);
        } else {
      //      System.out.println("FAILED: " + p.getFirst().shapeInfoToString());
        }
    }
}
 
Example 2
Source File: SameDiffTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testExpandSqueezeChain() {

    val origShape = new long[]{3, 4};

    for (int i = 0; i < 3; i++) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.expandDims(in, i);
            SDVariable squeeze = sd.squeeze(expand, i);

            INDArray out = sd.execAndEndResult();

            String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond();

            assertEquals(msg, out, inArr);  //expand -> squeeze: should be opposite ops
        }
    }
}
 
Example 3
Source File: TestInvertMatrices.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testInverseComparison() {

    List<Pair<INDArray, String>> list = NDArrayCreationUtil.getAllTestMatricesWithShape(10, 10, 12345);

    for (Pair<INDArray, String> p : list) {
        INDArray orig = p.getFirst();
        orig.assign(Nd4j.rand(orig.shape()));
        INDArray inverse = InvertMatrix.invert(orig, false);
        RealMatrix rm = CheckUtil.convertToApacheMatrix(orig);
        RealMatrix rmInverse = new LUDecomposition(rm).getSolver().getInverse();

        INDArray expected = CheckUtil.convertFromApacheMatrix(rmInverse);
        assertTrue(p.getSecond(), CheckUtil.checkEntries(expected, inverse, 1e-3, 1e-4));
    }
}
 
Example 4
Source File: TestInvertMatrices.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testInverseComparison() {

    List<Pair<INDArray, String>> list = NDArrayCreationUtil.getAllTestMatricesWithShape(10, 10, 12345, DataType.DOUBLE);

    for (Pair<INDArray, String> p : list) {
        INDArray orig = p.getFirst();
        orig.assign(Nd4j.rand(orig.shape()));
        INDArray inverse = InvertMatrix.invert(orig, false);
        RealMatrix rm = CheckUtil.convertToApacheMatrix(orig);
        RealMatrix rmInverse = new LUDecomposition(rm).getSolver().getInverse();

        INDArray expected = CheckUtil.convertFromApacheMatrix(rmInverse, orig.dataType());
        assertTrue(p.getSecond(), CheckUtil.checkEntries(expected, inverse, 1e-3, 1e-4));
    }
}
 
Example 5
Source File: NDArrayTestsFortran.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testDupAndDupWithOrder() {
        List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
        int count = 0;
        for (Pair<INDArray, String> pair : testInputs) {
            String msg = pair.getSecond();
            INDArray in = pair.getFirst();
//            System.out.println("Count " + count);
            INDArray dup = in.dup();
            INDArray dupc = in.dup('c');
            INDArray dupf = in.dup('f');

            assertEquals(msg, in, dup);
            assertEquals(msg, dup.ordering(), (char) Nd4j.order());
            assertEquals(msg, dupc.ordering(), 'c');
            assertEquals(msg, dupf.ordering(), 'f');
            assertEquals(msg, in, dupc);
            assertEquals(msg, in, dupf);
            count++;
        }
    }
 
Example 6
Source File: Nd4jTestsComparisonFortran.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAddSubtractWithOpsCommonsMath() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg1 = getTestWithOpsErrorMsg(i, j, "add", p1, p2);
            String errorMsg2 = getTestWithOpsErrorMsg(i, j, "sub", p1, p2);
            boolean addFail = CheckUtil.checkAdd(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
            assertTrue(errorMsg1, addFail);
            boolean subFail = CheckUtil.checkSubtract(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
            assertTrue(errorMsg2, subFail);
        }
    }
}
 
Example 7
Source File: Nd4jTestsComparisonFortran.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAddSubtractWithOpsCommonsMath() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg1 = getTestWithOpsErrorMsg(i, j, "add", p1, p2);
            String errorMsg2 = getTestWithOpsErrorMsg(i, j, "sub", p1, p2);
            boolean addFail = CheckUtil.checkAdd(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
            assertTrue(errorMsg1, addFail);
            boolean subFail = CheckUtil.checkSubtract(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
            assertTrue(errorMsg2, subFail);
        }
    }
}
 
Example 8
Source File: ReductionOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testStdev() {
    List<String> errors = new ArrayList<>();

    for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) {
        for (boolean biasCorrected : new boolean[]{false, true}) {
            SameDiff sd = SameDiff.create();
            SDVariable var = sd.var("in", p.getFirst());
            SDVariable stdev = var.std(biasCorrected);

            INDArray expOut = p.getFirst().std(biasCorrected);

            TestCase tc = new TestCase(sd)
                    .testName(p.getSecond() + " - biasCorrected=" + biasCorrected)
                    .expected(stdev, expOut)
                    .gradientCheck(false);

            String err = OpValidation.validate(tc);
            if (err != null) {
                errors.add(err);
            }
        }
    }
    assertEquals(errors.toString(), 0, errors.size());
}
 
Example 9
Source File: NDArrayTestsFortran.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testDupAndDupWithOrder() {
    List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123);
    int count = 0;
    for (Pair<INDArray, String> pair : testInputs) {
        String msg = pair.getSecond();
        INDArray in = pair.getFirst();
        System.out.println("Count " + count);
        INDArray dup = in.dup();
        INDArray dupc = in.dup('c');
        INDArray dupf = in.dup('f');

        assertEquals(msg, in, dup);
        assertEquals(msg, dup.ordering(), (char) Nd4j.order());
        assertEquals(msg, dupc.ordering(), 'c');
        assertEquals(msg, dupf.ordering(), 'f');
        assertEquals(msg, in, dupc);
        assertEquals(msg, in, dupf);
        count++;
    }
}
 
Example 10
Source File: Nd4jTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testSqueeze(){
    final List<Pair<INDArray, String>> testMatricesC = NDArrayCreationUtil.getAllTestMatricesWithShape('c', 3, 1, 0xDEAD, DataType.DOUBLE);
    final List<Pair<INDArray, String>> testMatricesF = NDArrayCreationUtil.getAllTestMatricesWithShape('f', 7, 1, 0xBEEF, DataType.DOUBLE);

    final ArrayList<Pair<INDArray, String>> testMatrices = new ArrayList<>(testMatricesC);
    testMatrices.addAll(testMatricesF);

    for (Pair<INDArray, String> testMatrixPair : testMatrices) {
        final String recreation = testMatrixPair.getSecond();
        final INDArray testMatrix = testMatrixPair.getFirst();
        final char ordering = testMatrix.ordering();
        val shape = testMatrix.shape();
        final INDArray squeezed = Nd4j.squeeze(testMatrix, 1);
        final long[] expShape = ArrayUtil.removeIndex(shape, 1);
        final String message = "Squeezing in dimension 1; Shape before squeezing: " + Arrays.toString(shape) + " " + ordering + " Order; Shape after expanding: " + Arrays.toString(squeezed.shape()) +  " "+squeezed.ordering()+"; Input Created via: " + recreation;

        assertArrayEquals(message, expShape, squeezed.shape());
        assertEquals(message, ordering, squeezed.ordering());
        assertEquals(message, testMatrix.ravel(), squeezed.ravel());

        testMatrix.assign(Nd4j.rand(shape));
        assertEquals(message, testMatrix.ravel(), squeezed.ravel());

    }
}
 
Example 11
Source File: SameDiffTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testExpandDims2d() {
    val origShape = new long[]{3, 4};

    for (int i = 0; i < 3; i++) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.f().expandDims(in, i);

            INDArray out = sd.execAndEndResult();

            INDArray expOut;
            switch (i) {
                case 0:
                    expOut = inArr.dup('c').reshape('c', 1, origShape[0], origShape[1]);
                    break;
                case 1:
                    expOut = inArr.dup('c').reshape('c', origShape[0], 1, origShape[1]);
                    break;
                case 2:
                    expOut = inArr.dup('c').reshape('c', origShape[0], origShape[1], 1);
                    break;
                default:
                    throw new RuntimeException();
            }

            String msg = "expandDim=" + i + ", source=" + p.getSecond();

            assertEquals(msg, out, expOut);
        }
    }
}
 
Example 12
Source File: NDArrayTestsFortran.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testToOffsetZeroCopy() {
    List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);

    int cnt = 0;
    for (Pair<INDArray, String> pair : testInputs) {
        String msg = pair.getSecond();
        INDArray in = pair.getFirst();
        INDArray dup = Shape.toOffsetZeroCopy(in);
        INDArray dupc = Shape.toOffsetZeroCopy(in, 'c');
        INDArray dupf = Shape.toOffsetZeroCopy(in, 'f');
        INDArray dupany = Shape.toOffsetZeroCopyAnyOrder(in);

        assertEquals(msg + ": " + cnt, in, dup);
        assertEquals(msg, in, dupc);
        assertEquals(msg, in, dupf);
        assertEquals(msg, dupc.ordering(), 'c');
        assertEquals(msg, dupf.ordering(), 'f');
        assertEquals(msg, in, dupany);

        assertEquals(dup.offset(), 0);
        assertEquals(dupc.offset(), 0);
        assertEquals(dupf.offset(), 0);
        assertEquals(dupany.offset(), 0);
        assertEquals(dup.length(), dup.data().length());
        assertEquals(dupc.length(), dupc.data().length());
        assertEquals(dupf.length(), dupf.data().length());
        assertEquals(dupany.length(), dupany.data().length());
        cnt++;
    }
}
 
Example 13
Source File: TestNdArrReadWriteTxt.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testNd4jReadWriteText() throws Exception {

        File dir = testDir.newFolder();
        int count = 0;
        for(val testShape : new long[][]{{1,1}, {3,1}, {4,5}, {1,2,3}, {2,1,3}, {2,3,1}, {2,3,4}, {1,2,3,4}, {2,3,4,2}}){
            List<Pair<INDArray, String>> l = null;
            switch (testShape.length){
                case 2:
                    l = NDArrayCreationUtil.getAllTestMatricesWithShape(testShape[0], testShape[1], 12345, Nd4j.defaultFloatingPointType());
                    break;
                case 3:
                    l = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, testShape, Nd4j.defaultFloatingPointType());
                    break;
                case 4:
                    l = NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, testShape, Nd4j.defaultFloatingPointType());
                    break;
                default:
                    throw new RuntimeException();
            }


            for (Pair<INDArray, String> p : l) {
                File f = new File(dir, (count++) + ".txt");
                Nd4j.writeTxt(p.getFirst(), f.getAbsolutePath());

                INDArray read = Nd4j.readTxt(f.getAbsolutePath());
                String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
//                System.out.println(s);

                assertEquals(p.getFirst(), read);
            }
        }
    }
 
Example 14
Source File: Nd4jTestsComparisonFortran.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMulDivOnCheckUtilMatrices() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg1 = getTestWithOpsErrorMsg(i, j, "mul", p1, p2);
            String errorMsg2 = getTestWithOpsErrorMsg(i, j, "div", p1, p2);
            assertTrue(errorMsg1, CheckUtil.checkMulManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
            assertTrue(errorMsg2, CheckUtil.checkDivManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
        }
    }
}
 
Example 15
Source File: SameDiffTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testExpandDims2d() {
    val origShape = new long[]{3, 4};

    for (int i = 0; i < 3; i++) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil
                .getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.expandDims(in, i);

            INDArray out = expand.eval();

            INDArray expOut;
            switch (i) {
                case 0:
                    expOut = inArr.dup('c').reshape('c', 1, origShape[0], origShape[1]);
                    break;
                case 1:
                    expOut = inArr.dup('c').reshape('c', origShape[0], 1, origShape[1]);
                    break;
                case 2:
                    expOut = inArr.dup('c').reshape('c', origShape[0], origShape[1], 1);
                    break;
                default:
                    throw new RuntimeException();
            }

            String msg = "expandDim=" + i + ", source=" + p.getSecond();

            assertEquals(msg, out, expOut);
        }
    }
}
 
Example 16
Source File: Nd4jTestsComparisonFortran.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMulDivOnCheckUtilMatrices() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg1 = getTestWithOpsErrorMsg(i, j, "mul", p1, p2);
            String errorMsg2 = getTestWithOpsErrorMsg(i, j, "div", p1, p2);
            assertTrue(errorMsg1, CheckUtil.checkMulManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
            assertTrue(errorMsg2, CheckUtil.checkDivManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
        }
    }
}
 
Example 17
Source File: Nd4jTestsComparisonFortran.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMmulWithOpsCommonsMath() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED);

    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2);
            assertTrue(errorMsg, CheckUtil.checkMmul(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
        }
    }
}
 
Example 18
Source File: Nd4jTestsComparisonFortran.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMmulWithOpsCommonsMath() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);

    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2);
            assertTrue(errorMsg, CheckUtil.checkMmul(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
        }
    }
}
 
Example 19
Source File: Nd4jTestsComparisonFortran.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testGemmWithOpsCommonsMath() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
    List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
    List<Pair<INDArray, String>> secondT = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, SEED, DataType.DOUBLE);
    double[] alpha = {1.0, -0.5, 2.5};
    double[] beta = {0.0, -0.25, 1.5};
    INDArray cOrig = Nd4j.create(new int[] {3, 4});
    Random r = new Random(12345);
    for (int i = 0; i < cOrig.size(0); i++) {
        for (int j = 0; j < cOrig.size(1); j++) {
            cOrig.putScalar(new int[] {i, j}, r.nextDouble());
        }
    }

    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            for (int k = 0; k < alpha.length; k++) {
                for (int m = 0; m < beta.length; m++) {
                    //System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m)));

                    INDArray cff = Nd4j.create(cOrig.shape(), 'f');
                    cff.assign(cOrig);
                    INDArray cft = Nd4j.create(cOrig.shape(), 'f');
                    cft.assign(cOrig);
                    INDArray ctf = Nd4j.create(cOrig.shape(), 'f');
                    ctf.assign(cOrig);
                    INDArray ctt = Nd4j.create(cOrig.shape(), 'f');
                    ctt.assign(cOrig);

                    double a = alpha[k];
                    double b = beta[k];
                    Pair<INDArray, String> p1 = first.get(i);
                    Pair<INDArray, String> p1T = firstT.get(i);
                    Pair<INDArray, String> p2 = second.get(j);
                    Pair<INDArray, String> p2T = secondT.get(j);
                    String errorMsgff = getGemmErrorMsg(i, j, false, false, a, b, p1, p2);
                    String errorMsgft = getGemmErrorMsg(i, j, false, true, a, b, p1, p2T);
                    String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2);
                    String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T);

                    assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a,
                                    b, 1e-4, 1e-6));
                    assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a,
                                    b, 1e-4, 1e-6));
                    assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a,
                                    b, 1e-4, 1e-6));
                    assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a,
                                    b, 1e-4, 1e-6));
                }
            }
        }
    }
}
 
Example 20
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testExpandDimsGradient() {
    val origShape = new long[]{3, 4};

    boolean first = true;
    for (int i = 0; i < 3; i++) {

        long[] expExpandShape;
        switch (i) {
            case 0:
                expExpandShape = new long[]{1, 3, 4};
                break;
            case 1:
                expExpandShape = new long[]{3, 1, 4};
                break;
            case 2:
                expExpandShape = new long[]{3, 4, 1};
                break;
            default:
                throw new RuntimeException();
        }

        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.f().expandDims(in, i);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", expand, true);

            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);

            assertArrayEquals(expExpandShape, expand.getArr().shape());
            INDArray expExpand = inArr.dup('c').reshape(expExpandShape);
            assertEquals(expExpand, expand.getArr());

            String msg = "expandDim=" + i + ", source=" + p.getSecond();
            log.info("Starting: " + msg);
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}