Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#expandDims()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#expandDims() .
Example 1
public void testExpandDimsGradient() {
    val origShape = new long[]{3, 4};

    List<String> failed = new ArrayList<>();

    boolean first = true;
    for (int i = 0; i < 3; i++) {

        long[] expExpandShape;
        switch (i) {
            case 0:
                expExpandShape = new long[]{1, 3, 4};
            case 1:
                expExpandShape = new long[]{3, 1, 4};
            case 2:
                expExpandShape = new long[]{3, 4, 1};
                throw new RuntimeException();

        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.DOUBLE)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.expandDims(in, i);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", expand, true);

            Map<String,INDArray> m = sd.outputAll(null);
            INDArray expOut = in.getArr().std(true);

            assertArrayEquals(expExpandShape, m.get(;
            INDArray expExpand = inArr.dup('c').reshape(expExpandShape);

            String msg = "expandDim=" + i + ", source=" + p.getSecond();
  "Starting: " + msg);

            TestCase tc = new TestCase(sd);
                    .expectedOutput("out", expOut)
                    .expectedOutput(, expExpand);

            String error = OpValidation.validate(tc);
            if(error != null){
    assertEquals(failed.toString(), 0, failed.size());
Example 2
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
    final val W = paramTable.get(WEIGHT_KEY);
    final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
    final val b = paramTable.get(BIAS_KEY);

    long[] shape = layerInput.getShape();
    Preconditions.checkState(shape != null, "Null shape for input placeholder");
    SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
    this.timeSteps = inputSlices.length;
    SDVariable[] outputSlices = new SDVariable[timeSteps];
    SDVariable prev = null;
    for (int i = 0; i < timeSteps; i++) {
        final val x_i = inputSlices[i];
        outputSlices[i] = x_i.mmul(W);
            outputSlices[i] = outputSlices[i].add(b);

        if(prev != null){
            SDVariable attn;
                val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
                val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
                val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
                val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);

                attn = sameDiff.nn.multiHeadDotProductAttention(getLayerName()+"_attention_"+i, prev, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
                attn = sameDiff.nn.dotProductAttention(getLayerName()+"_attention_"+i, prev, layerInput, layerInput, mask, true);

            attn = sameDiff.squeeze(attn, 2);

            outputSlices[i] = outputSlices[i].add(attn.mmul(R));

        outputSlices[i] = activation.asSameDiff(sameDiff, outputSlices[i]);
        outputSlices[i] = sameDiff.expandDims(outputSlices[i], 2);
        prev = outputSlices[i];
    return sameDiff.concat(2, outputSlices);
Example 3
public SDVariable defineLayer(SameDiff sd, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {

    // input: [mb, inputCapsules, inputCapsuleDimensions]

    // [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
    SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4);

    // [mb, inputCapsules, capsules  * capsuleDimensions, inputCapsuleDimensions, 1]
    SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);

    // [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
    SDVariable weights = paramTable.get(WEIGHT_PARAM);

    // uHat is the matrix of prediction vectors between two capsules
    // [mb, inputCapsules, capsules, capsuleDimensions, 1]
    SDVariable uHat = weights.times(tiled).sum(true, 3)
            .reshape(-1, inputCapsules, capsules, capsuleDimensions, 1);

    // b is the logits of the routing procedure
    // [mb, inputCapsules, capsules, 1, 1]
    SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));

    for(int i = 0 ; i < routings ; i++){

        // c is the coupling coefficient, i.e. the edge weight between the 2 capsules
        // [mb, inputCapsules, capsules, 1, 1]
        SDVariable c = sd.nn.softmax(b, 2);

        // [mb, 1, capsules, capsuleDimensions, 1]
        SDVariable s = c.times(uHat).sum(true, 1);
            s =;

        // v is the per capsule activations.  On the last routing iteration, this is output
        // [mb, 1, capsules, capsuleDimensions, 1]
        SDVariable v = CapsuleUtils.squash(sd, s, 3);

        if(i == routings - 1){
            return sd.squeeze(sd.squeeze(v, 1), 3);

        // [mb, inputCapsules, capsules, capsuleDimensions, 1]
        SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1);

        // [mb, inputCapsules, capsules, 1, 1]
        b =, 3));

    return null; // will always return in the loop