from tensorflow.python.framework import ops from utils import logger import numpy as np import tensorflow as tf hungarian_module = None log = logger.get() # Register gradient for Hungarian algorithm. ops.NoGradient("Hungarian") def get_device_fn(device): """Choose device for different ops.""" OPS_ON_CPU = set([ 'ResizeBilinear', 'ResizeBilinearGrad', 'Mod', 'Hungarian', 'SparseToDense', 'Print', 'Gather', 'Reverse' ]) def _device_fn(op): if op.type in OPS_ON_CPU: return "/cpu:0" else: return device return _device_fn def get_identity_match(num_ex, timespan, s_gt): zeros = tf.zeros(tf.pack([num_ex, timespan, timespan])) eye = tf.expand_dims(tf.constant(np.eye(timespan), dtype='float32'), 0) mask_x = tf.expand_dims(s_gt, 1) mask_y = tf.expand_dims(s_gt, 2) match = zeros + eye match = match * mask_x * mask_y return match def f_cum_min(s, d): """Calculates cumulative minimum. Args: s: Input matrix [B, D]. d: Second dim. Returns: s_min: [B, D], cumulative minimum across the second dim. """ s_min_list = [None] * d s_min_list[0] = s[:, 0:1] for ii in range(1, d): s_min_list[ii] = tf.minimum(s_min_list[ii - 1], s[:, ii:ii + 1]) return tf.concat(1, s_min_list) def f_cum_max(s, d): """Calculates cumulative maximum. Args: s: Input matrix [B, D]. d: Second dim. Returns: s_max: [B, D], cumulative maximum across the second dim, reversed. """ s_max_list = [None] * d s_max_list[-1] = s[:, d - 1:d] for ii in range(d - 2, -1, -1): s_max_list[ii] = tf.maximum(s_max_list[ii + 1], s[:, ii:ii + 1]) return tf.concat(1, s_max_list) def f_dice(a, b, timespan, pairwise=False): """Computes DICE score. Args: a: [B, N, H, W], or [N, H, W], or [H, W] b: [B, N, H, W], or [N, H, W], or [H, W] in pairwise mode, the second dimension can be different, e.g. [B, M, H, W], or [M, H, W], or [H, W] pairwise: whether the inputs are already aligned, outputs [B, N] or the inputs are orderless, outputs [B, N, M]. """ if pairwise: # N * [B, 1, M] y_list = [None] * timespan # [B, N, H, W] => [B, N, 1, H, W] a = tf.expand_dims(a, 2) # [B, N, 1, H, W] => N * [B, 1, 1, H, W] a_list = tf.split(1, timespan, a) # [B, M, H, W] => [B, 1, M, H, W] b = tf.expand_dims(b, 1) card_b = tf.reduce_sum(b + 1e-5, [3, 4]) for ii in range(timespan): # [B, 1, M] y_list[ii] = 2 * f_inter(a_list[ii], b) / \ (tf.reduce_sum(a_list[ii] + 1e-5, [3, 4]) + card_b) # N * [B, 1, M] => [B, N, M] return tf.concat(1, y_list) else: card_a = tf.reduce_sum(a + 1e-5, _get_reduction_indices(a)) card_b = tf.reduce_sum(b + 1e-5, _get_reduction_indices(b)) return 2 * f_inter(a, b) / (card_a + card_b) def f_inter(a, b): """Computes intersection.""" reduction_indices = _get_reduction_indices(a) return tf.reduce_sum(a * b, reduction_indices=reduction_indices) def f_union(a, b, eps=1e-5): """Computes union.""" reduction_indices = _get_reduction_indices(a) return tf.reduce_sum( a + b - (a * b) + eps, reduction_indices=reduction_indices) def _get_reduction_indices(a): """Gets the list of axes to sum over.""" dim = tf.shape(tf.shape(a)) return tf.concat(0, [dim - 2, dim - 1]) def f_iou(a, b, timespan=None, pairwise=False): """ Computes IOU score. Args: a: [B, N, H, W], or [N, H, W], or [H, W] b: [B, N, H, W], or [N, H, W], or [H, W] in pairwise mode, the second dimension can be different, e.g. [B, M, H, W], or [M, H, W], or [H, W] pairwise: whether the inputs are already aligned, outputs [B, N] or the inputs are orderless, outputs [B, N, M]. Returns: iou: [B, N] """ if pairwise: # N * [B, 1, M] y_list = [None] * timespan # [B, N, H, W] => [B, N, 1, H, W] a = tf.expand_dims(a, 2) # [B, N, 1, H, W] => N * [B, 1, 1, H, W] a_list = tf.split(1, timespan, a) # [B, M, H, W] => [B, 1, M, H, W] b = tf.expand_dims(b, 1) for ii in range(timespan): # [B, 1, M] y_list[ii] = f_inter(a_list[ii], b) / f_union(a_list[ii], b) # N * [B, 1, M] => [B, N, M] return tf.concat(1, y_list) else: return f_inter(a, b) / f_union(a, b) def f_iou_pair_new(a, b): """ a: [B, N, H, W] b: [B, N, H, W] """ a = tf.tile(tf.expand_dims(a, 2), tf.pack([1, 1, tf.shape(b)[1], 1, 1])) b = tf.expand_dims(b, 1) inter = tf.reduce_sum(a * b, [3, 4]) union = tf.reduce_sum(a + b, [3, 4]) union = tf.maximum(union - inter, 1) return inter / union def f_iou_all(a, b): """Computes total IOU score Args: a: Any shape b: Any shape Returns: iou: float """ inter = tf.reduce_sum(a * b) union = tf.reduce_sum(a) + tf.reduce_sum(b) - inter + 1e-5 return inter / union def f_inter_box(top_left_a, bot_right_a, top_left_b, bot_right_b): """Computes intersection area with boxes. Args: top_left_a: [B, T, 2] or [B, 2] bot_right_a: [B, T, 2] or [B, 2] top_left_b: [B, T, 2] or [B, 2] bot_right_b: [B, T, 2] or [B, 2] Returns: area: [B, T] """ top_left_max = tf.maximum(top_left_a, top_left_b) bot_right_min = tf.minimum(bot_right_a, bot_right_b) ndims = tf.shape(tf.shape(top_left_a)) # Check if the resulting box is valid. overlap = tf.to_float(top_left_max < bot_right_min) overlap = tf.reduce_prod(overlap, ndims - 1) area = tf.reduce_prod(bot_right_min - top_left_max, ndims - 1) area = overlap * tf.abs(area) return area def f_iou_box(top_left_a, bot_right_a, top_left_b, bot_right_b): """Compute IOU of boxes. Args: top_left_a: [B, T, 2] bot_right_a: [B, T, 2] top_left_b: [B, T, 2] bot_right_b: [B, T, 2] Returns: iou: [B, T] or [B] """ y1A = top_left_a[:, :, 0] x1A = top_left_a[:, :, 1] y2A = bot_right_a[:, :, 0] x2A = bot_right_a[:, :, 1] y1B = top_left_b[:, :, 0] x1B = top_left_b[:, :, 1] y2B = bot_right_b[:, :, 0] x2B = bot_right_b[:, :, 1] # compute intersection x1_max = tf.maximum(x1A, x1B) y1_max = tf.maximum(y1A, y1B) x2_min = tf.minimum(x2A, x2B) y2_min = tf.minimum(y2A, y2B) overlap_flag = tf.to_float(x1_max < x2_min) * tf.to_float(y1_max < y2_min) overlap_area = overlap_flag * (x2_min - x1_max) * (y2_min - y1_max) # compute union areaA = (x2A - x1A) * (y2A - y1A) areaB = (x2B - x1B) * (y2B - y1B) union_area = areaA + areaB - overlap_area return tf.div(overlap_area, union_area) def f_iou_box_old(top_left_a, bot_right_a, top_left_b, bot_right_b): """Computes IoU of boxes. Args: top_left_a: [B, T, 2] or [B, 2] bot_right_a: [B, T, 2] or [B, 2] top_left_b: [B, T, 2] or [B, 2] bot_right_b: [B, T, 2] or [B, 2] Returns: iou: [B, T] """ inter_area = f_inter_box(top_left_a, bot_right_a, top_left_b, bot_right_b) inter_area = tf.maximum(inter_area, 1e-6) ndims = tf.shape(tf.shape(top_left_a)) # area_a = tf.reduce_prod(bot_right_a - top_left_a, ndims - 1) # area_b = tf.reduce_prod(bot_right_b - top_left_b, ndims - 1) check_a = tf.reduce_prod(tf.to_float(top_left_a < bot_right_a), ndims - 1) area_a = check_a * tf.reduce_prod(bot_right_a - top_left_a, ndims - 1) check_b = tf.reduce_prod(tf.to_float(top_left_b < bot_right_b), ndims - 1) area_b = check_b * tf.reduce_prod(bot_right_b - top_left_b, ndims - 1) union_area = (area_a + area_b - inter_area + 1e-5) union_area = tf.maximum(union_area, 1e-5) iou = inter_area / union_area iou = tf.maximum(iou, 1e-5) iou = tf.minimum(iou, 1.0) return iou def f_coverage(iou): """Coverage function proposed in [1] [1] N. Silberman, D. Sontag, R. Fergus. Instance segmentation of indoor scenes using a coverage loss. ECCV 2015. Args: iou: [B, N, N]. Pairwise IoU. """ return tf.reduce_max(iou, [1]) def f_coverage_weight(y_gt): """Compute the normalized weight for each groundtruth instance.""" # [B, T] y_gt_sum = tf.reduce_sum(y_gt, [2, 3]) # Plus one to avoid dividing by zero. # The resulting weight will be zero for any zero cardinality instance. # [B, 1] y_gt_sum_sum = tf.reduce_sum( y_gt_sum, [1], keep_dims=True) + tf.to_float(tf.equal(y_gt_sum, 0)) # [B, T] return y_gt_sum / y_gt_sum_sum def f_weighted_coverage(iou, y_gt): """Weighted coverage score. Args: iou: [B, N, N]. Pairwise IoU. y_gt: [B, N, H, W]. Groundtruth segmentations. """ cov = f_coverage(iou) wt = f_coverage_weight(y_gt) num_ex = tf.to_float(tf.shape(y_gt)[0]) return tf.reduce_sum(cov * wt) / num_ex def f_unweighted_coverage(iou, count): """Unweighted coverage score. Args: iou: [B, N, N]. Pairwise IoU. """ # [B, N] cov = f_coverage(iou) num_ex = tf.to_float(tf.shape(iou)[0]) return tf.reduce_sum(tf.reduce_sum(cov, [1]) / count) / num_ex def f_conf_loss(s_out, match, timespan, use_cum_min=True): """Loss function for confidence score sequence. Args: s_out: match: use_cum_min: """ s_out_shape = tf.shape(s_out) num_ex = tf.to_float(s_out_shape[0]) max_num_obj = tf.to_float(s_out_shape[1]) match_sum = tf.reduce_sum(match, reduction_indices=[2]) # Loss for confidence scores. if use_cum_min: # [B, N] s_out_min = f_cum_min(s_out, timespan) s_out_max = f_cum_max(s_out, timespan) # [B, N] s_bce = f_bce_minmax(s_out_min, s_out_max, match_sum) else: s_bce = f_bce(s_out, match_sum) loss = tf.reduce_sum(s_bce) / num_ex / max_num_obj return loss def f_sem_loss(s_out, match, c_gt, timespan, num_semantic_classes, use_cum_min=True): # General monotonic score loss. c_loss = f_conf_loss( 1 - s_out[:, :, 0], match, timespan, use_cum_min=use_cum_min) # Match [B, T, T] # C_gt [B, T, C] => [B, 1, T, C] # C_gt' [B, T, T] * [B, 1, T, C] = [B, T, T, C] => [B, T, C] m2 = tf.tile(tf.expand_dims(match, 3), [1, 1, 1, num_semantic_classes]) c_gt2 = tf.reduce_sum(m2 * tf.expand_dims(c_gt, 1), [2]) s_out_shape = tf.shape(s_out) num_ex = tf.to_float(s_out_shape[0]) max_num_obj = tf.to_float(s_out_shape[1]) s_loss = tf.reduce_sum(f_ce(s_out, c_gt2)) / num_ex / max_num_obj return c_loss + s_loss # return s_loss def f_greedy_match(score, matched): """Compute greedy matching given the IOU, and matched. Args: score: [B, N] relatedness score, positive. matched: [B, N] binary mask Returns: match: [B, N] binary mask """ score = score * (1.0 - matched) max_score = tf.reshape(tf.reduce_max(score, reduction_indices=[1]), [-1, 1]) match = tf.to_float(tf.equal(score, max_score)) match_sum = tf.reshape(tf.reduce_sum(match, reduction_indices=[1]), [-1, 1]) return match / match_sum def f_segm_match(iou, s_gt): """Matching between segmentation output and groundtruth. Args: y_out: [B, T, H, W], output segmentations y_gt: [B, T, H, W], groundtruth segmentations s_gt: [B, T], groudtruth score sequence """ global hungarian_module if hungarian_module is None: mod_name = './hungarian.so' hungarian_module = tf.load_op_library(mod_name) log.info('Loaded library "{}"'.format(mod_name)) # Mask X, [B, M] => [B, 1, M] mask_x = tf.expand_dims(s_gt, dim=1) # Mask Y, [B, M] => [B, N, 1] mask_y = tf.expand_dims(s_gt, dim=2) iou_mask = iou * mask_x * mask_y # Keep certain precision so that we can get optimal matching within # reasonable time. eps = 1e-5 precision = 1e6 iou_mask = tf.round(iou_mask * precision) / precision match_eps = hungarian_module.hungarian(iou_mask + eps)[0] # [1, N, 1, 1] s_gt_shape = tf.shape(s_gt) num_segm_out = s_gt_shape[1] num_segm_out_mul = tf.pack([1, num_segm_out, 1]) # Mask the graph algorithm output. match = match_eps * mask_x * mask_y return match def f_ce(y_out, y_gt): """Multiclass cross entropy.""" eps = 1e-5 return -y_gt * tf.log(y_out + eps) def f_bce(y_out, y_gt): """Binary cross entropy.""" eps = 1e-5 return -y_gt * tf.log(y_out + eps) - (1 - y_gt) * tf.log(1 - y_out + eps) def f_bce_minmax(y_out_min, y_out_max, y_gt): """Binary cross entropy (encourages monotonic decreasing). Use minimum (cumulative from start) to compare against 1. Use maximum (cumulative till end) to compare against 0. """ eps = 1e-5 return -y_gt * tf.log(y_out_min + eps) - (1 - y_gt ) * tf.log(1 - y_out_max + eps) def f_match_loss(y_out, y_gt, match, timespan, loss_fn, model=None): """Binary cross entropy with matching. Args: y_out: [B, N, H, W] or [B, N, D] y_gt: [B, N, H, W] or [B, N, D] match: [B, N, N] match_count: [B] timespan: N loss_fn: """ # N * [B, 1, H, W] y_out_list = tf.split(1, timespan, y_out) # N * [B, 1, N] match_list = tf.split(1, timespan, match) err_list = [None] * timespan shape = tf.shape(y_out) num_ex = tf.to_float(shape[0]) num_dim = tf.to_float(tf.reduce_prod(tf.to_float(shape[2:]))) sshape = tf.size(shape) # [B, N, M] => [B, N] match_sum = tf.reduce_sum(match, reduction_indices=[2]) # [B, N] => [B] match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(match_count, 1) for ii in range(timespan): # [B, 1, H, W] * [B, N, H, W] => [B, N, H, W] => [B, N] # [B, N] * [B, N] => [B] # [B] => [B, 1] red_idx = tf.range(2, sshape) err_list[ii] = tf.expand_dims( tf.reduce_sum( tf.reduce_sum(loss_fn(y_out_list[ii], y_gt), red_idx) * tf.reshape(match_list[ii], [-1, timespan]), [1]), 1) # N * [B, 1] => [B, N] => [B] err_total = tf.reduce_sum(tf.concat(1, err_list), reduction_indices=[1]) return tf.reduce_sum(err_total / match_count) / num_ex / num_dim def f_count_acc(s_out, s_gt): """Counting accuracy. Args: s_out: s_gt: """ num_ex = tf.to_float(tf.shape(s_out)[0]) count_out = tf.reduce_sum(tf.to_float(s_out > 0.5), reduction_indices=[1]) count_gt = tf.reduce_sum(s_gt, reduction_indices=[1]) count_acc = tf.reduce_sum(tf.to_float(tf.equal(count_out, count_gt))) / num_ex return count_acc def f_dic(s_out, s_gt, abs=False): """Difference in count. Args: s_out: s_gt: """ num_ex = tf.to_float(tf.shape(s_out)[0]) count_out = tf.reduce_sum(tf.to_float(s_out > 0.5), reduction_indices=[1]) count_gt = tf.reduce_sum(s_gt, reduction_indices=[1]) count_diff = count_out - count_gt if abs: count_diff = tf.abs(count_diff) count_diff = tf.reduce_sum(tf.to_float(count_diff)) / num_ex return count_diff def f_huber(y_out, y_gt, threshold=1.0): """Huber loss. Smooth combination of L2 and L1 loss for robustness.""" size = tf.size(y_out) err = y_out - y_gt ind = tf.to_float(err <= 1) squared_err = 0.5 * err * err l1_err = tf.abs(err) - (threshold - 0.5 * (threshold**2)) huber = squared_err * ind + l1_err * (1 - ind) return huber def f_squared_err(y_out, y_gt): """Mean squared error (L2) loss.""" err = y_out - y_gt squared_err = 0.5 * err * err return squared_err def build_skip_conn_inner(cnn_channels, h_cnn, x): """Build skip connection.""" skip = [None] skip_ch = [0] for jj, layer in enumerate(h_cnn[-2::-1] + [x]): skip.append(layer_reshape) ch_idx = len(cnn_channels) - jj - 2 skip_ch.append(cnn_channels[ch_idx]) return skip, skip_ch def build_skip_conn(cnn_channels, h_cnn, x, timespan): """Build skip connection.""" skip = [None] skip_ch = [0] for jj, layer in enumerate(h_cnn[-2::-1] + [x]): ss = tf.shape(layer) zeros = tf.zeros(tf.pack([ss[0], timespan, ss[1], ss[2], ss[3]])) new_shape = tf.pack([ss[0] * timespan, ss[1], ss[2], ss[3]]) layer_reshape = tf.reshape(tf.expand_dims(layer, 1) + zeros, new_shape) skip.append(layer_reshape) ch_idx = len(cnn_channels) - jj - 2 skip_ch.append(cnn_channels[ch_idx]) return skip, skip_ch def build_skip_conn_attn(cnn_channels, h_cnn_time, x_time, timespan): """Build skip connection for attention based model.""" skip = [None] skip_ch = [0] nlayers = len(h_cnn_time[0]) timespan = len(h_cnn_time) for jj in range(nlayers): lidx = nlayers - jj - 2 if lidx >= 0: ll = [h_cnn_time[tt][lidx] for tt in range(timespan)] else: ll = x_time layer = tf.concat(1, [tf.expand_dims(l, 1) for l in ll]) ss = tf.shape(layer) layer = tf.reshape(layer, tf.pack([-1, ss[2], ss[3], ss[4]])) skip.append(layer) ch_idx = lidx + 1 skip_ch.append(cnn_channels[ch_idx]) return skip, skip_ch def get_gaussian_filter(center, size, lg_var, image_size, filter_size): """Get Gaussian-based attention filter along one dimension Args: center: center of one dimension (mean), [B] delta: delta of one dimension (size), [B] lg_var: variance of the filter, [B] image_size: image size of one dimension, [B] filter_size: filter size of one dimension, [B] """ # [1, 1, F]. span_filter = tf.to_float(tf.reshape(tf.range(filter_size), [1, 1, -1])) # [B, 1, 1] center = tf.reshape(center, [-1, 1, 1]) size = tf.reshape(size, [-1, 1, 1]) # [B, 1, 1] + [B, 1, 1] * [1, F, 1] = [B, 1, F] # mu = center + size / filter_size * (span_filter - (filter_size - 1) / 2.0) mu = center + (size + 1) / filter_size * \ (span_filter - (filter_size - 1) / 2.0) # [B, 1, 1] lg_var = tf.reshape(lg_var, [-1, 1, 1]) # [1, L, 1] span = tf.to_float( tf.reshape(tf.range(image_size), tf.pack([1, image_size, 1]))) # [1, L, 1] - [B, 1, F] = [B, L, F] filt = tf.mul(1 / tf.sqrt(tf.exp(lg_var)) / tf.sqrt(2 * np.pi), tf.exp(-0.5 * (span - mu) * (span - mu) / tf.exp(lg_var))) return filt def extract_patch(x, f_y, f_x, nchannels, normalize=False): """ Args: x: [B, H, W, D] f_y: [B, H, FH] f_x: [B, W, FH] nchannels: D Returns: patch: [B, FH, FW] """ patch = [None] * nchannels fsize_h = tf.shape(f_y)[2] fsize_w = tf.shape(f_x)[2] hh = tf.shape(x)[1] ww = tf.shape(x)[2] for dd in range(nchannels): # [B, H, W] x_ch = tf.reshape( tf.slice(x, [0, 0, 0, dd], [-1, -1, -1, 1]), tf.pack([-1, hh, ww])) patch[dd] = tf.reshape( tf.batch_matmul( tf.batch_matmul( f_y, x_ch, adj_x=True), f_x), tf.pack([-1, fsize_h, fsize_w, 1])) return tf.concat(3, patch) def get_gt_attn(y_gt, filter_height, filter_width, padding_ratio=0.0, center_shift_ratio=0.0, min_padding=10.0): """Get groundtruth attention box given segmentation.""" top_left, bot_right, box = get_gt_box( y_gt, padding_ratio=padding_ratio, center_shift_ratio=center_shift_ratio, min_padding=min_padding) ctr, size = get_box_ctr_size(top_left, bot_right) # lg_var = tf.zeros(tf.shape(ctr)) + 1.0 lg_var = get_normalized_var(size, filter_height, filter_width) lg_gamma = get_normalized_gamma(size, filter_height, filter_width) return ctr, size, lg_var, lg_gamma, box, top_left, bot_right def get_gt_box(y_gt, padding_ratio=0.0, center_shift_ratio=0.0, min_padding=10.0): """Get groundtruth bounding box given segmentation. Current only support [B, T, H, W] as input!!! Args: y_gt: Groundtruth segmentation [B, T, H, W], or [B, H, W] Returns: top_left: Bounding box top left coordinates [B, T, 2], or [B, 2] bot_right: Bounding box bottom right coordinates [B, T, 2], or [B, 2] """ s = tf.shape(y_gt) # [B, T, H, W, 2] idx = get_idx_map(s) y_gt_not_zero = tf.to_float(tf.reduce_sum(y_gt, [2, 3]) > 0) y_gt_not_zero = tf.expand_dims(y_gt_not_zero, 2) idx_min = idx + tf.expand_dims((1.0 - y_gt) * tf.to_float(s[2] * s[3]), 4) idx_max = idx * tf.expand_dims(y_gt, 4) # [B, T, 2] top_left = tf.reduce_min(idx_min, reduction_indices=[2, 3]) bot_right = tf.reduce_max(idx_max, reduction_indices=[2, 3]) # Enlarge the groundtruth box by some padding. size = bot_right - top_left top_left += center_shift_ratio * size top_left -= tf.maximum(padding_ratio * size, min_padding) bot_right += center_shift_ratio * size bot_right += tf.maximum(padding_ratio * size, min_padding) box = get_filled_box_idx(idx, top_left, bot_right) # If the segmentation is zero, then fix to top left corner. top_left *= y_gt_not_zero bot_right = y_gt_not_zero * bot_right + \ (1 - y_gt_not_zero) * (2 * min_padding) return top_left, bot_right, box def get_idx_map(shape): """Get index map for a image. Args: shape: [B, T, H, W] or [B, H, W] Returns: idx: [B, T, H, W, 2], or [B, H, W, 2] """ s = shape ndims = tf.shape(s) wdim = ndims - 1 hdim = ndims - 2 idx_shape = tf.concat(0, [s, tf.constant([1])]) ones_h = tf.ones(hdim - 1, dtype='int32') ones_w = tf.ones(wdim - 1, dtype='int32') h_shape = tf.concat(0, [ones_h, tf.constant([-1]), tf.constant([1, 1])]) w_shape = tf.concat(0, [ones_w, tf.constant([-1]), tf.constant([1])]) idx_y = tf.zeros(idx_shape, dtype='float') idx_x = tf.zeros(idx_shape, dtype='float') h = tf.slice(s, ndims - 2, [1]) w = tf.slice(s, ndims - 1, [1]) idx_y += tf.reshape(tf.to_float(tf.range(h[0])), h_shape) idx_x += tf.reshape(tf.to_float(tf.range(w[0])), w_shape) idx = tf.concat(ndims[0], [idx_y, idx_x]) return idx def get_filled_box_idx(idx, top_left, bot_right): """Fill a box with top left and bottom right coordinates. Args: idx: [B, T, H, W, 2] or [B, H, W, 2] or [H, W, 2] top_left: [B, T, 2] or [B, 2] or [2] bot_right: [B, T, 2] or [B, 2] or [2] """ ss = tf.shape(idx) ndims = tf.shape(ss) batch = tf.slice(ss, [0], ndims - 3) coord_shape = tf.concat(0, [batch, tf.constant([1, 1, 2])]) top_left = tf.reshape(top_left, coord_shape) bot_right = tf.reshape(bot_right, coord_shape) lower = tf.reduce_prod(tf.to_float(idx >= top_left), ndims - 1) upper = tf.reduce_prod(tf.to_float(idx <= bot_right), ndims - 1) box = lower * upper return box def get_unnormalized_center(ctr_norm, inp_height, inp_width): """Get unnormalized center coordinates Args: ctr_norm: [B, T, 2] or [B, 2] or [2], normalized within range [-1, +1] inp_height: int, image height inp_width: int, image width Returns: ctr: [B, 2] """ img_size = tf.to_float(tf.pack([inp_height, inp_width])) img_size = img_size / 2.0 ctr = (ctr_norm + 1.0) * img_size return ctr def get_normalized_center(ctr, inp_height, inp_width): """Get unnormalized center coordinates Args: ctr: [B, T, 2] or [B, 2] or [2] inp_height: int, image height inp_width: int, image width Returns: ctr: [B, 2], normalized within range [-1, +1] """ img_size = tf.to_float(tf.pack([inp_height, inp_width])) img_size = img_size / 2.0 ctr = ctr / img_size - 1 return ctr def get_normalized_var(size, filter_height, filter_width): """Get normalized variance. Args: size: [B, T, 2] or [B, 2] or [2] filter_height: int filter_width: int Returns: lg_var: [B, T, 2] or [B, 2] or [2] """ filter_size = tf.to_float(tf.pack([filter_height, filter_width])) lg_var = tf.log(size) - tf.log(filter_size) return lg_var def get_normalized_gamma(size, filter_height, filter_width): """Get normalized gamma. Args: size: [B, T, 2] or [B, 2] or [2] filter_height: int filter_width: int Returns: lg_gamma: [B, T] or [B] or float """ rank = tf.shape(tf.shape(size)) filter_area = filter_height * filter_width area = tf.reduce_prod(size, rank - 1) lg_gamma = tf.log(float(filter_area)) - tf.log(area) return lg_gamma def get_unnormalized_size(lg_size, inp_height, inp_width): """Get unnormalized patch size. Args: lg_size: [B, T, 2] or [B, 2] or [2], logarithm of delta. inp_height: int, image height. inp_width: int, image width. Returns: size: [B, T, 2] or [B, 2] or [2], patch size. """ size = tf.exp(lg_size) img_size = tf.to_float(tf.pack([inp_height, inp_width])) size *= img_size return size def get_normalized_size(size, inp_height, inp_width): """Get normalized patch size. Args: patch: [B, 2], patch size. inp_height: int, image height. inp_width: int, image width. patch_size: int patch size. Returns: lg_delta: [B, 2], logarithm of delta. """ img_size = tf.to_float(tf.pack([inp_height, inp_width])) lg_size = tf.log(size / img_size) return lg_size def get_unnormalized_attn(ctr, lg_size, inp_height, inp_width): """Unnormalize the attention parameters to image size.""" ctr = get_unnormalized_center(ctr, inp_height, inp_width) size = get_unnormalized_size(lg_size, inp_height, inp_width) return ctr, size def get_box_coord(ctr, size, truncate=True): """Get box coordinates given parameters.""" return ctr - size / 2.0, ctr + size / 2.0 def get_box_ctr_size(top_left, bot_right): return (top_left + bot_right) / 2.0, (bot_right - top_left)