Python tensorflow.python.lib.io.file_io.FileIO() Examples

The following are 30 code examples of tensorflow.python.lib.io.file_io.FileIO(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.lib.io.file_io , or try the search function .
Example #1
Source File: _local_predict.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def _download_images(data, img_cols):
  """Download images given image columns."""

  images = collections.defaultdict(list)
  for d in data:
    for img_col in img_cols:
      if d.get(img_col, None):
        if isinstance(d[img_col], Image.Image):
          # If it is already an Image, just copy and continue.
          images[img_col].append(d[img_col])
        else:
          # Otherwise it is image url. Load the image.
          with file_io.FileIO(d[img_col], 'rb') as fi:
            im = Image.open(fi)
          images[img_col].append(im)
      else:
        images[img_col].append('')

  return images 
Example #2
Source File: trainer.py    From MiniCat with Apache License 2.0 6 votes vote down vote up
def _write_vocabulary(vocab_counter, vocab_size, destination):
    """Write the top vocab_size number of words to a file.

    Returns : A word to index mapping python dictionary for the vocabulary.
    """
    # Remove words that occur less than 5 times
    vocab_counter = collections.Counter(
        {k: v for k, v in vocab_counter.iteritems() if v > 4})
    # Filter top words
    vocab_list = vocab_counter.most_common(
        min(len(vocab_counter), vocab_size - 1))
    # Add __UNK__ token to the start of the top_words
    vocab_list.insert(0, (__UNK__, 0))
    # Write the top_words to destination (line by line fashion)
    with file_io.FileIO(destination, 'w+') as f:
        for word in vocab_list:
            f.write(u'{} {}\n'.format(word[0], word[1]))
    # Create a rev_vocab dictionary that returns the index of each word
    return dict([(word, i)
                 for (i, (word, word_count)) in enumerate(vocab_list)]) 
Example #3
Source File: trainer.py    From MiniCat with Apache License 2.0 6 votes vote down vote up
def _check_params(gcs_working_dir, version):
    """Check if the data already exists by checking for file 'params.json'."""

    data_dir = '{}/v{}/data'.format(gcs_working_dir, version)

    # Prefix matching for the path
    bucket_name, prefix = data_dir[5:].split('/', 1)

    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=prefix)
    for blob in blobs:
        if blob.name.rsplit('/', 1)[-1] == PARAMS_FILE_NAME:
            with file_io.FileIO('{}/{}'.format(data_dir, PARAMS_FILE_NAME),
                                'r') as f:
                return json.load(f) 
Example #4
Source File: metadata_io.py    From transform with Apache License 2.0 6 votes vote down vote up
def read_metadata(path):
  """Load metadata in JSON format from a path into a new DatasetMetadata."""
  schema_file = os.path.join(path, 'schema.pbtxt')
  legacy_schema_file = os.path.join(path, 'v1-json', 'schema.json')
  if file_io.file_exists(schema_file):
    text_proto = file_io.FileIO(schema_file, 'r').read()
    schema_proto = text_format.Parse(text_proto, schema_pb2.Schema(),
                                     allow_unknown_extension=True)
  elif file_io.file_exists(legacy_schema_file):
    schema_json = file_io.FileIO(legacy_schema_file, 'r').read()
    schema_proto = _parse_schema_json(schema_json)
  else:
    raise IOError(
        'Schema file {} does not exist and neither did legacy format file '
        '{}'.format(schema_file, legacy_schema_file))
  return dataset_metadata.DatasetMetadata(schema_proto) 
Example #5
Source File: preprocess.py    From MVSNet with MIT License 6 votes vote down vote up
def write_cam(file, cam):
    # f = open(file, "w")
    f = file_io.FileIO(file, "w")

    f.write('extrinsic\n')
    for i in range(0, 4):
        for j in range(0, 4):
            f.write(str(cam[0][i][j]) + ' ')
        f.write('\n')
    f.write('\n')

    f.write('intrinsic\n')
    for i in range(0, 3):
        for j in range(0, 3):
            f.write(str(cam[1][i][j]) + ' ')
        f.write('\n')

    f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n')

    f.close() 
