Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#oneHot()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#oneHot() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testOneHot2() {

    INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);

    SameDiff sd = SameDiff.create();
    SDVariable indices = sd.constant("indices", indicesArr);
    int depth = 3;
    int axis = -1;
    SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.DOUBLE);

    INDArray exp = Nd4j.create(new double[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}});

    String err = OpValidation.validate(new TestCase(sd)
            .expected(oneHot, exp)
            .gradientCheck(false));

    assertNull(err);
}
 
Example 2
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testOneHot4() {

    INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);

    SameDiff sd = SameDiff.create();
    SDVariable indices = sd.constant("indices", indicesArr);
    int depth = 3;
    int axis = -1;
    SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.INT32);

    INDArray exp = Nd4j.create(new int[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}});

    String err = OpValidation.validate(new TestCase(sd)
            .expected(oneHot, exp)
            .gradientCheck(false));

    assertNull(err);
}
 
Example 3
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testOneHot1(){
    List<String> failed = new ArrayList<>();

    //Because it's on the diagonal, should be the same for all axis args...
    for( int i=-1; i<=0; i++ ) {
        INDArray indicesArr = Nd4j.createFromArray(0, 1, 2);
        int depth = 3;

        SameDiff sd = SameDiff.create();
        SDVariable indices = sd.constant(indicesArr);
        SDVariable oneHot = sd.oneHot(indices, depth, i, 1.0, 0.0, DataType.DOUBLE);

        INDArray exp = Nd4j.eye(3).castTo(DataType.DOUBLE);

        String msg = "Axis: " + i;
        log.info("Test case: " + msg);

        String err = OpValidation.validate(new TestCase(sd)
                .testName(msg)
                .gradientCheck(false)
                .expected(oneHot, exp));

        if(err != null){
            failed.add(err);
        }
    }
    assertEquals(failed.toString(), 0, failed.size());
}