import csv import os import os.path import tarfile from urllib.parse import urlparse import numpy as np import torch import torch.utils.data as data from PIL import Image from experiment import util object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] urls = { 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', } def read_image_label(file): print('[dataset] read ' + file) data = dict() with open(file, 'r') as f: for line in f: tmp = line.split(' ') name = tmp[0] label = int(tmp[-1]) data[name] = label return data def read_object_labels(root, dataset, set): path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') labeled_data = dict() num_classes = len(object_categories) for i in range(num_classes): file = os.path.join(path_labels, object_categories[i] + '_' + set + '.txt') data = read_image_label(file) if i == 0: for (name, label) in data.items(): labels = np.zeros(num_classes) labels[i] = label labeled_data[name] = labels else: for (name, label) in data.items(): labeled_data[name][i] = label return labeled_data def write_object_labels_csv(file, labeled_data): # write a csv file print('[dataset] write file %s' % file) with open(file, 'w') as csvfile: fieldnames = ['name'] fieldnames.extend(object_categories) writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for (name, labels) in labeled_data.items(): example = {'name': name} for i in range(20): example[fieldnames[i + 1]] = int(labels[i]) writer.writerow(example) csvfile.close() def read_object_labels_csv(file, header=True): images = [] num_categories = 0 print('[dataset] read', file) with open(file, 'r') as f: reader = csv.reader(f) rownum = 0 for row in reader: if header and rownum == 0: header = row else: if num_categories == 0: num_categories = len(row) - 1 name = row[0] labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32) labels = torch.from_numpy(labels) item = (name, labels) images.append(item) rownum += 1 return images def find_images_classification(root, dataset, set): path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') images = [] file = os.path.join(path_labels, set + '.txt') with open(file, 'r') as f: for line in f: images.append(line) return images def download_voc2007(root): path_devkit = os.path.join(root, 'VOCdevkit') path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') tmpdir = os.path.join(root, 'tmp') # create directory if not os.path.exists(root): os.makedirs(root) if not os.path.exists(path_devkit): if not os.path.exists(tmpdir): os.makedirs(tmpdir) parts = urlparse(urls['devkit']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['devkit'], cached_file)) util.download_url(urls['devkit'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') # train/val images/annotations if not os.path.exists(path_images): # download train/val images/annotations parts = urlparse(urls['trainval_2007']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['trainval_2007'], cached_file)) util.download_url(urls['trainval_2007'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') # test annotations test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') if not os.path.exists(test_anno): # download test annotations parts = urlparse(urls['test_images_2007']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['test_images_2007'], cached_file)) util.download_url(urls['test_images_2007'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') # test images test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') if not os.path.exists(test_image): # download test images parts = urlparse(urls['test_anno_2007']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['test_anno_2007'], cached_file)) util.download_url(urls['test_anno_2007'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') class Voc2007Classification(data.Dataset): def __init__(self, root, set, transform=None, target_transform=None): self.root = root self.path_devkit = os.path.join(root, 'VOCdevkit') self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') self.set = set self.transform = transform self.target_transform = target_transform # download dataset download_voc2007(self.root) # define path of csv file path_csv = os.path.join(self.root, 'files', 'VOC2007') # define filename of csv file file_csv = os.path.join(path_csv, 'classification_' + set + '.csv') # create the csv file if necessary if not os.path.exists(file_csv): if not os.path.exists(path_csv): # create dir if necessary os.makedirs(path_csv) # generate csv file labeled_data = read_object_labels(self.root, 'VOC2007', self.set) # write csv file write_object_labels_csv(file_csv, labeled_data) self.classes = object_categories self.images = read_object_labels_csv(file_csv) print('[dataset] VOC 2007 classification set=%s number of classes=%d number of images=%d' % ( set, len(self.classes), len(self.images))) def __getitem__(self, index): path, target = self.images[index] img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return (img, path), target def __len__(self): return len(self.images) def get_number_classes(self): return len(self.classes)