Example #6
Source File: _preprocess.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def process(self, element):
    from tensorflow.python.lib.io import file_io as tf_file_io

    uri, label_id = element
    try:
      with tf_file_io.FileIO(uri, 'r') as f:
        img = Image.open(f).convert('RGB')
    # A variety of different calling libraries throw different exceptions here.
    # They all correspond to an unreadable file so we treat them equivalently.
    # pylint: disable broad-except
    except Exception as e:
      logging.exception('Error processing image %s: %s', uri, str(e))
      error_count.inc()
      return

    # Convert to desired format and output.
    output = cStringIO.StringIO()
    img.save(output, 'jpeg')
    image_bytes = output.getvalue()
    yield uri, label_id, image_bytes 
Example #7
Source File: tfr.py    From PiNN with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def load_tfrecord(fname):
    """Load tfrecord dataset.

    Args:
       fname (str): filename of the .yml metadata file to be loaded.
       dtypes (dict): dtype of dataset.
    """
    # dataset
    with FileIO(fname, 'r') as f:
        format_dict = (yaml.safe_load(f)['format'])
    dtypes = {k: format_dict[k]['dtype'] for k in format_dict.keys()}
    shapes = {k: format_dict[k]['shape'] for k in format_dict.keys()}

    feature_dict = {k: tf.FixedLenFeature([], tf.string) for k in dtypes}

    def parser(example): return tf.parse_single_example(example, feature_dict)

    def converter(tensors):
        tensors = {k: tf.parse_tensor(v, dtypes[k])
                   for k, v in tensors.items()}
        [v.set_shape(shapes[k]) for k, v in tensors.items()]
        return tensors
    tfr = '.'.join(fname.split('.')[:-1]+['tfr'])
    dataset = tf.data.TFRecordDataset(tfr).map(parser).map(converter)
    return dataset 
Example #8
Source File: data_generator.py    From fritz-models with MIT License 6 votes vote down vote up
def load_class_labels(label_filename):
        """Load class labels.

        Assumes the data directory is left unchanged from the original zip.

        Args:
            root_directory (str): the dataset's root directory

        Returns:
            arr: an array of class labels
        """
        class_labels = []
        header = True
        with file_io.FileIO(label_filename, mode='r') as file:
            for line in file.readlines():
                if header:
                    header = False
                    continue
                line = line.rstrip()
                label = line.split('\t')[-1]
                class_labels.append(label)
        return numpy.array(class_labels) 
Example #9
Source File: _util.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def load_images(image_files, resize=True):
  """Load images from files and optionally resize it."""

  images = []
  for image_file in image_files:
    with file_io.FileIO(image_file, 'r') as ff:
      images.append(ff.read())
  if resize is False:
    return images

  # To resize, run a tf session so we can reuse 'decode_and_resize()'
  # which is used in prediction graph. This makes sure we don't lose
  # any quality in prediction, while decreasing the size of the images
  # submitted to the model over network.
  image_str_tensor = tf.placeholder(tf.string, shape=[None])
  image = tf.map_fn(resize_image, image_str_tensor, back_prop=False)
  feed_dict = collections.defaultdict(list)
  feed_dict[image_str_tensor.name] = images
  with tf.Session() as sess:
    images_resized = sess.run(image, feed_dict=feed_dict)
  return images_resized 
Example #10
Source File: create_tfrecord_dataset.py    From fritz-models with MIT License 6 votes vote down vote up
def _load_class_labels(label_filename):
    """Load class labels.

    Assumes the data directory is left unchanged from the original zip.

    Args:
        root_directory (str): the dataset's root directory

    Returns:
        List[(int, str)]: a list of class ids and labels
    """
    class_labels = []
    header = True
    with file_io.FileIO(label_filename, mode='r') as file:
        for line in file.readlines():
            if header:
                class_labels.append((0, 'none'))
                header = False
                continue
            line = line.rstrip()
            line = line.split('\t')
            label = line[-1]
            label_id = int(line[0])
            class_labels.append((label_id, label))
    return class_labels 
