Java Code Examples for org.nd4j.linalg.factory.Nd4j#trueScalar()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#trueScalar() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testWhileMapping2() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
        assertNotNull(tg);
        val input = Nd4j.trueScalar(4.0);
        tg.associateArrayWithVariable(input, tg.getVariable("input_1"));

        tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));

        //log.info("{}", tg.asFlatPrint());
/*
        val array = tg.execAndEndResult();
        val exp = Nd4j.create(2, 2).assign(2);
        assertNotNull(array);
        assertEquals(exp, array);*/
    }
 
Example 2
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileMapping3() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input = Nd4j.trueScalar(9.0);
    tg.associateArrayWithVariable(input, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(4);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 3
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-4.0);
    val input1 = Nd4j.trueScalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(-1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 4
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-9.0);
    val input1 = Nd4j.trueScalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(-3);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 5
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testCrash_119_simpleif_0() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2});
    val input1 = Nd4j.trueScalar(11f);

    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
}
 
Example 6
Source File: TFGraphTestAllHelper.java    From nd4j with Apache License 2.0 5 votes vote down vote up
protected static Map<String, INDArray> readVars(String modelName, String base_dir, String pattern) throws IOException {
    Map<String, INDArray> varMap = new HashMap<>();
    String modelDir = base_dir + "/" + modelName;
    ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(modelDir).getClassLoader());
    Resource[] resources = resolver.getResources("classpath*:" + modelDir + "/" + pattern + ".shape");
    val dtype = Nd4j.dataType();
    for (int i = 0; i < resources.length; i++) {
        String fileName = resources[i].getFilename();
        String varPath = modelDir + "/" + fileName;
        String[] varNameArr = fileName.split("\\.");
        String varName = String.join(".", Arrays.copyOfRange(varNameArr, 0, varNameArr.length - 2));
        int[] varShape = Nd4j.readNumpy(new ClassPathResource(varPath).getInputStream(), ",").data().asInt();
        try {
            float[] varContents = Nd4j.readNumpy(new ClassPathResource(varPath.replace(".shape", ".csv")).getInputStream(), ",").data().asFloat();
            INDArray varValue;
            if (varShape.length == 1) {
                if (varShape[0] == 0) {
                    varValue = Nd4j.trueScalar(varContents[0]);
                } else {
                    varValue = Nd4j.trueVector(varContents);
                }
            } else {
                varValue = Nd4j.create(varContents, varShape);
            }
            //varValue = Nd4j.readNumpy(new ClassPathResource(varPath.replace(".shape", ".csv")).getInputStream(), ",").reshape(varShape);
            if (varName.contains("____")) {
                //these are intermediate node outputs
                varMap.put(varName.replaceAll("____", "/"), varValue);
            } else {
                varMap.put(varName, varValue);
            }
        } catch (NumberFormatException e) {
            // FIXME: we can't parse boolean arrays right now :(
            continue;
        }
    }
    return varMap;
}
 
Example 7
Source File: ByteOrderTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testScalarEncoding() {
    val scalar = Nd4j.trueScalar(2.0f);

    FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
    val fb = scalar.toFlatArray(bufferBuilder);
    bufferBuilder.finish(fb);
    val db = bufferBuilder.dataBuffer();

    val flat = FlatArray.getRootAsFlatArray(db);


    val restored = Nd4j.createFromFlatArray(flat);

    assertEquals(scalar, restored);
}
 
Example 8
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray getScalar(long... indexes) {
    return Nd4j.trueScalar(getDouble(indexes));
}