import torch.utils.data as data import os import numpy as np import cv2 #/mnt/lustre/share/dingmingyu/new_list_lane.txt class MyDataset(data.Dataset): def __init__(self, file, dir_path, new_width, new_height, label_width, label_height): imgs = [] fw = open(file, 'r') lines = fw.readlines() for line in lines: words = line.strip().split() imgs.append((words[0], words[1])) self.imgs = imgs self.dir_path = dir_path self.height = new_height self.width = new_width self.label_height = label_height self.label_width = label_width def __getitem__(self, index): path, label= self.imgs[index] path = os.path.join(self.dir_path, path) img = cv2.imread(path).astype(np.float32) img = img[:,:,:3] img = cv2.resize(img, (self.width, self.height)) img -= [104, 117, 123] img = img.transpose(2, 0, 1) gt = cv2.imread(label,-1) gt = cv2.resize(gt, (self.label_width, self.label_height), interpolation = cv2.INTER_NEAREST) if len(gt.shape) == 3: gt = gt[:,:,0] gt_num_list = list(np.unique(gt)) gt_num_list.remove(0) target_ins = np.zeros((4, gt.shape[0],gt.shape[1])).astype('uint8') for index, ins in enumerate(gt_num_list): target_ins[index,:,:] += (gt==ins) return img, target_ins, len(gt_num_list), gt def __len__(self): return len(self.imgs)