Example #11
Source File: feature_transforms.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def read_vocab_file(file_path):
  """Reads a vocab file to memeory.

  Args:
    file_path: Each line of the vocab is in the form "token,example_count"

  Returns:
    Two lists, one for the vocab, and one for just the example counts.
  """
  with file_io.FileIO(file_path, 'r') as f:
    vocab_pd = pd.read_csv(
        f,
        header=None,
        names=['vocab', 'count'],
        dtype=str,  # Prevent pd from converting numerical categories.
        na_filter=False)  # Prevent pd from converting 'NA' to a NaN.

  vocab = vocab_pd['vocab'].tolist()
  ex_count = vocab_pd['count'].astype(int).tolist()

  return vocab, ex_count 
Example #12
Source File: pipeline.py    From cloudml-samples with Apache License 2.0 6 votes vote down vote up
def make_request_json(self, uri, output_json):
    """Produces a JSON request suitable to send to CloudML Prediction API.

    Args:
      uri: The input image URI.
      output_json: File handle of the output json where request will be written.
    """
    def _open_file_read_binary(uri):
      try:
        return file_io.FileIO(uri, mode='rb')
      except errors.InvalidArgumentError:
        return file_io.FileIO(uri, mode='r')

    with open(output_json, 'w') as outf:
      with _open_file_read_binary(uri) as f:
        image_bytes = f.read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        image = image.resize((299, 299), Image.BILINEAR)
        resized_image = io.BytesIO()
        image.save(resized_image, format='JPEG')
        encoded_image = base64.b64encode(resized_image.getvalue())
        row = json.dumps({'key': uri, 'image_bytes': {'b64': encoded_image}})
        outf.write(row)
        outf.write('\n') 
Example #13
Source File: feature_transforms.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def read_vocab_file(file_path):
  """Reads a vocab file to memeory.

  Args:
    file_path: Each line of the vocab is in the form "token,example_count"

  Returns:
    Two lists, one for the vocab, and one for just the example counts.
  """
  with file_io.FileIO(file_path, 'r') as f:
    vocab_pd = pd.read_csv(
        f,
        header=None,
        names=['vocab', 'count'],
        dtype=str,  # Prevent pd from converting numerical categories.
        na_filter=False)  # Prevent pd from converting 'NA' to a NaN.

  vocab = vocab_pd['vocab'].tolist()
  ex_count = vocab_pd['count'].astype(int).tolist()

  return vocab, ex_count 
