Java Code Examples for org.nd4j.linalg.ops.transforms.Transforms#tanh()

The following examples show how to use org.nd4j.linalg.ops.transforms.Transforms#tanh() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: FailingSameDiffTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testExecutionDifferentShapesTransform(){
    OpValidationSuite.ignoreFailing();
    SameDiff sd = SameDiff.create();
    SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));

    SDVariable tanh = sd.math().tanh(in);
    INDArray exp = Transforms.tanh(in.getArr(), true);

    INDArray out = tanh.eval();
    assertEquals(exp, out);

    //Now, replace with minibatch 5:
    in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4));
    INDArray out2 = tanh.eval();
    assertArrayEquals(new long[]{5,4}, out2.shape());

    exp = Transforms.tanh(in.getArr(), true);
    assertEquals(exp, out2);
}
 
Example 2
Source File: SameDiffTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testActivationBackprop() {

        Activation[] afns = new Activation[]{
                Activation.TANH,
                Activation.SIGMOID,
                Activation.ELU,
                Activation.SOFTPLUS,
                Activation.SOFTSIGN,
                Activation.HARDTANH,
                Activation.CUBE,            //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426
                Activation.RELU,            //JVM crash
                Activation.LEAKYRELU        //JVM crash
        };

        for (Activation a : afns) {

            SameDiff sd = SameDiff.create();
            INDArray inArr = Nd4j.linspace(-3, 3, 7);
            INDArray labelArr = Nd4j.linspace(-3, 3, 7).muli(0.5);
            SDVariable in = sd.var("in", inArr.dup());

//            System.out.println("inArr: " + inArr);

            INDArray outExp;
            SDVariable out;
            switch (a) {
                case ELU:
                    out = sd.elu("out", in);
                    outExp = Transforms.elu(inArr, true);
                    break;
                case HARDTANH:
                    out = sd.hardTanh("out", in);
                    outExp = Transforms.hardTanh(inArr, true);
                    break;
                case LEAKYRELU:
                    out = sd.leakyRelu("out", in, 0.01);
                    outExp = Transforms.leakyRelu(inArr, true);
                    break;
                case RELU:
                    out = sd.relu("out", in, 0.0);
                    outExp = Transforms.relu(inArr, true);
                    break;
                case SIGMOID:
                    out = sd.sigmoid("out", in);
                    outExp = Transforms.sigmoid(inArr, true);
                    break;
                case SOFTPLUS:
                    out = sd.softplus("out", in);
                    outExp = Transforms.softPlus(inArr, true);
                    break;
                case SOFTSIGN:
                    out = sd.softsign("out", in);
                    outExp = Transforms.softsign(inArr, true);
                    break;
                case TANH:
                    out = sd.tanh("out", in);
                    outExp = Transforms.tanh(inArr, true);
                    break;
                case CUBE:
                    out = sd.cube("out", in);
                    outExp = Transforms.pow(inArr, 3, true);
                    break;
                default:
                    throw new RuntimeException(a.toString());
            }

            //Sum squared error loss:
            SDVariable label = sd.var("label", labelArr.dup());
            SDVariable diff = label.sub("diff", out);
            SDVariable sqDiff = diff.mul("sqDiff", diff);
            SDVariable totSum = sd.sum("totSum", sqDiff, Integer.MAX_VALUE);    //Loss function...

            sd.exec();
            INDArray outAct = sd.getVariable("out").getArr();
            assertEquals(a.toString(), outExp, outAct);

            // L = sum_i (label - out)^2
            //dL/dOut = 2(out - label)
            INDArray dLdOutExp = outExp.sub(labelArr).mul(2);
            INDArray dLdInExp = a.getActivationFunction().backprop(inArr.dup(), dLdOutExp.dup()).getFirst();

            sd.execBackwards();
            SameDiff gradFn = sd.getFunction("grad");
            INDArray dLdOutAct = gradFn.getVariable("out-grad").getArr();
            INDArray dLdInAct = gradFn.getVariable("in-grad").getArr();

            assertEquals(a.toString(), dLdOutExp, dLdOutAct);
            assertEquals(a.toString(), dLdInExp, dLdInAct);
        }
    }
 
