Python tensorflow.case() Examples
The following are 30
code examples of tensorflow.case().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: common_video.py From BERT with Apache License 2.0 | 6 votes |
def finish(self): """Finishes transconding and returns the video. Returns: bytes Raises: IOError: in case of transcoding error. """ if self.proc is None: return None self.proc.stdin.close() for thread in (self._out_thread, self._err_thread): thread.join() (out, err) = [ b"".join(chunks) for chunks in (self._out_chunks, self._err_chunks) ] self.proc.stdout.close() self.proc.stderr.close() if self.proc.returncode: err = "\n".join([" ".join(self.cmd), err.decode("utf8")]) raise IOError(err) del self.proc self.proc = None return out
Example #2
Source File: base.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def video_features( self, all_frames, all_actions, all_rewards, all_raw_frames): """Optional video wide features. If the model requires access to all of the video frames (e.g. in case of approximating one latent for the whole video) override this function to add them. They will be accessible as video_features in next_frame function. Args: all_frames: list of all frames including input and target frames. all_actions: list of all actions including input and target actions. all_rewards: list of all rewards including input and target rewards. all_raw_frames: list of all raw frames (before modalities). Returns: video_features: a dictionary containing video-wide features. """ del all_frames, all_actions, all_rewards, all_raw_frames return None
Example #3
Source File: base.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def video_extra_loss(self, frames_predicted, frames_target, internal_states, video_features): """Optional video wide extra loss. If the model needs to calculate some extra loss across all predicted frames (e.g. in case of video GANS loss) override this function. Args: frames_predicted: list of all predicted frames. frames_target: list of all target frames. internal_states: internal states of the video. video_features: video wide features coming from video_features function. Returns: extra_loss: extra video side loss. """ del frames_predicted, frames_target, internal_states, video_features return 0.0
Example #4
Source File: dataset_util.py From multi-label-classification with MIT License | 6 votes |
def _augment(image): """ 对图片进行数据增强:饱和度,对比度, 亮度,加噪 :param image: 待增强图片 (H, W, ?) :return: """ image = DatasetUtil._add_noise(image) # 数据增强顺序 color_ordering = tf.random_uniform([], minval=0, maxval=4, dtype=tf.int32) image = tf.case(pred_fn_pairs=[(tf.equal(color_ordering, 0), lambda: DatasetUtil._augment_cond_0(image)), (tf.equal(color_ordering, 1), lambda: DatasetUtil._augment_cond_1(image)), (tf.equal(color_ordering, 2), lambda: DatasetUtil._augment_cond_2(image))], default=lambda: image) image = tf.clip_by_value(image, 0.0, 1.0) # 防止数据增强越界 return image
Example #5
Source File: common_video.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def finish(self): """Finishes transconding and returns the video. Returns: bytes Raises: IOError: in case of transcoding error. """ if self.proc is None: return None self.proc.stdin.close() for thread in (self._out_thread, self._err_thread): thread.join() (out, err) = [ b"".join(chunks) for chunks in (self._out_chunks, self._err_chunks) ] if self.proc.returncode: err = "\n".join([" ".join(self.cmd), err.decode("utf8")]) raise IOError(err) del self.proc self.proc = None return out
Example #6
Source File: optimizer.py From kfac with Apache License 2.0 | 6 votes |
def _get_qmodel_quantities(self, grads_and_vars): # Compute "preconditioned gradient". precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars) var_list = tuple(var for (_, var) in grads_and_vars) prev_updates_and_vars = self._compute_prev_updates(var_list) # While it might seem like this call performs needless computations # involving prev_updates_and_vars in the case where it is zero, because # we extract out only the part of the solution that is not zero the rest # of it will not actually be computed by TensorFlow (I think). m, c, b = self._compute_qmodel( precon_grads_and_vars, prev_updates_and_vars, grads_and_vars) return precon_grads_and_vars, m, c, b
Example #7
Source File: DataAugmentation.py From DeepDenoiser with Apache License 2.0 | 6 votes |
def permute_rgb(inputs, permute, data_format='channels_last'): assert Conv2dUtilities.has_valid_shape(inputs) assert Conv2dUtilities.number_of_channels(inputs, data_format) == 3 def _permute_rgb(inputs, permutation): channel_axis = Conv2dUtilities.channel_axis(inputs, data_format) result = tf.split(inputs, [1, 1, 1], channel_axis) result = tf.concat([result[permutation[0]], result[permutation[1]], result[permutation[2]]], channel_axis) return result cases =[ (tf.equal(permute, 1), lambda: _permute_rgb(inputs, [0, 2, 1])), (tf.equal(permute, 2), lambda: _permute_rgb(inputs, [1, 0, 2])), (tf.equal(permute, 3), lambda: _permute_rgb(inputs, [1, 2, 0])), (tf.equal(permute, 4), lambda: _permute_rgb(inputs, [2, 0, 1])), (tf.equal(permute, 5), lambda: _permute_rgb(inputs, [2, 1, 0]))] inputs = tf.case(cases, default=lambda: inputs, exclusive=True) return inputs
Example #8
Source File: ops.py From shuttleNet with GNU General Public License v3.0 | 6 votes |
def adjust_max(start, stop, start_value, stop_value, name=None): with ops.name_scope(name, "AdjustMax", [start, stop, name]) as name: global_step = tf.train.get_global_step() if global_step is not None: start = tf.convert_to_tensor(start, dtype=tf.int64) stop = tf.convert_to_tensor(stop, dtype=tf.int64) start_value = tf.convert_to_tensor(start_value, dtype=tf.float32) stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32) pred_fn_pairs = {} pred_fn_pairs[global_step <= start] = lambda: start_value pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay( start_value, global_step-start, stop-start, end_learning_rate=stop_value, power=1.0, cycle=False) default = lambda: stop_value return tf.case(pred_fn_pairs, default, exclusive=True) else: return None
Example #9
Source File: _transforms.py From tensorfx with Apache License 2.0 | 6 votes |
def _bucketize(instances, feature, schema, metadata): """Applies the bucketize transform to a numeric field. """ field = schema[feature.field] if not field.numeric: raise ValueError('A scale transform cannot be applied to non-numerical field "%s".' % feature.field) transform = feature.transform boundaries = map(float, transform['boundaries'].split(',')) # TODO: Figure out how to use tf.case instead of this contrib op from tensorflow.contrib.layers.python.ops.bucketization_op import bucketize # Create a one-hot encoded tensor. The dimension of this tensor is the set of buckets defined # by N boundaries == N + 1. # A squeeze is needed to remove the extra dimension added to the shape. value = instances[feature.field] value = tf.squeeze(tf.one_hot(bucketize(value, boundaries, name='bucket'), depth=len(boundaries) + 1, on_value=1.0, off_value=0.0, name='one_hot'), axis=1, name='bucketize') value.set_shape((None, len(boundaries) + 1)) return value
Example #10
Source File: functions.py From neuralmonkey with BSD 3-Clause "New" or "Revised" License | 6 votes |
def piecewise_function(param, values, changepoints, name=None, dtype=tf.float32): """Compute a piecewise function. Arguments: param: The function parameter. values: List of function values (numbers or tensors). changepoints: Sorted list of points where the function changes from one value to the next. Must be one item shorter than `values`. """ if len(changepoints) != len(values) - 1: raise ValueError("changepoints has length {}, expected {} (values " "has length {})".format(len(changepoints), len(values) - 1, len(values))) with tf.name_scope(name, "PiecewiseFunction", [param, values, changepoints]) as s_name: values = [tf.convert_to_tensor(y, dtype=dtype) for y in values] # this is a trick to make each lambda return a different y: lambdas = [lambda y=y: y for y in values] predicates = [tf.less(param, x) for x in changepoints] return tf.case(list(zip(predicates, lambdas[:-1])), lambdas[-1], name=s_name)
Example #11
Source File: post_export_metrics.py From model-analysis with Apache License 2.0 | 6 votes |
def _metric_key(self, base_key: Text) -> Text: """Constructs a metric key, including user-specified prefix if necessary. In cases with multi-headed models, an evaluation may need multiple instances of the same metric for different predictions and/or labels. To support this case, the metric should be named with the specified label to disambiguate between the two (and prevent key collisions). Args: base_key: The original key for the metric, often from metric_keys. Returns: Either the base key, or the key augmented with a specified tag or label. """ if self._metric_tag: return metric_keys.tagged_key(base_key, self._metric_tag) return base_key
Example #12
Source File: post_export_metrics.py From model-analysis with Apache License 2.0 | 6 votes |
def __init__(self, example_weight_key: Optional[Text] = None, target_prediction_keys: Optional[List[Text]] = None, labels_key: Optional[Text] = None, metric_tag: Optional[Text] = None, tensor_index: Optional[int] = None): """Create a metric that computes calibration. Args: example_weight_key: The key of the example weight column in the features dict. If None, all predictions are given a weight of 1.0. target_prediction_keys: If provided, the prediction keys to look for in order. labels_key: If provided, a custom label key. metric_tag: If provided, a custom metric tag. Only necessary to disambiguate instances of the same metric on different predictions. tensor_index: Optional index to specify class predictions to calculate metrics on in the case of multi-class models. """ self._example_weight_key = example_weight_key super(_Calibration, self).__init__( target_prediction_keys=target_prediction_keys, labels_key=labels_key, metric_tag=metric_tag)
Example #13
Source File: functions.py From neuralmonkey with BSD 3-Clause "New" or "Revised" License | 6 votes |
def piecewise_function(param, values, changepoints, name=None, dtype=tf.float32): """Compute a piecewise function. Arguments: param: The function parameter. values: List of function values (numbers or tensors). changepoints: Sorted list of points where the function changes from one value to the next. Must be one item shorter than `values`. """ if len(changepoints) != len(values) - 1: raise ValueError("changepoints has length {}, expected {} (values " "has length {})".format(len(changepoints), len(values) - 1, len(values))) with tf.name_scope(name, "PiecewiseFunction", [param, values, changepoints]) as s_name: values = [tf.convert_to_tensor(y, dtype=dtype) for y in values] # this is a trick to make each lambda return a different y: lambdas = [lambda y=y: y for y in values] predicates = [tf.less(param, x) for x in changepoints] return tf.case(list(zip(predicates, lambdas[:-1])), lambdas[-1], name=s_name)
Example #14
Source File: base.py From BERT with Apache License 2.0 | 6 votes |
def video_extra_loss(self, frames_predicted, frames_target, internal_states, video_features): """Optional video wide extra loss. If the model needs to calculate some extra loss across all predicted frames (e.g. in case of video GANS loss) override this function. Args: frames_predicted: list of all predicted frames. frames_target: list of all target frames. internal_states: internal states of the video. video_features: video wide features coming from video_features function. Returns: extra_loss: extra video side loss. """ del frames_predicted, frames_target, internal_states, video_features return 0.0
Example #15
Source File: base.py From BERT with Apache License 2.0 | 6 votes |
def video_features( self, all_frames, all_actions, all_rewards, all_raw_frames): """Optional video wide features. If the model requires access to all of the video frames (e.g. in case of approximating one latent for the whole video) override this function to add them. They will be accessible as video_features in next_frame function. Args: all_frames: list of all frames including input and target frames. all_actions: list of all actions including input and target actions. all_rewards: list of all rewards including input and target rewards. all_raw_frames: list of all raw frames (before modalities). Returns: video_features: a dictionary containing video-wide features. """ del all_frames, all_actions, all_rewards, all_raw_frames return None
Example #16
Source File: functions.py From neuralmonkey with BSD 3-Clause "New" or "Revised" License | 6 votes |
def piecewise_function(param, values, changepoints, name=None, dtype=tf.float32): """Compute a piecewise function. Arguments: param: The function parameter. values: List of function values (numbers or tensors). changepoints: Sorted list of points where the function changes from one value to the next. Must be one item shorter than `values`. """ if len(changepoints) != len(values) - 1: raise ValueError("changepoints has length {}, expected {} (values " "has length {})".format(len(changepoints), len(values) - 1, len(values))) with tf.name_scope(name, "PiecewiseFunction", [param, values, changepoints]) as s_name: values = [tf.convert_to_tensor(y, dtype=dtype) for y in values] # this is a trick to make each lambda return a different y: lambdas = [lambda y=y: y for y in values] predicates = [tf.less(param, x) for x in changepoints] return tf.case(list(zip(predicates, lambdas[:-1])), lambdas[-1], name=s_name)
Example #17
Source File: ops.py From basenji with Apache License 2.0 | 6 votes |
def adjust_max(start, stop, start_value, stop_value, name=None): with tf.name_scope(name, "AdjustMax", [start, stop, name]) as name: global_step = tf.train.get_or_create_global_step() if global_step is not None: start = tf.convert_to_tensor(start, dtype=tf.int64) stop = tf.convert_to_tensor(stop, dtype=tf.int64) start_value = tf.convert_to_tensor(start_value, dtype=tf.float32) stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32) pred_fn_pairs = {} pred_fn_pairs[global_step <= start] = lambda: start_value pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay( start_value, global_step-start, stop-start, end_learning_rate=stop_value, power=1.0, cycle=False) default = lambda: stop_value return tf.case(pred_fn_pairs, default, exclusive=True) else: return None
Example #18
Source File: post_export_metrics.py From model-analysis with Apache License 2.0 | 5 votes |
def _string_labels_to_class_ids(labels_tensor: tf.Tensor, classes_tensor: tf.Tensor) -> tf.Tensor: """Converts string labels into class IDs.""" # Note, the following code is preferrable but the classes are defined # dynamically so we would need to run tables_initializer for this to work and # we don't want to risk re-initializing other tables a second time. # with tf.control_dependencies([tf.compat.v1.tables_initializer()]): # classes_table = lookup_ops.index_table_from_tensor( # vocabulary_list=classes, name='class_id_lookup') # ... shape = tf.shape(input=labels_tensor) # Convert labels with shape (N) to (N, 1) if necessary expanded_tensor = tf.case([(tf.equal(tf.rank(labels_tensor), 1), lambda: tf.expand_dims(labels_tensor, axis=-1))], default=lambda: labels_tensor) # Creates a one-hot vector of shape (N, n_classes). For example: # classes_tensor = [['a', 'b', 'c'], ['a', 'b', 'c']] # expanded_tensor = [['b'], ['a']] # onehot_tensor = [[0, 1, 0], [1, 0, 0]] onehot_tensor = tf.compat.v1.where( tf.equal(classes_tensor, expanded_tensor), tf.ones(tf.shape(input=classes_tensor), dtype=tf.int64), tf.zeros(tf.shape(input=classes_tensor), dtype=tf.int64)) # Convert one-hot vector from shape (N, n_classes) to (N, 1) if expanded # shape was (N, 1). labels_tensor = tf.case([(tf.equal(tf.shape(input=expanded_tensor)[1], 1), lambda: tf.argmax(input=onehot_tensor, axis=1))], default=lambda: onehot_tensor) # Convert (N, 1) tensor to (N) if original input was (N) return tf.reshape(labels_tensor, shape)
Example #19
Source File: Architecture.py From ECCV2018-FaceDeSpoofing with MIT License | 5 votes |
def distorted_inputsB(a): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = FLAGS.data_dir if a==1: images, dmaps, labels, sizes, slabels = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size) else: images, dmaps, labels, sizes, slabels = cifar10_input.distorted_inputsA(data_dir=data_dir, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images = tf.cast(images, tf.float16) dmaps = tf.case(images, tf.float16) return images, dmaps, labels, sizes, slabels
Example #20
Source File: feature_normalization.py From monopsr with MIT License | 5 votes |
def tf_normalize_box_height_by_mean(unnormalized_box_height, class_strs): """Normalizes the 2D box height by dividing by the mean box height value of the class. Args: unnormalized_box_height: Unnormalized box height class_strs: tf.string of the class Returns: normalized_box_height: Normalized box height """ # See box_means.py def box_h_mean_car(): return 61.594734 def box_h_mean_pedestrian(): return 95.95055 def box_h_mean_cyclist(): return 76.85717 class_strs = tf.squeeze(class_strs) mean_box_h = tf.map_fn( lambda x: tf.case({ tf.equal(x, 'Car'): box_h_mean_car, tf.equal(x, 'Pedestrian'): box_h_mean_pedestrian, tf.equal(x, 'Cyclist'): box_h_mean_cyclist}), class_strs, dtype=tf.float32) normalized_box_height = unnormalized_box_height / tf.expand_dims(mean_box_h, 1) return normalized_box_height
Example #21
Source File: autoaugment_v1.py From mobilenetv2-yolov3 with MIT License | 5 votes |
def select_and_apply_random_policy(policies, image, bboxes): """Select a random policy from `policies` and apply it to `image`.""" policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) # Note that using tf.case instead of tf.conds would result in significantly # larger graphs and would even break export for some larger policies. for (i, policy) in enumerate(policies): image, bboxes = tf.cond( tf.equal(i, policy_to_select), lambda selected_policy=policy: selected_policy(image, bboxes), lambda: (image, bboxes)) return (image, bboxes)
Example #22
Source File: post_export_metrics.py From model-analysis with Apache License 2.0 | 5 votes |
def __init__(self, thresholds: List[float], example_weight_key: Optional[Text] = None, target_prediction_keys: Optional[List[Text]] = None, labels_key: Optional[Text] = None, metric_tag: Optional[Text] = None, tensor_index: Optional[int] = None) -> None: """Create a metric that computes the confusion matrix at given thresholds. Predictions should be one of: (a) a single float in [0, 1] (b) a dict containing the LOGISTIC key (c) a dict containing the PREDICTIONS key, where the prediction is in [0, 1] Label should be a single float that is in [0, 1] (string labels will be converted to 0 or 1 using ALL_CLASSES tensor if present). Args: thresholds: List of thresholds to compute the confusion matrix at. example_weight_key: The key of the example weight column in the features dict. If None, all predictions are given a weight of 1.0. target_prediction_keys: If provided, the prediction keys to look for in order. labels_key: If provided, a custom label key. metric_tag: If provided, a custom metric tag. Only necessary to disambiguate instances of the same metric on different predictions. tensor_index: Optional index to specify class predictions to calculate metrics on in the case of multi-class models. """ self._example_weight_key = example_weight_key self._thresholds = sorted(thresholds) super(_ConfusionMatrixBasedMetric, self).__init__( target_prediction_keys, labels_key, metric_tag, tensor_index=tensor_index)
Example #23
Source File: plan.py From fold with Apache License 2.0 | 5 votes |
def _tf_nth(fns, n): """Runs only the nth element of fns, where n is a scalar integer tensor.""" cases = [(tf.equal(tf.constant(i, n.dtype), n), fn) for i, fn in enumerate(fns)] final_pred, final_fn = cases.pop() def default(): with tf.control_dependencies([ tf.Assert(final_pred, [n, len(fns)], name='nth_index_error')]): return final_fn() if len(fns) == 1: return default() return tf.case(cases, default)
Example #24
Source File: layers.py From fold with Apache License 2.0 | 5 votes |
def _instantiate_subnet(self, batch, block_idx, seq_prefix): def zeros_fn(): return tf.zeros_like(batch) def base_case_fn(): return self._children[block_idx, seq_prefix](batch) def recursive_case_fn(): first_subnet = self._instantiate_subnet( batch, block_idx, seq_prefix + (0,)) return self._instantiate_subnet( first_subnet, block_idx, seq_prefix + (1,)) if len(seq_prefix) == self._fractal_block_depth: return base_case_fn() else: choice = self._drop_path_choices[self._choice_id[(block_idx, seq_prefix)]] base_case = tf.cond( tf.not_equal(choice, self._JUST_RECURSE), base_case_fn, zeros_fn) base_case.set_shape(batch.get_shape()) recursive_case = tf.cond( tf.not_equal(choice, self._JUST_BASE), recursive_case_fn, zeros_fn) recursive_case.set_shape(batch.get_shape()) cases = [ (tf.equal(choice, self._BOTH), lambda: self._mixer(base_case, recursive_case)), (tf.equal(choice, self._JUST_BASE), lambda: base_case), (tf.equal(choice, self._JUST_RECURSE), lambda: recursive_case)] result = tf.case(cases, lambda: base_case) result.set_shape(batch.get_shape()) return result
Example #25
Source File: layers.py From fold with Apache License 2.0 | 5 votes |
def _create_variables(self): """Creates the variables associated with this layer. Guaranteed to be called at most once, either when the layer's call operator is invoked for the first time, in which case the input type will have been set, or when the public method create_variables is called for the first time. Scope will be set to this layer's vscope. Raises: TypeError: If `input_type` is invalid for this layer or isn't set. """ pass
Example #26
Source File: Architecture.py From ECCV2018-FaceDeSpoofing with MIT License | 5 votes |
def inputs(testset): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = FLAGS.data_dir images, dmaps, labels, sizes, slabels = cifar10_input.inputs(testset = testset, data_dir=data_dir, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images = tf.cast(images, tf.float16) dmaps = tf.case(images, tf.float16) return images, dmaps, labels, sizes, slabels
Example #27
Source File: autoaugment_utils.py From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 | 5 votes |
def select_and_apply_random_policy(policies, image, bboxes): """Select a random policy from `policies` and apply it to `image`.""" policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) # Note that using tf.case instead of tf.conds would result in significantly # larger graphs and would even break export for some larger policies. for (i, policy) in enumerate(policies): image, bboxes = tf.cond( tf.equal(i, policy_to_select), lambda selected_policy=policy: selected_policy(image, bboxes), lambda: (image, bboxes)) return (image, bboxes)
Example #28
Source File: ntm.py From ntm_keras with BSD 3-Clause "New" or "Revised" License | 5 votes |
def build(self, input_shape): bs, input_length, input_dim = input_shape self.controller_input_dim, self.controller_output_dim = controller_input_output_shape( input_dim, self.units, self.m_depth, self.n_slots, self.shift_range, self.read_heads, self.write_heads) # Now that we've calculated the shape of the controller, we have add it to the layer/model. if self.controller is None: self.controller = Dense( name = "controller", activation = 'linear', bias_initializer = 'zeros', units = self.controller_output_dim, input_shape = (bs, input_length, self.controller_input_dim)) self.controller.build(input_shape=(self.batch_size, input_length, self.controller_input_dim)) self.controller_with_state = False # This is a fixed shift matrix self.C = _circulant(self.n_slots, self.shift_range) self.trainable_weights = self.controller.trainable_weights # We need to declare the number of states we want to carry around. # In our case the dimension seems to be 6 (LSTM) or 5 (GRU) or 4 (FF), # see self.get_initial_states, those respond to: # [old_ntm_output] + [init_M, init_wr, init_ww] + [init_h] (LSMT and GRU) + [(init_c] (LSTM only)) # old_ntm_output does not make sense in our world, but is required by the definition of the step function we # intend to use. # WARNING: What self.state_spec does is only poorly understood, # I only copied it from keras/recurrent.py. self.states = [None, None, None, None] self.state_spec = [InputSpec(shape=(None, self.output_dim)), # old_ntm_output InputSpec(shape=(None, self.n_slots, self.m_depth)), # Memory InputSpec(shape=(None, self.read_heads, self.n_slots)), # weights_read InputSpec(shape=(None, self.write_heads, self.n_slots))] # weights_write super(NeuralTuringMachine, self).build(input_shape)
Example #29
Source File: ntm.py From ntm_keras with BSD 3-Clause "New" or "Revised" License | 5 votes |
def _cosine_distance(M, k): # this is equation (6), or as I like to call it: The NaN factory. # TODO: Find it in a library (keras cosine loss?) # normalizing first as it is better conditioned. nk = K.l2_normalize(k, axis=-1) nM = K.l2_normalize(M, axis=-1) cosine_distance = K.batch_dot(nM, nk) # TODO: Do succesfull error handling #cosine_distance_error_handling = tf.Print(cosine_distance, [cosine_distance], message="NaN occured in _cosine_distance") #cosine_distance_error_handling = K.ones(cosine_distance_error_handling.shape) #cosine_distance = tf.case({K.any(tf.is_nan(cosine_distance)) : (lambda: cosine_distance_error_handling)}, # default = lambda: cosine_distance, strict=True) return cosine_distance
Example #30
Source File: 4pp_eusr.py From tf-perceptual-eusr with Apache License 2.0 | 5 votes |
def _generator(self, input_list, scale, reuse=False): with tf.variable_scope('generator', reuse=reuse): # pre-process input_list = tf.cast(input_list, tf.float32) input_list = self._mean_shift(input_list) x = input_list # first convolutional layer with tf.variable_scope('first_conv'): x = self._conv2d(x, num_features=self.num_conv_features, kernel_size=(3, 3)) # scale-specific local residual blocks with tf.variable_scope('initial_blocks'): pred_fn_pairs = [] pred_fn_pairs.append((tf.equal(scale, 2), lambda: self._scale_specific_processing(x, scale=2))) pred_fn_pairs.append((tf.equal(scale, 4), lambda: self._scale_specific_processing(x, scale=4))) pred_fn_pairs.append((tf.equal(scale, 8), lambda: self._scale_specific_processing(x, scale=8))) x = tf.case(pred_fn_pairs, exclusive=True) # shared residual module with tf.variable_scope('shared'): x = self._residual_module(x, num_features=self.num_conv_features, num_blocks=self.num_shared_blocks) # scale-specific upsampling with tf.variable_scope('upscaling'): pred_fn_pairs = [] pred_fn_pairs.append((tf.equal(scale, 2), lambda: self._scale_specific_upsampling(x, scale=2))) pred_fn_pairs.append((tf.equal(scale, 4), lambda: self._scale_specific_upsampling(x, scale=4))) pred_fn_pairs.append((tf.equal(scale, 8), lambda: self._scale_specific_upsampling(x, scale=8))) x = tf.case(pred_fn_pairs, exclusive=True) # post-process output_list = x output_list = self._mean_inverse_shift(output_list) return output_list