Example #14
Source File: meta_graph.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def read_meta_graph_file(filename):
  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.

  Args:
    filename: `meta_graph_def` filename including the path.

  Returns:
    A `MetaGraphDef` protocol buffer.

  Raises:
    IOError: If the file doesn't exist, or cannot be successfully parsed.
  """
  meta_graph_def = meta_graph_pb2.MetaGraphDef()
  if not file_io.file_exists(filename):
    raise IOError("File %s does not exist." % filename)
  # First try to read it as a binary file.
  file_content = file_io.FileIO(filename, "rb").read()
  try:
    meta_graph_def.ParseFromString(file_content)
    return meta_graph_def
  except Exception:  # pylint: disable=broad-except
    pass

  # Next try to read it as a text file.
  try:
    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
  except text_format.ParseError as e:
    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))

  return meta_graph_def 
Example #15
Source File: plugin.py    From keras-lambda with MIT License 5 votes vote down vote up
def _serve_bookmarks(self, request, query_params):
    run = query_params.get('run')
    if not run:
      request.respond('query parameter "run" is required', 'text/plain', 400)
      return

    name = query_params.get('name')
    if name is None:
      request.respond('query parameter "name" is required', 'text/plain', 400)
      return

    if run not in self.configs:
      request.respond('Unknown run: %s' % run, 'text/plain', 400)
      return

    config = self.configs[run]
    fpath = self._get_bookmarks_file_for_tensor(name, config)
    if not fpath:
      request.respond(
          'No bookmarks file found for tensor %s in the config file %s' %
          (name, self.config_fpaths[run]), 'text/plain', 400)
      return
    if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
      request.respond('%s is not a file' % fpath, 'text/plain', 400)
      return

    bookmarks_json = None
    with file_io.FileIO(fpath, 'r') as f:
      bookmarks_json = f.read()
    request.respond(bookmarks_json, 'application/json') 
Example #16
Source File: meta_graph.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _read_file(filename):
  """Reads a file containing `GraphDef` and returns the protocol buffer.

  Args:
    filename: `graph_def` filename including the path.

  Returns:
    A `GraphDef` protocol buffer.

  Raises:
    IOError: If the file doesn't exist, or cannot be successfully parsed.
  """
  graph_def = graph_pb2.GraphDef()
  if not file_io.file_exists(filename):
    raise IOError("File %s does not exist." % filename)
  # First try to read it as a binary file.
  file_content = file_io.FileIO(filename, "rb").read()
  try:
    graph_def.ParseFromString(file_content)
    return graph_def
  except Exception:  # pylint: disable=broad-except
    pass

  # Next try to read it as a text file.
  try:
    text_format.Merge(file_content, graph_def)
  except text_format.ParseError as e:
    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))

  return graph_def 
Example #17
Source File: trainer.py    From fritz-models with MIT License 5 votes vote down vote up
def copy_file_to_gcs(job_dir, file_path):
    with file_io.FileIO(file_path, mode='rb') as input_f:
        with file_io.FileIO(
                os.path.join(job_dir, file_path), mode='w+') as output_f:
            output_f.write(input_f.read()) 
Example #18
Source File: plugin.py    From keras-lambda with MIT License 5 votes vote down vote up
def _serve_sprite_image(self, request, query_params):
    run = query_params.get('run')
    if not run:
      request.respond('query parameter "run" is required', 'text/plain', 400)
      return

    name = query_params.get('name')
    if name is None:
      request.respond('query parameter "name" is required', 'text/plain', 400)
      return

    if run not in self.configs:
      request.respond('Unknown run: %s' % run, 'text/plain', 400)
      return

    config = self.configs[run]
    embedding_info = self._get_embedding(name, config)

    if not embedding_info or not embedding_info.sprite.image_path:
      request.respond(
          'No sprite image file found for tensor %s in the config file %s' %
          (name, self.config_fpaths[run]), 'text/plain', 400)
      return

    fpath = embedding_info.sprite.image_path
    if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
      request.respond(
          '%s does not exist or is directory' % fpath, 'text/plain', 400)
      return
    f = file_io.FileIO(fpath, 'r')
    encoded_image_string = f.read()
    f.close()
    image_type = imghdr.what(None, encoded_image_string)
    mime_type = _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE)
    request.respond(encoded_image_string, mime_type) 
Example #19
Source File: train.py    From fritz-models with MIT License 5 votes vote down vote up
def _copy_file_to_gcs(job_dir, file_path):
        gcs_url = os.path.join(job_dir, file_path)
        logger.info('Saving models to GCS: %s' % gcs_url)
        with file_io.FileIO(file_path, mode='rb') as input_f:
            with file_io.FileIO(gcs_url, mode='w+') as output_f:
                output_f.write(input_f.read()) 
Example #20
Source File: task.py    From MiniCat with Apache License 2.0 5 votes vote down vote up
def get_embeds(data_dir):
    embed_file = '{}/{}'.format(data_dir, 'embeddings.csv')
    embeddings = []
    with file_io.FileIO(embed_file, 'r') as f:
        for line in f.readlines():
            row = line.split(',')
            embeddings.append([float(x) for x in row])
    embeddings.append([0.] * len(embeddings[0]))  # Embedding for pads
    return embeddings 
Example #21
Source File: plugin.py    From keras-lambda with MIT License 5 votes vote down vote up
def _read_tensor_file(fpath):
  with file_io.FileIO(fpath, 'r') as f:
    tensor = []
    for line in f:
      if line:
        tensor.append(map(float, line.rstrip('\n').split('\t')))
  return np.array(tensor, dtype='float32') 
Example #22
Source File: main.py    From DeepMind-alphafold-repl with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def parse_model_config(json_file):
    #the 'open' function can't support goole cloud platform
    with file_io.FileIO(json_file, 'r') as f:
        config = json.load(f)
    return config 
Example #23
Source File: util.py    From pipelines with Apache License 2.0 5 votes vote down vote up
def load_data(data_file_url, window_size):
    """Loads data into preprocessed (train_X, train_y, eval_X, eval_y) dataframes.

    Returns:
      A tuple (train_X, train_y, eval_X, eval_y), where train_X and eval_X are
      Pandas dataframes with features for training and train_y and eval_y are
      numpy arrays with the corresponding labels.
    """
    # The % of data we should use for training
    TRAINING_SPLIT = 0.8

    # Download CSV and import into Pandas DataFrame
    file_stream = file_io.FileIO(data_file_url, mode='r')
    df = pd.read_csv(StringIO(file_stream.read()))
    df.index = df[df.columns[0]]
    df = df[['count']]

    scaler = StandardScaler()

    # Time series: split latest data into test set
    train = df.values[:int(TRAINING_SPLIT * len(df)), :]
    print(train)
    train = scaler.fit_transform(train)
    test = df.values[int(TRAINING_SPLIT * len(df)):, :]
    test = scaler.transform(test)

    # Create test and training sets
    train_X, train_y = create_dataset(train, window_size)
    test_X, test_y = create_dataset(test, window_size)

    # Reshape input data
    train_X = np.reshape(train_X, (train_X.shape[0], 1, train_X.shape[1]))
    test_X = np.reshape(test_X, (test_X.shape[0], 1, test_X.shape[1]))

    return train_X, train_y, test_X, test_y 
Example #24
Source File: inference.py    From youtube-8m with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  logging.set_verbosity(tf.logging.INFO)
  if FLAGS.input_model_tgz:
    if FLAGS.train_dir:
      raise ValueError("You cannot supply --train_dir if supplying "
                       "--input_model_tgz")
    # Untar.
    if not os.path.exists(FLAGS.untar_model_dir):
      os.makedirs(FLAGS.untar_model_dir)
    tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir)
    FLAGS.train_dir = FLAGS.untar_model_dir

  flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json")
  if not file_io.file_exists(flags_dict_file):
    raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file)
  flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read())

  # convert feature_names and feature_sizes to lists of values
  feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
      flags_dict["feature_names"], flags_dict["feature_sizes"])

  if flags_dict["frame_features"]:
    reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                            feature_sizes=feature_sizes)
  else:
    reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                 feature_sizes=feature_sizes)

  if not FLAGS.output_file:
    raise ValueError("'output_file' was not specified. "
                     "Unable to continue with inference.")

  if not FLAGS.input_data_pattern:
    raise ValueError("'input_data_pattern' was not specified. "
                     "Unable to continue with inference.")

  inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
            FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) 
Example #25
Source File: plugin.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _serve_sprite_image(self, request, query_params):
    run = query_params.get('run')
    if not run:
      request.respond('query parameter "run" is required', 'text/plain', 400)
      return

    name = query_params.get('name')
    if name is None:
      request.respond('query parameter "name" is required', 'text/plain', 400)
      return

    if run not in self.configs:
      request.respond('Unknown run: %s' % run, 'text/plain', 400)
      return

    config = self.configs[run]
    embedding_info = self._get_embedding(name, config)

    if not embedding_info or not embedding_info.sprite.image_path:
      request.respond(
          'No sprite image file found for tensor %s in the config file %s' %
          (name, self.config_fpaths[run]), 'text/plain', 400)
      return

    fpath = embedding_info.sprite.image_path
    if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
      request.respond(
          '%s does not exist or is directory' % fpath, 'text/plain', 400)
      return
    f = file_io.FileIO(fpath, 'r')
    encoded_image_string = f.read()
    f.close()
    image_type = imghdr.what(None, encoded_image_string)
    mime_type = _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE)
    request.respond(encoded_image_string, mime_type) 
Example #26
Source File: plugin.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _serve_bookmarks(self, request, query_params):
    run = query_params.get('run')
    if not run:
      request.respond('query parameter "run" is required', 'text/plain', 400)
      return

    name = query_params.get('name')
    if name is None:
      request.respond('query parameter "name" is required', 'text/plain', 400)
      return

    if run not in self.configs:
      request.respond('Unknown run: %s' % run, 'text/plain', 400)
      return

    config = self.configs[run]
    fpath = self._get_bookmarks_file_for_tensor(name, config)
    if not fpath:
      request.respond(
          'No bookmarks file found for tensor %s in the config file %s' %
          (name, self.config_fpaths[run]), 'text/plain', 400)
      return
    if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
      request.respond('%s is not a file' % fpath, 'text/plain', 400)
      return

    bookmarks_json = None
    with file_io.FileIO(fpath, 'r') as f:
      bookmarks_json = f.read()
    request.respond(bookmarks_json, 'application/json') 
Example #27
Source File: plugin.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _read_tensor_file(fpath):
  with file_io.FileIO(fpath, 'r') as f:
    tensor = []
    for line in f:
      if line:
        tensor.append(map(float, line.rstrip('\n').split('\t')))
  return np.array(tensor, dtype='float32') 
Example #28
Source File: tf_schema_utils.py    From spotify-tensorflow with Apache License 2.0 5 votes vote down vote up
def parse_schema_file(schema_path):  # type: (str) -> Schema
    """
    Read a schema file and return the proto object.
    """
    assert file_io.file_exists(schema_path), "File not found: {}".format(schema_path)
    schema = Schema()
    with file_io.FileIO(schema_path, "rb") as f:
        schema.ParseFromString(f.read())
    return schema 
Example #29
Source File: _local_predict.py    From pydatalab with Apache License 2.0 5 votes vote down vote up
def local_batch_predict(model_dir, csv_file_pattern, output_dir, output_format, batch_size=100):
  """ Batch Predict with a specified model.

  It does batch prediction, saves results to output files and also creates an output
  schema file. The output file names are input file names prepended by 'predict_results_'.

  Args:
    model_dir: The model directory containing a SavedModel (usually saved_model.pb).
    csv_file_pattern: a pattern of csv files as batch prediction source.
    output_dir: the path of the output directory.
    output_format: csv or json.
    batch_size: Larger batch_size improves performance but may
        cause more memory usage.
  """

  file_io.recursive_create_dir(output_dir)
  csv_files = file_io.get_matching_files(csv_file_pattern)
  if len(csv_files) == 0:
    raise ValueError('No files found given ' + csv_file_pattern)

  with tf.Graph().as_default(), tf.Session() as sess:
    input_alias_map, output_alias_map = _tf_load_model(sess, model_dir)
    csv_tensor_name = list(input_alias_map.values())[0]
    output_schema = _get_output_schema(sess, output_alias_map)
    for csv_file in csv_files:
      output_file = os.path.join(
          output_dir,
          'predict_results_' +
          os.path.splitext(os.path.basename(csv_file))[0] + '.' + output_format)
      with file_io.FileIO(output_file, 'w') as f:
        prediction_source = _batch_csv_reader(csv_file, batch_size)
        for batch in prediction_source:
          batch = [l.rstrip() for l in batch if l]
          predict_results = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: batch})
          formatted_results = _format_results(output_format, output_schema, predict_results)
          f.write('\n'.join(formatted_results) + '\n')

  file_io.write_string_to_file(os.path.join(output_dir, 'predict_results_schema.json'),
                               json.dumps(output_schema, indent=2)) 
Example #30
Source File: _local_predict.py    From pydatalab with Apache License 2.0 5 votes vote down vote up
def _batch_csv_reader(csv_file, n):
  with file_io.FileIO(csv_file, 'r') as f:
    args = [f] * n
    return six.moves.zip_longest(*args)