org.tensorflow.SavedModelBundle Java Examples

The following examples show how to use org.tensorflow.SavedModelBundle. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ModelPredictTest.java    From DeepMachineLearning with Apache License 2.0 7 votes vote down vote up
@Test
public void test() throws IOException, InterruptedException, IpssCacheException, ODMException, ExecutionException {
	AclfTrainDataGenerator gateway = new AclfTrainDataGenerator();
	//read case
	String filename = "testdata/cases/ieee14.ieee";
	gateway.loadCase(filename, "BusVoltLoadChangeTrainCaseBuilder");
	//run loadflow
	gateway.trainCaseBuilder.createTestCase();
	//generate input
	double[] inputs = gateway.trainCaseBuilder.getNetInput();
	float[][] inputs_f  = new float[1][inputs.length];
	for (int i = 0; i < inputs.length; i++) {
		inputs_f[0][i] =(float) inputs[i];
	}
	//read model
	SavedModelBundle bundle = SavedModelBundle.load("py/c_graph/single_net/model", "voltage");
	//predict
	float[][] output = bundle.session().runner().feed("x", Tensor.create(inputs_f)).fetch("z").run().get(0)
			.copyTo(new float[1][28]);
	double[][] output_d = new double[1][inputs.length];
	for (int i = 0; i < inputs.length; i++) {
		output_d[0][i] = output[0][i];
	}
	//print out mismatch 
	System.out.println("Model out mismatch: "+gateway.getMismatchInfo(output_d[0]));
}
 
Example #2
Source File: GraphImporter.java    From vespa with Apache License 2.0 6 votes vote down vote up
private static IntermediateOperation importOperation(String nodeName,
                                                     GraphDef tfGraph,
                                                     IntermediateGraph intermediateGraph,
                                                     SavedModelBundle bundle) {
    if (intermediateGraph.alreadyImported(nodeName)) {
        return intermediateGraph.get(nodeName);
    }
    NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
    List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
    IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
    intermediateGraph.put(nodeName, operation);

    List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
    if (controlInputs.size() > 0) {
        operation.setControlInputs(controlInputs);
    }

    if (operation.isConstant()) {
        operation.setConstantValueFunction(
                type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
    }

    return operation;
}
 
Example #3
Source File: InProcessClassification.java    From hazelcast-jet-demos with Apache License 2.0 6 votes vote down vote up
private static Pipeline buildPipeline(IMap<Long, String> reviewsMap) {
    // Set up the mapping context that loads the model on each member, shared
    // by all parallel processors on that member.
    ServiceFactory<Tuple2<SavedModelBundle, WordIndex>, Tuple2<SavedModelBundle, WordIndex>> modelContext = ServiceFactory
            .withCreateContextFn(context -> {
                File data = context.attachedDirectory("data");
                SavedModelBundle bundle = SavedModelBundle.load(data.toPath().resolve("model/1").toString(), "serve");
                return tuple2(bundle, new WordIndex(data));
            })
            .withDestroyContextFn(t -> t.f0().close())
            .withCreateServiceFn((context, tuple2) -> tuple2);
    Pipeline p = Pipeline.create();
    p.readFrom(Sources.map(reviewsMap))
     .map(Map.Entry::getValue)
     .mapUsingService(modelContext, (tuple, review) -> classify(review, tuple.f0(), tuple.f1()))
     // TensorFlow executes models in parallel, we'll use 2 local threads to maximize throughput.
     .setLocalParallelism(2)
     .writeTo(Sinks.logger(t -> String.format("Sentiment rating for review \"%s\" is %.2f", t.f0(), t.f1())));
    return p;
}
 
Example #4
Source File: BlogEvaluationBenchmark.java    From vespa with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) throws ParseException {
    SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
    ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);

    Context context = TestableTensorFlowModel.contextFrom(model);
    Tensor u = generateInputTensor();
    Tensor d = generateInputTensor();
    context.put("input_u", new TensorValue(u));
    context.put("input_d", new TensorValue(d));

    // Parse the ranking expression from imported string to force primitive tensor functions.
    RankingExpression expression = new RankingExpression(model.expressions().get("y").getRoot().toString());
    benchmarkJava(expression, context, 20, 200);

    System.out.println("*** Optimizing expression ***");
    ExpressionOptimizer optimizer = new ExpressionOptimizer();
    OptimizationReport report = optimizer.optimize(expression, (ContextIndex)context);
    System.out.println(report.toString());

    benchmarkJava(expression, context, 2000, 20000);
    benchmarkTensorFlow(tensorFlowModel, 2000, 20000);
}
 
