Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#expandDims()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#expandDims() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testExpandDimsGradient() {
    val origShape = new long[]{3, 4};

    List<String> failed = new ArrayList<>();

    boolean first = true;
    for (int i = 0; i < 3; i++) {

        long[] expExpandShape;
        switch (i) {
            case 0:
                expExpandShape = new long[]{1, 3, 4};
                break;
            case 1:
                expExpandShape = new long[]{3, 1, 4};
                break;
            case 2:
                expExpandShape = new long[]{3, 4, 1};
                break;
            default:
                throw new RuntimeException();
        }

        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.DOUBLE)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.expandDims(in, i);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", expand, true);

            Map<String,INDArray> m = sd.outputAll(null);
            INDArray expOut = in.getArr().std(true);

            assertArrayEquals(expExpandShape, m.get(expand.name()).shape());
            INDArray expExpand = inArr.dup('c').reshape(expExpandShape);

            String msg = "expandDim=" + i + ", source=" + p.getSecond();
            log.info("Starting: " + msg);

            TestCase tc = new TestCase(sd);
            tc.testName(msg)
                    .expectedOutput("out", expOut)
                    .expectedOutput(expand.name(), expExpand);

            String error = OpValidation.validate(tc);
            if(error != null){
                failed.add(error);
            }
        }
    }
    assertEquals(failed.toString(), 0, failed.size());
}
 
Example 2
Source File: RecurrentAttentionLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
    final val W = paramTable.get(WEIGHT_KEY);
    final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
    final val b = paramTable.get(BIAS_KEY);

    long[] shape = layerInput.getShape();
    Preconditions.checkState(shape != null, "Null shape for input placeholder");
    SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
    this.timeSteps = inputSlices.length;
    SDVariable[] outputSlices = new SDVariable[timeSteps];
    SDVariable prev = null;
    for (int i = 0; i < timeSteps; i++) {
        final val x_i = inputSlices[i];
        outputSlices[i] = x_i.mmul(W);
        if(hasBias){
            outputSlices[i] = outputSlices[i].add(b);
        }

        if(prev != null){
            SDVariable attn;
            if(projectInput){
                val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
                val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
                val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
                val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);

                attn = sameDiff.nn.multiHeadDotProductAttention(getLayerName()+"_attention_"+i, prev, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
            }else{
                attn = sameDiff.nn.dotProductAttention(getLayerName()+"_attention_"+i, prev, layerInput, layerInput, mask, true);
            }

            attn = sameDiff.squeeze(attn, 2);

            outputSlices[i] = outputSlices[i].add(attn.mmul(R));
        }

        outputSlices[i] = activation.asSameDiff(sameDiff, outputSlices[i]);
        outputSlices[i] = sameDiff.expandDims(outputSlices[i], 2);
        prev = outputSlices[i];
    }
    return sameDiff.concat(2, outputSlices);
}
 
Example 3
Source File: CapsuleLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public SDVariable defineLayer(SameDiff sd, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {

    // input: [mb, inputCapsules, inputCapsuleDimensions]

    // [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
    SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4);

    // [mb, inputCapsules, capsules  * capsuleDimensions, inputCapsuleDimensions, 1]
    SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);

    // [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
    SDVariable weights = paramTable.get(WEIGHT_PARAM);

    // uHat is the matrix of prediction vectors between two capsules
    // [mb, inputCapsules, capsules, capsuleDimensions, 1]
    SDVariable uHat = weights.times(tiled).sum(true, 3)
            .reshape(-1, inputCapsules, capsules, capsuleDimensions, 1);

    // b is the logits of the routing procedure
    // [mb, inputCapsules, capsules, 1, 1]
    SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));

    for(int i = 0 ; i < routings ; i++){

        // c is the coupling coefficient, i.e. the edge weight between the 2 capsules
        // [mb, inputCapsules, capsules, 1, 1]
        SDVariable c = sd.nn.softmax(b, 2);

        // [mb, 1, capsules, capsuleDimensions, 1]
        SDVariable s = c.times(uHat).sum(true, 1);
        if(hasBias){
            s = s.plus(paramTable.get(BIAS_PARAM));
        }

        // v is the per capsule activations.  On the last routing iteration, this is output
        // [mb, 1, capsules, capsuleDimensions, 1]
        SDVariable v = CapsuleUtils.squash(sd, s, 3);

        if(i == routings - 1){
            return sd.squeeze(sd.squeeze(v, 1), 3);
        }

        // [mb, inputCapsules, capsules, capsuleDimensions, 1]
        SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1);

        // [mb, inputCapsules, capsules, 1, 1]
        b = b.plus(uHat.times(vTiled).sum(true, 3));
    }

    return null; // will always return in the loop
}