Java Code Examples for org.nd4j.linalg.factory.Nd4j#expandDims()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#expandDims() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Gan4Exemple.java    From dl4j-tutorials with MIT License 6 votes vote down vote up
private static INDArray toINDArray(double[] dataDb) {
    int[] shape = new int[] {channel, height, width};

    INDArray ret2 = Nd4j.create(1, width * height * channel);
    int len = width * height;
    for (int y = 0; y < height; y++) {
        for (int x = 0; x < width; x++) {
            int idx = y * width + x;
            ret2.putScalar(idx, dataDb[y * width + x]);
            if (channel > 1) {
                ret2.putScalar(idx + len, dataDb[y * width + x + len]);
            }
            if (channel > 2) {
                ret2.putScalar(idx + len * 2, dataDb[y * width + x + len * 2]);
            }
        }
    }
    return Nd4j.expandDims(ret2.reshape(shape), 0);
}
 
Example 2
Source File: Gan4Exemple.java    From dl4j-tutorials with MIT License 6 votes vote down vote up
private static INDArray toINDArray(int[] colorArr) {
    int[] shape = new int[] {channel, height, width};

    INDArray ret2 = Nd4j.create(1, width * height * channel);

    for (int y = 0; y < height; y++) {
        for (int x = 0; x < width; x++) {
            int idx = y * width + x;
            int color = colorArr[idx];
            int[] argb = trimRGBColor(color);
            ret2.putScalar(idx, (argb[1]) & 0xFF);
            ret2.putScalar(idx + colorArr.length, (argb[2]) & 0xFF);
            ret2.putScalar(idx + colorArr.length * 2, (argb[3]) & 0xFF);
        }
    }
    return Nd4j.expandDims(ret2.reshape(shape), 0);
}
 
Example 3
Source File: ImageUtils.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
public static INDArray toINDArrayBGR(BufferedImage image) {
    int height = image.getHeight();
    int width = image.getWidth();
    int bands = image.getRaster().getNumBands();

    int[] pixels = new int[width * height];
    pixels = getRGB(image, 0, 0, width, height, pixels);
    int[] shape = new int[] {bands, height, width};

    INDArray ret2 = Nd4j.create(1, width * height * bands);

    for (int y = 0; y < height; y++) {
        for (int x = 0; x < width; x++) {
            int idx = y * width + x;
            int color = pixels[idx];
            int[] argb = trimRGBColor(color);
            ret2.putScalar(idx, (argb[1]) & 0xFF);
            if (bands > 1) {
                ret2.putScalar(idx + pixels.length, (argb[2]) & 0xFF);
            }
            if (bands > 2) {
                ret2.putScalar(idx + pixels.length * 2, (argb[3]) & 0xFF);
            }
        }
    }
    return Nd4j.expandDims(ret2.reshape(shape), 0);
}
 
Example 4
Source File: MtcnnService.java    From mtcnn-java with Apache License 2.0 4 votes vote down vote up
/**
 *  STAGE 2
 *
 * @param image
 * @param totalBoxes
 * @param padResult
 * @return
 * @throws IOException
 */
