package org.nd4j.linalg.util; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.shape.Tile; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** * @author Adam Gibson */ @Slf4j @RunWith(Parameterized.class) public class ShapeTestC extends BaseNd4jTest { public ShapeTestC(Nd4jBackend backend) { super(backend); } @Test public void testToOffsetZero() { INDArray matrix = Nd4j.rand(3, 5); INDArray rowOne = matrix.getRow(1); INDArray row1Copy = Shape.toOffsetZero(rowOne); assertEquals(rowOne, row1Copy); INDArray rows = matrix.getRows(1, 2); INDArray rowsOffsetZero = Shape.toOffsetZero(rows); assertEquals(rows, rowsOffsetZero); INDArray tensor = Nd4j.rand(new int[] {3, 3, 3}); INDArray getTensor = tensor.slice(1).slice(1); INDArray getTensorZero = Shape.toOffsetZero(getTensor); assertEquals(getTensor, getTensorZero); } @Test public void testTile() { INDArray arr = Nd4j.scalar(1.0); //INDArray[] inputs, INDArray[] outputs, int[] axis INDArray result = Nd4j.createUninitialized(2,2); Tile tile = new Tile(new INDArray[]{arr},new INDArray[]{result},new int[] {2,2}); Nd4j.getExecutioner().exec(tile); INDArray tiled = Nd4j.tile(arr,2,2); assertEquals(tiled,result); } @Test public void testElementWiseCompareOnesInMiddle() { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray onesInMiddle = Nd4j.linspace(1, 6, 6).reshape(2, 1, 3); for (int i = 0; i < arr.length(); i++) assertEquals(arr.getDouble(i), onesInMiddle.getDouble(i), 1e-3); } @Test public void testKeepDimsShape_1_T() throws Exception { val shape = new int[]{5, 5}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1}, result); } @Test public void testKeepDimsShape_1_F() throws Exception { val shape = new int[]{5, 5}; val axis = new int[]{0, 0, 1}; val result = Shape.getReducedShape(shape, axis, false, true); assertArrayEquals(new long[]{}, result); } @Test public void testKeepDimsShape_2_T() throws Exception { val shape = new int[]{5, 5, 5}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1, 5}, result); } @Test public void testKeepDimsShape_2_F() throws Exception { val shape = new int[]{5, 5, 5}; val axis = new int[]{0, 0, 1}; val result = Shape.getReducedShape(shape, axis, false, true); assertArrayEquals(new long[]{5}, result); } @Test public void testKeepDimsShape_3_T() throws Exception { val shape = new int[]{1, 1}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1}, result); } @Test public void testKeepDimsShape_3_F() throws Exception { val shape = new int[]{1, 1}; val axis = new int[]{0, 0}; val result = Shape.getReducedShape(shape, axis, false, true); log.info("Result: {}", result); assertArrayEquals(new long[]{1}, result); } @Test public void testKeepDimsShape_4_F() throws Exception { val shape = new int[]{4, 4}; val axis = new int[]{0, 0}; val result = Shape.getReducedShape(shape, axis, false, true); log.info("Result: {}", result); assertArrayEquals(new long[]{4}, result); } @Test public void testAxisNormalization_1() throws Exception { val axis = new int[] {1, -2}; val rank = 2; val exp = new int[] {0, 1}; val norm = Shape.normalizeAxis(rank, axis); assertArrayEquals(exp, norm); } @Test public void testAxisNormalization_2() throws Exception { val axis = new int[] {1, -2, 0}; val rank = 2; val exp = new int[] {0, 1}; val norm = Shape.normalizeAxis(rank, axis); assertArrayEquals(exp, norm); } @Test(expected = ND4JIllegalStateException.class) public void testAxisNormalization_3() throws Exception { val axis = new int[] {1, -2, 2}; val rank = 2; val exp = new int[] {0, 1}; val norm = Shape.normalizeAxis(rank, axis); assertArrayEquals(exp, norm); } @Test public void testAxisNormalization_4() throws Exception { val axis = new int[] {1, 2, 0}; val rank = 3; val exp = new int[] {0, 1, 2}; val norm = Shape.normalizeAxis(rank, axis); assertArrayEquals(exp, norm); } @Override public char ordering() { return 'c'; } }