# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS-IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility to convert the output of batch prediction into a CSV submission. It converts the JSON files created by the command 'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. """ import json import tensorflow as tf from builtins import range from tensorflow import app from tensorflow import flags from tensorflow import gfile from tensorflow import logging FLAGS = flags.FLAGS if __name__ == "__main__": flags.DEFINE_string( "json_prediction_files_pattern", None, "Pattern specifying the list of JSON files that the command " "'gcloud beta ml jobs submit prediction' outputs. These files are " "located in the output path of the prediction command and are prefixed " "with 'prediction.results'.") flags.DEFINE_string( "csv_output_file", None, "The file to save the predictions converted to the CSV format.") def get_csv_header(): return "VideoId,LabelConfidencePairs\n" def to_csv_row(json_data): video_id = json_data["video_id"] class_indexes = json_data["class_indexes"] predictions = json_data["predictions"] if isinstance(video_id, list): video_id = video_id[0] class_indexes = class_indexes[0] predictions = predictions[0] if len(class_indexes) != len(predictions): raise ValueError( "The number of indexes (%s) and predictions (%s) must be equal." % (len(class_indexes), len(predictions))) return (video_id.decode("utf-8") + "," + " ".join("%i %f" % (class_indexes[i], predictions[i]) for i in range(len(class_indexes))) + "\n") def main(unused_argv): logging.set_verbosity(tf.logging.INFO) if not FLAGS.json_prediction_files_pattern: raise ValueError( "The flag --json_prediction_files_pattern must be specified.") if not FLAGS.csv_output_file: raise ValueError("The flag --csv_output_file must be specified.") logging.info("Looking for prediction files with pattern: %s", FLAGS.json_prediction_files_pattern) file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) logging.info("Found files: %s", file_paths) logging.info("Writing submission file to: %s", FLAGS.csv_output_file) with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: output_file.write(get_csv_header()) for file_path in file_paths: logging.info("processing file: %s", file_path) with gfile.Open(file_path) as input_file: for line in input_file: json_data = json.loads(line) output_file.write(to_csv_row(json_data)) output_file.flush() logging.info("done") if __name__ == "__main__": app.run()