org.nd4j.linalg.indexing.SpecifiedIndex Java Examples

The following examples show how to use org.nd4j.linalg.indexing.SpecifiedIndex. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SlicingTestsC.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testGetRow() {
    INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
    INDArray get = arr.getRow(1);
    INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
    INDArray assertion = Nd4j.create(new double[] {4, 5, 6});
    assertEquals(assertion, get);
    assertEquals(get, get2);
    get2.assign(Nd4j.linspace(1, 3, 3));
    assertEquals(Nd4j.linspace(1, 3, 3), get2);

    INDArray threeByThree = Nd4j.linspace(1, 9, 9).reshape(3, 3);
    INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all());
    INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}});

    assertEquals(threeByThreeAssertion, offsetTest);
}
 
Example #2
Source File: SlicingTestsC.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testGetRow() {
    INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
    INDArray get = arr.getRow(1);
    INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
    INDArray assertion = Nd4j.create(new double[] {4, 5, 6});
    assertEquals(assertion, get);
    assertEquals(get, get2);
    get2.assign(Nd4j.linspace(1, 3, 3, DataType.DOUBLE));
    assertEquals(Nd4j.linspace(1, 3, 3, DataType.DOUBLE), get2);

    INDArray threeByThree = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
    INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all());
    INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}});

    assertEquals(threeByThreeAssertion, offsetTest);
}
 
Example #3
Source File: IndexingTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testIntervalLowerBound() {
    INDArray wholeArr = Nd4j.linspace(1, 24, 24).reshape(4, 2, 3);
    INDArray subarray = wholeArr.get(interval(1, 3), new SpecifiedIndex(new int[] {0}),
                    new SpecifiedIndex(new int[] {0, 2}));
    INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});

    assertEquals(assertion, subarray);

}
 
Example #4
Source File: IndexingTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSpecifiedIndexVector() {
    INDArray rootMatrix = Nd4j.linspace(1, 16, 16).reshape(4, 4);
    INDArray threeD = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
    INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
    INDArray assertion = Nd4j.create(new double[][] {{1, 3}, {5, 7}, {9, 11}, {13, 15}});

    assertEquals(assertion, get);

    INDArray assertion2 = Nd4j.create(new double[][] {{1, 3, 4}, {5, 7, 8}, {9, 11, 12}, {13, 15, 16}});
    INDArray get2 = rootMatrix.get(all(), new SpecifiedIndex(0, 2, 3));

    assertEquals(assertion2, get2);

}
 
Example #5
Source File: IndexingTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGetRows() {
    INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3);
    INDArray testAssertion = Nd4j.create(new double[][] {{4, 5}, {7, 8}});

    INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(0, 1));
    assertEquals(testAssertion, test);

}
 
Example #6
Source File: IndexingTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMultiRow() {
    INDArray matrix = Nd4j.linspace(1, 9, 9).reshape(3, 3);
    INDArray assertion = Nd4j.create(new double[][] {{4, 7}});

    INDArray test = matrix.get(new SpecifiedIndex(1, 2), NDArrayIndex.interval(0, 1));
    assertEquals(assertion, test);
}
 
Example #7
Source File: IndexingTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGetRows() {
    INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3);
    INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}});

    INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(1, 2));
    assertEquals(testAssertion, test);

}
 
Example #8
Source File: BaseSparseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray put(INDArray indices, INDArray element) {
    INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()];
    for(int i = 0; i < realIndices.length; i++) {
        realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt());
    }


    return put(realIndices,element);

}
 
Example #9
Source File: IndexingTestsC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGetRows() {
    INDArray arr = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
    INDArray testAssertion = Nd4j.create(new double[][] {{4, 5}, {7, 8}});

    INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(0, 1));
    assertEquals(testAssertion, test);

}
 
Example #10
Source File: IndexingTestsC.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMultiRow() {
    INDArray matrix = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
    INDArray assertion = Nd4j.create(new double[][] {{4, 7}});

    INDArray test = matrix.get(new SpecifiedIndex(1, 2), NDArrayIndex.interval(0, 1));
    assertEquals(assertion, test);
}
 
Example #11
Source File: IndexingTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGetRows() {
    INDArray arr = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
    INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}});

    INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(1, 2));
    assertEquals(testAssertion, test);

}
 