Example #5
Source File: TestableTensorFlowModel.java    From vespa with Apache License 2.0 5 votes vote down vote up
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
    Session.Runner runner = model.session().runner();
    org.tensorflow.Tensor<?> input = floatInput ? tensorFlowFloatInputArgument() : tensorFlowDoubleInputArgument();
    runner.feed(inputName, input);
    List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
    assertEquals(1, results.size());
    return TensorConverter.toVespaTensor(results.get(0));
}
 
Example #6
Source File: TestableTensorFlowModel.java    From vespa with Apache License 2.0 5 votes vote down vote up
public TestableTensorFlowModel(String modelName, String modelDir, int d0Size, int d1Size, boolean floatInput) {
    tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
    model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel);
    this.d0Size = d0Size;
    this.d1Size = d1Size;
    this.floatInput = floatInput;
}
 
Example #7
Source File: VariableConverter.java    From vespa with Apache License 2.0 5 votes vote down vote up
/**
 * Reads the tensor with the given TensorFlow name at the given model location,
 * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type.
 * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
 * tensor dimensions are implicitly ordered.
 */
static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
    try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
        return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
                                                                                               bundle),
                                                               OrderedTensorType.fromSpec(orderedTypeSpec)));
    }
    catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
    }
}
 
Example #8
Source File: GraphImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
    Session.Runner fetched = bundle.session().runner().fetch(name);
    List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
    if (importedTensors.size() != 1)
        throw new IllegalStateException("Expected 1 tensor from fetching " + name +
                                        ", but got " + importedTensors.size());
    return importedTensors.get(0);
}
 
Example #9
Source File: GraphImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static List<IntermediateOperation> importControlInputs(NodeDef node,
                                                               GraphDef tfGraph,
                                                               IntermediateGraph intermediateGraph,
                                                               SavedModelBundle bundle) {
    return node.getInputList().stream()
            .filter(nodeName -> isControlDependency(nodeName))
            .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
            .collect(Collectors.toList());
}
 
Example #10
Source File: GraphImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static List<IntermediateOperation> importOperationInputs(NodeDef node,
                                                                 GraphDef tfGraph,
                                                                 IntermediateGraph intermediateGraph,
                                                                 SavedModelBundle bundle) {
    return node.getInputList().stream()
            .filter(name -> ! isControlDependency(name))
            .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
            .collect(Collectors.toList());
}
 
Example #11
Source File: GraphImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static void importOperations(MetaGraphDef tfGraph,
                                     IntermediateGraph intermediateGraph,
                                     SavedModelBundle bundle) {
    for (String signatureName : intermediateGraph.signatures()) {
        for (String outputName : intermediateGraph.outputs(signatureName).values()) {
            importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
        }
    }
}
 
Example #12
Source File: GraphImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
    MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());

    IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
    importSignatures(tfGraph, intermediateGraph);
    importOperations(tfGraph, intermediateGraph, bundle);
    verifyOutputTypes(tfGraph, intermediateGraph);

    return intermediateGraph;
}
 
Example #13
Source File: TensorFlowImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
/** Imports a TensorFlow model */
public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
    try {
        IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
        return convertIntermediateGraphToModel(graph, modelDir);
    }
    catch (IOException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
    }
}
 
Example #14
Source File: TensorFlowImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
/**
 * Imports a saved TensorFlow model from a directory.
 * The model should be saved as a .pbtxt or .pb file.
 *
 * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
 * @param modelDir the directory containing the TensorFlow model files to import
 */
@Override
public ImportedModel importModel(String modelName, String modelDir) {
    // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model.
    if (modelDir.contains("tf_2_onnx")) {
        return convertToOnnxAndImport(modelName, modelDir);
    }
    try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
        return importModel(modelName, modelDir, model);
    }
    catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
    }
}
 
Example #15
Source File: TensorFlowModelProducer.java    From samantha with MIT License 5 votes vote down vote up
public TensorFlowModel createTensorFlowModelModelFromExportDir(
        String modelName,
        SpaceMode spaceMode,
        String exportDir,
        List<String> groupKeys,
        List<List<String>> equalSizeChecks,
        List<String> indexKeys,
        List<FeatureExtractor> featureExtractors,
        String predItemFea,
        String lossOper,
        String updateOper,
        String outputOper,
        String initOper,
        String topKOper,
        String topKId,
        String topKValue,
        String itemIndex) {
    IndexSpace indexSpace = getIndexSpace(modelName, spaceMode, indexKeys);
    VariableSpace variableSpace = getVariableSpace(modelName, spaceMode);
    SavedModelBundle savedModel = loadTensorFlowSavedModel(exportDir);
    Session session = null;
    Graph graph = null;
    if (savedModel != null) {
        session = savedModel.session();
        graph = savedModel.graph();
    }
    return new TensorFlowModel(graph, session, null, exportDir, indexSpace, variableSpace,
            featureExtractors, predItemFea, lossOper, updateOper, topKId, itemIndex,
            topKValue, outputOper, topKOper, initOper, groupKeys, equalSizeChecks);
}
 
