Python tensorflow.train() Examples

The following are 30 code examples of tensorflow.train(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: train_autoencoder.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
Example #2
Source File: train.py    From Youtube-8M-WILLOW with Apache License 2.0 6 votes vote down vote up
def get_meta_filename(self, start_new_model, train_dir):
    if start_new_model:
      logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                   task_as_string(self.task))
      return None
    
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if not latest_checkpoint: 
      logging.info("%s: No checkpoint file found. Building a new model.",
                   task_as_string(self.task))
      return None
    
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("%s: No meta graph file found. Building a new model.",
                     task_as_string(self.task))
      return None
    else:
      return meta_filename 
Example #3
Source File: train.py    From Youtube-8M-WILLOW with Apache License 2.0 6 votes vote down vote up
def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn) 
Example #4
Source File: write_tfrecords.py    From scGAN with MIT License 6 votes vote down vote up
def close(self):
        """
        Closes the files associated to the TFRecordWriter objects.

        Returns
        -------

        """

        try:
            self.test.close()
        except Exception as e:
            pass

        try:
            self.valid.close()
        except Exception as e:
            pass

        for f in self.train:
            try:
                f.close()
            except Exception as e:
                pass 
Example #5
Source File: train.py    From Youtube-8M-WILLOW with Apache License 2.0 6 votes vote down vote up
def build_model(self, model, reader):
    """Find the model and build the graph."""

    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
    optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
  
    build_graph(reader=reader,
                 model=model,
                 optimizer_class=optimizer_class,
                 clip_gradient_norm=FLAGS.clip_gradient_norm,
                 train_data_pattern=FLAGS.train_data_pattern,
                 label_loss_fn=label_loss_fn,
                 base_learning_rate=FLAGS.base_learning_rate,
                 learning_rate_decay=FLAGS.learning_rate_decay,
                 learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
                 regularization_penalty=FLAGS.regularization_penalty,
                 num_readers=FLAGS.num_readers,
                 batch_size=FLAGS.batch_size,
                 num_epochs=FLAGS.num_epochs)
  
    return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=5) 
Example #6
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def get_meta_filename(self, start_new_model, train_dir):
    if start_new_model:
      logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                   task_as_string(self.task))
      return None
    
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if not latest_checkpoint: 
      logging.info("%s: No checkpoint file found. Building a new model.",
                   task_as_string(self.task))
      return None
    
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("%s: No meta graph file found. Building a new model.",
                     task_as_string(self.task))
      return None
    else:
      return meta_filename 
Example #7
Source File: graph_builder.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def _create_learning_rate(hyperparams, step_var):
  """Creates learning rate var, with decay and switching for CompositeOptimizer.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    step_var: tf.Variable, global training step.

  Returns:
    a scalar `Tensor`, the learning rate based on current step and hyperparams.
  """
  if hyperparams.learning_method != 'composite':
    base_rate = hyperparams.learning_rate
  else:
    spec = hyperparams.composite_optimizer_spec
    switch = tf.less(step_var, spec.switch_after_steps)
    base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
                        lambda: tf.constant(spec.method2.learning_rate))
  return tf.train.exponential_decay(
      base_rate,
      step_var,
      hyperparams.decay_steps,
      hyperparams.decay_base,
      staircase=hyperparams.decay_staircase) 
Example #8
Source File: train.py    From Youtube-8M-WILLOW with Apache License 2.0 6 votes vote down vote up
def start_server(cluster, task):
  """Creates a Server.

  Args:
    cluster: A tf.train.ClusterSpec if the execution is distributed.
      None otherwise.
    task: A TaskSpec describing the job type and the task index.
  """

  if not task.type:
    raise ValueError("%s: The task type must be specified." %
                     task_as_string(task))
  if task.index is None:
    raise ValueError("%s: The task index must be specified." %
                     task_as_string(task))

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol="grpc",
      job_name=task.type,
      task_index=task.index) 
