/* * Copyright (C) 2019 Spotify AB * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.spotify.zoltar.tf; import java.io.IOException; import java.net.URI; import java.nio.file.Files; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.framework.ConfigProto; import com.google.auto.value.AutoValue; import com.spotify.zoltar.Model; import com.spotify.zoltar.fs.FileSystemExtras; /** * This model can be used to load protobuf definition of a TensorFlow {@link Graph}. * * <p>For an easy model freezing function see <a * href="https://github.com/spotify/spotify-tensorflow/blob/master/spotify_tensorflow/freeze_graph.py">spotify-tensorflow</a> * * <p>TensorFlowGraphModel is thread-safe. */ @AutoValue public abstract class TensorFlowGraphModel implements Model<Session>, AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(TensorFlowGraphModel.class); private static final Model.Id DEFAULT_ID = Id.create("tensorflow-graph"); /** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param graphUri URI to the TensorFlow graph definition. * @param config config for TensorFlow {@link Session}. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphModel create( final URI graphUri, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException { return create(DEFAULT_ID, graphUri, config, prefix); } /** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param graphUri URI to the TensorFlow graph definition. * @param config config for TensorFlow {@link Session}. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphModel create( final Model.Id id, final URI graphUri, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException { final byte[] graphBytes = Files.readAllBytes(FileSystemExtras.path(graphUri)); return create(id, graphBytes, config, prefix); } /** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config ConfigProto config for TensorFlow {@link Session}. * @param prefix a prefix that will be prepended to names in graphDef. */ public static TensorFlowGraphModel create( final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException { return create(DEFAULT_ID, graphDef, config, prefix); } /** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config ConfigProto config for TensorFlow {@link Session}. * @param prefix a prefix that will be prepended to names in graphDef. */ public static TensorFlowGraphModel create( final Model.Id id, final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) { final Graph graph = new Graph(); final Session session = new Session(graph, config != null ? config.toByteArray() : null); final long loadStart = System.currentTimeMillis(); if (prefix == null) { LOG.debug("Loading graph definition without prefix"); graph.importGraphDef(graphDef); } else { LOG.debug("Loading graph definition with prefix: {}", prefix); graph.importGraphDef(graphDef, prefix); } LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart); return new AutoValue_TensorFlowGraphModel(id, graph, session); } /** Close the model. */ @Override public void close() { if (instance() != null) { LOG.debug("Closing TensorFlow session"); instance().close(); } if (graph() != null) { LOG.debug("Closing TensorFlow graph"); graph().close(); } } /** Returns TensorFlow graph. */ public abstract Graph graph(); /** Returns TensorFlow {@link Session}. */ @Override public abstract Session instance(); }