Java Code Examples for org.tensorflow.SavedModelBundle#load()

The following examples show how to use org.tensorflow.SavedModelBundle#load() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ModelPredictTest.java    From DeepMachineLearning with Apache License 2.0 7 votes vote down vote up
@Test
public void test() throws IOException, InterruptedException, IpssCacheException, ODMException, ExecutionException {
	AclfTrainDataGenerator gateway = new AclfTrainDataGenerator();
	//read case
	String filename = "testdata/cases/ieee14.ieee";
	gateway.loadCase(filename, "BusVoltLoadChangeTrainCaseBuilder");
	//run loadflow
	gateway.trainCaseBuilder.createTestCase();
	//generate input
	double[] inputs = gateway.trainCaseBuilder.getNetInput();
	float[][] inputs_f  = new float[1][inputs.length];
	for (int i = 0; i < inputs.length; i++) {
		inputs_f[0][i] =(float) inputs[i];
	}
	//read model
	SavedModelBundle bundle = SavedModelBundle.load("py/c_graph/single_net/model", "voltage");
	//predict
	float[][] output = bundle.session().runner().feed("x", Tensor.create(inputs_f)).fetch("z").run().get(0)
			.copyTo(new float[1][28]);
	double[][] output_d = new double[1][inputs.length];
	for (int i = 0; i < inputs.length; i++) {
		output_d[0][i] = output[0][i];
	}
	//print out mismatch 
	System.out.println("Model out mismatch: "+gateway.getMismatchInfo(output_d[0]));
}
 
Example 2
Source File: BlogEvaluationBenchmark.java    From vespa with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) throws ParseException {
    SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
    ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);

    Context context = TestableTensorFlowModel.contextFrom(model);
    Tensor u = generateInputTensor();
    Tensor d = generateInputTensor();
    context.put("input_u", new TensorValue(u));
    context.put("input_d", new TensorValue(d));

    // Parse the ranking expression from imported string to force primitive tensor functions.
    RankingExpression expression = new RankingExpression(model.expressions().get("y").getRoot().toString());
    benchmarkJava(expression, context, 20, 200);

    System.out.println("*** Optimizing expression ***");
    ExpressionOptimizer optimizer = new ExpressionOptimizer();
    OptimizationReport report = optimizer.optimize(expression, (ContextIndex)context);
    System.out.println(report.toString());

    benchmarkJava(expression, context, 2000, 20000);
    benchmarkTensorFlow(tensorFlowModel, 2000, 20000);
}
 
Example 3
Source File: ObjectDetector.java    From OpenLabeler with Apache License 2.0 5 votes vote down vote up
private Void update(Path path) {
    try {
        File savedModelFile = new File(Settings.getTFSavedModelDir());
        if (savedModelFile.exists() && (path == null || "saved_model".equals(path.toString()))) {
            if (path != null) { // coming from watched file
                Thread.sleep(5000); // Wait a while for model to be exported
            }
            synchronized (ObjectDetector.this) {
                if (model != null) {
                    model.close();
                }
                model = SavedModelBundle.load(savedModelFile.getAbsolutePath(), "serve");
                String message = MessageFormat.format(bundle.getString("msg.loadedSavedModel"), savedModelFile);
                LOG.info(message);
                printSignature(model);
                Platform.runLater(() -> statusProperty.set(message));
            }
        }
        else if (!savedModelFile.exists() && path == null) {
            LOG.info(savedModelFile.toString() + " does not exist");
        }
    }
    catch (Exception ex) {
        LOG.log(Level.SEVERE, "Unable to update " + path, ex);
    }
    return null;
}
 
Example 4
Source File: Bert.java    From easy-bert with MIT License 5 votes vote down vote up
/**
 * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities
 *
 * @param path
 *        the path to load the model from
 * @return a ready-to-use BERT model
 * @since 1.0.3
 */
public static Bert load(Path path) {
    path = path.toAbsolutePath();
    ModelDetails model;
    try {
        model = new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class);
    } catch(final IOException e) {
        throw new RuntimeException(e);
    }

    return new Bert(SavedModelBundle.load(path.toString(), "serve"), model, path.resolve("assets").resolve(VOCAB_FILE));
}
 
Example 5
Source File: TensorFlowModel.java    From zoltar with Apache License 2.0 5 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Returns a TensorFlow model with metadata given {@link SavedModelBundle} export directory URI
 * and {@link Options}.
 */
