org.tensorflow.framework.ConfigProto Java Examples

The following examples show how to use org.tensorflow.framework.ConfigProto. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: GraphRunner.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static org.tensorflow.framework.ConfigProto getAlignedWithNd4j() {
    org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance();
    ConfigProto.Builder builder1 = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread());
    try {
        //cuda
        if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("jcu")) {
            builder1.setGpuOptions(org.tensorflow.framework.GPUOptions.newBuilder()
                    .setAllowGrowth(true)
                    .setPerProcessGpuMemoryFraction(0.5)
                    .build());
        }
        //cpu
        else {
        }

    } catch (Exception e) {
        log.error("",e);
    }

    return builder1.build();
}
 
Example #2
Source File: GraphRunner.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Convert a json string written out
 * by {@link org.nd4j.shade.protobuf.util.JsonFormat}
 * to a {@link org.bytedeco.tensorflow.ConfigProto}
 * @param json the json to read
 * @return the config proto to use
 */
public static org.tensorflow.framework.ConfigProto fromJson(String json) {
    org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder();
    try {
        org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json,builder);
        org.tensorflow.framework.ConfigProto build = builder.build();
        org.nd4j.shade.protobuf.ByteString serialized = build.toByteString();
        byte[] binaryString = serialized.toByteArray();
        org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.parseFrom(binaryString);
        return configProto;
    } catch (Exception e) {
        log.error("",e);
    }

    return null;
}
 
Example #3
Source File: TensorFlowGraphModel.java    From zoltar with Apache License 2.0 6 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config ConfigProto config for TensorFlow {@link Session}.
 * @param prefix a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(
    final Model.Id id,
    final byte[] graphDef,
    @Nullable final ConfigProto config,
    @Nullable final String prefix) {
  final Graph graph = new Graph();
  final Session session = new Session(graph, config != null ? config.toByteArray() : null);
  final long loadStart = System.currentTimeMillis();
  if (prefix == null) {
    LOG.debug("Loading graph definition without prefix");
    graph.importGraphDef(graphDef);
  } else {
    LOG.debug("Loading graph definition with prefix: {}", prefix);
    graph.importGraphDef(graphDef, prefix);
  }
  LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart);
  return new AutoValue_TensorFlowGraphModel(id, graph, session);
}
 
Example #4
Source File: GraphRunnerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private void runGraphRunnerTest(GraphRunner graphRunner) throws Exception {
    String json = graphRunner.sessionOptionsToJson();
    if( json != null ) {
        org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder();
        JsonFormat.parser().merge(json, builder);
        org.tensorflow.framework.ConfigProto build = builder.build();
        assertEquals(build,graphRunner.getSessionOptionsConfigProto());
    }
    assertNotNull(graphRunner.getInputOrder());
    assertNotNull(graphRunner.getOutputOrder());


    org.tensorflow.framework.ConfigProto configProto1 = json == null ? null : GraphRunner.fromJson(json);

    assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1);
    assertEquals(2,graphRunner.getInputOrder().size());
    assertEquals(1,graphRunner.getOutputOrder().size());

    INDArray input1 = Nd4j.linspace(1,4,4).reshape(4);
    INDArray input2 = Nd4j.linspace(1,4,4).reshape(4);

    Map<String,INDArray> inputs = new LinkedHashMap<>();
    inputs.put("input_0",input1);
    inputs.put("input_1",input2);

    for(int i = 0; i < 2; i++) {
        Map<String,INDArray> outputs = graphRunner.run(inputs);

        INDArray assertion = input1.add(input2);
        assertEquals(assertion,outputs.get("output"));
    }

}
 
Example #5
Source File: GraphRunnerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static ConfigProto getConfig(){
    String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
    if("CUDA".equalsIgnoreCase(backend)) {
        org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance();
        ConfigProto.Builder b = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread());
        return b.setGpuOptions(GPUOptions.newBuilder()
                .setAllowGrowth(true)
                .setPerProcessGpuMemoryFraction(0.5)
                .build()).build();
    }
    return null;
}
 
Example #6
Source File: TFServer.java    From TensorFlowOnYARN with Apache License 2.0 5 votes vote down vote up
public static TFServer createLocalServer() {
  HashMap<String, List<String>> cluster = new HashMap<String, List<String>>();
  List<String> address_list = new ArrayList<String>();
  address_list.add("localhost:0");
  cluster.put("worker", address_list);
  ClusterSpec cluster_spec = new ClusterSpec(cluster);
  return new TFServer(cluster_spec, "worker", 0, "grpc", ConfigProto.getDefaultInstance());
}
 
