""" Define some tool functions Written by Heng Fan """ import torch import numpy as np import cv2 def create_logisticloss_label(label_size, rPos, rNeg): """ construct label for logistic loss (same for all pairs) """ label_side = int(label_size[0]) logloss_label = torch.zeros(label_side, label_side) label_origin = np.array([np.ceil(label_side / 2), np.ceil(label_side / 2)]) for i in range(label_side): for j in range(label_side): dist_from_origin = np.sqrt((i - label_origin[0]) ** 2 + (j - label_origin[1]) ** 2) if dist_from_origin <= rPos: logloss_label[i, j] = +1 else: if dist_from_origin <= rNeg: logloss_label[i, j] = 0 return logloss_label def create_label(fixed_label_size, config, use_gpu): """ create label with weight """ rPos = config.rPos / config.stride rNeg = config.rNeg / config.stride half = int(np.floor(fixed_label_size[0] / 2) + 1) if config.label_weight_method == "balanced": fixed_label = create_logisticloss_label(fixed_label_size, rPos, rNeg) # plt.imshow(fixed_label) # plt.colorbar() # plt.show() instance_weight = torch.ones(fixed_label.shape[0], fixed_label.shape[1]) tmp_idx_P = np.where(fixed_label == 1) sumP = tmp_idx_P[0].size tmp_idx_N = np.where(fixed_label == 0) sumN = tmp_idx_N[0].size instance_weight[tmp_idx_P] = 0.5 * instance_weight[tmp_idx_P] / sumP instance_weight[tmp_idx_N] = 0.5 * instance_weight[tmp_idx_N] / sumN # plt.imshow(instance_weight) # plt.colorbar() # plt.show() # reshape label fixed_label = torch.reshape(fixed_label, (1, 1, fixed_label.shape[0], fixed_label.shape[1])) # copy label to match batchsize fixed_label = fixed_label.repeat(config.batch_size, 1, 1, 1) # reshape weight instance_weight = torch.reshape(instance_weight, (1, instance_weight.shape[0], instance_weight.shape[1])) if use_gpu: return fixed_label.cuda(), instance_weight.cuda() else: return fixed_label, instance_weight def cv2_brg2rgb(bgr_img): """ convert brg image to rgb """ b, g, r = cv2.split(bgr_img) rgb_img = cv2.merge([r, g, b]) return rgb_img def float32_to_uint8(img): """ convert float32 array to uint8 """ beyong_255 = np.where(img > 255) img[beyong_255] = 255 less_0 = np.where(img < 0) img[less_0] = 0 img = np.round(img) return img.astype(np.uint8)