Java Code Examples for org.tensorflow.Session#Runner

The following examples show how to use org.tensorflow.Session#Runner . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TfSymbolBlock.java    From djl with Apache License 2.0 6 votes vote down vote up
/** {@inheritDoc} */
@Override
public NDList forward(
        ParameterStore parameterStore,
        NDList inputs,
        boolean training,
        PairList<String, Object> params) {
    Session.Runner runner = session.runner();
    PairList<String, Shape> inputDescriptions = describeInput();
    PairList<String, Shape> outputDescriptions = describeOutput();

    for (int i = 0; i < inputDescriptions.size(); i++) {
        runner.feed(inputDescriptions.get(i).getKey(), ((TfNDArray) inputs.get(i)).getTensor());
    }
    for (int i = 0; i < outputDescriptions.size(); i++) {
        runner.fetch(outputDescriptions.get(i).getKey());
    }
    List<Tensor<?>> result = runner.run();

    NDList resultNDList = new NDList();
    TfNDManager tfNDManager = (TfNDManager) inputs.head().getManager();
    for (Tensor<?> tensor : result) {
        resultNDList.add(tfNDManager.create(tensor));
    }
    return resultNDList;
}
 
Example 2
Source File: TensorFlowExtras.java    From zoltar with Apache License 2.0 6 votes vote down vote up
/**
 * Fetch a list of operations from a {@link Session.Runner}, run it, extract output {@link
 * Tensor}s as {@link JTensor}s and close them.
 *
 * @param runner {@link Session.Runner} to fetch operations and extract outputs from.
 * @param fetchOps operations to fetch.
 * @return a {@link Map} of operations and output {@link JTensor}s. Map keys are in the same order
 *     as {@code fetchOps}.
 */
public static Map<String, JTensor> runAndExtract(
    final Session.Runner runner, final String... fetchOps) {
  for (final String op : fetchOps) {
    runner.fetch(op);
  }
  final Map<String, JTensor> result = Maps.newLinkedHashMapWithExpectedSize(fetchOps.length);
  final List<Tensor<?>> tensors = runner.run();
  try {
    for (int i = 0; i < fetchOps.length; i++) {
      final Tensor<?> tensor = tensors.get(i);
      result.put(fetchOps[i], JTensor.create(tensor));
    }
  } finally {
    tensors.forEach(Tensor::close);
  }
  return result;
}
 
Example 3
Source File: TensorFlowExtrasTest.java    From zoltar with Apache License 2.0 5 votes vote down vote up
@Test
public void testExtract1() {
  final Graph graph = createDummyGraph();
  final Session session = new Session(graph);
  final Session.Runner runner = session.runner();
  runner.feed("input", Tensors.create(10.0));
  final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2);
  assertEquals(Sets.newHashSet(mul2), result.keySet());
  assertScalar(result.get(mul2), 20.0);
  session.close();
  graph.close();
}
 
Example 4
Source File: TensorFlowExtrasTest.java    From zoltar with Apache License 2.0 5 votes vote down vote up
@Test
public void testExtract2a() {
  final Graph graph = createDummyGraph();
  final Session session = new Session(graph);
  final Session.Runner runner = session.runner();
  runner.feed("input", Tensors.create(10.0));
  final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2, mul3);
  assertEquals(Lists.newArrayList(mul2, mul3), new ArrayList<>(result.keySet()));
  assertScalar(result.get(mul2), 20.0);
  assertScalar(result.get(mul3), 30.0);
  session.close();
  graph.close();
}
 
Example 5
Source File: TensorFlowExtrasTest.java    From zoltar with Apache License 2.0 5 votes vote down vote up
@Test
public void testExtract2b() {
  final Graph graph = createDummyGraph();
  final Session session = new Session(graph);
  final Session.Runner runner = session.runner();
  runner.feed("input", Tensors.create(10.0));
  final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul3, mul2);
  assertEquals(Lists.newArrayList(mul3, mul2), new ArrayList<>(result.keySet()));
  assertScalar(result.get(mul2), 20.0);
  assertScalar(result.get(mul3), 30.0);
  session.close();
  graph.close();
}
 
Example 6
Source File: TensorFlowPredictFn.java    From zoltar with Apache License 2.0 5 votes vote down vote up
/**
 * TensorFlow Example prediction function.
 *
 * @deprecated Use {@link #example(Function, String...)}
 * @param outTensorExtractor Function to extract the output value from JTensor's
 * @param fetchOps operations to fetch.
 */
@Deprecated
static <InputT, ValueT> TensorFlowPredictFn<InputT, List<Example>, ValueT> exampleBatch(
    final Function<Map<String, JTensor>, ValueT> outTensorExtractor, final String... fetchOps) {
  final BiFunction<TensorFlowModel, List<Example>, ValueT> predictFn =
      (model, examples) -> {
        final byte[][] bytes = examples.stream().map(Example::toByteArray).toArray(byte[][]::new);

        try (final Tensor<String> t = Tensors.create(bytes)) {
          final Session.Runner runner =
              model.instance().session().runner().feed("input_example_tensor", t);
          final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, fetchOps);

          return outTensorExtractor.apply(result);
        }
      };

  return (model, vectors) -> {
    final List<CompletableFuture<Prediction<InputT, ValueT>>> predictions =
        vectors
            .stream()
            .map(
                vector ->
                    CompletableFuture.supplyAsync(() -> predictFn.apply(model, vector.value()))
                        .thenApply(v -> Prediction.create(vector.input(), v)))
            .collect(Collectors.toList());

    return CompletableFutures.allAsList(predictions);
  };
}
 
