Java Code Examples for org.nd4j.linalg.api.shape.Shape#getReducedShape()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#getReducedShape() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Variance.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
    INDArray x = oc != null ? oc.getInputArray(0) : x();

    if(oc == null && args().length < 1) {
        throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
    }

    long[] argShape = arg().getShape();
    if (argShape == null && x == null) {
        return Collections.emptyList();
    }
    long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x.shape() : argShape);

    val ret = new ArrayList<LongShapeDescriptor>(1);
    val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims());
    ret.add(LongShapeDescriptor.fromShape(reducedShape, resultType()));
    return ret;
}
 
Example 2
Source File: StandardDeviation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
    if(args().length < 1) {
        throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
    }

    long[] argShape = arg().getShape();
    if (argShape == null && x() == null) {
        return Collections.emptyList();
    }
    long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x().shape() : argShape);

    val ret = new ArrayList<LongShapeDescriptor>(1);
    val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims());
    ret.add(LongShapeDescriptor.fromShape(reducedShape, resultType()));
    return ret;
}
 
Example 3
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_2_T() {
    val shape = new int[]{5, 5, 5};
    val axis = new int[]{1, 0, 1};

    val result = Shape.getReducedShape(shape, axis, true, true);

    assertArrayEquals(new long[]{1, 1, 5}, result);
}
 
Example 4
Source File: BaseReduceBoolOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
    INDArray x = oc != null ? oc.getInputArray(0) : x();
    if(x == null)
        return Collections.emptyList();

    //Calculate reduction shape. Note that reduction on scalar - returns a scalar
    long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
    return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.BOOL));
}
 
Example 5
Source File: BaseReduceFloatOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
    INDArray x = oc != null ? oc.getInputArray(0) : x();

    if(x == null)
        return Collections.emptyList();

    //Calculate reduction shape. Note that reduction on scalar - returns a scalar
    long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
    DataType retType = arg().dataType();
    if(!retType.isFPType())
        retType = Nd4j.defaultFloatingPointType();
    return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, retType));
}
 
Example 6
Source File: BaseReduceSameOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
    INDArray x = oc != null ? oc.getInputArray(0) : x();

    if(x == null)
        return Collections.emptyList();

    //Calculate reduction shape. Note that reduction on scalar - returns a scalar
    long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
    DataType rt = oc != null ? resultType(oc) : resultType();
    return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, rt));
}
 
Example 7
Source File: BaseIndexAccumulation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
    INDArray x = oc != null ? oc.getInputArray(0) : x();
    if(x == null)
        return Collections.emptyList();

    long[] reducedShape = Shape.getReducedShape(x.shape(), dimensions, keepDims);
    return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.LONG));
}
 
Example 8
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_4_F() {
    val shape = new int[]{4, 4};
    val axis = new int[]{0, 0};

    val result = Shape.getReducedShape(shape, axis, false, true);

    log.info("Result: {}", result);

    assertArrayEquals(new long[]{4}, result);
}
 
Example 9
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_3_F() {
    val shape = new int[]{1, 1};
    val axis = new int[]{0, 0};

    val result = Shape.getReducedShape(shape, axis, false, true);

    log.info("Result: {}", result);

    assertArrayEquals(new long[]{1}, result);
}
 
Example 10
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_3_T() {
    val shape = new int[]{1, 1};
    val axis = new int[]{1, 0, 1};

    val result = Shape.getReducedShape(shape, axis, true, true);

    assertArrayEquals(new long[]{1, 1}, result);
}
 
Example 11
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_2_F() {
    val shape = new int[]{5, 5, 5};
    val axis = new int[]{0, 0, 1};

    val result = Shape.getReducedShape(shape, axis, false, true);

    assertArrayEquals(new long[]{5}, result);
}
 
Example 12
Source File: ShapeTestC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_1_T() throws Exception {
    val shape = new int[]{5, 5};
    val axis = new int[]{1, 0, 1};

    val result = Shape.getReducedShape(shape, axis, true, true);

    assertArrayEquals(new long[]{1, 1}, result);
}
 
Example 13
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_1_F() {
    val shape = new int[]{5, 5};
    val axis = new int[]{0, 0, 1};

    val result = Shape.getReducedShape(shape, axis, false, true);

    assertArrayEquals(new long[]{}, result);
}
 
Example 14
Source File: ShapeTestC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_1_T() {
    val shape = new int[]{5, 5};
    val axis = new int[]{1, 0, 1};

    val result = Shape.getReducedShape(shape, axis, true, true);

    assertArrayEquals(new long[]{1, 1}, result);
}
 
Example 15
Source File: BaseAccumulation.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<long[]> calculateOutputShape() {
    if(args().length < 1) {
        throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
    }

    if(arg().getShape() == null)
        return Collections.emptyList();

    List<long[]> ret = new ArrayList<>(1);
    val reducedShape = Shape.getReducedShape(arg().getShape(),dimensions, isKeepDims(), newFormat);
    ret.add(reducedShape);
    return ret;
}
 
Example 16
Source File: ShapeTestC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_3_F() throws Exception {
    val shape = new int[]{1, 1};
    val axis = new int[]{0, 0};

    val result = Shape.getReducedShape(shape, axis, false, true);

    log.info("Result: {}", result);

    assertArrayEquals(new long[]{1}, result);
}
 
Example 17
Source File: ShapeTestC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_3_T() throws Exception {
    val shape = new int[]{1, 1};
    val axis = new int[]{1, 0, 1};

    val result = Shape.getReducedShape(shape, axis, true, true);

    assertArrayEquals(new long[]{1, 1}, result);
}
 
Example 18
Source File: ShapeTestC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_2_F() throws Exception {
    val shape = new int[]{5, 5, 5};
    val axis = new int[]{0, 0, 1};

    val result = Shape.getReducedShape(shape, axis, false, true);

    assertArrayEquals(new long[]{5}, result);
}
 
Example 19
Source File: ShapeTestC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_2_T() throws Exception {
    val shape = new int[]{5, 5, 5};
    val axis = new int[]{1, 0, 1};

    val result = Shape.getReducedShape(shape, axis, true, true);

    assertArrayEquals(new long[]{1, 1, 5}, result);
}
 
Example 20
Source File: ShapeTestC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testKeepDimsShape_1_F() throws Exception {
    val shape = new int[]{5, 5};
    val axis = new int[]{0, 0, 1};

    val result = Shape.getReducedShape(shape, axis, false, true);

    assertArrayEquals(new long[]{}, result);
}