Example 3
Source File: SameDiffTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testActivationBackprop() {

        Activation[] afns = new Activation[]{
                Activation.TANH,
                Activation.SIGMOID,
                Activation.ELU,
                Activation.SOFTPLUS,
                Activation.SOFTSIGN,
                Activation.HARDTANH,
                Activation.CUBE,
                //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426
                Activation.RELU,            //JVM crash
                Activation.LEAKYRELU        //JVM crash
        };

        for (Activation a : afns) {

            SameDiff sd = SameDiff.create();
            INDArray inArr = Nd4j.linspace(-3, 3, 7);
            INDArray labelArr = Nd4j.linspace(-3, 3, 7).muli(0.5);
            SDVariable in = sd.var("in", inArr.dup());

//            System.out.println("inArr: " + inArr);

            INDArray outExp;
            SDVariable out;
            switch (a) {
                case ELU:
                    out = sd.nn().elu("out", in);
                    outExp = Transforms.elu(inArr, true);
                    break;
                case HARDTANH:
                    out = sd.nn().hardTanh("out", in);
                    outExp = Transforms.hardTanh(inArr, true);
                    break;
                case LEAKYRELU:
                    out = sd.nn().leakyRelu("out", in, 0.01);
                    outExp = Transforms.leakyRelu(inArr, true);
                    break;
                case RELU:
                    out = sd.nn().relu("out", in, 0.0);
                    outExp = Transforms.relu(inArr, true);
                    break;
                case SIGMOID:
                    out = sd.nn().sigmoid("out", in);
                    outExp = Transforms.sigmoid(inArr, true);
                    break;
                case SOFTPLUS:
                    out = sd.nn().softplus("out", in);
                    outExp = Transforms.softPlus(inArr, true);
                    break;
                case SOFTSIGN:
                    out = sd.nn().softsign("out", in);
                    outExp = Transforms.softsign(inArr, true);
                    break;
                case TANH:
                    out = sd.math().tanh("out", in);
                    outExp = Transforms.tanh(inArr, true);
                    break;
                case CUBE:
                    out = sd.math().cube("out", in);
                    outExp = Transforms.pow(inArr, 3, true);
                    break;
                default:
                    throw new RuntimeException(a.toString());
            }

            //Sum squared error loss:
            SDVariable label = sd.var("label", labelArr.dup());
            SDVariable diff = label.sub("diff", out);
            SDVariable sqDiff = diff.mul("sqDiff", diff);
            SDVariable totSum = sd.sum("totSum", sqDiff, Integer.MAX_VALUE);    //Loss function...

            Map<String,INDArray> m = sd.output(Collections.emptyMap(), "out");
            INDArray outAct = m.get("out");
            assertEquals(a.toString(), outExp, outAct);

            // L = sum_i (label - out)^2
            //dL/dOut = 2(out - label)
            INDArray dLdOutExp = outExp.sub(labelArr).mul(2);
            INDArray dLdInExp = a.getActivationFunction().backprop(inArr.dup(), dLdOutExp.dup()).getFirst();

            Map<String,INDArray> grads = sd.calculateGradients(null, "out", "in");
//            sd.execBackwards(Collections.emptyMap());
//            SameDiff gradFn = sd.getFunction("grad");
            INDArray dLdOutAct = grads.get("out");
            INDArray dLdInAct = grads.get("in");

            assertEquals(a.toString(), dLdOutExp, dLdOutAct);
            assertEquals(a.toString(), dLdInExp, dLdInAct);
        }
    }
 
