import os import sys import math import torch import argparse import numpy as np import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader from imageio import imread from PIL import Image from tqdm import tqdm sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder from lib.enet import create_enet_for_3d from lib.config import CONF # scannet data # NOTE: read only! SCANNET_FRAME_ROOT = CONF.SCANNET_FRAMES SCANNET_FRAME_PATH = os.path.join(SCANNET_FRAME_ROOT, "{}") # name of the file SCANNET_LIST = CONF.SCANNETV2_LIST ENET_PATH = CONF.ENET_WEIGHTS ENET_FEATURE_ROOT = CONF.ENET_FEATURES_SUBROOT ENET_FEATURE_PATH = CONF.ENET_FEATURES_PATH class EnetDataset(Dataset): def __init__(self): self._init_resources() def __len__(self): return len(self.data) def __getitem__(self, idx): scene_id, frame_id = self.data[idx] image = self._load_image(SCANNET_FRAME_PATH.format(scene_id, "color", "{}.jpg".format(frame_id)), [328, 256]) return scene_id, frame_id, image def _init_resources(self): self._get_scene_list() self.data = [] for scene_id in self.scene_list: frame_list = sorted(os.listdir(SCANNET_FRAME_ROOT.format(scene_id, "color")), key=lambda x:int(x.split(".")[0])) for frame_file in frame_list: self.data.append( ( scene_id, int(frame_file.split(".")[0]) ) ) def _get_scene_list(self): with open(SCANNET_LIST, 'r') as f: self.scene_list = sorted(list(set(f.read().splitlines()))) def _resize_crop_image(self, image, new_image_dims): image_dims = [image.shape[1], image.shape[0]] if image_dims != new_image_dims: resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1]))) image = transforms.Resize([new_image_dims[1], resize_width], interpolation=Image.NEAREST)(Image.fromarray(image)) image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image) return np.array(image) def _load_image(self, file, image_dims): image = imread(file) # preprocess image = self._resize_crop_image(image, image_dims) if len(image.shape) == 3: # color image image = np.transpose(image, [2, 0, 1]) # move feature to front image = transforms.Normalize(mean=[0.496342, 0.466664, 0.440796], std=[0.277856, 0.28623, 0.291129])(torch.Tensor(image.astype(np.float32) / 255.0)) elif len(image.shape) == 2: # label image image = np.expand_dims(image, 0) else: raise ValueError return image def collate_fn(self, data): scene_ids, frame_ids, images = zip(*data) scene_ids = list(scene_ids) frame_ids = list(frame_ids) images = torch.stack(images, 0).cuda() return scene_ids, frame_ids, images def create_enet(): enet_fixed, enet_trainable, _ = create_enet_for_3d(41, ENET_PATH, 21) enet = nn.Sequential( enet_fixed, enet_trainable ).cuda() enet.eval() for param in enet.parameters(): param.requires_grad = False return enet if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, help='gpu', default='0') args = parser.parse_args() # setting os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # init dataset = EnetDataset() dataloader = DataLoader(dataset, batch_size=256, shuffle=False, collate_fn=dataset.collate_fn) enet = create_enet() # feed print("extracting multiview features from ENet...") for scene_ids, frame_ids, images in tqdm(dataloader): features = enet(images) batch_size = images.shape[0] for batch_id in range(batch_size): os.makedirs(ENET_FEATURE_ROOT.format(scene_ids[batch_id]), exist_ok=True) np.save(ENET_FEATURE_PATH.format(scene_ids[batch_id], frame_ids[batch_id]), features[batch_id].cpu().numpy()) print("done!")