Example #7
Source File: MtcnnService.java    From mtcnn-java with Apache License 2.0 5 votes vote down vote up
private GraphRunner createGraphRunner(String tensorflowModelUri, String inputLabel) {
	try {
		return new GraphRunner(
				IOUtils.toByteArray(new DefaultResourceLoader().getResource(tensorflowModelUri).getInputStream()),
				Arrays.asList(inputLabel),
				ConfigProto.getDefaultInstance());
	}
	catch (IOException e) {
		throw new IllegalStateException(String.format("Failed to load TF model [%s] and input [%s]:",
				tensorflowModelUri, inputLabel), e);
	}
}
 
Example #8
Source File: GraphRunner.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * The constructor for creating a graph runner via builder
 * @param inputNames the input names to use
 * @param outputNames the output names to use
 * @param savedModelConfig the saved model configuration to load from (note this can not be used in conjunction
 *                         with graph path)
 * @param sessionOptionsConfigProto the session options for running the model (this maybe null)
 * @param sessionOptionsProtoBytes the proto bytes equivalent of the session configuration
 * @param sessionOptionsProtoPath the file path to a session configuration proto file
 * @param graph the tensorflow graph to use
 * @param graphPath the path to the graph
 * @param graphBytes the in memory bytes of the graph
 * @param inputDataTypes the expected input data types
 * @param outputDataTypes the expected output data types
 */



@Builder
public GraphRunner(List<String> inputNames,
                   List<String> outputNames,
                   SavedModelConfig savedModelConfig,
                   org.tensorflow.framework.ConfigProto sessionOptionsConfigProto,
                   byte[] sessionOptionsProtoBytes,
                   File sessionOptionsProtoPath,
                   TF_Graph graph,
                   File graphPath,
                   byte[] graphBytes,
                   Map<String, TensorDataType> inputDataTypes,
                   Map<String, TensorDataType> outputDataTypes) {
    try {
        if(sessionOptionsConfigProto == null) {
            if(sessionOptionsConfigProto != null) {
                this.sessionOptionsConfigProto = ConfigProto.parseFrom(sessionOptionsProtoBytes);
            }
            else if(sessionOptionsProtoPath != null) {
                byte[] load = FileUtils.readFileToByteArray(sessionOptionsProtoPath);
                this.sessionOptionsConfigProto = ConfigProto.parseFrom(load);
            }
        }
        else
            this.sessionOptionsConfigProto = sessionOptionsConfigProto;


        this.inputDataTypes = inputDataTypes;
        this.outputDataTypes = outputDataTypes;
        //note that the input and output order, maybe null here
        //if the names are specified, we should defer to those instead
        this.inputOrder = inputNames;
        this.outputOrder = outputNames;
        initOptionsIfNeeded();

        if(graph != null) {
            this.graph = graph;
        }
        else if(graphBytes != null) {
            this.graph = conversion.loadGraph(graphBytes, status);
        }
        else if(graphPath != null) {
            graphBytes = IOUtils.toByteArray(graphPath.toURI());
            this.graph = conversion.loadGraph(graphBytes, status);
        }
        else
            this.graph = TF_NewGraph();

        if(savedModelConfig != null) {
            this.savedModelConfig = savedModelConfig;
            Map<String,String> inputsMap = new LinkedHashMap<>();
            Map<String,String> outputsMap = new LinkedHashMap<>();

            this.session = conversion.loadSavedModel(savedModelConfig, options, null, this.graph, inputsMap, outputsMap, status);

            if(inputOrder == null || inputOrder.isEmpty())
                inputOrder = new ArrayList<>(inputsMap.values());
            if(outputOrder == null || outputOrder.isEmpty())
                outputOrder = new ArrayList<>(outputsMap.values());

            savedModelConfig.setSavedModelInputOrder(new ArrayList<>(inputsMap.values()));
            savedModelConfig.setSaveModelOutputOrder(new ArrayList<>(outputsMap.values()));
            log.info("Loaded input names from saved model configuration " + inputOrder);
            log.info("Loaded output names from saved model configuration " + outputOrder);

        }


        initSessionAndStatusIfNeeded(graphBytes);
    } catch (Exception e) {
        throw new IllegalArgumentException("Unable to parse protobuf",e);
    }
}
 
Example #9
Source File: TFServer.java    From TensorFlowOnYARN with Apache License 2.0 4 votes vote down vote up
public static ServerDef makeServerDef(ClusterSpec clusterSpec, String jobName,
    int taskIndex, String proto, ConfigProto config) {
  return ServerDef.newBuilder().setCluster(clusterSpec.as_cluster_def())
      .setJobName(jobName).setProtocol(proto).setTaskIndex(taskIndex)
      .setDefaultSessionConfig(config).build();
}
 
Example #10
Source File: TFServer.java    From TensorFlowOnYARN with Apache License 2.0 4 votes vote down vote up
public static ServerDef makeServerDef(ServerDef serverDef, String jobName,
    int taskIndex, String proto, ConfigProto config) {
  return ServerDef.newBuilder().mergeFrom(serverDef).setJobName(jobName)
      .setTaskIndex(taskIndex).setProtocol(proto).setDefaultSessionConfig(config).build();
}
 