public static TensorFlowModel create(
    final Model.Id id,
    final URI modelResource,
    final Options options,
    final String signatureDefinition)
    throws IOException {
  // GCS requires that directory URIs have a trailing slash, so add the slash if it's missing
  // and the URI starts with 'gs'.
  final URI normalizedUri =
      !CloudStorageFileSystem.URI_SCHEME.equalsIgnoreCase(modelResource.getScheme())
              || modelResource.toString().endsWith("/")
          ? modelResource
          : URI.create(modelResource.toString() + "/");
  final URI localDir = FileSystemExtras.downloadIfNonLocal(normalizedUri);
  final SavedModelBundle model =
      SavedModelBundle.load(localDir.toString(), options.tags().toArray(new String[0]));
  final MetaGraphDef metaGraphDef;
  try {
    metaGraphDef = extractMetaGraphDefinition(model);
  } catch (TensorflowMetaGraphDefParsingException e) {
    throw new IOException(e);
  }
  final SignatureDef signatureDef = metaGraphDef.getSignatureDefOrThrow(signatureDefinition);

  return new AutoValue_TensorFlowModel(
      id,
      model,
      options,
      metaGraphDef,
      signatureDef,
      toNameMap(signatureDef.getInputsMap()),
      toNameMap(signatureDef.getOutputsMap()));
}
 
Example 6
Source File: TensorFlowModelProducer.java    From samantha with MIT License 5 votes vote down vote up
static public SavedModelBundle loadTensorFlowSavedModel(String exportDir) {
    SavedModelBundle savedModel = null;
    if (new File(exportDir).exists()) {
        savedModel = SavedModelBundle.load(exportDir, TensorFlowModel.SAVED_MODEL_TAG);
    } else {
        logger.warn("TensorFlow exported model dir does not exist: {}.", exportDir);
    }
    return savedModel;
}
 
Example 7
Source File: TensorFlowImporter.java    From vespa with Apache License 2.0 5 votes vote down vote up
/**
 * Imports a saved TensorFlow model from a directory.
 * The model should be saved as a .pbtxt or .pb file.
 *
 * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
 * @param modelDir the directory containing the TensorFlow model files to import
 */
@Override
public ImportedModel importModel(String modelName, String modelDir) {
    // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model.
    if (modelDir.contains("tf_2_onnx")) {
        return convertToOnnxAndImport(modelName, modelDir);
    }
    try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
        return importModel(modelName, modelDir, model);
    }
    catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
    }
}
 
Example 8
Source File: VariableConverter.java    From vespa with Apache License 2.0 5 votes vote down vote up
/**
 * Reads the tensor with the given TensorFlow name at the given model location,
 * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type.
 * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
 * tensor dimensions are implicitly ordered.
 */
static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
    try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
        return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
                                                                                               bundle),
                                                               OrderedTensorType.fromSpec(orderedTypeSpec)));
    }
    catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
    }
}
 
Example 9
Source File: TestableTensorFlowModel.java    From vespa with Apache License 2.0 5 votes vote down vote up
public TestableTensorFlowModel(String modelName, String modelDir, int d0Size, int d1Size, boolean floatInput) {
    tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
    model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel);
    this.d0Size = d0Size;
    this.d1Size = d1Size;
    this.floatInput = floatInput;
}
 
Example 10
Source File: BlogEvaluationTestCase.java    From vespa with Apache License 2.0 5 votes vote down vote up
@Test
public void testImport() {
    SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
    ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);

    ImportedModel.Signature y = model.signature("serving_default.y");
    assertNotNull(y);
    assertEquals(0, y.inputs().size());
}
 
