""" CFUN The main CFUN model implementation. """ import time import math import os import re import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.utils.data from torch.autograd import Variable import skimage.transform import utils import backbone import mask_branch ############################################################ # Logging Utility Functions ############################################################ def log(text, array=None): """Prints a text message. And, optionally, if a Numpy array is provided it prints it's shape, min, and max values. """ if array is not None: text = text.ljust(25) text += ("shape: {:20} min: {:10.5f} max: {:10.5f}".format( str(array.shape), array.min() if array.size else "", array.max() if array.size else "")) print(text) def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill=''): """Call in a loop to create terminal progress bar @params: iteration - Required : current iteration (Int) total - Required : total iterations (Int) prefix - Optional : prefix string (Str) suffix - Optional : suffix string (Str) decimals - Optional : positive number of decimals in percent complete (Int) length - Optional : character length of bar (Int) fill - Optional : bar fill character (Str) """ percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) filled_length = int(length * iteration // total) bar = fill * filled_length + '-' * (length - filled_length) print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='\n') # Print New Line on Complete if iteration == total: print() ############################################################ # Pytorch Utility Functions ############################################################ def unique1d(tensor): if tensor.size()[0] == 0 or tensor.size()[0] == 1: return tensor tensor = tensor.sort()[0] unique_bool = tensor[1:] != tensor[:-1] first_element = Variable(torch.ByteTensor([True]), requires_grad=False) if tensor.is_cuda: first_element = first_element.cuda() unique_bool = torch.cat((first_element, unique_bool), dim=0) return tensor[unique_bool.detach()] def intersect1d(tensor1, tensor2): aux = torch.cat((tensor1, tensor2), dim=0) aux = aux.sort()[0] return aux[:-1][(aux[1:] == aux[:-1]).detach()] def log2(x): """Implementation of log2. Pytorch doesn't have a native implementation.""" ln2 = torch.log(torch.FloatTensor([2.0])) if x.is_cuda: ln2 = ln2.cuda() return torch.log(x) / ln2 def compute_backbone_shapes(config, image_shape): """Computes the depth, width and height of each stage of the backbone network. Returns: [N, (depth, height, width)]. Where N is the number of stages """ H, W, D = image_shape[:3] return np.array( [[int(math.ceil(D / stride)), int(math.ceil(H / stride)), int(math.ceil(W / stride))] for stride in config.BACKBONE_STRIDES]) ############################################################ # FPN Graph ############################################################ class FPN(nn.Module): def __init__(self, C1, C2, C3, out_channels, config): super(FPN, self).__init__() self.out_channels = out_channels self.C1 = C1 self.C2 = C2 self.C3 = C3 self.P3_conv1 = nn.Conv3d(config.BACKBONE_CHANNELS[1] * 4, self.out_channels, kernel_size=1, stride=1) self.P3_conv2 = nn.Conv3d(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=1) self.P2_conv1 = nn.Conv3d(config.BACKBONE_CHANNELS[0] * 4, self.out_channels, kernel_size=1, stride=1) self.P2_conv2 = nn.Conv3d(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = self.C1(x) x = self.C2(x) c2_out = x x = self.C3(x) c3_out = x p3_out = self.P3_conv1(c3_out) p2_out = self.P2_conv1(c2_out) + F.upsample(p3_out, scale_factor=2) p3_out = self.P3_conv2(p3_out) p2_out = self.P2_conv2(p2_out) return [p2_out, p3_out] ############################################################ # Proposal Layer ############################################################ def apply_box_deltas(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 6] where each row is z1, y1, x1, z2, y2, x2 deltas: [N, 6] where each row is [dz, dy, dx, log(dd), log(dh), log(dw)] """ # Convert to z, y, x, d, h, w depth = boxes[:, 3] - boxes[:, 0] height = boxes[:, 4] - boxes[:, 1] width = boxes[:, 5] - boxes[:, 2] center_z = boxes[:, 0] + 0.5 * depth center_y = boxes[:, 1] + 0.5 * height center_x = boxes[:, 2] + 0.5 * width # Apply deltas center_z += deltas[:, 0] * depth center_y += deltas[:, 1] * height center_x += deltas[:, 2] * width depth *= torch.exp(deltas[:, 3]) height *= torch.exp(deltas[:, 4]) width *= torch.exp(deltas[:, 5]) # Convert back to z1, y1, x1, z2, y2, x2 z1 = center_z - 0.5 * depth y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width z2 = z1 + depth y2 = y1 + height x2 = x1 + width result = torch.stack([z1, y1, x1, z2, y2, x2], dim=1) return result def clip_boxes(boxes, window): """boxes: [N, 6] each col is z1, y1, x1, z2, y2, x2 window: [6] in the form z1, y1, x1, z2, y2, x2 """ boxes = torch.stack( [boxes[:, 0].clamp(float(window[0]), float(window[3])), boxes[:, 1].clamp(float(window[1]), float(window[4])), boxes[:, 2].clamp(float(window[2]), float(window[5])), boxes[:, 3].clamp(float(window[0]), float(window[3])), boxes[:, 4].clamp(float(window[1]), float(window[4])), boxes[:, 5].clamp(float(window[2]), float(window[5]))], 1) return boxes def proposal_layer(inputs, proposal_count, nms_threshold, anchors, config=None): """Receives anchor scores and selects a subset to pass as proposals to the second stage. Filtering is done based on anchor scores and non-max suppression to remove overlaps. It also applies bounding box refinement deltas to anchors. Inputs: rpn_probs: [batch, anchors, (bg prob, fg prob)] rpn_bbox: [batch, anchors, (dz, dy, dx, log(dd), log(dh), log(dw))] Returns: Proposals in normalized coordinates [batch, rois, (z1, y1, x1, z2, y2, x2)] """ # Currently only supports batchsize 1 inputs[0] = inputs[0].squeeze(0) inputs[1] = inputs[1].squeeze(0) # Box Scores. Use the foreground class confidence. [Batch, num_rois, 1] scores = inputs[0][:, 1] # Box deltas [batch, num_rois, 6] deltas = inputs[1] std_dev = torch.from_numpy(np.reshape(config.RPN_BBOX_STD_DEV, [1, 6])).float() if config.GPU_COUNT: std_dev = std_dev.cuda() deltas = deltas * std_dev # Improve performance by trimming to top anchors by score # and doing the rest on the smaller subset. pre_nms_limit = min(config.PRE_NMS_LIMIT, anchors.size()[0]) scores, order = scores.sort(descending=True) order = order[:pre_nms_limit] scores = scores[:pre_nms_limit] deltas = deltas[order.detach(), :] anchors = anchors[order.detach(), :] # Apply deltas to anchors to get refined anchors. # [batch, N, (z1, y1, x1, z2, y2, x2)] boxes = apply_box_deltas(anchors, deltas) # Clip to image boundaries. [batch, N, (z1, y1, x1, z2, y2, x2)] height, width, depth = config.IMAGE_SHAPE[:3] window = np.array([0, 0, 0, depth, height, width]).astype(np.float32) boxes = clip_boxes(boxes, window) # Non-max suppression keep = utils.non_max_suppression(boxes.cpu().detach().numpy(), scores.cpu().detach().numpy(), nms_threshold, proposal_count) keep = torch.from_numpy(keep).long() boxes = boxes[keep, :] # Normalize dimensions to range of 0 to 1. norm = torch.from_numpy(np.array([depth, height, width, depth, height, width])).float() if config.GPU_COUNT: norm = norm.cuda() normalized_boxes = boxes / norm # Add back batch dimension normalized_boxes = normalized_boxes.unsqueeze(0) return normalized_boxes ############################################################ # ROIAlign Layer ############################################################ def RoI_Align(feature_map, pool_size, boxes): """Implementation of 3D RoI Align (actually it's just pooling rather than align). feature_map: [channels, depth, height, width]. Generated from FPN. pool_size: [D, H, W]. The shape of the output. boxes: [num_boxes, (z1, y1, x1, z2, y2, x2)]. """ boxes = utils.denorm_boxes_graph(boxes, (feature_map.size()[1], feature_map.size()[2], feature_map.size()[3])) boxes[:, 0] = boxes[:, 0].floor() boxes[:, 1] = boxes[:, 1].floor() boxes[:, 2] = boxes[:, 2].floor() boxes[:, 3] = boxes[:, 3].ceil() boxes[:, 4] = boxes[:, 4].ceil() boxes[:, 5] = boxes[:, 5].ceil() boxes = boxes.long() output = torch.zeros((boxes.size()[0], feature_map.size()[0], pool_size[0], pool_size[1], pool_size[2])).cuda() for i in range(boxes.size()[0]): try: output[i] = F.interpolate((feature_map[:, boxes[i][0]:boxes[i][3], boxes[i][1]:boxes[i][4], boxes[i][2]:boxes[i][5]]).unsqueeze(0), size=pool_size, mode='trilinear', align_corners=True).cuda() except: # print("RoI_Align error!") # print("box:", boxes[i], "feature_map size:", feature_map.size()) pass return output.cuda() def pyramid_roi_align(inputs, pool_size, test_flag=False): """Implements ROI Pooling on multiple levels of the feature pyramid. Params: - pool_size: [depth, height, width] of the output pooled regions. Usually [7, 7, 7] - image_shape: [height, width, depth, channels]. Shape of input image in pixels Inputs: - boxes: [batch, num_boxes, (z1, y1, x1, z2, y2, x2)] in normalized coordinates. - Feature maps: List of feature maps from different levels of the pyramid. Each is [batch, channels, depth, height, width] Output: Pooled regions in the shape: [num_boxes, channels, depth, height, width]. The width, height and depth are those specific in the pool_shape in the layer constructor. """ # Currently only supports batchsize 1 if test_flag: for i in range(0, len(inputs)): inputs[i] = inputs[i].squeeze(0) else: for i in range(1, len(inputs)): inputs[i] = inputs[i].squeeze(0) # Crop boxes [batch, num_boxes, (y1, x1, z1, y2, x2, z2)] in normalized coordinates boxes = inputs[0] # Feature Maps. List of feature maps from different level of the # feature pyramid. Each is [batch, channels, depth, height, width] feature_maps = inputs[1:] # Assign each ROI to a level in the pyramid based on the ROI volume. z1, y1, x1, z2, y2, x2 = boxes.chunk(6, dim=1) d = z2 - z1 h = y2 - y1 w = x2 - x1 # Equation 1 in the Feature Pyramid Networks paper. # Account for the fact that our coordinates are normalized here. # TODO: change the equation here roi_level = 4 + (1. / 3.) * log2(h * w * d) roi_level = roi_level.round().int() roi_level = roi_level.clamp(2, 3) # Loop through levels and apply ROI pooling to P2 or P3. pooled = [] box_to_level = [] for i, level in enumerate(range(2, 4)): ix = (roi_level == level) if not ix.any(): continue ix = torch.nonzero(ix)[:, 0] level_boxes = boxes[ix.detach(), :] # Keep track of which box is mapped to which level box_to_level.append(ix.detach()) # Stop gradient propagation to ROI proposals level_boxes = level_boxes.detach() # Crop and Resize # From Mask R-CNN paper: "We sample four regular locations, so that we can evaluate # either max or average pooling. In fact, interpolating only a single value at each bin center # (without pooling) is nearly as effective." # Here we use the simplified approach of a single value per bin. # Result: [batch * num_boxes, channels, pool_depth, pool_height, pool_width] pooled_features = RoI_Align(feature_maps[i], pool_size, level_boxes) pooled.append(pooled_features) # Pack pooled features into one tensor pooled = torch.cat(pooled, dim=0) # Pack box_to_level mapping into one array and add another # column representing the order of pooled boxes box_to_level = torch.cat(box_to_level, dim=0) # Rearrange pooled features to match the order of the original boxes _, box_to_level = torch.sort(box_to_level) pooled = pooled[box_to_level, :, :, :] return pooled ############################################################ # Detection Target Layer ############################################################ def bbox_overlaps(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (z1, y1, x1, z2, y2, x2)]. """ # 1. Tile boxes2 and repeat boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1, boxes1_repeat).view(-1, 6) boxes2 = boxes2.repeat(boxes2_repeat, 1) # 2. Compute intersections b1_z1, b1_y1, b1_x1, b1_z2, b1_y2, b1_x2 = boxes1.chunk(6, dim=1) b2_z1, b2_y1, b2_x1, b2_z2, b2_y2, b2_x2 = boxes2.chunk(6, dim=1) z1 = torch.max(b1_z1, b2_z1)[:, 0] y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] z2 = torch.min(b1_z2, b2_z2)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] zeros = Variable(torch.zeros(z1.size()[0]), requires_grad=False) if z1.is_cuda: zeros = zeros.cuda() intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros) # 3. Compute unions b1_volume = (b1_z2 - b1_z1) * (b1_y2 - b1_y1) * (b1_x2 - b1_x1) b2_volume = (b2_z2 - b2_z1) * (b2_y2 - b2_y1) * (b2_x2 - b2_x1) union = b1_volume[:, 0] + b2_volume[:, 0] - intersection # 4. Compute IoU and reshape to [boxes1, boxes2] iou = intersection / union overlaps = iou.view(boxes2_repeat, boxes1_repeat) return overlaps def detection_target_layer(proposals, gt_class_ids, gt_boxes, gt_masks, config): """Subsamples proposals and generates target box refinement, class_ids, and masks for each. Inputs: proposals: [batch, N, (z1, y1, x1, z2, y2, x2)] in normalized coordinates. Might be zero padded if there are not enough proposals. gt_class_ids: [batch, (0, 1, ..., num_classes - 1)] Integer class IDs. gt_boxes: [batch, num_classes - 1, (z1, y1, x1, z2, y2, x2)] in normalized coordinates. gt_masks: [batch, num_classes, depth, height, width] of np.int32 type Returns: Target ROIs and corresponding class IDs, bounding box shifts, and masks. rois: [batch, TRAIN_ROIS_PER_IMAGE, (z1, y1, x1, z2, y2, x2)] in normalized coordinates target_class_ids: [batch, TRAIN_ROIS_PER_IMAGE]. Integer class IDs. target_deltas: [batch, TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, (dz, dy, dx, log(dd), log(dh), log(dw), class_id)] Class-specific bbox refinements. target_mask: [batch, TRAIN_ROIS_PER_IMAGE, depth, height, width) Masks cropped to bbox boundaries and resized to neural network output size. """ # import pdb # Currently only supports batchsize 1 proposals = proposals.squeeze(0) gt_class_ids = gt_class_ids.squeeze(0) gt_boxes = gt_boxes.squeeze(0) gt_masks = gt_masks.squeeze(0) # Compute overlaps matrix [proposals, gt_boxes] overlaps = bbox_overlaps(proposals, gt_boxes) # pdb.set_trace() # Determine positive and negative ROIs roi_iou_max = torch.max(overlaps, dim=1)[0] # print("rpn_roi_iou_max:", roi_iou_max.max()) # 1. Positive ROIs are those with >= 0.5 IoU with a GT box positive_roi_bool = roi_iou_max >= config.DETECTION_TARGET_IOU_THRESHOLD # Subsample ROIs. Aim for 33% positive # Positive ROIs if torch.nonzero(positive_roi_bool).size()[0] != 0: positive_indices = torch.nonzero(positive_roi_bool)[:, 0] positive_count = int(round(config.TRAIN_ROIS_PER_IMAGE * config.ROI_POSITIVE_RATIO)) rand_idx = torch.randperm(positive_indices.size()[0]) rand_idx = rand_idx[:positive_count] if config.GPU_COUNT: rand_idx = rand_idx.cuda() positive_indices = positive_indices[rand_idx] positive_count = positive_indices.size()[0] positive_rois = proposals[positive_indices.detach(), :] # Assign positive ROIs to GT boxes. positive_overlaps = overlaps[positive_indices.detach(), :] roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1] roi_gt_boxes = gt_boxes[roi_gt_box_assignment.detach(), :] roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment.detach()] # Compute bbox refinement for positive ROIs deltas = Variable(utils.box_refinement(positive_rois.detach(), roi_gt_boxes.detach()), requires_grad=False) std_dev = torch.from_numpy(config.BBOX_STD_DEV).float() if config.GPU_COUNT: std_dev = std_dev.cuda() deltas /= std_dev # Assign positive ROIs to GT masks # Permute masks to [N, depth, height, width] # Pick the right mask for each ROI roi_gt_masks = np.zeros((positive_rois.shape[0], gt_masks.shape[0],) + config.MASK_SHAPE) for i in range(0, positive_rois.shape[0]): z1 = int(gt_masks.shape[1] * positive_rois[i, 0]) z2 = int(gt_masks.shape[1] * positive_rois[i, 3]) y1 = int(gt_masks.shape[2] * positive_rois[i, 1]) y2 = int(gt_masks.shape[2] * positive_rois[i, 4]) x1 = int(gt_masks.shape[3] * positive_rois[i, 2]) x2 = int(gt_masks.shape[3] * positive_rois[i, 5]) crop_mask = gt_masks[:, z1:z2, y1:y2, x1:x2].cpu().numpy() crop_mask = skimage.transform.resize(crop_mask, (gt_masks.shape[0],) + config.MASK_SHAPE, order=0, preserve_range=True, mode="constant", anti_aliasing=False) roi_gt_masks[i, :, :, :, :] = crop_mask roi_gt_masks = torch.from_numpy(roi_gt_masks).cuda() roi_gt_masks = roi_gt_masks.type(torch.DoubleTensor) masks = roi_gt_masks else: positive_count = 0 # 2. Negative ROIs are those with < 0.5 with every GT box. negative_roi_bool = roi_iou_max < config.DETECTION_TARGET_IOU_THRESHOLD negative_roi_bool = negative_roi_bool # Negative ROIs. Add enough to maintain positive:negative ratio. if torch.nonzero(negative_roi_bool).size()[0] != 0 and positive_count > 0: negative_indices = torch.nonzero(negative_roi_bool)[:, 0] r = 1.0 / config.ROI_POSITIVE_RATIO negative_count = int(round(r * positive_count - positive_count)) rand_idx = torch.randperm(negative_indices.size()[0]) rand_idx = rand_idx[:negative_count] if config.GPU_COUNT: rand_idx = rand_idx.cuda() negative_indices = negative_indices[rand_idx] negative_count = negative_indices.size()[0] negative_rois = proposals[negative_indices.detach(), :] else: negative_count = 0 # Append negative ROIs and pad bbox deltas and masks that # are not used for negative ROIs with zeros. if positive_count > 0 and negative_count > 0: rois = torch.cat((positive_rois, negative_rois), dim=0) zeros = Variable(torch.zeros(negative_count), requires_grad=False).long() if config.GPU_COUNT: zeros = zeros.cuda() roi_gt_class_ids = torch.cat([roi_gt_class_ids.long(), zeros], dim=0) zeros = Variable(torch.zeros(negative_count, 6), requires_grad=False) if config.GPU_COUNT: zeros = zeros.cuda() deltas = torch.cat([deltas, zeros], dim=0) zeros = Variable(torch.zeros((negative_count,) + config.MASK_SHAPE), requires_grad=False) if config.GPU_COUNT: zeros = zeros.cuda() masks = roi_gt_masks elif positive_count > 0: rois = positive_rois elif negative_count > 0: positive_rois = Variable(torch.FloatTensor(), requires_grad=False) rois = negative_rois zeros = Variable(torch.zeros(negative_count), requires_grad=False) if config.GPU_COUNT: zeros = zeros.cuda() positive_rois = positive_rois.cuda() roi_gt_class_ids = zeros zeros = Variable(torch.zeros(negative_count, 6), requires_grad=False).int() if config.GPU_COUNT: zeros = zeros.cuda() deltas = zeros zeros = Variable(torch.zeros((negative_count,) + config.MASK_SHAPE), requires_grad=False) if config.GPU_COUNT: zeros = zeros.cuda() masks = zeros else: positive_rois = Variable(torch.FloatTensor(), requires_grad=False) rois = Variable(torch.FloatTensor(), requires_grad=False) roi_gt_class_ids = Variable(torch.IntTensor(), requires_grad=False) deltas = Variable(torch.FloatTensor(), requires_grad=False) masks = Variable(torch.FloatTensor(), requires_grad=False) if config.GPU_COUNT: positive_rois = positive_rois.cuda() rois = rois.cuda() roi_gt_class_ids = roi_gt_class_ids.cuda() deltas = deltas.cuda() masks = masks.cuda() return positive_rois, rois, roi_gt_class_ids, deltas, masks ############################################################ # Detection Layer ############################################################ def clip_to_window(window, boxes): """window: (z1, y1, x1, z2, y2, x2). The window in the image we want to clip to. boxes: [N, (z1, y1, x1, z2, y2, x2)] """ boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[3])) boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[4])) boxes[:, 2] = boxes[:, 2].clamp(float(window[2]), float(window[5])) boxes[:, 3] = boxes[:, 3].clamp(float(window[0]), float(window[3])) boxes[:, 4] = boxes[:, 4].clamp(float(window[1]), float(window[4])) boxes[:, 5] = boxes[:, 5].clamp(float(window[2]), float(window[5])) return boxes def refine_detections(rois, probs, deltas, window, config): """Refine classified proposals and filter overlaps and return final detections. Inputs: rois: [N, (z1, y1, x1, z2, y2, x2)] in normalized coordinates probs: [N, num_classes]. Class probabilities. deltas: [N, num_classes, (dz, dy, dx, log(dd), log(dh), log(dw))]. Class-specific bounding box deltas. window: (z1, y1, x1, z2, y2, x2) in image coordinates. The part of the image that contains the image excluding the padding. Returns detections shaped: [N, (z1, y1, x1, z2, y2, x2, class_id, score)] """ # import pdb # Class IDs per ROI _, class_ids = torch.max(probs, dim=1) # pdb.set_trace() # Class probability of the top class of each ROI # Class-specific bounding box deltas idx = torch.arange(class_ids.size()[0]).long() if config.GPU_COUNT: idx = idx.cuda() class_scores = probs[idx, class_ids.detach()] deltas_specific = deltas[idx, class_ids.detach()] # pdb.set_trace() # Apply bounding box deltas # Shape: [boxes, (z1, y1, x1, z2, y2, x2)] in normalized coordinates std_dev = torch.from_numpy(np.reshape(config.RPN_BBOX_STD_DEV, [1, 6])).float() if config.GPU_COUNT: std_dev = std_dev.cuda() refined_rois = apply_box_deltas(rois, deltas_specific * std_dev) # pdb.set_trace() # Convert coordinates to image domain height, width, depth = config.IMAGE_SHAPE[:3] scale = torch.from_numpy(np.array([depth, height, width, depth, height, width])).float() if config.GPU_COUNT: scale = scale.cuda() refined_rois *= scale # pdb.set_trace() # Clip boxes to image window refined_rois = clip_to_window(window, refined_rois) # Round and cast to int since we're dealing with pixels now refined_rois = torch.round(refined_rois) # Filter out background boxes keep_bool = class_ids > 0 # Filter out low confidence boxes if config.DETECTION_MIN_CONFIDENCE: keep_bool = keep_bool & (class_scores >= config.DETECTION_MIN_CONFIDENCE) keep = torch.nonzero(keep_bool)[:, 0] # pdb.set_trace() # Apply per-class NMS pre_nms_class_ids = class_ids[keep.detach()] pre_nms_scores = class_scores[keep.detach()] pre_nms_rois = refined_rois[keep.detach()] # pdb.set_trace() for i, class_id in enumerate(unique1d(pre_nms_class_ids)): # Pick detections of this class ixs = torch.nonzero(pre_nms_class_ids == class_id)[:, 0] # Sort ix_rois = pre_nms_rois[ixs.detach()] ix_scores = pre_nms_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order.detach(), :] class_keep = utils.non_max_suppression(ix_rois.cpu().detach().numpy(), ix_scores.cpu().detach().numpy(), config.DETECTION_NMS_THRESHOLD, config.DETECTION_MAX_INSTANCES) class_keep = torch.from_numpy(class_keep).long() # pdb.set_trace() # Map indices class_keep = keep[ixs[order[class_keep].detach()].detach()] # pdb.set_trace() if i == 0: nms_keep = class_keep else: nms_keep = unique1d(torch.cat((nms_keep, class_keep))) keep = intersect1d(keep, nms_keep) # Keep top detections roi_count = config.DETECTION_MAX_INSTANCES roi_count = min(roi_count, keep.size()[0]) # pdb.set_trace() top_ids = class_scores[keep.detach()].sort(descending=True)[1][:roi_count] keep = keep[top_ids.detach()] # Arrange output as [N, (z1, y1, x1, z2, y2, x2, class_id, score)] # Coordinates are in image domain. result = torch.cat((refined_rois[keep.detach()], class_ids[keep.detach()].unsqueeze(1).float(), class_scores[keep.detach()].unsqueeze(1)), dim=1) return result def detection_layer(config, rois, mrcnn_class, mrcnn_bbox, image_meta): """Takes classified proposal boxes and their bounding box deltas and returns the final detection boxes. Returns: [batch, num_detections, (z1, y1, x1, z2, y2, x2, class_score)] in pixels """ # Currently only supports batchsize 1 rois = rois.squeeze(0) _, _, window, _ = parse_image_meta(image_meta) window = window[0] detections = refine_detections(rois, mrcnn_class, mrcnn_bbox, window, config) return detections ############################################################ # Region Proposal Network ############################################################ class RPN(nn.Module): """Builds the model of Region Proposal Network. anchors_per_location: number of anchors per pixel in the feature map anchor_stride: Controls the density of anchors. Typically 1 (anchors for every pixel in the feature map), or 2 (every other pixel). Returns: rpn_logits: [batch, D, H, W, 2] Anchor classifier logits (before softmax) rpn_probs: [batch, D, H, W, 2] Anchor classifier probabilities. rpn_bbox: [batch, D, H, W, (dz, dy, dx, log(dd), log(dh), log(dw))] Deltas to be applied to anchors. """ def __init__(self, anchors_per_location, anchor_stride, channel, conv_channel): super(RPN, self).__init__() self.conv_shared = nn.Conv3d(channel, conv_channel, kernel_size=3, stride=anchor_stride, padding=1) self.relu = nn.ReLU(inplace=True) self.conv_class = nn.Conv3d(conv_channel, 2 * anchors_per_location, kernel_size=1, stride=1) self.softmax = nn.Softmax(dim=2) self.conv_bbox = nn.Conv3d(conv_channel, 6 * anchors_per_location, kernel_size=1, stride=1) def forward(self, x): # Shared convolutional base of the RPN x = self.relu(self.conv_shared(x)) # Anchor Score. [batch, anchors per location * 2, depth, height, width]. rpn_class_logits = self.conv_class(x) # Reshape to [batch, anchors, 2] rpn_class_logits = rpn_class_logits.permute(0, 2, 3, 4, 1) rpn_class_logits = rpn_class_logits.contiguous() rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) # Softmax on last dimension of BG/FG. rpn_probs = self.softmax(rpn_class_logits) # Bounding box refinement. [batch, anchors per location * 6, D, H, W] # where 6 == delta [z, y, x, log(d), log(h), log(w)] rpn_bbox = self.conv_bbox(x) # Reshape to [batch, anchors, 6] rpn_bbox = rpn_bbox.permute(0, 2, 3, 4, 1) rpn_bbox = rpn_bbox.contiguous() rpn_bbox = rpn_bbox.view(x.size()[0], -1, 6) return [rpn_class_logits, rpn_probs, rpn_bbox] ############################################################ # Feature Pyramid Network Heads ############################################################ class Classifier(nn.Module): def __init__(self, channel, pool_size, image_shape, num_classes, fc_size, test_flag=False): super(Classifier, self).__init__() self.pool_size = pool_size self.image_shape = image_shape self.fc_size = fc_size self.test_flag = test_flag self.conv1 = nn.Conv3d(channel, fc_size, kernel_size=self.pool_size, stride=1) self.bn1 = nn.BatchNorm3d(fc_size, eps=0.001, momentum=0.01) self.conv2 = nn.Conv3d(fc_size, fc_size, kernel_size=1, stride=1) self.bn2 = nn.BatchNorm3d(fc_size, eps=0.001, momentum=0.01) self.relu = nn.ReLU(inplace=True) self.linear_class = nn.Linear(fc_size, num_classes) self.softmax = nn.Softmax(dim=1) self.linear_bbox = nn.Linear(fc_size, num_classes * 6) def forward(self, x, rois): x = pyramid_roi_align([rois] + x, self.pool_size, self.test_flag) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = x.view(-1, self.fc_size) mrcnn_class_logits = self.linear_class(x) mrcnn_probs = self.softmax(mrcnn_class_logits) mrcnn_bbox = self.linear_bbox(x) mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, 6) return [mrcnn_class_logits, mrcnn_probs, mrcnn_bbox] class Mask(nn.Module): def __init__(self, channel, pool_size, num_classes, conv_channel, stage, test_flag=False): super(Mask, self).__init__() self.pool_size = pool_size self.test_flag = test_flag self.modified_u_net = mask_branch.Modified3DUNet(channel, num_classes, stage, conv_channel) self.softmax = nn.Softmax(dim=1) def forward(self, x, rois): x = pyramid_roi_align([rois] + x, self.pool_size, self.test_flag) x = self.modified_u_net(x) output = self.softmax(x) return x, output ############################################################ # Loss Functions ############################################################ def compute_rpn_class_loss(rpn_match, rpn_class_logits): """RPN anchor classifier loss. rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive, -1=negative, 0=neutral anchor. rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for FG/BG. """ # Squeeze last dim to simplify rpn_match = rpn_match.squeeze(2) # Get anchor classes. Convert the -1/+1 match to 0/1 values. anchor_class = (rpn_match == 1).long() # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. indices = torch.nonzero(rpn_match != 0) # Pick rows that contribute to the loss and filter out the rest. rpn_class_logits = rpn_class_logits[indices.detach()[:, 0], indices.detach()[:, 1], :] anchor_class = anchor_class[indices.detach()[:, 0], indices.detach()[:, 1]] # print("rpn size:", rpn_class_logits.shape) # Cross-entropy loss loss = F.cross_entropy(rpn_class_logits, anchor_class) return loss def compute_rpn_bbox_loss(target_bbox, rpn_match, rpn_bbox): """Return the RPN bounding box loss graph. target_bbox: [batch, max positive anchors, (dz, dy, dx, log(dd), log(dh), log(dw))]. Uses 0 padding to fill in unused bbox deltas. rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive, -1=negative, 0=neutral anchor. rpn_bbox: [batch, anchors, (dz, dy, dx, log(dd), log(dh), log(dw))] """ # Squeeze last dim to simplify rpn_match = rpn_match.squeeze(2) # Positive anchors contribute to the loss, but negative and # neutral anchors (match value of 0 or -1) don't. indices = torch.nonzero(rpn_match == 1) # Pick bbox deltas that contribute to the loss rpn_bbox = rpn_bbox[indices.detach()[:, 0], indices.detach()[:, 1]] # Trim target bounding box deltas to the same length as rpn_bbox. target_bbox = target_bbox[0, :rpn_bbox.size()[0], :] # Smooth L1 loss loss = F.smooth_l1_loss(rpn_bbox, target_bbox) return loss def compute_mrcnn_class_loss(target_class_ids, pred_class_logits): """Loss for the classifier head of Mask RCNN. target_class_ids: [batch, num_rois]. Integer class IDs. Uses zero padding to fill in the array. pred_class_logits: [batch, num_rois, num_classes] """ # Loss if target_class_ids.size()[0] != 0: loss = F.cross_entropy(pred_class_logits, target_class_ids.long()) else: loss = Variable(torch.FloatTensor([0]), requires_grad=False) if target_class_ids.is_cuda: loss = loss.cuda() return loss def compute_mrcnn_bbox_loss(target_bbox, target_class_ids, pred_bbox): """Loss for Mask R-CNN bounding box refinement. target_bbox: [batch, num_rois, (dz, dy, dx, log(dd), log(dh), log(dw))] target_class_ids: [batch, num_rois]. Integer class IDs. pred_bbox: [batch, num_rois, num_classes, (dz, dy, dx, log(dd), log(dh), log(dw))] """ if target_class_ids.size()[0] != 0: # Only positive ROIs contribute to the loss. And only # the right class_id of each ROI. Get their indices. positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] positive_roi_class_ids = target_class_ids[positive_roi_ix.detach()].long() indices = torch.stack((positive_roi_ix, positive_roi_class_ids), dim=1) # Gather the deltas (predicted and true) that contribute to loss target_bbox = target_bbox[indices[:, 0].detach(), :] pred_bbox = pred_bbox[indices[:, 0].detach(), indices[:, 1].detach(), :] # Smooth L1 loss loss = F.smooth_l1_loss(pred_bbox, target_bbox) else: loss = Variable(torch.FloatTensor([0]), requires_grad=False) if target_class_ids.is_cuda: loss = loss.cuda() return loss def compute_mrcnn_mask_loss(target_masks, target_class_ids, pred_masks): """Mask binary cross-entropy loss for the masks head. target_masks: [batch, num_rois, depth, height, width]. A float32 tensor of values 0 or 1. Uses zero padding to fill array. target_class_ids: [batch, num_rois]. Integer class IDs. Zero padded. pred_masks: [batch, proposals, num_classes, depth, height, width] float32 tensor with values from 0 to 1. """ if target_class_ids.size()[0] != 0: # Only positive ROIs contribute to the loss. And only the class specific mask of each ROI. positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] positive_class_ids = target_class_ids[positive_ix.detach()].long() indices = torch.stack((positive_ix, positive_class_ids), dim=1) # Gather the masks (predicted and true) that contribute to loss y_true_ = target_masks[indices[:, 0], :, :, :] y_true = y_true_.long().cuda() y_true = torch.argmax(y_true, dim=1) y_pred = pred_masks[indices[:, 0].detach(), :, :, :, :] # Binary cross entropy loss_fn = nn.CrossEntropyLoss(weight=torch.FloatTensor([1., 1., 100.]).cuda()).cuda() loss = loss_fn(y_pred, y_true) else: loss = Variable(torch.FloatTensor([0]), requires_grad=False) if target_class_ids.is_cuda: loss = loss.cuda() return loss def compute_mrcnn_mask_edge_loss(target_masks, target_class_ids, pred_masks): """Mask edge mean square error loss for the Edge Agreement Head. Here I use the Sobel kernel without smoothing the ground_truth masks. target_masks: [batch, num_rois, depth, height, width]. target_class_ids: [batch, num_rois]. Integer class IDs. Zero padded. pred_masks: [batch, proposals, num_classes, depth, height, width] float32 tensor with values from 0 to 1. """ if target_class_ids.size()[0] != 0: # Generate the xyz dimension Sobel kernels kernel_x = np.array([[[1, 2, 1], [0, 0, 0], [-1, -2, -1]], [[2, 4, 2], [0, 0, 0], [-2, -4, -2]], [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]]) kernel_y = kernel_x.transpose((1, 0, 2)) kernel_z = kernel_x.transpose((0, 2, 1)) kernel = torch.from_numpy(np.array([kernel_x, kernel_y, kernel_z]).reshape((3, 1, 3, 3, 3))).float().cuda() # Only positive ROIs contribute to the loss. And only the class specific mask of each ROI. positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] positive_class_ids = target_class_ids[positive_ix.detach()].long() indices = torch.stack((positive_ix, positive_class_ids), dim=1) # Gather the masks (predicted and true) that contribute to loss y_true = target_masks[:indices.size()[0], 1:, :, :] y_pred = pred_masks[indices[:, 0].detach(), 1:, :, :, :] # Implement the edge detection convolution loss_fn = nn.MSELoss() loss = torch.FloatTensor([0]).cuda() for i in range(indices.size()[0]): # only compute the tumor's edge y_true_ = y_true[i] y_pred_ = y_pred[i].unsqueeze(0) # [N, 2, 64, 64, 64] for j in range(y_true_.shape[0]): y_true_final = F.conv3d(y_true_[j, :, :, :].unsqueeze(0).unsqueeze(0).cuda().float(), kernel) y_pred_final = F.conv3d(y_pred_[:, j, :, :, :].unsqueeze(1), kernel) # y_true_final = torch.sqrt(torch.pow(y_true_final[:, 0], 2) + torch.pow(y_true_final[:, 1], 2) + # torch.pow(y_true_final[:, 0], 2)) # y_pred_final = torch.sqrt(torch.pow(y_pred_final [:, 0], 2) + torch.pow(y_pred_final [:, 1], 2) + # torch.pow(y_pred_final [:, 0], 2)) # Mean Square Error loss += loss_fn(y_pred_final, y_true_final) # import pdb # pdb.set_trace() loss /= indices.size()[0] else: loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() return loss def compute_losses(rpn_match, rpn_bbox, rpn_class_logits, rpn_pred_bbox, target_class_ids, mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask, mrcnn_mask_logits, stage): if stage == "beginning": rpn_class_loss = compute_rpn_class_loss(rpn_match, rpn_class_logits) rpn_bbox_loss = compute_rpn_bbox_loss(rpn_bbox, rpn_match, rpn_pred_bbox) mrcnn_class_loss = compute_mrcnn_class_loss(torch.from_numpy(np.where(target_class_ids > 0, 1, 0)).cuda(), mrcnn_class_logits) mrcnn_bbox_loss = compute_mrcnn_bbox_loss(target_deltas, torch.from_numpy(np.where(target_class_ids > 0, 1, 0)).cuda(), mrcnn_bbox) mrcnn_mask_loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() mrcnn_mask_edge_loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() else: rpn_class_loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() rpn_bbox_loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() mrcnn_class_loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() mrcnn_bbox_loss = Variable(torch.FloatTensor([0]), requires_grad=False).cuda() mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, target_class_ids, mrcnn_mask_logits) mrcnn_mask_edge_loss = compute_mrcnn_mask_edge_loss(target_mask, target_class_ids, mrcnn_mask) return [rpn_class_loss, rpn_bbox_loss, mrcnn_class_loss, mrcnn_bbox_loss, mrcnn_mask_loss, mrcnn_mask_edge_loss] ############################################################ # Data Generator ############################################################ def load_image_gt(mask, config, anchors): """Generate the ground truth data for a mask. mask: [D, H, W] Returns: image: [1, D, H, W] class_ids: [instance_count] Integer class IDs bbox: [instance_count, (z1, y1, x1, z2, y2, x2)] mask: [num_classes, D, H, W] rpn_match: [batch, N] Integer (1=positive anchor, -1=negative, 0=neutral) rpn_bbox: [batch, N, (dz, dy, dx, log(dd), log(dh), log(dw))] Anchor bbox deltas """ # Bounding boxes: [num_instances, (z1, y1, x1, z2, y2, x2)] bbox = utils.extract_bboxes(mask) # we here use the whole liver + tumor as the gt-bbox bbox = utils.extend_bbox(bbox, mask.shape) # extend the gt_bbox with 5% ratio in each dimension bbox = np.tile(bbox, (config.NUM_CLASSES - 1, 1)) # [num_classes - 1, (z1, y1, x1, z2, y2, x2)] # RPN Targets rpn_match, rpn_bbox = build_rpn_targets(anchors, np.array([bbox[0]]), config) # Add to batch rpn_match = rpn_match[:, np.newaxis] return rpn_match, rpn_bbox, bbox def build_rpn_targets(anchors, gt_boxes, config): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (z1, y1, x1, z2, y2, x2)] gt_class_ids: [num_gt_boxes] Integer class IDs. gt_boxes: [num_gt_boxes, (z1, y1, x1, z2, y2, x2)] Returns: rpn_match: [N] (int32) matches between anchors and GT boxes. 1 = positive anchor, -1 = negative anchor, 0 = neutral rpn_bbox: [N, (dz, dy, dx, log(dd), log(dh), log(dw))] Anchor bbox deltas. """ # RPN Match: 1 = positive anchor, -1 = negative anchor, 0 = neutral rpn_match = np.zeros([anchors.shape[0]], dtype=np.int32) # RPN bounding boxes: [max anchors per image, (dz, dy, dx, log(dd), log(dh), log(dw))] rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 6)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = utils.compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= 0.7 then it's positive. # If an anchor overlaps a GT box with IoU < 0.3 then it's negative. # Neutral anchors are those that don't match the conditions above, and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). # Instead, match it to the closest anchor (even if its max IoU is < 0.3). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] rpn_match[anchor_iou_max < 0.3] = -1 # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) rpn_match[gt_iou_argmax] = 1 # 3. Set anchors with high overlap as positive. rpn_match[anchor_iou_max >= 0.7] = 1 # Subsample to balance positive and negative anchors # Don't let positives be more than half the anchors ids = np.where(rpn_match == 1)[0] extra = len(ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) rpn_match[ids] = 0 # Same for negative proposals ids = np.where(rpn_match == -1)[0] extra = len(ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE - np.sum(rpn_match == 1)) if extra > 0: # Rest the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) rpn_match[ids] = 0 # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(rpn_match == 1)[0] ix = 0 # index into rpn_bbox for i, a in zip(ids, anchors[ids]): # Closest gt box (it might have IoU < 0.7) gt = gt_boxes[anchor_iou_argmax[i]] # Convert coordinates to center plus width/height. # GT Box gt_d = gt[3] - gt[0] gt_h = gt[4] - gt[1] gt_w = gt[5] - gt[2] gt_center_z = gt[0] + 0.5 * gt_d gt_center_y = gt[1] + 0.5 * gt_h gt_center_x = gt[2] + 0.5 * gt_w # Anchor a_d = a[3] - a[0] a_h = a[4] - a[1] a_w = a[5] - a[2] a_center_z = a[0] + 0.5 * a_d a_center_y = a[1] + 0.5 * a_h a_center_x = a[2] + 0.5 * a_w # Compute the bbox refinement that the RPN should predict. rpn_bbox[ix] = [ (gt_center_z - a_center_z) / a_d, (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_d / a_d), np.log(gt_h / a_h), np.log(gt_w / a_w), ] # Normalize rpn_bbox[ix] /= config.RPN_BBOX_STD_DEV ix += 1 return rpn_match, rpn_bbox class Dataset(torch.utils.data.Dataset): def __init__(self, dataset, config, augmentation=False): """A data_generator that returns the image, mask and other ground truth for training.""" self.image_ids = np.copy(dataset.image_ids) self.dataset = dataset self.config = config self.augmentation = augmentation # Anchors # [anchor_count, (z1, y1, x1, z2, y2, x2)] self.anchors = utils.generate_pyramid_anchors(config.RPN_ANCHOR_SCALES, config.RPN_ANCHOR_RATIOS, compute_backbone_shapes(config, config.IMAGE_SHAPE), config.BACKBONE_STRIDES, config.RPN_ANCHOR_STRIDE) def __getitem__(self, image_index): image_id = self.image_ids[image_index] # Load image, which is [H, W, D] first. image = self.dataset.load_image(image_id) image = preprocess_image(image) # Load mask, which is [H, W, D] first. mask = self.dataset.load_mask(image_id) # np.int32 # Apply some augmentation image_shape = image.shape assert image_shape == mask.shape whole_image = np.zeros(self.config.PAD_IMAGE_SHAPE) whole_mask = np.zeros(self.config.PAD_IMAGE_SHAPE) start_x = int((whole_image.shape[0] - image_shape[0]) / 2.) start_y = int((whole_image.shape[1] - image_shape[1]) / 2.) start_z = int((whole_image.shape[2] - image_shape[2]) / 2.) whole_image[start_x:start_x + image_shape[0], start_y:start_y + image_shape[1], start_z:start_z + image_shape[2]] = image whole_mask[start_x:start_x + image_shape[0], start_y:start_y + image_shape[1], start_z:start_z + image_shape[2]] = mask image = whole_image mask = whole_mask del whole_image del whole_mask if self.augmentation: # apply image augmentation # randomly crop or pad the image """ choice = float(torch.rand(1)) if choice < 0.5: # crop the image box = utils.extract_bboxes(mask).astype(np.float32) min_x = box[0] / image_shape[0] max_x = (image_shape[0] - box[3]) / image_shape[0] min_y = box[1] / image_shape[1] max_y = (image_shape[1] - box[4]) / image_shape[1] min_z = box[2] / image_shape[2] max_z = (image_shape[2] - box[5]) / image_shape[2] ratio = min(min_x, max_x, min_y, max_y, min_z, max_z, self.config.CROP_PAD_RATIO) ratio = ratio * float(torch.rand(1)) dx = int(image_shape[0] * ratio) dy = int(image_shape[1] * ratio) dz = int(image_shape[2] * ratio) dx_1 = int(torch.randint(0, dx + 1, (1,))) dx_2 = dx - dx_1 dy_1 = int(torch.randint(0, dy + 1, (1,))) dy_2 = dy - dy_1 dz_1 = int(torch.randint(0, dz + 1, (1,))) dz_2 = dz - dz_1 image = image[dx_1:image_shape[0] - dx_2, dy_1:image_shape[1] - dy_2, dz_1:image_shape[2] - dz_2] mask = mask[dx_1:image_shape[0] - dx_2, dy_1:image_shape[1] - dy_2, dz_1:image_shape[2] - dz_2] else: # pad the image ratio = self.config.CROP_PAD_RATIO * float(torch.rand(1)) dx = int(image_shape[0] * ratio) dy = int(image_shape[1] * ratio) dz = int(image_shape[2] * ratio) dx_1 = int(torch.randint(0, dx + 1, (1,))) dx_2 = dx - dx_1 dy_1 = int(torch.randint(0, dy + 1, (1,))) dy_2 = dy - dy_1 dz_1 = int(torch.randint(0, dz + 1, (1,))) dz_2 = dz - dz_1 small_image = skimage.transform.resize(image, (image_shape[0] - dx, image_shape[1] - dy, image_shape[2] - dz), order=0, preserve_range=True, mode='constant', anti_aliasing=False) small_mask = skimage.transform.resize(mask, (image_shape[0] - dx, image_shape[1] - dy, image_shape[2] - dz), order=0, preserve_range=True, mode='constant', anti_aliasing=False) image = np.zeros(image_shape) image[dx_1:image_shape[0] - dx_2, dy_1:image_shape[1] - dy_2, dz_1:image_shape[2] - dz_2] = small_image mask = np.zeros(image_shape, dtype=np.int32) mask[dx_1:image_shape[0] - dx_2, dy_1:image_shape[1] - dy_2, dz_1:image_shape[2] - dz_2] = small_mask """ # randomly rotate between -30 to 30 degree angle = int(torch.randint(self.config.ROTATE_ANGLE[0], self.config.ROTATE_ANGLE[1], (1,))) image = skimage.transform.rotate(image, angle=angle, order=0, mode='constant', preserve_range=True) # [H, W, D] mask = skimage.transform.rotate(mask, angle=angle, order=0, mode='constant', preserve_range=True) # [H, W, D] # resize the image and mask to standard input shape image = skimage.transform.resize(image, self.config.IMAGE_SHAPE[:3], order=0, preserve_range=True, mode='constant', anti_aliasing=False) mask = skimage.transform.resize(mask, self.config.IMAGE_SHAPE[:3], order=0, preserve_range=True, mode='constant', anti_aliasing=False) else: image = skimage.transform.resize(image, self.config.IMAGE_SHAPE[:3], order=0, preserve_range=True, mode='constant', anti_aliasing=False) mask = skimage.transform.resize(mask, self.config.IMAGE_SHAPE[:3], order=0, preserve_range=True, mode='constant', anti_aliasing=False) # Note that window has already been (z1, y1, x1, z2, y2, x2) here. window = (start_z * self.config.IMAGE_SHAPE[2] / self.config.PAD_IMAGE_SHAPE[2], start_x * self.config.IMAGE_SHAPE[0] / self.config.PAD_IMAGE_SHAPE[0], start_y * self.config.IMAGE_SHAPE[1] / self.config.PAD_IMAGE_SHAPE[1], self.config.IMAGE_MIN_DIM - start_z * self.config.IMAGE_SHAPE[2] / self.config.PAD_IMAGE_SHAPE[2], self.config.IMAGE_MAX_DIM - start_x * self.config.IMAGE_SHAPE[0] / self.config.PAD_IMAGE_SHAPE[0], self.config.IMAGE_MAX_DIM - start_y * self.config.IMAGE_SHAPE[1] / self.config.PAD_IMAGE_SHAPE[1]) # Active classes # Different datasets have different classes, so track the classes supported in the dataset of this image. active_class_ids = np.zeros([self.dataset.num_classes], dtype=np.int32) source_class_ids = self.dataset.source_class_ids[self.dataset.image_info[image_id]["source"]] active_class_ids[source_class_ids] = 1 # Image meta data image_meta = compose_image_meta(image_id, image.shape, window, active_class_ids) # Generate different ground truth for the image image = np.expand_dims(image.transpose((2, 0, 1)), axis=0) # [C, D, H, W] mask = mask.transpose((2, 0, 1)) # [D, H, W] rpn_match, rpn_bbox, bbox = load_image_gt(mask, self.config, self.anchors) # Get the instance-specific masks and class_ids. masks, class_ids = self.dataset.process_mask(mask) return image, image_meta, rpn_match, rpn_bbox, class_ids, bbox, masks def __len__(self): return self.image_ids.shape[0] ############################################################ # MaskRCNN Class ############################################################ class MaskRCNN(nn.Module): """Encapsulates the 3D-Mask-RCNN model functionality.""" def __init__(self, config, model_dir, test_flag=False): """config: A Sub-class of the Config class model_dir: Directory to save training logs and trained weights """ super(MaskRCNN, self).__init__() self.epoch = 0 self.config = config self.model_dir = model_dir self.build(config=config, test_flag=test_flag) self.initialize_weights() def build(self, config, test_flag=False): """Build 3D-Mask-RCNN architecture.""" # Image size must be dividable by 2 multiple times h, w, d = config.IMAGE_SHAPE[:3] if h / 16 != int(h / 16) or w / 16 != int(w / 16) or d / 16 != int(d / 16): raise Exception("Image size must be dividable by 16. Use 256, 320, 512, ... etc.") # Build the shared convolutional layers. # Returns a list of the last layers of each stage, 3 in total. if self.config.BACKBONE == "P3D19": print("using P3D19 as backbone") P3D_Resnet = backbone.P3D19(config=config) else: print("using P3D35 as backbone") P3D_Resnet = backbone.P3D35(config=config) C1, C2, C3 = P3D_Resnet.stages() # Top-down Layers self.fpn = FPN(C1, C2, C3, out_channels=config.TOP_DOWN_PYRAMID_SIZE, config=config) # Generate Anchors self.anchors = Variable(torch.from_numpy(utils.generate_pyramid_anchors(config.RPN_ANCHOR_SCALES, config.RPN_ANCHOR_RATIOS, compute_backbone_shapes( config, config.IMAGE_SHAPE), config.BACKBONE_STRIDES, config.RPN_ANCHOR_STRIDE)).float(), requires_grad=False) if self.config.GPU_COUNT: self.anchors = self.anchors.cuda() # RPN self.rpn = RPN(len(config.RPN_ANCHOR_RATIOS), config.RPN_ANCHOR_STRIDE, config.TOP_DOWN_PYRAMID_SIZE, config.RPN_CONV_CHANNELS) if self.config.STAGE != 'beginning': for p in self.parameters(): p.requires_grad = False # FPN Classifier self.classifier = Classifier(config.TOP_DOWN_PYRAMID_SIZE, config.POOL_SIZE, config.IMAGE_SHAPE, 2, config.FPN_CLASSIFY_FC_LAYERS_SIZE, test_flag) # FPN Mask self.mask = Mask(1, config.MASK_POOL_SIZE, config.NUM_CLASSES, config.UNET_MASK_BRANCH_CHANNEL, self.config.STAGE, test_flag) if test_flag: for p in self.parameters(): p.requires_grad = False # Fix batch norm layers def set_bn_fix(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: for p in m.parameters(): p.requires_grad = False if not config.TRAIN_BN: self.apply(set_bn_fix) def initialize_weights(self): """Initialize model weights.""" for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) m.bias.data.zero_() def set_trainable(self, layer_regex): """Sets model layers as trainable if their names match the given regular expression.""" for param in self.named_parameters(): layer_name = param[0] trainable = bool(re.fullmatch(layer_regex, layer_name)) if not trainable: param[1].requires_grad = False def load_weights(self, file_path): """Modified version of the corresponding Keras function with the addition of multi-GPU support and the ability to exclude some layers from loading. exclude: list of layer names to exclude """ if os.path.exists(file_path): pretrained_dict = torch.load(file_path) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict, strict=True) print("Pre-trained weights load success!") else: print("Weight file not found ...") def detect(self, images): """Runs the detection pipeline. images: List of images, potentially of different sizes. [1, height, width, depth] Returns a list of dicts, one dict per image. The dict contains: rois: [N, (y1, x1, z1, y2, x2, z2)] detection bounding boxes class_ids: [N] int class IDs scores: [N] float probability scores for the class IDs masks: [H, W, D, N] instance binary masks Transform all outputs from pytorch shape to normal shape here. """ # Mold inputs to format expected by the neural network # Has been transformed to pytorch shapes. start_time = time.time() molded_images, image_metas, windows = self.mold_inputs(images) # Convert images to torch tensor molded_images = torch.from_numpy(molded_images).float() # To GPU if self.config.GPU_COUNT: molded_images = molded_images.cuda() # Run object detection detections, mrcnn_mask = self.predict([molded_images, image_metas], mode='inference') # Convert to numpy detections = detections.detach().cpu().numpy() mrcnn_mask = mrcnn_mask.permute(0, 1, 3, 4, 5, 2).detach().cpu().numpy() print("detect done, using time", time.time() - start_time) # import pdb # pdb.set_trace() # Process detections results = [] for i, image in enumerate(images): final_rois, final_class_ids, final_scores, final_mask = \ self.unmold_detections(detections[i], mrcnn_mask[i], [1, image.shape[2], image.shape[0], image.shape[1]], windows[i]) results.append({ "rois": final_rois, "class_ids": final_class_ids, "scores": final_scores, "mask": final_mask, }) return results def predict(self, inputs, mode): molded_images = inputs[0] image_metas = inputs[1] if mode == 'inference': self.eval() elif mode == 'training': self.train() # Set batchnorm always in eval mode during training def set_bn_eval(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() self.apply(set_bn_eval) # Feature extraction p2_out, p3_out = self.fpn(molded_images) rpn_feature_maps = [p2_out, p3_out] mrcnn_classifier_feature_maps = [p2_out, p3_out] mrcnn_mask_feature_maps = [molded_images, molded_images] # Loop through pyramid layers layer_outputs = [] # list of lists for p in rpn_feature_maps: layer_outputs.append(self.rpn(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] outputs = list(zip(*layer_outputs)) outputs = [torch.cat(list(o), dim=1) for o in outputs] rpn_class_logits, rpn_class, rpn_bbox = outputs # Generate proposals # Proposals are [batch, N, (z1, y1, x1, z2, y2, x2)] in normalized coordinates # and zero padded. proposal_count = self.config.POST_NMS_ROIS_TRAINING if mode == "training" \ else self.config.POST_NMS_ROIS_INFERENCE rpn_rois = proposal_layer([rpn_class, rpn_bbox], proposal_count=proposal_count, nms_threshold=self.config.RPN_NMS_THRESHOLD, anchors=self.anchors, config=self.config) if mode == 'inference': # Network Heads # Proposal classifier and bbox regressor heads mrcnn_class_logits, mrcnn_class, mrcnn_bbox = self.classifier(mrcnn_classifier_feature_maps, rpn_rois) # Detections # output is [batch, num_detections, (z1, y1, x1, z2, y2, x2, class_id, score)] in image coordinates detections = detection_layer(self.config, rpn_rois, mrcnn_class, mrcnn_bbox, image_metas) # Convert boxes to normalized coordinates h, w, d = self.config.IMAGE_SHAPE[:3] scale = torch.from_numpy(np.array([d, h, w, d, h, w])).float() if self.config.GPU_COUNT: scale = scale.cuda() detection_boxes = detections[:, :6] / scale # Add back batch dimension detection_boxes = detection_boxes.unsqueeze(0) # Create masks for detections if self.config.STAGE != 'beginning': _, mrcnn_mask = self.mask(mrcnn_mask_feature_maps, detection_boxes) mrcnn_mask = mrcnn_mask.unsqueeze(0) else: mrcnn_mask = torch.zeros((1, detection_boxes.shape[1], 3,) + self.config.MINI_MASK_SHAPE).cuda() # Add back batch dimension detections = detections.unsqueeze(0) return [detections, mrcnn_mask] elif mode == 'training': gt_class_ids = inputs[2] # [1, 2, ..., num_classes - 1] gt_boxes = inputs[3] # [num_classes - 1, (z1, y1, x1, z2, y2, x2)] gt_masks = inputs[4] # multi_classes masks [num_classes, D, H, W] # Normalize coordinates h, w, d = self.config.IMAGE_SHAPE[:3] scale = torch.from_numpy(np.array([d, h, w, d, h, w])).float() if self.config.GPU_COUNT: scale = scale.cuda() gt_boxes = gt_boxes / scale # Generate detection targets # Subsamples proposals and generates target outputs for training # import pdb # pdb.set_trace() p_rois, rois, target_class_ids, target_deltas, target_mask = \ detection_target_layer(rpn_rois, gt_class_ids, gt_boxes, gt_masks, self.config) # print("rois size:", rois.shape, "p_rois size:", p_rois.shape) if rois.size()[0] == 0: mrcnn_class_logits = Variable(torch.FloatTensor()) mrcnn_class = Variable(torch.IntTensor()) mrcnn_bbox = Variable(torch.FloatTensor()) mrcnn_mask = Variable(torch.FloatTensor()) mrcnn_mask_logits = Variable(torch.FloatTensor()) if self.config.GPU_COUNT: mrcnn_class_logits = mrcnn_class_logits.cuda() mrcnn_class = mrcnn_class.cuda() mrcnn_bbox = mrcnn_bbox.cuda() mrcnn_mask = mrcnn_mask.cuda() mrcnn_mask_logits = mrcnn_mask_logits.cuda() elif p_rois.size()[0] == 0 or self.config.STAGE == 'beginning': # Network Heads # Proposal classifier and BBox regressor heads mrcnn_class_logits, mrcnn_class, mrcnn_bbox = self.classifier(mrcnn_classifier_feature_maps, rois) mrcnn_mask = Variable(torch.FloatTensor()) mrcnn_mask_logits = Variable(torch.FloatTensor()) if self.config.GPU_COUNT: mrcnn_mask = mrcnn_mask.cuda() mrcnn_mask_logits = mrcnn_mask_logits.cuda() else: # only train the mask branch # Network Heads # Proposal classifier and BBox regressor heads mrcnn_class_logits = Variable(torch.FloatTensor()) mrcnn_class = Variable(torch.IntTensor()) mrcnn_bbox = Variable(torch.FloatTensor()) if self.config.GPU_COUNT: mrcnn_class_logits = mrcnn_class_logits.cuda() mrcnn_class = mrcnn_class.cuda() mrcnn_bbox = mrcnn_bbox.cuda() # Create masks for detections mrcnn_mask_logits, mrcnn_mask = self.mask(mrcnn_mask_feature_maps, p_rois) return [rpn_class_logits, rpn_bbox, target_class_ids, mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask, mrcnn_mask_logits] def train_model(self, train_dataset, val_dataset, learning_rate, epochs): """Train the model. train_dataset, val_dataset: Training and validation Dataset objects. learning_rate: The learning rate to train with epochs: Number of training epochs. Note that previous training epochs are considered to be done already, so this actually determines the epochs to train in total rather than in this particular call. """ layers = ".*" # set all the layers trainable # Data generators train_set = Dataset(train_dataset, self.config, augmentation=self.config.AUGMENTATION) train_generator = torch.utils.data.DataLoader(train_set, batch_size=self.config.BATCH_SIZE, shuffle=self.config.SHUFFLE_DATASET, num_workers=self.config.TRAIN_NUM_WORKERS) val_set = Dataset(val_dataset, self.config, augmentation=False) val_generator = torch.utils.data.DataLoader(val_set, batch_size=self.config.BATCH_SIZE, shuffle=self.config.SHUFFLE_DATASET, num_workers=self.config.VAL_NUM_WORKERS) # Train self.set_trainable(layers) # Optimizer object # Add L2 Regularization # Skip gamma and beta weights of batch normalization layers. trainables_wo_bn = [param for name, param in self.named_parameters() if param.requires_grad and 'bn' not in name] trainables_only_bn = [param for name, param in self.named_parameters() if param.requires_grad and 'bn' in name] optimizer = optim.SGD([ {'params': trainables_wo_bn, 'weight_decay': self.config.WEIGHT_DECAY}, {'params': trainables_only_bn} ], lr=learning_rate, momentum=self.config.LEARNING_MOMENTUM) total_start_time = time.time() start_datetime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) if not os.path.exists("./logs/" + str(start_datetime)): os.makedirs("./logs/" + str(start_datetime)) for epoch in range(self.epoch + 1, epochs + 1): log("Epoch {}/{}.".format(epoch, epochs)) start_time = time.time() # Training loss, loss_rpn_class, loss_rpn_bbox, \ loss_mrcnn_class, loss_mrcnn_bbox, loss_mrcnn_mask, loss_mrcnn_mask_edge = \ self.one_epoch(train_generator, optimizer, self.config.STEPS_PER_EPOCH) print("One Training Epoch time:", int(time.time() - start_time), "Total time:", int(time.time() - total_start_time)) torch.cuda.empty_cache() if epoch % self.config.SAVE_EPOCH == 0: # Validation val_loss, val_loss_rpn_class, val_loss_rpn_bbox, \ val_loss_mrcnn_class, val_loss_mrcnn_bbox, val_loss_mrcnn_mask, val_loss_mrcnn_mask_edge = \ self.one_epoch(val_generator, None, self.config.VALIDATION_STEPS) torch.save(self.state_dict(), "./logs/" + str(start_datetime) + "/model" + str(epoch) + "_loss: " + str(round(loss, 4)) + "_val: " + str(round(val_loss, 4))) torch.cuda.empty_cache() self.epoch = epochs def one_epoch(self, datagenerator, optimizer, steps): batch_count = 0 loss_sum = 0 loss_rpn_class_sum = 0 loss_rpn_bbox_sum = 0 loss_mrcnn_class_sum = 0 loss_mrcnn_bbox_sum = 0 loss_mrcnn_mask_sum = 0 loss_mrcnn_mask_edge_sum = 0 step = 0 if optimizer is not None: optimizer.zero_grad() for inputs in datagenerator: batch_count += 1 images = inputs[0] # [batch, C, D, H, W] image_metas = inputs[1] rpn_match = inputs[2] rpn_bbox = inputs[3] gt_class_ids = inputs[4] gt_boxes = inputs[5] # [batch, num_classes - 1, (z1, y1, x1, z2, y2, x2)] gt_masks = inputs[6] # [batch, num_classes, D, H, W] # image_metas as numpy array image_metas = image_metas.numpy() # To GPU if self.config.GPU_COUNT: images = images.cuda().float() rpn_match = rpn_match.cuda() rpn_bbox = rpn_bbox.cuda().float() gt_class_ids = gt_class_ids.cuda() gt_boxes = gt_boxes.cuda().float() gt_masks = gt_masks.cuda().float() # import pdb # pdb.set_trace() # Run object detection rpn_class_logits, rpn_pred_bbox, target_class_ids, \ mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask, mrcnn_mask_logits = \ self.predict([images, image_metas, gt_class_ids, gt_boxes, gt_masks], mode='training') # Compute losses rpn_class_loss, rpn_bbox_loss, \ mrcnn_class_loss, mrcnn_bbox_loss, mrcnn_mask_loss, mrcnn_mask_edge_loss = \ compute_losses(rpn_match, rpn_bbox, rpn_class_logits, rpn_pred_bbox, target_class_ids, mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask, mrcnn_mask_logits, self.config.STAGE) loss = self.config.LOSS_WEIGHTS["rpn_class_loss"] * rpn_class_loss + \ self.config.LOSS_WEIGHTS["rpn_bbox_loss"] * rpn_bbox_loss + \ self.config.LOSS_WEIGHTS["mrcnn_class_loss"] * mrcnn_class_loss + \ self.config.LOSS_WEIGHTS["mrcnn_bbox_loss"] * mrcnn_bbox_loss + \ self.config.LOSS_WEIGHTS["mrcnn_mask_loss"] * mrcnn_mask_loss + \ self.config.LOSS_WEIGHTS["mrcnn_mask_edge_loss"] * mrcnn_mask_edge_loss # Back propagation if optimizer is not None: try: loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), 5.0) if (batch_count % self.config.BATCH_SIZE) == 0: optimizer.step() optimizer.zero_grad() batch_count = 0 except: # print("backward error! loss is", loss) optimizer.zero_grad() # Progress print_progress_bar(step + 1, steps, prefix="\t{}/{}".format(step + 1, steps), suffix="Complete - loss: {:.5f} - rpn_class_loss: {:.5f} - rpn_bbox_loss: {:.5f}" " - mrcnn_class_loss: {:.5f} - mrcnn_bbox_loss: {:.5f} - mrcnn_mask_loss: {:.5f}" " - mrcnn_mask_edge_loss: {:.5f}" .format(loss.detach().cpu().item(), self.config.LOSS_WEIGHTS["rpn_class_loss"] * rpn_class_loss.detach().cpu().item(), self.config.LOSS_WEIGHTS["rpn_bbox_loss"] * rpn_bbox_loss.detach().cpu().item(), self.config.LOSS_WEIGHTS["mrcnn_class_loss"] * mrcnn_class_loss.detach().cpu().item(), self.config.LOSS_WEIGHTS["mrcnn_bbox_loss"] * mrcnn_bbox_loss.detach().cpu().item(), self.config.LOSS_WEIGHTS["mrcnn_mask_loss"] * mrcnn_mask_loss.detach().cpu().item(), self.config.LOSS_WEIGHTS["mrcnn_mask_edge_loss"] * mrcnn_mask_edge_loss.detach().cpu().item()), length=50) # Statistics loss_sum += loss.detach().cpu().item() / steps loss_rpn_class_sum += rpn_class_loss.detach().cpu().item() / steps loss_rpn_bbox_sum += rpn_bbox_loss.detach().cpu().item() / steps loss_mrcnn_class_sum += mrcnn_class_loss.detach().cpu().item() / steps loss_mrcnn_bbox_sum += mrcnn_bbox_loss.detach().cpu().item() / steps loss_mrcnn_mask_sum += mrcnn_mask_loss.detach().cpu().item() / steps loss_mrcnn_mask_edge_sum += mrcnn_mask_edge_loss.detach().cpu().item() / steps # Break after 'steps' steps if step == steps - 1: break step += 1 if step % 20 == 0: torch.cuda.empty_cache() return loss_sum, loss_rpn_class_sum, loss_rpn_bbox_sum, \ loss_mrcnn_class_sum, loss_mrcnn_bbox_sum, loss_mrcnn_mask_sum, loss_mrcnn_mask_edge_sum def mold_inputs(self, images): """Takes a list of images and modifies them to the format expected as an input to the neural network. images: List of image matrices [height, width, depth]. Images can have different sizes. Returns 3 Numpy matrices: molded_images: [N, c, d, h, w]. Images resized and normalized. image_metas: [N, length of meta data]. Details about each image. windows: [N, (z1, y1, x1, z2, y2, x2)]. The portion of the image that has the original image (padding excluded). """ molded_images = [] image_metas = [] windows = [] for image in images: image = preprocess_image(image) image_shape = image.shape whole_image = np.zeros(self.config.PAD_IMAGE_SHAPE) start_x = int((whole_image.shape[0] - image_shape[0]) / 2.) start_y = int((whole_image.shape[1] - image_shape[1]) / 2.) start_z = int((whole_image.shape[2] - image_shape[2]) / 2.) whole_image[start_x:start_x + image_shape[0], start_y:start_y + image_shape[1], start_z:start_z + image_shape[2]] = image image = whole_image del whole_image image = skimage.transform.resize(image, self.config.IMAGE_SHAPE[:3], order=0, preserve_range=True, mode='constant', anti_aliasing=False) # Note that window has already been (z1, y1, x1, z2, y2, x2) here. window = (start_z * self.config.IMAGE_SHAPE[2] / self.config.PAD_IMAGE_SHAPE[2], start_x * self.config.IMAGE_SHAPE[0] / self.config.PAD_IMAGE_SHAPE[0], start_y * self.config.IMAGE_SHAPE[1] / self.config.PAD_IMAGE_SHAPE[1], self.config.IMAGE_MIN_DIM - start_z * self.config.IMAGE_SHAPE[2] / self.config.PAD_IMAGE_SHAPE[2], self.config.IMAGE_MAX_DIM - start_x * self.config.IMAGE_SHAPE[0] / self.config.PAD_IMAGE_SHAPE[0], self.config.IMAGE_MAX_DIM - start_y * self.config.IMAGE_SHAPE[1] / self.config.PAD_IMAGE_SHAPE[1]) molded_image = np.expand_dims(image.transpose((2, 0, 1)), axis=0) # [C, D, H, W] # Build image_meta image_meta = compose_image_meta( 0, image.shape, window, np.zeros([self.config.NUM_CLASSES], dtype=np.int32)) # Append molded_images.append(molded_image) windows.append(window) image_metas.append(image_meta) # Pack into arrays molded_images = np.stack(molded_images) image_metas = np.stack(image_metas) windows = np.stack(windows) return molded_images, image_metas, windows def unmold_detections(self, detections, mrcnn_mask, image_shape, window): """Reformat the detections of one image from the format of the neural network output to a format suitable for use in the rest of the application. detections: [N, (z1, y1, x1, z2, y2, x2, class_id, score)] mrcnn_mask: [N, depth, height, width, num_classes] image_shape: [channels, depth, height, width] Original size of the image before resizing window: [z1, y1, x1, z2, y2, x2] Box in the image where the real image is excluding the padding. Returns: boxes: [N, (y1, x1, z1, y2, x2, z2)] Bounding boxes in pixels class_ids: [N] Integer class IDs for each bounding box scores: [N] Float probability scores of the class_id masks: [height, width, depth] normal shape full mask """ # import pdb start_time = time.time() # How many detections do we have? # Detections array is padded with zeros. Find the first class_id == 0. zero_ix = np.where(detections[:, 6] == 0)[0] N = zero_ix[0] if zero_ix.shape[0] > 0 else detections.shape[0] # Extract boxes, class_ids, scores, and class-specific masks boxes = detections[:N, :6].astype(np.int32) class_ids = detections[:N, 6].astype(np.int32) scores = detections[:N, 7] masks = mrcnn_mask[np.arange(N), :, :, :, :] # Compute scale and shift to translate the bounding boxes to image domain. d_scale = image_shape[1] / (window[3] - window[0]) h_scale = image_shape[2] / (window[4] - window[1]) w_scale = image_shape[3] / (window[5] - window[2]) shift = window[:3] # z, y, x scales = np.array([d_scale, h_scale, w_scale, d_scale, h_scale, w_scale]) shifts = np.array([shift[0], shift[1], shift[2], shift[0], shift[1], shift[2]]) # pdb.set_trace() boxes = np.multiply(boxes - shifts, scales).astype(np.int32) # pdb.set_trace() # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. exclude_ix = np.where( (boxes[:, 3] - boxes[:, 0]) * (boxes[:, 4] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 2]) <= 0 )[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) masks = np.delete(masks, exclude_ix, axis=0) # Resize masks to original image size. # box: [N, (z1, y1, x1, z2, y2, x2)] in image coordinates # mask: [N, depth, height, width, num_instances] full_masks = utils.unmold_mask(masks, boxes, image_shape) full_mask = np.argmax(full_masks, axis=3) # Transform the shapes of boxes to normal shape. boxes[:, [0, 1, 2, 3, 4, 5]] = boxes[:, [1, 2, 0, 4, 5, 3]] print("unmold done, using time", time.time() - start_time) return boxes, np.arange(1, 3), scores, full_mask.transpose((1, 2, 0)) ############################################################ # Data Formatting ############################################################ def compose_image_meta(image_id, image_shape, window, active_class_ids): """Takes attributes of an image and puts them in one 1D array. Use parse_image_meta() to parse the values back. image_id: An int ID of the image. Useful for debugging. image_shape: [channels, depth, height, width] window: (z1, y1, x1, z2, y2, x2) in pixels. The volume of the image where the real image is (excluding the padding) active_class_ids: List of class_ids available in the dataset from which the image came. Useful if training on images from multiple datasets where not all classes are present in all datasets. """ meta = np.array( [image_id] + # size = 1 list(image_shape) + # size = 3: [H, W, C] list(window) + # size = 6: (z1, y1, x1, z2, y2, x2) in image coordinates list(active_class_ids) # size = num_classes ) return meta def parse_image_meta(meta): """Parses an image info Numpy array to its components. See compose_image_meta() for more details. """ image_id = meta[:, 0] image_shape = meta[:, 1:4] # [H, W, D] window = meta[:, 4:10] # (z1, y1, x1, z2, y2, x2) window of image in in pixels active_class_ids = meta[:, 10:] return image_id, image_shape, window, active_class_ids def preprocess_image(image): """Preprocessing the image to [0., 1.].""" MIN_BOUND = 300 MAX_BOUND = -300 image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) image[image > 1.] = 1. image[image < 0.] = 0. return image