Example #11
Source File: TFServer.java    From TensorFlowOnYARN with Apache License 2.0 4 votes vote down vote up
public TFServer(Map<String, List<String>> clusterSpec, String jobName, int taskIndex)
    throws TFServerException {
  this(new ClusterSpec(clusterSpec), jobName, taskIndex,
      "grpc", ConfigProto.getDefaultInstance());
}
 
Example #12
Source File: TensorFlowGraphLoader.java    From zoltar with Apache License 2.0 3 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
static TensorFlowGraphLoader create(
    final Model.Id id,
    final byte[] graphDef,
    @Nullable final ConfigProto config,
    @Nullable final String prefix) {
  return create(() -> TensorFlowGraphModel.create(id, graphDef, config, prefix));
}
 
Example #13
Source File: TensorFlowGraphLoader.java    From zoltar with Apache License 2.0 3 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on
 *     local filesystem, resource, GCS etc.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
static TensorFlowGraphLoader create(
    final Model.Id id,
    final String modelUri,
    @Nullable final ConfigProto config,
    @Nullable final String prefix) {
  return create(() -> TensorFlowGraphModel.create(id, URI.create(modelUri), config, prefix));
}
 
Example #14
Source File: TensorFlowGraphModel.java    From zoltar with Apache License 2.0 3 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param graphUri URI to the TensorFlow graph definition.
 * @param config config for TensorFlow {@link Session}.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
public static TensorFlowGraphModel create(
    final Model.Id id,
    final URI graphUri,
    @Nullable final ConfigProto config,
    @Nullable final String prefix)
    throws IOException {
  final byte[] graphBytes = Files.readAllBytes(FileSystemExtras.path(graphUri));
  return create(id, graphBytes, config, prefix);
}
 
Example #15
Source File: Models.java    From zoltar with Apache License 2.0 3 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
public static TensorFlowGraphLoader tensorFlowGraph(
    final Model.Id id,
    final byte[] graphDef,
    @Nullable final ConfigProto config,
    @Nullable final String prefix) {
  return TensorFlowGraphLoader.create(id, graphDef, config, prefix);
}
 
Example #16
Source File: Models.java    From zoltar with Apache License 2.0 3 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on
 *     local filesystem, resource, GCS etc.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
public static TensorFlowGraphLoader tensorFlowGraph(
    final Model.Id id,
    final String modelUri,
    @Nullable final ConfigProto config,
    @Nullable final String prefix) {
  return TensorFlowGraphLoader.create(id, modelUri, config, prefix);
}
 
Example #17
Source File: Models.java    From zoltar with Apache License 2.0 2 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on
 *     local filesystem, resource, GCS etc.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
public static TensorFlowGraphLoader tensorFlowGraph(
    final String modelUri, @Nullable final ConfigProto config, @Nullable final String prefix) {
  return TensorFlowGraphLoader.create(modelUri, config, prefix);
}
 
Example #18
Source File: TensorFlowGraphLoader.java    From zoltar with Apache License 2.0 2 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
static TensorFlowGraphLoader create(
    final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) {
  return create(() -> TensorFlowGraphModel.create(graphDef, config, prefix));
}
 
Example #19
Source File: TensorFlowGraphLoader.java    From zoltar with Apache License 2.0 2 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on
 *     local filesystem, resource, GCS etc.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
static TensorFlowGraphLoader create(
    final String modelUri, @Nullable final ConfigProto config, @Nullable final String prefix) {
  return create(() -> TensorFlowGraphModel.create(URI.create(modelUri), config, prefix));
}
 
Example #20
Source File: TensorFlowGraphModel.java    From zoltar with Apache License 2.0 2 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
 *
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config ConfigProto config for TensorFlow {@link Session}.
 * @param prefix a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(
    final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix)
    throws IOException {
  return create(DEFAULT_ID, graphDef, config, prefix);
}
 
Example #21
Source File: TensorFlowGraphModel.java    From zoltar with Apache License 2.0 2 votes vote down vote up
/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
 *
 * @param graphUri URI to the TensorFlow graph definition.
 * @param config config for TensorFlow {@link Session}.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
public static TensorFlowGraphModel create(
    final URI graphUri, @Nullable final ConfigProto config, @Nullable final String prefix)
    throws IOException {
  return create(DEFAULT_ID, graphUri, config, prefix);
}
 
Example #22
Source File: Models.java    From zoltar with Apache License 2.0 2 votes vote down vote up
/**
 * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}.
 *
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config optional TensorFlow {@link ConfigProto} config.
 * @param prefix optional prefix that will be prepended to names in the graph.
 */
public static TensorFlowGraphLoader tensorFlowGraph(
    final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) {
  return TensorFlowGraphLoader.create(graphDef, config, prefix);
}