Python keras.backend.get_session() Examples

The following are 30 code examples of keras.backend.get_session(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module keras.backend , or try the search function .
Example #1
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 7 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ 
    Arguments
        ckpt_name       file name of the checkpoint
        var_scope_name  name of the variable scope
        scope           arg_scope
        constructor     constructor of the model
        input_tensor    tensor of input image
        label_offset    whether it is 1000 classes or 1001 classes, if it is 1001, remove class 0
        load_weights    whether to load weights
        kwargs 
            is_training 
            create_aux_logits 
    """
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #2
Source File: tfrecord_model.py    From sample-cnn with MIT License 6 votes vote down vote up
def predict_tfrecord(self, x_batch):
    if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
      ins = [0.]
    else:
      ins = []
    self._make_tfrecord_predict_function()

    try:
      sess = K.get_session()
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      outputs = self.predict_function(ins)

    finally:
      # TODO: If you close the queue, you can't open it again..
      # if stop_queue_runners:
      #   coord.request_stop()
      #   coord.join(threads)
      pass

    if len(outputs) == 1:
      return outputs[0]
    return outputs 
Example #3
Source File: util.py    From deeplift with MIT License 6 votes vote down vote up
def compile_func(inputs, outputs):
    if (isinstance(inputs, list)==False):
        print("Wrapping the inputs in a list...")
        inputs = [inputs]
    assert isinstance(inputs, list)
    def func_to_return(inp):
        if len(inp) > len(inputs) and len(inputs)==1:
            print("Wrapping the inputs in a list...")
            inp = [inp]
        assert len(inp)==len(inputs),\
            ("length of provided list should be "
             +str(len(inputs))+" for tensors "+str(inputs)
             +" but got input of length "+str(len(inp)))
        feed_dict = {}
        for input_tensor, input_val in zip(inputs, inp):
            feed_dict[input_tensor] = input_val 
        sess = get_session()
        return sess.run(outputs, feed_dict=feed_dict)  
    return func_to_return 
Example #4
Source File: image_classifier_tf.py    From aiexamples with Apache License 2.0 6 votes vote down vote up
def keras_to_tensorflow(keras_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):

    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)

    out_nodes = []

    for i in range(len(keras_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(keras_model.output[i], out_prefix + str(i + 1))

        sess = K.get_session()

        init_graph = sess.graph.as_graph_def()

        main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

        graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)

        if log_tensorboard:
            import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir) 
Example #5
Source File: test_shap.py    From AIX360 with Apache License 2.0 6 votes vote down vote up
def test_ShapGradientExplainer(self):

    #     model = VGG16(weights='imagenet', include_top=True)
    #     X, y = shap.datasets.imagenet50()
    #     to_explain = X[[39, 41]]
    #
    #     url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
    #     fname = shap.datasets.cache(url)
    #     with open(fname) as f:
    #         class_names = json.load(f)
    #
    #     def map2layer(x, layer):
    #         feed_dict = dict(zip([model.layers[0].input], [preprocess_input(x.copy())]))
    #         return K.get_session().run(model.layers[layer].input, feed_dict)
    #
    #     e = GradientExplainer((model.layers[7].input, model.layers[-1].output),
    #                           map2layer(preprocess_input(X.copy()), 7))
    #     shap_values, indexes = e.explain_instance(map2layer(to_explain, 7), ranked_outputs=2)
    #
          print("Skipped Shap GradientExplainer") 
Example #6
Source File: mnist_dnn.py    From tensorflow_examples with Apache License 2.0 6 votes vote down vote up
def export_savedmodel(model):
  print("input: {}, output: {}".format(model.input, model.output))
  model_signature = tf.saved_model.signature_def_utils.predict_signature_def(
      inputs={'input': model.input}, outputs={'output': model.output})

  model_path = "model"
  model_version = 1
  export_path = os.path.join(
      compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
  logging.info("Export the model to {}".format(export_path))

  builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  builder.add_meta_graph_and_variables(
      sess=K.get_session(),
      tags=[tf.saved_model.tag_constants.SERVING],
      clear_devices=True,
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          model_signature
      })
  builder.save() 
Example #7
Source File: callbacks.py    From training_results_v0.6 with Apache License 2.0 6 votes vote down vote up
def _average_metrics_in_place(self, logs):
        logs = logs or {}
        reduced_logs = {}
        # Reduce every metric among workers. Sort metrics by name
        # to ensure consistent order.
        for metric, value in sorted(logs.items()):
            if metric not in self.variables:
                self.variables[metric], self.allreduce_ops[metric] = \
                    self._make_variable(metric, value)
            else:
                K.set_value(self.variables[metric], value)
            reduced_logs[metric] = \
                K.get_session().run(self.allreduce_ops[metric])
        # Override the reduced values back into logs dictionary
        # for other callbacks to use.
        for metric, value in reduced_logs.items():
            logs[metric] = value 
Example #8
Source File: query_based_attack.py    From blackbox-attacks with MIT License 6 votes vote down vote up
def CW_est(logits, x, x_plus_i, x_minus_i, curr_sample, curr_target):
    curr_logits = K.get_session().run([logits], feed_dict={x: curr_sample})[0]
    # So that when max is taken, it returns max among classes apart from the
    # target
    curr_logits[np.arange(BATCH_SIZE), list(curr_target)] = -1e4
    max_indices = np.argmax(curr_logits, 1)
    logit_plus = K.get_session().run([logits], feed_dict={x: x_plus_i})[0]
    logit_plus_t = logit_plus[np.arange(BATCH_SIZE), list(curr_target)]
    logit_plus_max = logit_plus[np.arange(BATCH_SIZE), list(max_indices)]

    logit_minus = K.get_session().run([logits], feed_dict={x: x_minus_i})[0]
    logit_minus_t = logit_minus[np.arange(BATCH_SIZE), list(curr_target)]
    logit_minus_max = logit_minus[np.arange(BATCH_SIZE), list(max_indices)]

    logit_t_grad_est = (logit_plus_t - logit_minus_t)/args.delta
    logit_max_grad_est = (logit_plus_max - logit_minus_max)/args.delta

    return logit_t_grad_est/2.0, logit_max_grad_est/2.0 
Example #9
Source File: log_utils.py    From rpg_public_dronet with MIT License 6 votes vote down vote up
def on_epoch_end(self, epoch, logs={}):
        
        # Save training and validation losses
        logz.log_tabular('train_loss', logs.get('loss'))
        logz.log_tabular('val_loss', logs.get('val_loss'))
        logz.dump_tabular()

        # Save model every 'period' epochs
        if (epoch+1) % self.period == 0:
            filename = self.filepath + '/model_weights_' + str(epoch) + '.h5'
            print("Saved model at {}".format(filename))
            self.model.save_weights(filename, overwrite=True)

        # Hard mining
        sess = K.get_session()
        mse_function = self.batch_size-(self.batch_size-10)*(np.maximum(0.0,1.0-np.exp(-1.0/30.0*(epoch-30.0))))
        entropy_function = self.batch_size-(self.batch_size-5)*(np.maximum(0.0,1.0-np.exp(-1.0/30.0*(epoch-30.0))))
        self.model.k_mse.load(int(np.round(mse_function)), sess)
        self.model.k_entropy.load(int(np.round(entropy_function)), sess) 
Example #10
Source File: cifar10_query_based.py    From blackbox-attacks with MIT License 6 votes vote down vote up
def one_shot_method(prediction, x, curr_sample, curr_target, p_t):
    grad_est = np.zeros((BATCH_SIZE, IMAGE_ROWS, IMAGE_COLS, NUM_CHANNELS))
    DELTA = np.random.randint(2, size=(BATCH_SIZE, IMAGE_ROWS, IMAGE_COLS, NUM_CHANNELS))
    np.place(DELTA, DELTA==0, -1)

    y_plus = np.clip(curr_sample + args.delta * DELTA, CLIP_MIN, CLIP_MAX)
    y_minus = np.clip(curr_sample - args.delta * DELTA, CLIP_MIN, CLIP_MAX)

    if args.CW_loss == 0:
        pred_plus = K.get_session().run([prediction], feed_dict={x: y_plus, K.learning_phase(): 0})[0]
        pred_plus_t = pred_plus[np.arange(BATCH_SIZE), list(curr_target)]

        pred_minus = K.get_session().run([prediction], feed_dict={x: y_minus, K.learning_phase(): 0})[0]
        pred_minus_t = pred_minus[np.arange(BATCH_SIZE), list(curr_target)]

        num_est = (pred_plus_t - pred_minus_t)

    grad_est = num_est[:, None, None, None]/(args.delta * DELTA)

    # Getting gradient of the loss
    if args.CW_loss == 0:
        loss_grad = -1.0 * grad_est/p_t[:, None, None, None]

    return loss_grad 
Example #11
Source File: yolov3.py    From keras-onnx with MIT License 6 votes vote down vote up
def __init__(self, model_path='model_data/yolo.h5', anchors_path='model_data/yolo_anchors.txt', yolo3_dir=None):
        self.yolo3_dir = yolo3_dir
        self.model_path = model_path
        self.anchors_path = anchors_path
        self.classes_path = 'model_data/coco_classes.txt'
        self.score = 0.3
        self.iou = 0.45
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()
        self.model_image_size = (416, 416)  # fixed size or (None, None), hw
        self.session = None
        self.final_model = None

        # Generate colors for drawing bounding boxes.
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
        np.random.seed(10101)  # Fixed seed for consistent colors across runs.
        np.random.shuffle(self.colors)  # Shuffle colors to decorrelate adjacent classes.
        np.random.seed(None)  # Reset seed to default.
        K.set_learning_phase(0) 
Example #12
Source File: callbacks.py    From deform-conv with MIT License 5 votes vote down vote up
def set_model(self, model):
        self.model = model
        self.sess = K.get_session()
        total_loss = self.model.total_loss
        if self.histogram_freq and self.merged is None:
            for layer in self.model.layers:
                for weight in layer.weights:
                    # dense_1/bias:0 > dense_1/bias_0
                    name = weight.name.replace(':', '_')
                    tf.summary.histogram(name, weight)
                    tf.summary.histogram(
                        '{}_gradients'.format(name),
                        K.gradients(total_loss, [weight])[0]
                    )
                    if self.write_images:
                        w_img = tf.squeeze(weight)
                        shape = w_img.get_shape()
                        if len(shape) > 1 and shape[0] > shape[1]:
                            w_img = tf.transpose(w_img)
                        if len(shape) == 1:
                            w_img = tf.expand_dims(w_img, 0)
                        w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1)
                        tf.summary.image(name, w_img)

                if hasattr(layer, 'output'):
                    tf.summary.histogram('{}_out'.format(layer.name),
                                         layer.output)
        self.merged = tf.summary.merge_all()

        if self.write_graph:
            self.writer = tf.summary.FileWriter(self.log_dir,
                                                self.sess.graph)
        else:
            self.writer = tf.summary.FileWriter(self.log_dir) 
Example #13
Source File: keras_models.py    From gentun with Apache License 2.0 5 votes vote down vote up
def reset_weights(self):
        """Initialize model weights."""
        session = K.get_session()
        for layer in self.model.layers:
            if hasattr(layer, 'kernel_initializer'):
                layer.kernel.initializer.run(session=session) 
Example #14
Source File: train_mrcnn.py    From maskrcnn with MIT License 5 votes vote down vote up
def set_debugger_session():
    sess = K.get_session()
    sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    sess.add_tensor_filter('name_filter', name_filter)
    K.set_session(sess) 
Example #15
Source File: yolo.py    From keras-YOLOv3-mobilenet with MIT License 5 votes vote down vote up
def __init__(self, **kwargs):
        self.__dict__.update(self._defaults) # set up default values
        self.__dict__.update(kwargs) # and update with user overrides
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()
        self.boxes, self.scores, self.classes = self.generate() 
Example #16
Source File: yolo.py    From yolo3_keras_Flag_Detection with MIT License 5 votes vote down vote up
def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)  # set up default values
        self.__dict__.update(kwargs)  # and update with user overrides
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()
        self.boxes, self.scores, self.classes = self.generate() 
Example #17
Source File: __init__.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def allgather(value, name=None):
    """
    Perform an allgather on a tensor-compatible value.

    The concatenation is done on the first dimension, so the input values on the
    different processes must have the same rank and shape, except for the first
    dimension, which is allowed to be different.

    Arguments:
        value: A tensor-compatible value to gather.
        name: Optional name prefix for the constants created by this operation.
    """
    allgather_op = hvd.allgather(tf.constant(value, name=name))
    return K.get_session().run(allgather_op) 
Example #18
Source File: network_utils.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def restore_source_model(saved_pb_name, grad_dict=None):
    print('restoring', saved_pb_name)
    with open(saved_pb_name + '.pickle', 'rb') as f:
        info = pickle.load(f)
    print(info)
    sess = K.get_session() 
    print('restoring frozen graph def')
    with open(saved_pb_name + '.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

    # prepare set of tensors that needs to be found
    tensor_search_name = set()
    for gname, (input_names, output_names) in info.items():
        tensor_search_name = tensor_search_name.union(set(input_names+output_names))
    found_tensors = {}
    
    # search for tensors
    ops = tf.get_default_graph().get_operations()
    for op in ops:
        if len(op.outputs) != 1:
            continue
        if op.outputs[0].name in tensor_search_name:
            found_tensors[op.outputs[0].name] = op.outputs[0]

    flag = True
    for t in tensor_search_name:
        if t not in found_tensors:
            print('Tensor not found:', t)
            flag = False
    if not flag:
        return

    print('all nodes found')
    for gname, (input_names, output_names) in info.items():
        input_list = [found_tensors[tname] for tname in input_names]
        output_list = [found_tensors[tname] for tname in output_names]
        print('{0}\n  Input: {1}\n  Output: {2}\n'.format(gname, input_list, output_list))
        grad_dict[gname] = (input_list, output_list, K.function(input_list, output_list))
    print('restore finished') 
Example #19
Source File: network_utils.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def restore_source_model(saved_ckpt_name, grad_dict=None, from_frozen=False):
    print('restoring', saved_ckpt_name)
    with open(saved_ckpt_name + '.pickle', 'rb') as f:
        info = pickle.load(f)
    print(info[0])
    print(info[1])
    sess = K.get_session()
    if not from_frozen:
        print('restoring graph')
        saver = tf.train.import_meta_graph(saved_ckpt_name + '.ckpt.meta')
        print('restoring variables')
        saver.restore(sess, saved_ckpt_name + '.ckpt')
    else:
        print('restoring frozen graph def')
        with open(saved_ckpt_name + '.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    ops = tf.get_default_graph().get_operations()
    tensor_search_name = set(info[0] + list(info[1].values()))
    found_tensors = {}
    for op in ops:
        if len(op.outputs) != 1:
            continue
        if op.outputs[0].name.startswith('src_input'):
            print(op.outputs[0].name)
        if op.outputs[0].name in tensor_search_name:
            found_tensors[op.outputs[0].name] = op.outputs[0]
    input_tensors = [found_tensors[nm] for nm in info[0]]
    pred_input_tensors = input_tensors[:2] + [input_tensors[3]]
    print(input_tensors)
    for model_name, tensor_name in info[1].items():
        grad_dict[model_name] = K.function(\
                input_tensors if model_name != 'PRED' else pred_input_tensors, \
                [found_tensors[tensor_name]])
    print('restore finished') 
Example #20
Source File: network_utils.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def restore_source_model(saved_ckpt_name, grad_dict=None, from_frozen=False):
    print('restoring', saved_ckpt_name)
    with open(saved_ckpt_name + '.pickle', 'rb') as f:
        info = pickle.load(f)
    print(info[0])
    print(info[1])
    sess = K.get_session()
    if not from_frozen:
        print('restoring graph')
        saver = tf.train.import_meta_graph(saved_ckpt_name + '.ckpt.meta')
        print('restoring variables')
        saver.restore(sess, saved_ckpt_name + '.ckpt')
    else:
        print('restoring frozen graph def')
        with open(saved_ckpt_name + '.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    ops = tf.get_default_graph().get_operations()
    tensor_search_name = set(info[0] + list(info[1].values()))
    found_tensors = {}
    for op in ops:
        if len(op.outputs) != 1:
            continue
        if op.outputs[0].name.startswith('src_input'):
            print(op.outputs[0].name)
        if op.outputs[0].name in tensor_search_name:
            found_tensors[op.outputs[0].name] = op.outputs[0]
    input_tensors = [found_tensors[nm] for nm in info[0]]
    # TODO this part is not really right. need to integrate with
    # attack better
    # also in the end learning phase is added twice, that's why
    # it worked for the frozen graphs, but will break for the non-frozen ones
    pred_input_tensors = input_tensors + [input_tensors[2]]
    print(input_tensors)
    for model_name, tensor_name in info[1].items():
        grad_dict[model_name] = K.function(\
                input_tensors if model_name != 'PRED' else pred_input_tensors, \
                [found_tensors[tensor_name]])
    print('restore finished') 
Example #21
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ kwargs are is_training and create_aux_logits """
    print(var_scope_name)
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #22
Source File: network_utils.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def restore_source_model(saved_ckpt_name, grad_dict=None, from_frozen=False):
    print('restoring', saved_ckpt_name)
    with open(saved_ckpt_name + '.pickle', 'rb') as f:
        info = pickle.load(f)
    print(info[0])
    print(info[1])
    sess = K.get_session()
    if not from_frozen:
        print('restoring graph')
        saver = tf.train.import_meta_graph(saved_ckpt_name + '.ckpt.meta')
        print('restoring variables')
        saver.restore(sess, saved_ckpt_name + '.ckpt')
    else:
        print('restoring frozen graph def')
        with open(saved_ckpt_name + '.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    ops = tf.get_default_graph().get_operations()
    tensor_search_name = set(info[0] + list(info[1].values()))
    found_tensors = {}
    for op in ops:
        if len(op.outputs) != 1:
            continue
        if op.outputs[0].name.startswith('src_input'):
            print(op.outputs[0].name)
        if op.outputs[0].name in tensor_search_name:
            found_tensors[op.outputs[0].name] = op.outputs[0]
    input_tensors = [found_tensors[nm] for nm in info[0]]
    pred_input_tensors = input_tensors[:2] + [input_tensors[3]]
    print(input_tensors)
    for model_name, tensor_name in info[1].items():
        grad_dict[model_name] = K.function(\
                input_tensors if model_name != 'PRED' else pred_input_tensors, \
                [found_tensors[tensor_name]])
    print('restore finished') 
Example #23
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ kwargs are is_training and create_aux_logits """
    print(var_scope_name)
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #24
Source File: __init__.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def allreduce(value, name=None, average=True):
    """
    Perform an allreduce on a tensor-compatible value.

    Arguments:
        value: A tensor-compatible value to reduce.
               The shape of the input must be identical across all ranks.
        name: Optional name for the constants created by this operation.
        average: If True, computes the average over all ranks.
                 Otherwise, computes the sum over all ranks.
    """
    allreduce_op = hvd.allreduce(tf.constant(value, name=name), average=average)
    return K.get_session().run(allreduce_op) 
Example #25
Source File: utils.py    From dts with MIT License 5 votes vote down vote up
def get_flops(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()

    # We use the Keras session graph in the call to the profiler.
    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)

    return flops.total_float_ops  # Prints the "flops" of the model. 
Example #26
Source File: __init__.py    From BERT-keras with GNU General Public License v3.0 5 votes vote down vote up
def tpu_compatible():
    '''Fit the tpu problems we meet while using keras tpu model'''
    if not hasattr(tpu_compatible, 'once'):
        tpu_compatible.once = True
    else:
        return
    import tensorflow as tf
    import tensorflow.keras.backend as K
    _version = tf.__version__.split('.')
    is_correct_version = int(_version[0]) >= 1 and (int(_version[0]) >= 2 or int(_version[1]) >= 13)
    from tensorflow.contrib.tpu.python.tpu.keras_support import KerasTPUModel
    def initialize_uninitialized_variables():
        sess = K.get_session()
        uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])
        init_op = tf.variables_initializer(
            [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]
        )
        sess.run(init_op)

    _tpu_compile = KerasTPUModel.compile

    def tpu_compile(self,
                    optimizer,
                    loss=None,
                    metrics=None,
                    loss_weights=None,
                    sample_weight_mode=None,
                    weighted_metrics=None,
                    target_tensors=None,
                    **kwargs):
        if not is_correct_version:
            raise ValueError('You need tensorflow >= 1.3 for better keras tpu support!')
        _tpu_compile(self, optimizer, loss, metrics, loss_weights,
                     sample_weight_mode, weighted_metrics,
                     target_tensors, **kwargs)
        initialize_uninitialized_variables()  # for unknown reason, we should run this after compile sometimes

    KerasTPUModel.compile = tpu_compile 
Example #27
Source File: utils.py    From voxelmorph with GNU General Public License v3.0 5 votes vote down vote up
def reset_weights(model, session=None):
    """
    reset weights of model with the appropriate initializer.
    Note: only uses "kernel_initializer" and "bias_initializer"
    does not close session.

    Reference:
    https://www.codementor.io/nitinsurya/how-to-re-initialize-keras-model-weights-et41zre2g

    Parameters:
        model: keras model to reset
        session (optional): the current session
    """

    if session is None:
        session = K.get_session()

    for layer in model.layers: 
        reset = False
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)
            reset = True
        
        if hasattr(layer, 'bias_initializer'):
            layer.bias.initializer.run(session=session)
            reset = True
        
        if not reset:
            print('Could not find initializer for layer %s. skipping', layer.name) 
Example #28
Source File: __init__.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def broadcast(value, root_rank, name=None):
    """
    Perform a broadcast on a tensor-compatible value.

    Arguments:
        value: A tensor-compatible value to reduce.
               The shape of the input must be identical across all ranks.
        root_rank: Rank of the process from which global variables will be
                   broadcasted to all other processes.
        name: Optional name for the constants created by this operation.
    """
    bcast_op = hvd.broadcast(tf.constant(value, name=name), root_rank)
    return K.get_session().run(bcast_op) 
Example #29
Source File: word_vectors.py    From keras-image-captioning with MIT License 5 votes vote down vote up
def vectorize_words(self, words):
        vectors = []
        for word in words:
            vector = self._word_vector_of.get(word)
            vectors.append(vector)

        num_unknowns = len(filter(lambda x: x is None, vectors))
        inits = self._initializer(shape=(num_unknowns, self._embedding_size))
        inits = K.get_session().run(inits)
        inits = iter(inits)
        for i in range(len(vectors)):
            if vectors[i] is None:
                vectors[i] = next(inits)

        return np.array(vectors) 
Example #30
Source File: callbacks.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def on_train_begin(self, logs=None):
        with tf.device(self.device):
            bcast_op = hvd.broadcast_global_variables(self.root_rank)
            K.get_session().run(bcast_op)