Java Code Examples for org.tensorflow.Graph#importGraphDef()

The following examples show how to use org.tensorflow.Graph#importGraphDef() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorFlowGraphModel.java    From zoltar with Apache License 2.0 6 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config ConfigProto config for TensorFlow {@link Session}.
 * @param prefix a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(
    final Model.Id id,
    final byte[] graphDef,
    @Nullable final ConfigProto config,
    @Nullable final String prefix) {
  final Graph graph = new Graph();
  final Session session = new Session(graph, config != null ? config.toByteArray() : null);
  final long loadStart = System.currentTimeMillis();
  if (prefix == null) {
    LOG.debug("Loading graph definition without prefix");
    graph.importGraphDef(graphDef);
  } else {
    LOG.debug("Loading graph definition with prefix: {}", prefix);
    graph.importGraphDef(graphDef, prefix);
  }
  LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart);
  return new AutoValue_TensorFlowGraphModel(id, graph, session);
}
 
Example 2
Source File: AbstractClassifier.java    From tensorboot with Apache License 2.0 5 votes vote down vote up
/**
 * Initialize classifier
 * @param graphBytes Model graph binary data
 * @param inputLayerName  Input layer name
 * @param outputLayerName Output layer name
 */
public void init(byte[] graphBytes, String inputLayerName, String outputLayerName) {
    Assert.notNull(graphBytes, "Model data shouldn't be null");
    Assert.notNull(inputLayerName, "Input layer name shouldn't be null");
    Assert.notNull(outputLayerName, "Output layer name shouldn't be null");

    model = new Graph();
    model.importGraphDef(graphBytes);
    this.inputLayerName = inputLayerName;
    this.outputLayerName = outputLayerName;
}
 
Example 3
Source File: TensorFlowService.java    From tensorflow-spring-cloud-stream-app-starters with Apache License 2.0 5 votes vote down vote up
public TensorFlowService(Resource modelLocation) throws IOException {
	try (InputStream is = modelLocation.getInputStream()) {
		graph = new Graph();
		logger.info("Loading TensorFlow graph model: " + modelLocation);
		graph.importGraphDef(toByteArray(buffer(is)));
		logger.info("TensorFlow Graph Model Ready To Serve!");
	}
}
 
Example 4
Source File: RNTensorFlowGraphModule.java    From react-native-tensorflow with Apache License 2.0 5 votes vote down vote up
@ReactMethod
public void importGraphDefWithPrefix(String id, String graphDef, String prefix, Promise promise) {
    try {
        Graph graph = graphs.get(id);
        graph.importGraphDef(Base64.decode(graphDef, Base64.DEFAULT), prefix);
        promise.resolve(true);
    } catch (Exception e) {
        promise.reject(e);
    }
}
 
Example 5
Source File: RNTensorflowInference.java    From react-native-tensorflow with Apache License 2.0 5 votes vote down vote up
private static TfContext createContext(ReactContext reactContext, String model) throws IOException {
    byte[] b = new ResourceManager(reactContext).loadResource(model);

    Graph graph = new Graph();
    graph.importGraphDef(b);
    Session session = new Session(graph);
    Session.Runner runner = session.runner();

    return new TfContext(session, runner, graph);
}
 
Example 6
Source File: TensorFlowService.java    From tensorflow with Apache License 2.0 5 votes vote down vote up
public TensorFlowService(Resource modelLocation) {
	if (logger.isInfoEnabled()) {
		logger.info("Loading TensorFlow graph model: " + modelLocation);
	}
	graph = new Graph();
	byte[] model = new ModelExtractor().getModel(modelLocation);
	graph.importGraphDef(model);
}
 