Example 11
Source File: Tf2OnnxImportTestCase.java    From vespa with Apache License 2.0 5 votes vote down vote up
private boolean testModelWithOpset(Report report, int opset, String tfModel) throws IOException {
    String onnxModel = Paths.get(testFolder.getRoot().getAbsolutePath(), "converted.onnx").toString();

    var res = tf2onnxConvert(tfModel, onnxModel, opset);
    if (res.getFirst() != 0) {
        return reportAndFail(report, opset, tfModel, "tf2onnx conversion failed: " + res.getSecond());
    }

    SavedModelBundle tensorFlowModel = SavedModelBundle.load(tfModel, "serve");
    ImportedModel model = new TensorFlowImporter().importModel("test", tfModel, tensorFlowModel);
    ImportedModel onnxImportedModel = new OnnxImporter().importModel("test", onnxModel);

    if (model.signature("serving_default").skippedOutputs().size() > 0) {
        return reportAndFail(report, opset, tfModel, "Failed to import model from TensorFlow due to skipped outputs");
    }
    if (onnxImportedModel.signature("default").skippedOutputs().size() > 0) {
        return reportAndFail(report, opset, tfModel, "Failed to import model from ONNX due to skipped outputs");
    }

    ImportedModel.Signature sig = model.signatures().values().iterator().next();
    String output = sig.outputs().values().iterator().next();
    String onnxOutput = onnxImportedModel.signatures().values().iterator().next().outputs().values().iterator().next();

    Tensor tfResult = evaluateTF(tensorFlowModel, output, model.inputs());
    Tensor vespaResult = evaluateVespa(model, output, model.inputs());
    Tensor onnxResult = evaluateVespa(onnxImportedModel, onnxOutput, model.inputs());

    if ( ! tfResult.equals(vespaResult) ) {
        return reportAndFail(report, opset, tfModel, "Diff between tf and imported tf evaluation:\n\t" + tfResult + "\n\t" + vespaResult);
    }
    if ( ! vespaResult.equals(onnxResult) ) {
        return reportAndFail(report, opset, tfModel, "Diff between imported tf eval and onnx eval:\n\t" + vespaResult + "\n\t" + onnxResult);
    }

    return reportAndSucceed(report, opset, tfModel, "Ok");
}
 
Example 12
Source File: EstimatorTest.java    From jpmml-tensorflow with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate){
	ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate){

		@Override
		public IntegrationTest getIntegrationTest(){
			return EstimatorTest.this;
		}

		@Override
		public PMML getPMML() throws Exception {
			File savedModelDir = getSavedModelDir();

			SavedModelBundle bundle = SavedModelBundle.load(savedModelDir.getAbsolutePath(), "serve");

			try(SavedModel savedModel = new SavedModel(bundle)){
				EstimatorFactory estimatorFactory = EstimatorFactory.newInstance();

				Estimator estimator = estimatorFactory.newEstimator(savedModel);

				PMML pmml = estimator.encodePMML();

				ensureValidity(pmml);

				return pmml;
			}
		}

		private File getSavedModelDir() throws IOException, URISyntaxException {
			ClassLoader classLoader = (EstimatorTest.this.getClass()).getClassLoader();

			String protoPath = ("savedmodel/" + getName() + getDataset() + "/saved_model.pbtxt");

			URL protoResource = classLoader.getResource(protoPath);
			if(protoResource == null){
				throw new NoSuchFileException(protoPath);
			}

			File protoFile = (Paths.get(protoResource.toURI())).toFile();

			return protoFile.getParentFile();
		}
	};

	return result;
}
 
Example 13
Source File: TensorFlowProcessor.java    From datacollector with Apache License 2.0 4 votes vote down vote up
@Override
protected List<ConfigIssue> init() {
  List<ConfigIssue> issues = super.init();
  String[] modelTags = new String[conf.modelTags.size()];
  modelTags = conf.modelTags.toArray(modelTags);

  if (Strings.isNullOrEmpty(conf.modelPath)) {
    issues.add(getContext().createConfigIssue(
        Groups.TENSOR_FLOW.name(),
        TensorFlowConfigBean.MODEL_PATH_CONFIG,
        Errors.TENSOR_FLOW_01
    ));
    return issues;
  }

  try {
    File exportedModelDir = new File(conf.modelPath);
    if (!exportedModelDir.isAbsolute()) {
      exportedModelDir = new File(getContext().getResourcesDirectory(), conf.modelPath).getAbsoluteFile();
    }
    this.savedModel = SavedModelBundle.load(exportedModelDir.getAbsolutePath(), modelTags);
  } catch (TensorFlowException ex) {
    issues.add(getContext().createConfigIssue(
        Groups.TENSOR_FLOW.name(),
        TensorFlowConfigBean.MODEL_PATH_CONFIG,
        Errors.TENSOR_FLOW_02,
        ex
    ));
    return issues;
  }

  this.session = this.savedModel.session();
  this.conf.inputConfigs.forEach(inputConfig -> {
        Pair<String, Integer> key = Pair.of(inputConfig.operation, inputConfig.index);
        inputConfigMap.put(key, inputConfig);
      }
  );

  fieldPathEval = getContext().createELEval("conf.inputConfigs");
  fieldPathVars = getContext().createELVars();

  errorRecordHandler = new DefaultErrorRecordHandler(getContext());

  return issues;
}
 
Example 14
Source File: TensorflowSavedModel.java    From tutorials with MIT License 4 votes vote down vote up
public static void main(String[] args) {
	SavedModelBundle model = SavedModelBundle.load("./model", "serve");
	Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class))
			.feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class);
	System.out.println(tensor.intValue());
}