Example #16
Source File: TensorFlowModelProducer.java    From samantha with MIT License 5 votes vote down vote up
static public SavedModelBundle loadTensorFlowSavedModel(String exportDir) {
    SavedModelBundle savedModel = null;
    if (new File(exportDir).exists()) {
        savedModel = SavedModelBundle.load(exportDir, TensorFlowModel.SAVED_MODEL_TAG);
    } else {
        logger.warn("TensorFlow exported model dir does not exist: {}.", exportDir);
    }
    return savedModel;
}
 
Example #17
Source File: BlogEvaluationTestCase.java    From vespa with Apache License 2.0 5 votes vote down vote up
@Test
public void testImport() {
    SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
    ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);

    ImportedModel.Signature y = model.signature("serving_default.y");
    assertNotNull(y);
    assertEquals(0, y.inputs().size());
}
 
Example #18
Source File: BlogEvaluationBenchmark.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static void benchmarkTensorFlow(SavedModelBundle tensorFlowModel, int warmup, int iterations) {
    org.tensorflow.Tensor<?> u = generateInputTensorFlow();
    org.tensorflow.Tensor<?> d = generateInputTensorFlow();

    System.out.println("*** TensorFlow evaluation - warmup ***");
    evaluateTensorflow(tensorFlowModel, u, d, warmup);

    System.gc();
    System.out.println("*** TensorFlow evaluation - " + iterations + " iterations ***");
    double startTime = System.nanoTime();
    evaluateTensorflow(tensorFlowModel, u, d, iterations);
    double endTime = System.nanoTime();
    System.out.println("Model evaluation time is " + ((endTime-startTime) / (1000*1000) + " ms"));
    System.out.println("Average model evaluation time is " + ((endTime-startTime) / (1000*1000)) / iterations + " ms");
}
 
Example #19
Source File: BlogEvaluationBenchmark.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static double evaluateTensorflow(SavedModelBundle tensorFlowModel, org.tensorflow.Tensor<?> u, org.tensorflow.Tensor<?> d, int iterations) {
    double result = 0;
    for (int i = 0 ; i < iterations; i++) {
        Session.Runner runner = tensorFlowModel.session().runner();
        runner.feed("input_u", u);
        runner.feed("input_d", d);
        List<org.tensorflow.Tensor<?>> results = runner.fetch("y").run();
        result = TensorConverter.toVespaTensor(results.get(0)).sum().asDouble();
    }
    return result;
}
 
Example #20
Source File: Tf2OnnxImportTestCase.java    From vespa with Apache License 2.0 5 votes vote down vote up
private boolean testModelWithOpset(Report report, int opset, String tfModel) throws IOException {
    String onnxModel = Paths.get(testFolder.getRoot().getAbsolutePath(), "converted.onnx").toString();

    var res = tf2onnxConvert(tfModel, onnxModel, opset);
    if (res.getFirst() != 0) {
        return reportAndFail(report, opset, tfModel, "tf2onnx conversion failed: " + res.getSecond());
    }

    SavedModelBundle tensorFlowModel = SavedModelBundle.load(tfModel, "serve");
    ImportedModel model = new TensorFlowImporter().importModel("test", tfModel, tensorFlowModel);
    ImportedModel onnxImportedModel = new OnnxImporter().importModel("test", onnxModel);

    if (model.signature("serving_default").skippedOutputs().size() > 0) {
        return reportAndFail(report, opset, tfModel, "Failed to import model from TensorFlow due to skipped outputs");
    }
    if (onnxImportedModel.signature("default").skippedOutputs().size() > 0) {
        return reportAndFail(report, opset, tfModel, "Failed to import model from ONNX due to skipped outputs");
    }

    ImportedModel.Signature sig = model.signatures().values().iterator().next();
    String output = sig.outputs().values().iterator().next();
    String onnxOutput = onnxImportedModel.signatures().values().iterator().next().outputs().values().iterator().next();

    Tensor tfResult = evaluateTF(tensorFlowModel, output, model.inputs());
    Tensor vespaResult = evaluateVespa(model, output, model.inputs());
    Tensor onnxResult = evaluateVespa(onnxImportedModel, onnxOutput, model.inputs());

    if ( ! tfResult.equals(vespaResult) ) {
        return reportAndFail(report, opset, tfModel, "Diff between tf and imported tf evaluation:\n\t" + tfResult + "\n\t" + vespaResult);
    }
    if ( ! vespaResult.equals(onnxResult) ) {
        return reportAndFail(report, opset, tfModel, "Diff between imported tf eval and onnx eval:\n\t" + vespaResult + "\n\t" + onnxResult);
    }

    return reportAndSucceed(report, opset, tfModel, "Ok");
}
 