Example 7
Source File: YOLO.java    From cineast with MIT License 5 votes vote down vote up
public YOLO() {
  byte[] GRAPH_DEF = new byte[0];
  try {
    GRAPH_DEF = Files
        .readAllBytes((Paths.get("resources/YOLO/yolo-voc.pb")));
  } catch (IOException e) {
    throw new RuntimeException(
        "could not load graph for YOLO: " + LogHelper.getStackTrace(e));
  }
  yoloGraph = new Graph();
  yoloGraph.importGraphDef(GRAPH_DEF);
  yoloSession = new Session(yoloGraph);

  preprocessingGraph = new Graph();

  GraphBuilder graphBuilder = new GraphBuilder(preprocessingGraph);

  Output<Float> imageFloat = graphBuilder.placeholder("T", Float.class);

  final int[] size = new int[]{416, 416};

  final Output<Float> output =

      graphBuilder.resizeBilinear( // Resize using bilinear interpolation
          graphBuilder.expandDims( // Increase the output tensors dimension
              imageFloat,
              graphBuilder.constant("make_batch", 0)),
          graphBuilder.constant("size", size)
      );

  imageOutName = output.op().name();

  preprocessingSession = new Session(preprocessingGraph);

}
 
Example 8
Source File: FaceRecognizer.java    From server_face_recognition with GNU General Public License v3.0 4 votes vote down vote up
private FaceRecognizer() {
    graph = new Graph();

    graph.importGraphDef(loadGraphDef());
    faceDetector = UserFaceDetector.create();
}
 
Example 9
Source File: DLSegment.java    From orbit-image-analysis with GNU General Public License v3.0 4 votes vote down vote up
public static Session buildSessionBytes(byte[] graphDef) {
    Graph g = new Graph();
    g.importGraphDef(graphDef);
    Session s = new Session(g);
    return s;
}
 
Example 10
Source File: MRCNNBrainDetector.java    From orbit-image-analysis with GNU General Public License v3.0 4 votes vote down vote up
public Graph loadGraph(byte[] graphDef) {
    logger.info("TF version "+TensorFlow.version());
    Graph g = new Graph();
    g.importGraphDef(graphDef);
    return g;
}
 
Example 11
Source File: MRCNNCorpusCallosum.java    From orbit-image-analysis with GNU General Public License v3.0 4 votes vote down vote up
public Graph loadGraph(byte[] graphDef) {
    logger.info("TF version "+TensorFlow.version());
    Graph g = new Graph();
    g.importGraphDef(graphDef);
    return g;
}
 
Example 12
Source File: InstSegMaskRCNN.java    From orbit-image-analysis with GNU General Public License v3.0 4 votes vote down vote up
public Graph loadGraph(byte[] graphDef) {
    logger.info("TF version "+TensorFlow.version());
    Graph g = new Graph();
    g.importGraphDef(graphDef);
    return g;
}
 
Example 13
Source File: Inception5h.java    From cineast with MIT License 4 votes vote down vote up
public Inception5h(List<String> outputOperations) {

    byte[] graphDef = new byte[0];
    try {
      graphDef = Files
          .readAllBytes((Paths.get("resources/inception5h/tensorflow_inception_graph.pb")));
    } catch (IOException e) {
      throw new RuntimeException(
          "could not load graph for Inception5h: " + LogHelper.getStackTrace(e));
    }
    classificationGraph = new Graph();
    classificationGraph.importGraphDef(graphDef);
    classificationSession = new Session(classificationGraph);

    preprocessingGraph = new Graph();
    GraphBuilder b = new GraphBuilder(preprocessingGraph);
    preProcessingSession = new Session(preprocessingGraph);

    final int H = 224;
    final int W = 224;

    Output<Float> imageFloat = b.placeholder("T", Float.class);
    output =
        b.resizeBilinear(
            b.expandDims(
                imageFloat,
                b.constant("make_batch", 0)),
            b.constant("size", new int[]{H, W}));

    if (outputOperations != null && !outputOperations.isEmpty()) {
      this.outputOperations = new ArrayList<>();
      this.outputOperations.addAll(
          GraphHelper.filterOperations(outputOperations, classificationGraph)
      );

    } else {
      this.outputOperations = new ArrayList<>(1);
      this.outputOperations.add("output2"); //default output
    }

  }