Example #9
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
Example #10
Source File: train_autoencoder.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server(cluster, task):
  """Creates a Server.

  Args:
    cluster: A tf.train.ClusterSpec if the execution is distributed.
      None otherwise.
    task: A TaskSpec describing the job type and the task index.
  """

  if not task.type:
    raise ValueError("%s: The task type must be specified." %
                     task_as_string(task))
  if task.index is None:
    raise ValueError("%s: The task index must be specified." %
                     task_as_string(task))

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol="grpc",
      job_name=task.type,
      task_index=task.index) 
Example #11
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def get_input_data_tensors(reader,
                           data_pattern,
                           batch_size=256,
                           num_epochs=None):
  logging.info("Using batch size of " + str(batch_size) + " for training.")
  with tf.name_scope("train_input"):
    files = gfile.Glob(data_pattern)
    if not files:
      raise IOError("Unable to find training files. data_pattern='" +
                    data_pattern + "'.")
    logging.info("Number of training files: %s.", str(len(files)))
    files.sort()
    filename_queue = tf.train.string_input_producer(
        files, num_epochs=num_epochs, shuffle=False)
    training_data = reader.prepare_reader(filename_queue)

    return tf.train.batch(
        training_data,
        batch_size=batch_size,
        capacity=FLAGS.batch_size * 4,
        allow_smaller_final_batch=True,
        enqueue_many=True) 
Example #12
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
Example #13
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn) 
Example #14
Source File: train_ensemble.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn) 
Example #15
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def get_meta_filename(self, start_new_model, train_dir):
    if start_new_model:
      logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                   task_as_string(self.task))
      return None
    
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if not latest_checkpoint: 
      logging.info("%s: No checkpoint file found. Building a new model.",
                   task_as_string(self.task))
      return None
    
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("%s: No meta graph file found. Building a new model.",
                     task_as_string(self.task))
      return None
    else:
      return meta_filename 
Example #16
Source File: train_embedding.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server(cluster, task):
  """Creates a Server.

  Args:
    cluster: A tf.train.ClusterSpec if the execution is distributed.
      None otherwise.
    task: A TaskSpec describing the job type and the task index.
  """

  if not task.type:
    raise ValueError("%s: The task type must be specified." %
                     task_as_string(task))
  if task.index is None:
    raise ValueError("%s: The task index must be specified." %
                     task_as_string(task))

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol="grpc",
      job_name=task.type,
      task_index=task.index) 
Example #17
Source File: train_embedding.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn) 
Example #18
Source File: train_embedding.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement,gpu_options=gpu_options)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
Example #19
Source File: train_autoencoder.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn) 
Example #20
Source File: train-with-rebuild.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def get_meta_filename(self, start_new_model, train_dir):
        if start_new_model:
            logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                         task_as_string(self.task))
            return None

        latest_checkpoint = tf.train.latest_checkpoint(train_dir)
        if not latest_checkpoint:
            logging.info("%s: No checkpoint file found. Building a new model.",
                         task_as_string(self.task))
            return None

        meta_filename = latest_checkpoint + ".meta"
        if not gfile.Exists(meta_filename):
            logging.info("%s: No meta graph file found. Building a new model.",
                         task_as_string(self.task))
            return None
        else:
            return meta_filename 
Example #21
Source File: train-with-rebuild.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server_if_distributed(self):
        """Starts a server if the execution is distributed."""

        if self.cluster:
            logging.info("%s: Starting trainer within cluster %s.",
                         task_as_string(self.task), self.cluster.as_dict())
            server = start_server(self.cluster, self.task)
            target = server.target
            device_fn = tf.train.replica_device_setter(
                ps_device="/job:ps",
                worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
                cluster=self.cluster)
        else:
            target = ""
            device_fn = ""
        return (target, device_fn) 
Example #22
Source File: train-with-rebuild.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def __init__(self, cluster, task, train_dir, log_device_placement=True):
        """"Creates a Trainer.

        Args:
          cluster: A tf.train.ClusterSpec if the execution is distributed.
            None otherwise.
          task: A TaskSpec describing the job type and the task index.
        """

        self.cluster = cluster
        self.task = task
        self.is_master = (task.type == "master" and task.index == 0)
        self.train_dir = train_dir
        self.config = tf.ConfigProto(log_device_placement=log_device_placement)

        if self.is_master and self.task.index > 0:
            raise StandardError("%s: Only one replica of master expected",
                                task_as_string(self.task)) 
