Python tensorflow.keras.backend.set_session() Examples

The following are 10 code examples of tensorflow.keras.backend.set_session(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.keras.backend , or try the search function .
Example #1
Source File: networks.py    From rltrader with MIT License 6 votes vote down vote up
def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            inp = None
            output = None
            if self.shared_network is None:
                inp = Input((self.input_dim,))
                output = self.get_network_head(inp).output
            else:
                inp = self.shared_network.input
                output = self.shared_network.output
            output = Dense(
                self.output_dim, activation=self.activation, 
                kernel_initializer='random_normal')(output)
            self.model = Model(inp, output)
            self.model.compile(
                optimizer=SGD(lr=self.lr), loss=self.loss) 
Example #2
Source File: networks.py    From rltrader with MIT License 6 votes vote down vote up
def __init__(self, *args, num_steps=1, **kwargs):
        super().__init__(*args, **kwargs)
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            self.num_steps = num_steps
            inp = None
            output = None
            if self.shared_network is None:
                inp = Input((self.num_steps, self.input_dim, 1))
                output = self.get_network_head(inp).output
            else:
                inp = self.shared_network.input
                output = self.shared_network.output
            output = Dense(
                self.output_dim, activation=self.activation,
                kernel_initializer='random_normal')(output)
            self.model = Model(inp, output)
            self.model.compile(
                optimizer=SGD(lr=self.lr), loss=self.loss) 
Example #3
Source File: main.py    From stacks-usecase with Apache License 2.0 6 votes vote down vote up
def cpu_config(first=False):
    # intel optimizations
    num_cores, num_sockets = get_cpuinfo()
    if first:
        print("system info::")
        print("Number of physical cores:: ", num_cores)
        print("Number of sockets::", num_sockets)
    backend.set_session(
        tf.Session(
            config=tf.ConfigProto(
                intra_op_parallelism_threads=num_cores,
                inter_op_parallelism_threads=num_sockets,
            )
        )
    )
###########################################################
# Training
########################################################### 
Example #4
Source File: networks.py    From rltrader with MIT License 5 votes vote down vote up
def set_session(sess): pass 
Example #5
Source File: networks.py    From rltrader with MIT License 5 votes vote down vote up
def predict(self, sample):
        with self.lock:
            with graph.as_default():
                if sess is not None:
                    set_session(sess)
                return self.model.predict(sample).flatten() 
Example #6
Source File: networks.py    From rltrader with MIT License 5 votes vote down vote up
def train_on_batch(self, x, y):
        loss = 0.
        with self.lock:
            with graph.as_default():
                if sess is not None:
                    set_session(sess)
                loss = self.model.train_on_batch(x, y)
        return loss 
Example #7
Source File: networks.py    From rltrader with MIT License 5 votes vote down vote up
def get_shared_network(cls, net='dnn', num_steps=1, input_dim=0):
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            if net == 'dnn':
                return DNN.get_network_head(Input((input_dim,)))
            elif net == 'lstm':
                return LSTMNetwork.get_network_head(
                    Input((num_steps, input_dim)))
            elif net == 'cnn':
                return CNN.get_network_head(
                    Input((1, num_steps, input_dim))) 
Example #8
Source File: train.py    From rl with MIT License 5 votes vote down vote up
def _run(FLAGS):
  hparams = init_hparams(FLAGS)
  init_random_seeds(hparams)

  for run in range(hparams.copies):
    log_start_of_run(FLAGS, hparams, run)

    with tf.Session() as sess:
      K.set_session(sess)
      agent, checkpoint = init_agent(sess, hparams)

      restored = checkpoint.restore()
      if not restored:
        sess.run(tf.global_variables_initializer())

      if not hparams.test_only:
        log_graph()

        agent.clone_weights()

        if hparams.num_workers == 1:
          train(0, agent, hparams, checkpoint)
        else:
          workers = [
              threading.Thread(
                  target=train, args=(worker_id, agent, hparams, checkpoint))
              for worker_id in range(hparams.num_workers)
          ]

          for worker in workers:
            worker.start()

          for worker in workers:
            worker.join()
      else:
        test(hparams, agent)

    hparams = init_hparams(FLAGS) 
Example #9
Source File: experiment_engine.py    From brainstorm with MIT License 5 votes vote down vote up
def configure_gpus(gpus):
    # set gpu id and tf settings
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(g) for g in gpus])

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    K.set_session(tf.Session(config=config))


# loads a saved experiment using the saved parameters.
# runs all initialization steps so that we can use the models right away 
Example #10
Source File: train.py    From MultiPlanarUNet with MIT License 4 votes vote down vote up
def run(project_dir, gpu_mon, logger, args):
    """
    Runs training of a model in a mpunet project directory.

    Args:
        project_dir: A path to a mpunet project
        gpu_mon: An initialized GPUMonitor object
        logger: A mpunet logging object
        args: argparse arguments
    """
    # Read in hyperparameters from YAML file
    from mpunet.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/train_hparams.yaml", logger=logger)
    validate_hparams(hparams)

    # Wait for PID to terminate before continuing?
    if args.wait_for:
        from mpunet.utils import await_PIDs
        await_PIDs(args.wait_for)

    # Prepare sequence generators and potential model specific hparam changes
    train, val, hparams = get_data_sequences(project_dir=project_dir,
                                             hparams=hparams,
                                             logger=logger,
                                             args=args)

    # Set GPU visibility and create model with MirroredStrategy
    set_gpu(gpu_mon, args)
    import tensorflow as tf
    with tf.distribute.MirroredStrategy().scope():
        model = get_model(project_dir=project_dir, train_seq=train,
                          hparams=hparams, logger=logger, args=args)

        # Get trainer and compile model
        from mpunet.train import Trainer
        trainer = Trainer(model, logger=logger)
        trainer.compile_model(n_classes=hparams["build"].get("n_classes"),
                              reduction=tf.keras.losses.Reduction.NONE,
                              **hparams["fit"])

    # Debug mode?
    if args.debug:
        from tensorflow.python import debug as tfdbg
        from tensorflow.keras import backend as K
        K.set_session(tfdbg.LocalCLIDebugWrapperSession(K.get_session()))

    # Fit the model
    _ = trainer.fit(train=train, val=val,
                    train_im_per_epoch=args.train_images_per_epoch,
                    val_im_per_epoch=args.val_images_per_epoch,
                    hparams=hparams, no_im=args.no_images, **hparams["fit"])
    save_final_weights(model, project_dir, logger)