# Lint as: python3 # Copyright 2019 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Converts an image between PNG and TFCI formats. Use this script to compress images with pre-trained models as published. See the 'models' subcommand for a list of available models. """ import argparse import os import sys import urllib from absl import app from absl.flags import argparse_flags import tensorflow.compat.v1 as tf import tensorflow_compression as tfc # pylint:disable=unused-import # Default URL to fetch metagraphs from. URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs" # Default location to store cached metagraphs. METAGRAPH_CACHE = "/tmp/tfc_metagraphs" def read_png(filename): """Creates graph to load a PNG image file.""" string = tf.io.read_file(filename) image = tf.image.decode_image(string) image = tf.expand_dims(image, 0) return image def write_png(filename, image): """Creates graph to write a PNG image file.""" image = tf.squeeze(image, 0) if image.dtype.is_floating: image = tf.round(image) if image.dtype != tf.uint8: image = tf.saturate_cast(image, tf.uint8) string = tf.image.encode_png(image) return tf.io.write_file(filename, string) def load_cached(filename): """Downloads and caches files from web storage.""" pathname = os.path.join(METAGRAPH_CACHE, filename) try: with tf.io.gfile.GFile(pathname, "rb") as f: string = f.read() except tf.errors.NotFoundError: url = URL_PREFIX + "/" + filename try: request = urllib.request.urlopen(url) string = request.read() finally: request.close() tf.io.gfile.makedirs(os.path.dirname(pathname)) with tf.io.gfile.GFile(pathname, "wb") as f: f.write(string) return string def import_metagraph(model): """Imports a trained model metagraph into the current graph.""" string = load_cached(model + ".metagraph") metagraph = tf.MetaGraphDef() metagraph.ParseFromString(string) tf.train.import_meta_graph(metagraph) return metagraph.signature_def def instantiate_signature(signature_def): """Fetches tensors defined in a signature from the graph.""" graph = tf.get_default_graph() inputs = { k: graph.get_tensor_by_name(v.name) for k, v in signature_def.inputs.items() } outputs = { k: graph.get_tensor_by_name(v.name) for k, v in signature_def.outputs.items() } return inputs, outputs def compress_image(model, input_image): """Compresses an image array into a bitstring.""" with tf.Graph().as_default(): # Load model metagraph. signature_defs = import_metagraph(model) inputs, outputs = instantiate_signature(signature_defs["sender"]) # Just one input tensor. inputs = inputs["input_image"] # Multiple output tensors, ordered alphabetically, without names. outputs = [outputs[k] for k in sorted(outputs) if k.startswith("channel:")] # Run encoder. with tf.Session() as sess: arrays = sess.run(outputs, feed_dict={inputs: input_image}) # Pack data into bitstring. packed = tfc.PackedTensors() packed.model = model packed.pack(outputs, arrays) return packed.string def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False): """Compresses a PNG file to a TFCI file.""" if not output_file: output_file = input_file + ".tfci" # Load image. with tf.Graph().as_default(): with tf.Session() as sess: input_image = sess.run(read_png(input_file)) num_pixels = input_image.shape[-2] * input_image.shape[-3] if not target_bpp: # Just compress with a specific model. bitstring = compress_image(model, input_image) else: # Get model list. models = load_cached(model + ".models") models = models.decode("ascii").split() # Do a binary search over all RD points. lower = -1 upper = len(models) bpp = None best_bitstring = None best_bpp = None while bpp != target_bpp and upper - lower > 1: i = (upper + lower) // 2 bitstring = compress_image(models[i], input_image) bpp = 8 * len(bitstring) / num_pixels is_admissible = bpp <= target_bpp or not bpp_strict is_better = (best_bpp is None or abs(bpp - target_bpp) < abs(best_bpp - target_bpp)) if is_admissible and is_better: best_bitstring = bitstring best_bpp = bpp if bpp < target_bpp: lower = i if bpp > target_bpp: upper = i if best_bpp is None: assert bpp_strict raise RuntimeError( "Could not compress image to less than {} bpp.".format(target_bpp)) bitstring = best_bitstring # Write bitstring to disk. with tf.io.gfile.GFile(output_file, "wb") as f: f.write(bitstring) def decompress(input_file, output_file): """Decompresses a TFCI file and writes a PNG file.""" if not output_file: output_file = input_file + ".png" with tf.Graph().as_default(): # Unserialize packed data from disk. with tf.io.gfile.GFile(input_file, "rb") as f: packed = tfc.PackedTensors(f.read()) # Load model metagraph. signature_defs = import_metagraph(packed.model) inputs, outputs = instantiate_signature(signature_defs["receiver"]) # Multiple input tensors, ordered alphabetically, without names. inputs = [inputs[k] for k in sorted(inputs) if k.startswith("channel:")] # Just one output operation. outputs = write_png(output_file, outputs["output_image"]) # Unpack data. arrays = packed.unpack(inputs) # Run decoder. with tf.Session() as sess: sess.run(outputs, feed_dict=dict(zip(inputs, arrays))) def list_models(): url = URL_PREFIX + "/models.txt" try: request = urllib.request.urlopen(url) print(request.read().decode("utf-8")) finally: request.close() def parse_args(argv): """Parses command line arguments.""" parser = argparse_flags.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) # High-level options. parser.add_argument( "--url_prefix", default=URL_PREFIX, help="URL prefix for downloading model metagraphs.") parser.add_argument( "--metagraph_cache", default=METAGRAPH_CACHE, help="Directory where to cache model metagraphs.") subparsers = parser.add_subparsers( title="commands", dest="command", help="Invoke '<command> -h' for more information.") # 'compress' subcommand. compress_cmd = subparsers.add_parser( "compress", formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Reads a PNG file, compresses it using the given model, and " "writes a TFCI file.") compress_cmd.add_argument( "model", help="Unique model identifier. See 'models' command for options. If " "'target_bpp' is provided, don't specify the index at the end of " "the model identifier.") compress_cmd.add_argument( "--target_bpp", type=float, help="Target bits per pixel. If provided, a binary search is used to try " "to match the given bpp as close as possible. In this case, don't " "specify the index at the end of the model identifier. It will be " "automatically determined.") compress_cmd.add_argument( "--bpp_strict", action="store_true", help="Try never to exceed 'target_bpp'. Ignored if 'target_bpp' is not " "set.") # 'decompress' subcommand. decompress_cmd = subparsers.add_parser( "decompress", formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Reads a TFCI file, reconstructs the image using the model " "it was compressed with, and writes back a PNG file.") # Arguments for both 'compress' and 'decompress'. for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")): cmd.add_argument( "input_file", help="Input filename.") cmd.add_argument( "output_file", nargs="?", help="Output filename (optional). If not provided, appends '{}' to " "the input filename.".format(ext)) # 'models' subcommand. subparsers.add_parser( "models", formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Lists available trained models. Requires an internet " "connection.") # Parse arguments. args = parser.parse_args(argv[1:]) if args.command is None: parser.print_usage() sys.exit(2) return args def main(args): # Command line can override these defaults. global URL_PREFIX, METAGRAPH_CACHE URL_PREFIX = args.url_prefix METAGRAPH_CACHE = args.metagraph_cache # Invoke subcommand. if args.command == "compress": compress(args.model, args.input_file, args.output_file, args.target_bpp, args.bpp_strict) if args.command == "decompress": decompress(args.input_file, args.output_file) if args.command == "models": list_models() if __name__ == "__main__": app.run(main, flags_parser=parse_args)