Example 4
Source File: RnnOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRnnBlockCell(){
    Nd4j.getRandom().setSeed(12345);
    int mb = 2;
    int nIn = 3;
    int nOut = 4;

    SameDiff sd = SameDiff.create();
    SDVariable x = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nIn));
    SDVariable cLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut));
    SDVariable yLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut));
    SDVariable W = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), 4*nOut));
    SDVariable Wci = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
    SDVariable Wcf = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
    SDVariable Wco = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
    SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));

    double fb = 1.0;
    LSTMConfiguration conf = LSTMConfiguration.builder()
            .peepHole(true)
            .forgetBias(fb)
            .clippingCellValue(0.0)
            .build();

    LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
            .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();

    LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf));  //Output order: i, c, f, o, z, h, y
    List<String> toExec = new ArrayList<>();
    for(SDVariable sdv : v.getAllOutputs()){
        toExec.add(sdv.name());
    }

    //Test forward pass:
    Map<String,INDArray> m = sd.output(null, toExec);

    //Weights and bias order: [i, f, z, o]

    //Block input (z) - post tanh:
    INDArray wz_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(nOut, 2*nOut));           //Input weights
    INDArray wz_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(nOut, 2*nOut));    //Recurrent weights
    INDArray bz = b.getArr().get(NDArrayIndex.interval(nOut, 2*nOut));

    INDArray zExp = x.getArr().mmul(wz_x).addiRowVector(bz);        //[mb,nIn]*[nIn, nOut] + [nOut]
    zExp.addi(yLast.getArr().mmul(wz_r));   //[mb,nOut]*[nOut,nOut]
    Transforms.tanh(zExp, false);

    INDArray zAct = m.get(toExec.get(4));
    assertEquals(zExp, zAct);

    //Input modulation gate (post sigmoid) - i: (note: peephole input - last time step)
    INDArray wi_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(0, nOut));           //Input weights
    INDArray wi_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(0, nOut));    //Recurrent weights
    INDArray bi = b.getArr().get(NDArrayIndex.interval(0, nOut));

    INDArray iExp = x.getArr().mmul(wi_x).addiRowVector(bi);        //[mb,nIn]*[nIn, nOut] + [nOut]
    iExp.addi(yLast.getArr().mmul(wi_r));   //[mb,nOut]*[nOut,nOut]
    iExp.addi(cLast.getArr().mulRowVector(Wci.getArr()));    //Peephole
    Transforms.sigmoid(iExp, false);
    assertEquals(iExp, m.get(toExec.get(0)));

    //Forget gate (post sigmoid): (note: peephole input - last time step)
    INDArray wf_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(2*nOut, 3*nOut));           //Input weights
    INDArray wf_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(2*nOut, 3*nOut));    //Recurrent weights
    INDArray bf = b.getArr().get(NDArrayIndex.interval(2*nOut, 3*nOut));

    INDArray fExp = x.getArr().mmul(wf_x).addiRowVector(bf);        //[mb,nIn]*[nIn, nOut] + [nOut]
    fExp.addi(yLast.getArr().mmul(wf_r));   //[mb,nOut]*[nOut,nOut]
    fExp.addi(cLast.getArr().mulRowVector(Wcf.getArr()));   //Peephole
    fExp.addi(fb);
    Transforms.sigmoid(fExp, false);
    assertEquals(fExp, m.get(toExec.get(2)));

    //Cell state (pre tanh): tanh(z) .* sigmoid(i) + sigmoid(f) .* cLast
    INDArray cExp = zExp.mul(iExp).add(fExp.mul(cLast.getArr()));
    INDArray cAct = m.get(toExec.get(1));
    assertEquals(cExp, cAct);

    //Output gate (post sigmoid): (note: peephole input: current time step)
    INDArray wo_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(3*nOut, 4*nOut));           //Input weights
    INDArray wo_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(3*nOut, 4*nOut));    //Recurrent weights
    INDArray bo = b.getArr().get(NDArrayIndex.interval(3*nOut, 4*nOut));

    INDArray oExp = x.getArr().mmul(wo_x).addiRowVector(bo);        //[mb,nIn]*[nIn, nOut] + [nOut]
    oExp.addi(yLast.getArr().mmul(wo_r));   //[mb,nOut]*[nOut,nOut]
    oExp.addi(cExp.mulRowVector(Wco.getArr())); //Peephole
    Transforms.sigmoid(oExp, false);
    assertEquals(oExp, m.get(toExec.get(3)));

    //Cell state, post tanh
    INDArray hExp = Transforms.tanh(cExp, true);
    assertEquals(hExp, m.get(toExec.get(5)));

    //Final output
    INDArray yExp = hExp.mul(oExp);
    assertEquals(yExp, m.get(toExec.get(6)));
}
 
