Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#stridedSlice()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#stridedSlice() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testStridedSlice2dBasic() {
    INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);

    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", inArr);
    SDVariable slice_full = sd.stridedSlice(in,new long[]{0, 0},new long[]{3, 4},new long[]{1, 1});
    SDVariable subPart = sd.stridedSlice(in,new long[]{1, 2},new long[]{3, 4},new long[]{1, 1});
    // SDVariable subPart2 = sd.stridedSlice(in,new long[]{0, 0},new long[]{4, 5},new long[]{2, 2});

    sd.outputAll(null);

    assertEquals(inArr, slice_full.getArr());
    assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr());
    // assertEquals(inArr.get(interval(0, 2, 4), interval(0, 2, 5)), subPart2.getArr());
}
 
Example 2
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testStridedSliceEllipsisMask() {
    INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", inArr);

    //[1:3,...] -> [1:3,:,:]
    SDVariable slice = sd.stridedSlice(in,new long[]{1},new long[]{3},new long[]{1}, 0, 0, 1 << 1, 0, 0);
    //[1:3,...,1:4] -> [1:3,:,1:4]
    SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 1},new long[]{3, 4},new long[]{1, 1}, 0, 0, 1 << 1, 0, 0);

    sd.outputAll(Collections.emptyMap());

    assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr());
    assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr());
}
 
Example 3
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testStridedSliceShrinkAxisMask() {

    INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", inArr);
    SDVariable slice = sd.stridedSlice(in,new long[]{0, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1);
    SDVariable slice2 = sd.stridedSlice(in,new long[]{2, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1);
    SDVariable slice3 = sd.stridedSlice(in,new long[]{1, 2, 1},new long[]{-999, -999, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1);

    sd.outputAll(null);

    assertEquals(inArr.get(point(0), all(), all()), slice.getArr());
    assertEquals(inArr.get(point(2), all(), all()), slice2.getArr());
    assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr());
}
 
Example 4
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStridedSliceGradient() {
    Nd4j.getRandom().setSeed(12345);

    //Order here: original shape, begin, size
    List<SSCase> testCases = new ArrayList<>();
    testCases.add(SSCase.builder().shape(3, 4).begin(0, 0).end(3, 4).strides(1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(2, 3).strides(1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(3, 4).strides(1, 1).beginMask(1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(3, -999).strides(1, 1).endMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(-999, 4).strides(1, 1).beginMask(1).endMask(1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0, 0).end(-999, 3, 4).strides(1, 1).newAxisMask(1).build());

    testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 4, 5).strides(1, 1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2, 3).end(3, 4, 5).strides(1, 1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 3, 5).strides(1, 2, 2).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, 4).strides(1, 1, 1).beginMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, -999).strides(1, 1, 1).beginMask(1 << 1).endMask(1 << 2).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2).end(3, 4).strides(1, 1).ellipsisMask(1 << 1).build());   //[1:3,...,2:4]
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1, 2).end(3, -999, 3, 4).strides(1, -999, 1, 2).newAxisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 0, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 1, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build());


    for (int i = 0; i < testCases.size(); i++) {
        SSCase t = testCases.get(i);
        INDArray arr = Nd4j.rand(t.getShape());

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.var("in", arr);
        SDVariable slice = sd.stridedSlice(in, t.getBegin(), t.getEnd(), t.getStrides(), t.getBeginMask(),
                t.getEndMask(), t.getEllipsisMask(), t.getNewAxisMask(), t.getShrinkAxisMask());
        SDVariable stdev = sd.standardDeviation(slice, true);

        String msg = "i=" + i + ": " + t;
        log.info("Starting test: " + msg);
        GradCheckUtil.checkGradients(sd);
    }
}
 
Example 5
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStridedSliceBeginEndMask() {
    INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);

    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", inArr);
    SDVariable slice1 = sd.stridedSlice(in,new long[]{-999, 0},new long[]{2, 4},new long[]{1, 1}, 1 << 1, 0, 0, 0, 0);
    SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 0},new long[]{-999, 4},new long[]{1, 1}, 0, 1, 0, 0, 0);

    sd.outputAll(null);

    assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr());
    assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr());
}
 
Example 6
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStridedSliceNewAxisMask() {
    INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", inArr);
    SDVariable slice = sd.stridedSlice(in,new long[]{-999, 0, 0, 0},new long[]{-999, 3, 4, 5},new long[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0);

    INDArray out = slice.eval();

    assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape());
    assertEquals(inArr, out.get(point(0), all(), all(), all()));
}
 
Example 7
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStridedSliceNewAxisMask2() {
    INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", inArr);
    SDVariable slice = sd.stridedSlice(in,new long[]{1, 1, -999, 1},new long[]{3, 3, -999, 4},new long[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0);
    INDArray out = slice.eval();

    assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
}
 
Example 8
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testStridedSliceGradient() {
    Nd4j.getRandom().setSeed(12345);

    //Order here: original shape, begin, size
    List<SSCase> testCases = new ArrayList<>();
    testCases.add(SSCase.builder().shape(3, 4).begin(0, 0).end(3, 4).strides(1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(2, 3).strides(1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(3, 4).strides(1, 1).beginMask(1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(3, -999).strides(1, 1).endMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(-999, 4).strides(1, 1).beginMask(1).endMask(1).build());

    testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 4, 5).strides(1, 1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2, 3).end(3, 4, 5).strides(1, 1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 3, 5).strides(1, 2, 2).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, 4).strides(1, 1, 1).beginMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, -999).strides(1, 1, 1).beginMask(1 << 1).endMask(1 << 2).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2).end(3, 4).strides(1, 1).ellipsisMask(1 << 1).build());   //[1:3,...,2:4]
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1, 2).end(3, -999, 3, 4).strides(1, -999, 1, 2).newAxisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 0, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 1, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build());

    Map<Integer,INDArrayIndex[]> indices = new HashMap<>();
    indices.put(0, new INDArrayIndex[]{all(), all()});
    indices.put(1, new INDArrayIndex[]{interval(1,2), interval(1,3)});
    indices.put(2, new INDArrayIndex[]{interval(0,3), interval(0,4)});
    indices.put(3, new INDArrayIndex[]{interval(1,3), interval(1,4)});

    indices.put(5, new INDArrayIndex[]{all(), all(), all()});
    indices.put(7, new INDArrayIndex[]{interval(0,1,3), interval(0,2,3), interval(0,2,5)});


    List<String> failed = new ArrayList<>();

    for (int i = 0; i < testCases.size(); i++) {
        SSCase t = testCases.get(i);
        INDArray arr = Nd4j.rand(t.getShape());

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.var("in", arr);
        SDVariable slice = sd.stridedSlice(in, t.getBegin(), t.getEnd(), t.getStrides(), t.getBeginMask(),
                t.getEndMask(), t.getEllipsisMask(), t.getNewAxisMask(), t.getShrinkAxisMask());
        SDVariable stdev = sd.standardDeviation(slice, true);

        String msg = "i=" + i + ": " + t;
        log.info("Starting test: " + msg);

        TestCase tc = new TestCase(sd);
        tc.testName(msg);

        if(indices.containsKey(i)){
            tc.expected(slice, arr.get(indices.get(i)).dup());
        }

        String error = OpValidation.validate(tc, true);
        if(error != null){
            failed.add(error);
        }
    }
    assertEquals(failed.toString(), 0, failed.size());
}