import os.path import random import numpy as np from PIL import Image from data.base_dataset import BaseDataset, get_label_transform, get_transform from data.image_folder import make_cs_labels, make_dataset from data.cityscapes import remap_labels_to_train_ids ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] def syn_relabel(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out class SynthiaCityscapesDataset(BaseDataset): def initialize(self, opt): self.opt = opt self.root = opt.dataroot self.dir_A = os.path.join(opt.dataroot, 'synthia', 'RGB') self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit') self.dir_A_label = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS') self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine') self.A_paths = make_dataset(self.dir_A) self.B_paths = make_dataset(self.dir_B) self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) self.A_size = len(self.A_paths) self.B_size = len(self.B_paths) self.A_labels = make_dataset(self.dir_A_label) self.B_labels = make_cs_labels(self.dir_B_label) self.A_labels = sorted(self.A_labels) self.B_labels = sorted(self.B_labels) self.transform = get_transform(opt) self.label_transform = get_label_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] if self.opt.serial_batches: index_B = index % self.B_size else: index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] A_label_path = self.A_labels[index % self.A_size] B_label_path = self.B_labels[index_B] A_label = Image.open(A_label_path) B_label = Image.open(B_label_path) A_label = np.asarray(A_label) A_label = syn_relabel(A_label) A_label = Image.fromarray(A_label, 'L') B_label = np.asarray(B_label) B_label = remap_labels_to_train_ids(B_label) B_label = Image.fromarray(B_label, 'L') A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') A = self.transform(A_img) B = self.transform(B_img) A_label = self.label_transform(A_label) B_label = self.label_transform(B_label) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc else: input_nc = self.opt.input_nc output_nc = self.opt.output_nc if input_nc == 1: # RGB to gray tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 A = tmp.unsqueeze(0) if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label} def __len__(self): return max(self.A_size, self.B_size) def name(self): return 'Synthia_Cityscapes'