Example 5
Source File: RnnOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testGRUCell(){
    Nd4j.getRandom().setSeed(12345);
    int mb = 2;
    int nIn = 3;
    int nOut = 4;

    SameDiff sd = SameDiff.create();
    SDVariable x = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nIn));
    SDVariable hLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut));
    SDVariable Wru = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), 2*nOut));
    SDVariable Wc = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), nOut));
    SDVariable bru = sd.constant(Nd4j.rand(DataType.FLOAT, 2*nOut));
    SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));

    double fb = 1.0;
    GRUWeights weights = GRUWeights.builder()
            .ruWeight(Wru)
            .cWeight(Wc)
            .ruBias(bru)
            .cBias(bc)
            .build();

    SDVariable[] v = sd.rnn().gruCell(x, hLast, weights);
    List<String> toExec = new ArrayList<>();
    for(SDVariable sdv : v){
        toExec.add(sdv.name());
    }

    //Test forward pass:
    Map<String,INDArray> m = sd.output(null, toExec);

    //Weights and bias order: [r, u], [c]

    //Reset gate:
    INDArray wr_x = Wru.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(0, nOut));           //Input weights
    INDArray wr_r = Wru.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(0, nOut));    //Recurrent weights
    INDArray br = bru.getArr().get(NDArrayIndex.interval(0, nOut));

    INDArray rExp = x.getArr().mmul(wr_x).addiRowVector(br);        //[mb,nIn]*[nIn, nOut] + [nOut]
    rExp.addi(hLast.getArr().mmul(wr_r));   //[mb,nOut]*[nOut,nOut]
    Transforms.sigmoid(rExp,false);

    INDArray rAct = m.get(toExec.get(0));
    assertEquals(rExp, rAct);

    //Update gate:
    INDArray wu_x = Wru.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(nOut, 2*nOut));           //Input weights
    INDArray wu_r = Wru.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(nOut, 2*nOut));    //Recurrent weights
    INDArray bu = bru.getArr().get(NDArrayIndex.interval(nOut, 2*nOut));

    INDArray uExp = x.getArr().mmul(wu_x).addiRowVector(bu);        //[mb,nIn]*[nIn, nOut] + [nOut]
    uExp.addi(hLast.getArr().mmul(wu_r));   //[mb,nOut]*[nOut,nOut]
    Transforms.sigmoid(uExp,false);

    INDArray uAct = m.get(toExec.get(1));
    assertEquals(uExp, uAct);

    //c = tanh(x * Wcx + Wcr * (hLast .* r))
    INDArray Wcx = Wc.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.all());
    INDArray Wcr = Wc.getArr().get(NDArrayIndex.interval(nIn, nIn+nOut), NDArrayIndex.all());
    INDArray cExp = x.getArr().mmul(Wcx);
    cExp.addi(hLast.getArr().mul(rExp).mmul(Wcr));
    cExp.addiRowVector(bc.getArr());
    Transforms.tanh(cExp, false);

    assertEquals(cExp, m.get(toExec.get(2)));

    //h = u * hLast + (1-u) * c
    INDArray hExp = uExp.mul(hLast.getArr()).add(uExp.rsub(1.0).mul(cExp));
    assertEquals(hExp, m.get(toExec.get(3)));
}
 
Example 6
Source File: TestSimpleRnn.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleRnn(){
    Nd4j.getRandom().setSeed(12345);

    int m = 3;
    int nIn = 5;
    int layerSize = 6;
    int tsLength = 7;
    INDArray in;
    if (rnnDataFormat == RNNFormat.NCW){
        in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
    }
    else{
        in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
    }


    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .updater(new NoOp())
            .weightInit(WeightInit.XAVIER)
            .activation(Activation.TANH)
            .list()
            .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
            .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    INDArray out = net.output(in);

    INDArray w = net.getParam("0_W");
    INDArray rw = net.getParam("0_RW");
    INDArray b = net.getParam("0_b");

    INDArray outLast = null;
    for( int i=0; i<tsLength; i++ ){
        INDArray inCurrent;
        if (rnnDataFormat == RNNFormat.NCW){
            inCurrent = in.get(all(), all(), point(i));
        }
        else{
            inCurrent = in.get(all(), point(i), all());
        }

        INDArray outExpCurrent = inCurrent.mmul(w);
        if(outLast != null){
            outExpCurrent.addi(outLast.mmul(rw));
        }

        outExpCurrent.addiRowVector(b);

        Transforms.tanh(outExpCurrent, false);

        INDArray outActCurrent;
        if (rnnDataFormat == RNNFormat.NCW){
            outActCurrent = out.get(all(), all(), point(i));
        }
        else{
            outActCurrent = out.get(all(), point(i), all());
        }
        assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);

        outLast = outExpCurrent;
    }


    TestUtils.testModelSerialization(net);
}