Python tensorflow.Estimator() Examples

The following are 30 code examples of tensorflow.Estimator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: base_estimator.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
Example #2
Source File: base_estimator.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions 
Example #3
Source File: base_estimator.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 
Example #4
Source File: base_estimator.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
Example #5
Source File: dual_net.py    From training with Apache License 2.0 6 votes vote down vote up
def bootstrap():
    """Initialize a tf.Estimator run with random initial weights."""
    # a bit hacky - forge an initial checkpoint with the name that subsequent
    # Estimator runs will expect to find.
    #
    # Estimator will do this automatically when you call train(), but calling
    # train() requires data, and I didn't feel like creating training data in
    # order to run the full train pipeline for 1 step.
    maybe_set_seed()
    initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(FLAGS.work_dir, initial_checkpoint_name)
    sess = tf.Session(graph=tf.Graph())
    with sess.graph.as_default():
        features, labels = get_inference_input()
        model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
                 params=FLAGS.flag_values_dict())
        sess.run(tf.global_variables_initializer())
        tf.train.Saver().save(sess, save_file) 
Example #6
Source File: dual_net.py    From training with Apache License 2.0 6 votes vote down vote up
def export_model(model_path):
    """Take the latest checkpoint and copy it to model_path.

    Assumes that all relevant model files are prefixed by the same name.
    (For example, foo.index, foo.meta and foo.data-00000-of-00001).

    Args:
        model_path: The path (can be a gs:// path) to export model
    """
    estimator = tf.estimator.Estimator(model_fn, model_dir=FLAGS.work_dir,
                                       params=FLAGS.flag_values_dict())
    latest_checkpoint = estimator.latest_checkpoint()
    all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*')
    for filename in all_checkpoint_files:
        suffix = filename.partition(latest_checkpoint)[2]
        destination_path = model_path + suffix
        print('Copying {} to {}'.format(filename, destination_path))
        tf.gfile.Copy(filename, destination_path) 
Example #7
Source File: dual_net.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def export_model(working_dir, model_path):
    """Take the latest checkpoint and export it to model_path for selfplay.

    Assumes that all relevant model files are prefixed by the same name.
    (For example, foo.index, foo.meta and foo.data-00000-of-00001).

    Args:
        working_dir: The directory where tf.estimator keeps its checkpoints
        model_path: The path (can be a gs:// path) to export model to
    """
    estimator = tf.estimator.Estimator(model_fn, model_dir=working_dir,
                                       params='ignored')
    latest_checkpoint = estimator.latest_checkpoint()
    all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*')
    for filename in all_checkpoint_files:
        suffix = filename.partition(latest_checkpoint)[2]
        destination_path = model_path + suffix
        print("Copying {} to {}".format(filename, destination_path))
        tf.gfile.Copy(filename, destination_path) 
Example #8
Source File: base_estimator.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
Example #9
Source File: base_estimator.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 
Example #10
Source File: base_estimator.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions 
Example #11
Source File: base_estimator.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 
Example #12
Source File: base_estimator.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
Example #13
Source File: calculator.py    From PiNN with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, model=None, atoms=None, to_eV=1.0,
                 properties=['energy', 'forces', 'stress']):
        """PiNN interface with ASE as a calculator

        Args:
            model: tf.Estimator object
            atoms: optional, ase Atoms object
            properties: properties to calculate.
                the properties to calculate is fixed for each calculator,
                to avoid resetting the predictor during get_* calls.
        """
        Calculator.__init__(self)
        self.implemented_properties = properties
        self.model = model
        self.pbc = False
        self.atoms = atoms
        self.predictor = None
        self.to_eV = to_eV 
Example #14
Source File: base_estimator.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
Example #15
Source File: base_estimator.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 
Example #16
Source File: base_estimator.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions 
Example #17
Source File: base_estimator.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions 
Example #18
Source File: base_estimator.py    From models with Apache License 2.0 6 votes vote down vote up
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
Example #19
Source File: base_estimator.py    From models with Apache License 2.0 6 votes vote down vote up
def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 
Example #20
Source File: base_estimator.py    From models with Apache License 2.0 6 votes vote down vote up
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions 
Example #21
Source File: dual_net.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def get_estimator(working_dir, **hparams):
    hparams = get_default_hyperparams(**hparams)
    return tf.estimator.Estimator(
        model_fn,
        model_dir=working_dir,
        params=hparams) 
Example #22
Source File: dual_net.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def bootstrap(working_dir, **hparams):
    """Initialize a tf.Estimator run with random initial weights.

    Args:
        working_dir: The directory where tf.estimator will drop logs,
            checkpoints, and so on
        hparams: hyperparams of the model.
    """
    hparams = get_default_hyperparams(**hparams)
    # a bit hacky - forge an initial checkpoint with the name that subsequent
    # Estimator runs will expect to find.
    #
    # Estimator will do this automatically when you call train(), but calling
    # train() requires data, and I didn't feel like creating training data in
    # order to run the full train pipeline for 1 step.
    estimator_initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(working_dir, estimator_initial_checkpoint_name)
    sess = tf.Session(graph=tf.Graph())
    with sess.graph.as_default():
        features, labels = get_inference_input()
        model_fn(features, labels, tf.estimator.ModeKeys.PREDICT, hparams)
        sess.run(tf.global_variables_initializer())
        tf.train.Saver().save(sess, save_file)

    with open("./minigo.pbtxt", "w") as f:
        f.write(str(sess.graph.as_graph_def())) 
Example #23
Source File: dual_net.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def get_estimator(working_dir, **hparams):
    hparams = get_default_hyperparams(**hparams)
    return tf.estimator.Estimator(
        model_fn,
        model_dir=working_dir,
        params=hparams) 
Example #24
Source File: translate.py    From models with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  from official.transformer import transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file) 
Example #25
Source File: premade_lib.py    From lattice with Apache License 2.0 5 votes vote down vote up
def _get_lattice_weights(prefitting_model, lattice_index):
  """Gets the weights of the lattice at the specfied index."""
  if isinstance(prefitting_model, tf.keras.Model):
    lattice_layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, lattice_index)
    weights = tf.keras.backend.get_value(
        prefitting_model.get_layer(lattice_layer_name).weights[0])
  else:
    # We have already checked the types by this point, so if prefitting_model
    # is not a keras Model it must be an Estimator.
    lattice_kernel_variable_name = '{}_{}/{}'.format(
        LATTICE_LAYER_NAME, lattice_index, lattice_layer.LATTICE_KERNEL_NAME)
    weights = prefitting_model.get_variable_value(lattice_kernel_variable_name)
  return weights 
Example #26
Source File: translate.py    From models with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  from official.transformer import transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file) 
Example #27
Source File: translate.py    From models with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  from official.transformer import transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file) 
Example #28
Source File: translate.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  from official.transformer import transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file) 
Example #29
Source File: dual_net.py    From training with Apache License 2.0 5 votes vote down vote up
def _get_nontpu_estimator():
    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    run_config = tf.estimator.RunConfig(
        save_summary_steps=FLAGS.summary_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=session_config)
    return tf.estimator.Estimator(
        model_fn,
        model_dir=FLAGS.work_dir,
        config=run_config,
        params=FLAGS.flag_values_dict()) 
Example #30
Source File: translate.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  from official.transformer import transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file)