Example #23
Source File: train_autoencoder.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def get_meta_filename(self, start_new_model, train_dir):
    if start_new_model:
      logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                   task_as_string(self.task))
      return None
    
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if not latest_checkpoint: 
      logging.info("%s: No checkpoint file found. Building a new model.",
                   task_as_string(self.task))
      return None
    
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("%s: No meta graph file found. Building a new model.",
                     task_as_string(self.task))
      return None
    else:
      return meta_filename 
Example #24
Source File: train_ensemble.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server(cluster, task):
  """Creates a Server.

  Args:
    cluster: A tf.train.ClusterSpec if the execution is distributed.
      None otherwise.
    task: A TaskSpec describing the job type and the task index.
  """

  if not task.type:
    raise ValueError("%s: The task type must be specified." %
                     task_as_string(task))
  if task.index is None:
    raise ValueError("%s: The task index must be specified." %
                     task_as_string(task))

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol="grpc",
      job_name=task.type,
      task_index=task.index) 
Example #25
Source File: train_ensemble.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def get_meta_filename(self, start_new_model, train_dir):
    if start_new_model:
      logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                   task_as_string(self.task))
      return None
    
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if not latest_checkpoint: 
      logging.info("%s: No checkpoint file found. Building a new model.",
                   task_as_string(self.task))
      return None
    
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("%s: No meta graph file found. Building a new model.",
                     task_as_string(self.task))
      return None
    else:
      return meta_filename 
Example #26
Source File: train.py    From youtube-8m with Apache License 2.0 6 votes vote down vote up
def start_server(cluster, task):
  """Creates a Server.

  Args:
    cluster: A tf.train.ClusterSpec if the execution is distributed.
      None otherwise.
    task: A TaskSpec describing the job type and the task index.
  """

  if not task.type:
    raise ValueError("%s: The task type must be specified." %
                     task_as_string(task))
  if task.index is None:
    raise ValueError("%s: The task index must be specified." %
                     task_as_string(task))

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol="grpc",
      job_name=task.type,
      task_index=task.index) 
Example #27
Source File: train_autoencoder.py    From youtube-8m with Apache License 2.0 5 votes vote down vote up
def remove_training_directory(self, train_dir):
    """Removes the training directory."""
    try:
      logging.info(
          "%s: Removing existing train directory.",
          task_as_string(self.task))
      gfile.DeleteRecursively(train_dir)
    except:
      logging.error(
          "%s: Failed to delete directory " + train_dir +
          " when starting a new model. Please delete it manually and" +
          " try again.", task_as_string(self.task)) 
Example #28
Source File: train-with-rebuild.py    From youtube-8m with Apache License 2.0 5 votes vote down vote up
def recover_model(self, meta_filename):
        logging.info("%s: Restoring from meta graph file %s",
                     task_as_string(self.task), meta_filename)
        return tf.train.import_meta_graph(meta_filename) 
Example #29
Source File: train-with-rebuild.py    From youtube-8m with Apache License 2.0 5 votes vote down vote up
def get_latest_checkpoint(self, start_new_model, train_dir):
        if start_new_model:
            logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                         task_as_string(self.task))
            return None

        latest_checkpoint = tf.train.latest_checkpoint(train_dir)
        if not latest_checkpoint:
            logging.info("%s: No checkpoint file found. Building a new model.",
                         task_as_string(self.task))
            return None

        return latest_checkpoint 
Example #30
Source File: train.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def main(_):
  """Create or load configuration and launch the trainer."""
  utility.set_up_logging()
  if not FLAGS.config:
    raise KeyError('You must specify a configuration.')
  logdir = FLAGS.logdir and os.path.expanduser(os.path.join(
      FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config)))
  try:
    config = utility.load_config(logdir)
  except IOError:
    config = tools.AttrDict(getattr(configs, FLAGS.config)())
    config = utility.save_config(config, logdir)
  for score in train(config, FLAGS.env_processes):
    tf.logging.info('Score {}.'.format(score))