Example #12
Source File: MtcnnService.java    From mtcnn-java with Apache License 2.0 4 votes vote down vote up
/**
 *  STAGE 2
 *
 * @param image
 * @param totalBoxes
 * @param padResult
 * @return
 * @throws IOException
 */
private INDArray refinementStage(INDArray image, INDArray totalBoxes, MtcnnUtil.PadResult padResult) throws IOException {

	// num_boxes = total_boxes.shape[0]
	int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0];
	// if num_boxes == 0:
	//   return total_boxes, stage_status
	if (numBoxes == 0) {
		return totalBoxes;
	}

	INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 24);

	//this.refineNetGraph.associateArrayWithVariable(tempImg1, this.refineNetGraph.variableMap().get("rnet/input"));
	//List<DifferentialFunction> refineNetResults = this.refineNetGraph.exec().getRight();
	//INDArray out0 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/fc2-2/fc2-2"))
	//		.findFirst().get().outputVariable().getArr();
	//INDArray out1 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/prob1"))
	//		.findFirst().get().outputVariable().getArr();

	Map<String, INDArray> resultMap = this.refineNetGraphRunner.run(Collections.singletonMap("rnet/input", tempImg1));
	//INDArray out0 = resultMap.get("rnet/fc2-2/fc2-2");  // for ipazc/mtcnn model
	INDArray out0 = resultMap.get("rnet/conv5-2/conv5-2");
	INDArray out1 = resultMap.get("rnet/prob1");

	//  score = out1[1, :]
	INDArray score = out1.get(all(), point(1)).transposei();

	// ipass = np.where(score > self.__steps_threshold[1])
	INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[1]);
	//INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[1]));

	if (ipass.isEmpty()) {
		totalBoxes = Nd4j.empty();
		return totalBoxes;
	}
	// total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)])
	INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4));
	INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1)
			: Nd4j.expandDims(score.get(ipass), 1);
	totalBoxes = Nd4j.hstack(b1, b2);

	// mv = out0[:, ipass[0]]
	INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei();

	// if total_boxes.shape[0] > 0:
	if (!totalBoxes.isEmpty() && totalBoxes.shape()[0] > 0) {
		// pick = self.__nms(total_boxes, 0.7, 'Union')
		INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Union).transpose();

		// total_boxes = total_boxes[pick, :]
		totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all());

		// total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv[:, pick]))
		totalBoxes = MtcnnUtil.bbreg(totalBoxes, mv.get(all(), new SpecifiedIndex(pick.toLongVector())).transpose());

		// total_boxes = self.__rerec(total_boxes.copy())
		totalBoxes = MtcnnUtil.rerec(totalBoxes, false);
	}

	return totalBoxes;
}
 
Example #13
Source File: MtcnnUtil.java    From mtcnn-java with Apache License 2.0 4 votes vote down vote up
private static INDArrayIndex[] toUpdateIndex(INDArray array) {
	return new INDArrayIndex[] { new SpecifiedIndex(array.toLongVector()) };
}
 
Example #14
Source File: BaseSparseNDArray.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
    if(indices.size() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
        for(int i = 0; i < indArrayIndices.length; i++) {
            indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
        }
        boolean hasNext = true;
        Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
        while(hasNext) {
            try {
                List<List<Long>> next = iterate.next();
                for(int i = 0; i < next.size(); i++) {
                    int[] curr = Ints.toArray(next.get(i));
                    putScalar(curr,element.getDouble(ndIndexIterator.next()));
                }
            }
            catch(NoSuchElementException e) {
                hasNext = false;
            }
        }

    }
    else {
        List<INDArray> arrList = new ArrayList<>();

        if(indices.size() >= 2) {
            for(int i = 0; i < indices.size(); i++) {
                List<Integer> row = indices.get(i);
                for(int j = 0; j < row.size(); j++) {
                    INDArray slice = slice(row.get(j));
                    Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
                    arrList.add(slice(row.get(j)));
                }


            }
        }
        else if(indices.size() == 1) {
            for(int i = 0; i < indices.size(); i++) {
                arrList.add(slice(indices.get(0).get(i)));
            }
        }

    }


    return this;
}