Example #21
Source File: TestableModel.java    From vespa with Apache License 2.0 5 votes vote down vote up
Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) {
    Session.Runner runner = tensorFlowModel.session().runner();
    for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
        try {
            runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
        } catch (Exception e) {
            runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
        }
    }
    List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
    assertEquals(1, results.size());
    return TensorConverter.toVespaTensor(results.get(0));
}
 
Example #22
Source File: InProcessClassification.java    From hazelcast-jet-demos with Apache License 2.0 5 votes vote down vote up
private static Tuple2<String, Float> classify(
        String review, SavedModelBundle model, WordIndex wordIndex
) {
    try (Tensor<Float> input = Tensors.create(wordIndex.createTensorInput(review));
         Tensor<?> output = model.session().runner()
                                 .feed("embedding_input:0", input)
                                 .fetch("dense_1/Sigmoid:0").run().get(0)
    ) {
        float[][] result = new float[1][1];
        output.copyTo(result);
        return tuple2(review, result[0][0]);
    }
}
 
Example #23
Source File: TensorFlowModel.java    From zoltar with Apache License 2.0 5 votes vote down vote up
private static MetaGraphDef extractMetaGraphDefinition(final SavedModelBundle bundle)
    throws TensorflowMetaGraphDefParsingException {
  final MetaGraphDef metaGraphDef;
  try {
    metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());
  } catch (InvalidProtocolBufferException e) {
    throw new TensorflowMetaGraphDefParsingException(
        "Failed parsing tensorflow metagraph " + "definition", e);
  }

  return metaGraphDef;
}
 
Example #24
Source File: TensorFlowModel.java    From zoltar with Apache License 2.0 5 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Returns a TensorFlow model with metadata given {@link SavedModelBundle} export directory URI
 * and {@link Options}.
 */
public static TensorFlowModel create(
    final Model.Id id,
    final URI modelResource,
    final Options options,
    final String signatureDefinition)
    throws IOException {
  // GCS requires that directory URIs have a trailing slash, so add the slash if it's missing
  // and the URI starts with 'gs'.
  final URI normalizedUri =
      !CloudStorageFileSystem.URI_SCHEME.equalsIgnoreCase(modelResource.getScheme())
              || modelResource.toString().endsWith("/")
          ? modelResource
          : URI.create(modelResource.toString() + "/");
  final URI localDir = FileSystemExtras.downloadIfNonLocal(normalizedUri);
  final SavedModelBundle model =
      SavedModelBundle.load(localDir.toString(), options.tags().toArray(new String[0]));
  final MetaGraphDef metaGraphDef;
  try {
    metaGraphDef = extractMetaGraphDefinition(model);
  } catch (TensorflowMetaGraphDefParsingException e) {
    throw new IOException(e);
  }
  final SignatureDef signatureDef = metaGraphDef.getSignatureDefOrThrow(signatureDefinition);

  return new AutoValue_TensorFlowModel(
      id,
      model,
      options,
      metaGraphDef,
      signatureDef,
      toNameMap(signatureDef.getInputsMap()),
      toNameMap(signatureDef.getOutputsMap()));
}
 
Example #25
Source File: Bert.java    From easy-bert with MIT License 5 votes vote down vote up
private Bert(final SavedModelBundle bundle, final ModelDetails model, final Path vocabulary) {
    tokenizer = new FullTokenizer(vocabulary, model.doLowerCase);
    this.bundle = bundle;
    this.model = model;

    final int[] ids = tokenizer.convert(new String[] {START_TOKEN, SEPARATOR_TOKEN});
    startTokenId = ids[0];
    separatorTokenId = ids[1];
}
 
Example #26
Source File: Bert.java    From easy-bert with MIT License 5 votes vote down vote up
/**
 * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities
 *
 * @param path
 *        the path to load the model from
 * @return a ready-to-use BERT model
 * @since 1.0.3
 */
public static Bert load(Path path) {
    path = path.toAbsolutePath();
    ModelDetails model;
    try {
        model = new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class);
    } catch(final IOException e) {
        throw new RuntimeException(e);
    }

    return new Bert(SavedModelBundle.load(path.toString(), "serve"), model, path.resolve("assets").resolve(VOCAB_FILE));
}
 