private INDArray refinementStage(INDArray image, INDArray totalBoxes, MtcnnUtil.PadResult padResult) throws IOException {

	// num_boxes = total_boxes.shape[0]
	int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0];
	// if num_boxes == 0:
	//   return total_boxes, stage_status
	if (numBoxes == 0) {
		return totalBoxes;
	}

	INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 24);

	//this.refineNetGraph.associateArrayWithVariable(tempImg1, this.refineNetGraph.variableMap().get("rnet/input"));
	//List<DifferentialFunction> refineNetResults = this.refineNetGraph.exec().getRight();
	//INDArray out0 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/fc2-2/fc2-2"))
	//		.findFirst().get().outputVariable().getArr();
	//INDArray out1 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/prob1"))
	//		.findFirst().get().outputVariable().getArr();

	Map<String, INDArray> resultMap = this.refineNetGraphRunner.run(Collections.singletonMap("rnet/input", tempImg1));
	//INDArray out0 = resultMap.get("rnet/fc2-2/fc2-2");  // for ipazc/mtcnn model
	INDArray out0 = resultMap.get("rnet/conv5-2/conv5-2");
	INDArray out1 = resultMap.get("rnet/prob1");

	//  score = out1[1, :]
	INDArray score = out1.get(all(), point(1)).transposei();

	// ipass = np.where(score > self.__steps_threshold[1])
	INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[1]);
	//INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[1]));

	if (ipass.isEmpty()) {
		totalBoxes = Nd4j.empty();
		return totalBoxes;
	}
	// total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)])
	INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4));
	INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1)
			: Nd4j.expandDims(score.get(ipass), 1);
	totalBoxes = Nd4j.hstack(b1, b2);

	// mv = out0[:, ipass[0]]
	INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei();

	// if total_boxes.shape[0] > 0:
	if (!totalBoxes.isEmpty() && totalBoxes.shape()[0] > 0) {
		// pick = self.__nms(total_boxes, 0.7, 'Union')
		INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Union).transpose();

		// total_boxes = total_boxes[pick, :]
		totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all());

		// total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv[:, pick]))
		totalBoxes = MtcnnUtil.bbreg(totalBoxes, mv.get(all(), new SpecifiedIndex(pick.toLongVector())).transpose());

		// total_boxes = self.__rerec(total_boxes.copy())
		totalBoxes = MtcnnUtil.rerec(totalBoxes, false);
	}

	return totalBoxes;
}
 
Example 5
Source File: SameDiffTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testMultiGradientManualRecurrent() {
    final INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{3, 4, 2});
    final INDArray[] output = new INDArray[(int) input.size(2)];
    for (int i = 0; i < input.size(2); i++) {
        final INDArray x_i = input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i));

        output[i] = x_i;
        if (i > 0) {
            output[i] = output[i].add(Nd4j.squeeze(output[i - 1], 2));
        }

        output[i] = Nd4j.expandDims(output[i], 2);
    }
    final INDArray out = Nd4j.concat(2, output).norm2();

    SameDiff sd = SameDiff.create();
    final SDVariable sdInput = sd.var("input", input);

    final long timeSteps = sdInput.getShape()[2];
    SDVariable[] outputSlices = new SDVariable[(int) timeSteps];
    final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2);

    final val x_0 = inputSlices[0];
    outputSlices[0] = x_0;
    outputSlices[0] = sd.expandDims("X_0-e", outputSlices[0], 2);

    final val x_1 = inputSlices[1];
    outputSlices[1] = x_1;
    outputSlices[1] = outputSlices[1].add(sd.squeeze("X_0-s", outputSlices[0], 2));
    outputSlices[1] = sd.expandDims("X_1-e", outputSlices[1], 2);

    SDVariable t = sd.concat(2, outputSlices);
    t.norm2("out");
    String err = OpValidation.validate(new TestCase(sd)
            .testFlatBufferSerialization(TestCase.TestSerialization.BOTH)
            .expectedOutput("out", out)
            .gradientCheck(true));

    assertNull(err);
}
 
Example 6
Source File: LossWasserstein.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
    INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask);
    return Nd4j.expandDims(scoreArr.mean(1), 1);
}
 
Example 7
Source File: INDArrayHelper.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
/**
 * Force the input source to have the correct shape:
 *  <p><ul>
 *      <li>DL4J requires it to be at least 2D</li>
 *      <li>RL4J has a convention to have the batch size on dimension 0 to all INDArrays</li>
 *  </ul></p>
 * @param source The {@link INDArray} to be corrected.
 * @return The corrected INDArray
 */
public static INDArray forceCorrectShape(INDArray source) {

    return source.shape()[0] == 1 && source.rank() > 1
            ? source
            : Nd4j.expandDims(source, 0);

}