Example 7
Source File: RNTensorflowInference.java    From react-native-tensorflow with Apache License 2.0 5 votes vote down vote up
private static TfContext createContext(ReactContext reactContext, String model) throws IOException {
    byte[] b = new ResourceManager(reactContext).loadResource(model);

    Graph graph = new Graph();
    graph.importGraphDef(b);
    Session session = new Session(graph);
    Session.Runner runner = session.runner();

    return new TfContext(session, runner, graph);
}
 
Example 8
Source File: GraphImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
    Session.Runner fetched = bundle.session().runner().fetch(name);
    List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
    if (importedTensors.size() != 1)
        throw new IllegalStateException("Expected 1 tensor from fetching " + name +
                                        ", but got " + importedTensors.size());
    return importedTensors.get(0);
}
 
Example 9
Source File: TestableTensorFlowModel.java    From vespa with Apache License 2.0 5 votes vote down vote up
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
    Session.Runner runner = model.session().runner();
    org.tensorflow.Tensor<?> input = floatInput ? tensorFlowFloatInputArgument() : tensorFlowDoubleInputArgument();
    runner.feed(inputName, input);
    List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
    assertEquals(1, results.size());
    return TensorConverter.toVespaTensor(results.get(0));
}
 
Example 10
Source File: BlogEvaluationBenchmark.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static double evaluateTensorflow(SavedModelBundle tensorFlowModel, org.tensorflow.Tensor<?> u, org.tensorflow.Tensor<?> d, int iterations) {
    double result = 0;
    for (int i = 0 ; i < iterations; i++) {
        Session.Runner runner = tensorFlowModel.session().runner();
        runner.feed("input_u", u);
        runner.feed("input_d", d);
        List<org.tensorflow.Tensor<?>> results = runner.fetch("y").run();
        result = TensorConverter.toVespaTensor(results.get(0)).sum().asDouble();
    }
    return result;
}
 
Example 11
Source File: TestableModel.java    From vespa with Apache License 2.0 5 votes vote down vote up
Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) {
    Session.Runner runner = tensorFlowModel.session().runner();
    for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
        try {
            runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
        } catch (Exception e) {
            runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
        }
    }
    List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
    assertEquals(1, results.size());
    return TensorConverter.toVespaTensor(results.get(0));
}
 
Example 12
Source File: TensorFlowProcessor.java    From datacollector with Apache License 2.0 5 votes vote down vote up
private void processUseEntireBatch(Batch batch, SingleLaneBatchMaker singleLaneBatchMaker) throws StageException {
  Session.Runner runner = this.session.runner();
  Iterator<Record> batchRecords = batch.getRecords();
  if (batchRecords.hasNext()) {
    Map<Pair<String, Integer>, Tensor> inputs = convertBatch(batch, conf.inputConfigs);
    try {
      for (Map.Entry<Pair<String, Integer>, Tensor> inputMapEntry : inputs.entrySet()) {
        runner.feed(inputMapEntry.getKey().getLeft(),
            inputMapEntry.getKey().getRight(),
            inputMapEntry.getValue()
        );
      }

      for (TensorConfig outputConfig : conf.outputConfigs) {
        runner.fetch(outputConfig.operation, outputConfig.index);
      }

      List<Tensor<?>> tensorOutput = runner.run();
      LinkedHashMap<String, Field> outputTensorFieldMap = createOutputFieldValue(tensorOutput);
      EventRecord eventRecord = TensorFlowEvents.TENSOR_FLOW_OUTPUT_CREATOR.create(getContext()).create();
      eventRecord.set(Field.createListMap(outputTensorFieldMap));
      getContext().toEvent(eventRecord);
    } finally {
      inputs.values().forEach(Tensor::close);
    }

    Iterator<Record> it = batch.getRecords();
    while (it.hasNext()) {
      singleLaneBatchMaker.addRecord(it.next());
    }
  }
}
 
Example 13
Source File: TensorFlowProcessor.java    From datacollector with Apache License 2.0 5 votes vote down vote up
public void processUseRecordByRecord(Batch batch, SingleLaneBatchMaker singleLaneBatchMaker) throws StageException {
  Iterator<Record> it = batch.getRecords();
  while (it.hasNext()) {
    Record record = it.next();
    setInputConfigFields(record);
    Session.Runner runner = this.session.runner();

    Map<Pair<String, Integer>, Tensor> inputs = null;
    try {
      inputs = convertRecord(record, conf.inputConfigs);
    } catch (OnRecordErrorException ex) {
      errorRecordHandler.onError(ex);
      continue;
    }

    try {
      for (Map.Entry<Pair<String, Integer>, Tensor> inputMapEntry : inputs.entrySet()) {
        runner.feed(inputMapEntry.getKey().getLeft(),
            inputMapEntry.getKey().getRight(),
            inputMapEntry.getValue()
        );
      }

      for (TensorConfig outputConfig : conf.outputConfigs) {
        runner.fetch(outputConfig.operation, outputConfig.index);
      }

      List<Tensor<?>> tensorOutput = runner.run();
      LinkedHashMap<String, Field> outputTensorFieldMap = createOutputFieldValue(tensorOutput);
      record.set(conf.outputField, Field.create(outputTensorFieldMap));
      singleLaneBatchMaker.addRecord(record);

    } finally {
      inputs.values().forEach(Tensor::close);
    }
  }
}
 
Example 14
Source File: RNTensorflowInference.java    From react-native-tensorflow with Apache License 2.0 4 votes vote down vote up
TfContext(Session session, Session.Runner runner, Graph graph) {
    this.session = session;
    this.runner = runner;
    this.graph = graph;
    outputTensors = new HashMap<>();
}