Example #27
Source File: TfModel.java    From djl with Apache License 2.0 5 votes vote down vote up
/** {@inheritDoc} */
@Override
public void load(Path modelPath, String prefix, Map<String, Object> options)
        throws FileNotFoundException {
    modelDir = modelPath.toAbsolutePath();
    if (prefix == null) {
        prefix = modelName;
    }
    Path exportDir = findModleDir(prefix);
    if (exportDir == null) {
        exportDir = findModleDir("saved_model.pb");
        if (exportDir == null) {
            throw new FileNotFoundException("No TensorFlow model found in: " + modelDir);
        }
    }
    String[] tags = null;
    ConfigProto proto = null;
    RunOptions runOptions = null;
    if (options != null) {
        tags = (String[]) options.get("Tags");
        proto = (ConfigProto) options.get("ConfigProto");
        runOptions = (RunOptions) options.get("RunOptions");
    }
    if (tags == null) {
        tags = new String[] {"serve"};
    }

    SavedModelBundle.Loader loader =
            SavedModelBundle.loader(exportDir.toString()).withTags(tags);
    if (proto != null) {
        loader.withConfigProto(proto);
    }
    if (runOptions != null) {
        loader.withRunOptions(runOptions);
    }

    SavedModelBundle bundle = loader.load();
    block = new TfSymbolBlock(bundle);
}
 
Example #28
Source File: ObjectDetector.java    From OpenLabeler with Apache License 2.0 5 votes vote down vote up
private Void update(Path path) {
    try {
        File savedModelFile = new File(Settings.getTFSavedModelDir());
        if (savedModelFile.exists() && (path == null || "saved_model".equals(path.toString()))) {
            if (path != null) { // coming from watched file
                Thread.sleep(5000); // Wait a while for model to be exported
            }
            synchronized (ObjectDetector.this) {
                if (model != null) {
                    model.close();
                }
                model = SavedModelBundle.load(savedModelFile.getAbsolutePath(), "serve");
                String message = MessageFormat.format(bundle.getString("msg.loadedSavedModel"), savedModelFile);
                LOG.info(message);
                printSignature(model);
                Platform.runLater(() -> statusProperty.set(message));
            }
        }
        else if (!savedModelFile.exists() && path == null) {
            LOG.info(savedModelFile.toString() + " does not exist");
        }
    }
    catch (Exception ex) {
        LOG.log(Level.SEVERE, "Unable to update " + path, ex);
    }
    return null;
}
 
Example #29
Source File: TensorFlowProcessor.java    From datacollector with Apache License 2.0 4 votes vote down vote up
@Override
protected List<ConfigIssue> init() {
  List<ConfigIssue> issues = super.init();
  String[] modelTags = new String[conf.modelTags.size()];
  modelTags = conf.modelTags.toArray(modelTags);

  if (Strings.isNullOrEmpty(conf.modelPath)) {
    issues.add(getContext().createConfigIssue(
        Groups.TENSOR_FLOW.name(),
        TensorFlowConfigBean.MODEL_PATH_CONFIG,
        Errors.TENSOR_FLOW_01
    ));
    return issues;
  }

  try {
    File exportedModelDir = new File(conf.modelPath);
    if (!exportedModelDir.isAbsolute()) {
      exportedModelDir = new File(getContext().getResourcesDirectory(), conf.modelPath).getAbsoluteFile();
    }
    this.savedModel = SavedModelBundle.load(exportedModelDir.getAbsolutePath(), modelTags);
  } catch (TensorFlowException ex) {
    issues.add(getContext().createConfigIssue(
        Groups.TENSOR_FLOW.name(),
        TensorFlowConfigBean.MODEL_PATH_CONFIG,
        Errors.TENSOR_FLOW_02,
        ex
    ));
    return issues;
  }

  this.session = this.savedModel.session();
  this.conf.inputConfigs.forEach(inputConfig -> {
        Pair<String, Integer> key = Pair.of(inputConfig.operation, inputConfig.index);
        inputConfigMap.put(key, inputConfig);
      }
  );

  fieldPathEval = getContext().createELEval("conf.inputConfigs");
  fieldPathVars = getContext().createELVars();

  errorRecordHandler = new DefaultErrorRecordHandler(getContext());

  return issues;
}
 
Example #30
Source File: TensorflowSavedModel.java    From tutorials with MIT License 4 votes vote down vote up
public static void main(String[] args) {
	SavedModelBundle model = SavedModelBundle.load("./model", "serve");
	Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class))
			.feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class);
	System.out.println(tensor.intValue());
}