import torch from torch.utils.data import Dataset from PIL import Image import os.path import random import numpy as np from tool import imutils class PolyOptimizer(torch.optim.SGD): def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): super().__init__(params, lr, weight_decay) self.global_step = 0 self.max_step = max_step self.momentum = momentum self.__initial_lr = [group['lr'] for group in self.param_groups] def step(self, closure=None): if self.global_step < self.max_step: lr_mult = (1 - self.global_step / self.max_step) ** self.momentum for i in range(len(self.param_groups)): self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult super().step(closure) self.global_step += 1 class BatchNorm2dFixed(torch.nn.Module): def __init__(self, num_features, eps=1e-5): super(BatchNorm2dFixed, self).__init__() self.num_features = num_features self.eps = eps self.weight = torch.nn.Parameter(torch.Tensor(num_features)) self.bias = torch.nn.Parameter(torch.Tensor(num_features)) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, input): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, False, eps=self.eps) def __call__(self, x): return self.forward(x) class SegmentationDataset(Dataset): def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, img_transform=None, mask_transform=None): self.img_name_list_path = img_name_list_path self.img_dir = img_dir self.label_dir = label_dir self.img_transform = img_transform self.mask_transform = mask_transform self.img_name_list = open(self.img_name_list_path).read().splitlines() self.rescale = rescale self.flip = flip self.cropsize = cropsize def __len__(self): return len(self.img_name_list) def __getitem__(self, idx): name = self.img_name_list[idx] img = Image.open(os.path.join(self.img_dir, name + '.jpg')).convert("RGB") mask = Image.open(os.path.join(self.label_dir, name + '.png')) if self.rescale is not None: s = self.rescale[0] + random.random() * (self.rescale[1] - self.rescale[0]) adj_size = (round(img.size[0]*s/8)*8, round(img.size[1]*s/8)*8) img = img.resize(adj_size, resample=Image.CUBIC) mask = img.resize(adj_size, resample=Image.NEAREST) if self.img_transform is not None: img = self.img_transform(img) if self.mask_transform is not None: mask = self.mask_transform(mask) if self.cropsize is not None: img, mask = imutils.random_crop([img, mask], self.cropsize, (0, 255)) mask = imutils.RescaleNearest(0.125)(mask) if self.flip is True and bool(random.getrandbits(1)): img = np.flip(img, 1).copy() mask = np.flip(mask, 1).copy() img = np.transpose(img, (2, 0, 1)) return name, img, mask class ExtractAffinityLabelInRadius(): def __init__(self, cropsize, radius=5): self.radius = radius self.search_dist = [] for x in range(1, radius): self.search_dist.append((0, x)) for y in range(1, radius): for x in range(-radius+1, radius): if x*x + y*y < radius*radius: self.search_dist.append((y, x)) self.radius_floor = radius-1 self.crop_height = cropsize - self.radius_floor self.crop_width = cropsize - 2 * self.radius_floor return def __call__(self, label): labels_from = label[:-self.radius_floor, self.radius_floor:-self.radius_floor] labels_from = np.reshape(labels_from, [-1]) labels_to_list = [] valid_pair_list = [] for dy, dx in self.search_dist: labels_to = label[dy:dy+self.crop_height, self.radius_floor+dx:self.radius_floor+dx+self.crop_width] labels_to = np.reshape(labels_to, [-1]) valid_pair = np.logical_and(np.less(labels_to, 255), np.less(labels_from, 255)) labels_to_list.append(labels_to) valid_pair_list.append(valid_pair) bc_labels_from = np.expand_dims(labels_from, 0) concat_labels_to = np.stack(labels_to_list) concat_valid_pair = np.stack(valid_pair_list) pos_affinity_label = np.equal(bc_labels_from, concat_labels_to) bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32) fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32) neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32) return bg_pos_affinity_label, fg_pos_affinity_label, neg_affinity_label class AffinityFromMaskDataset(SegmentationDataset): def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, img_transform=None, mask_transform=None, radius=5): super().__init__(img_name_list_path, img_dir, label_dir, rescale, flip, cropsize, img_transform, mask_transform) self.radius = radius self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) def __getitem__(self, idx): name, img, mask = super().__getitem__(idx) aff_label = self.extract_aff_lab_func(mask) return name, img, aff_label