from __future__ import print_function, division, absolute_import import csv import os import os.path import tarfile from six.moves.urllib.parse import urlparse import numpy as np import torch import torch.utils.data as data from PIL import Image from . import utils object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] urls = { 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', } def read_image_label(file): print('[dataset] read ' + file) data = dict() with open(file, 'r') as f: for line in f: tmp = line.split(' ') name = tmp[0] label = int(tmp[-1]) data[name] = label # data.append([name, label]) # print('%s %d' % (name, label)) return data def read_object_labels(root, dataset, set): path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') labeled_data = dict() num_classes = len(object_categories) for i in range(num_classes): file = os.path.join(path_labels, object_categories[i] + '_' + set + '.txt') data = read_image_label(file) if i == 0: for (name, label) in data.items(): labels = np.zeros(num_classes) labels[i] = label labeled_data[name] = labels else: for (name, label) in data.items(): labeled_data[name][i] = label return labeled_data def write_object_labels_csv(file, labeled_data): # write a csv file print('[dataset] write file %s' % file) with open(file, 'w') as csvfile: fieldnames = ['name'] fieldnames.extend(object_categories) writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for (name, labels) in labeled_data.items(): example = {'name': name} for i in range(20): example[fieldnames[i + 1]] = int(labels[i]) writer.writerow(example) csvfile.close() def read_object_labels_csv(file, header=True): images = [] num_categories = 0 print('[dataset] read', file) with open(file, 'r') as f: reader = csv.reader(f) rownum = 0 for row in reader: if header and rownum == 0: header = row else: if num_categories == 0: num_categories = len(row) - 1 name = row[0] labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32) labels = torch.from_numpy(labels) item = (name, labels) images.append(item) rownum += 1 return images def find_images_classification(root, dataset, set): path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') images = [] file = os.path.join(path_labels, set + '.txt') with open(file, 'r') as f: for line in f: images.append(line) return images def download_voc2007(root): path_devkit = os.path.join(root, 'VOCdevkit') path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') tmpdir = os.path.join(root, 'tmp') # create directory if not os.path.exists(root): os.makedirs(root) if not os.path.exists(path_devkit): if not os.path.exists(tmpdir): os.makedirs(tmpdir) parts = urlparse(urls['devkit']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['devkit'], cached_file)) utils.download_url(urls['devkit'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') # train/val images/annotations if not os.path.exists(path_images): # download train/val images/annotations parts = urlparse(urls['trainval_2007']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['trainval_2007'], cached_file)) utils.download_url(urls['trainval_2007'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') # test annotations test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') if not os.path.exists(test_anno): # download test annotations parts = urlparse(urls['test_images_2007']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['test_images_2007'], cached_file)) utils.download_url(urls['test_images_2007'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') # test images test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') if not os.path.exists(test_image): # download test images parts = urlparse(urls['test_anno_2007']) filename = os.path.basename(parts.path) cached_file = os.path.join(tmpdir, filename) if not os.path.exists(cached_file): print('Downloading: "{}" to {}\n'.format(urls['test_anno_2007'], cached_file)) utils.download_url(urls['test_anno_2007'], cached_file) # extract file print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) cwd = os.getcwd() tar = tarfile.open(cached_file, "r") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('[dataset] Done!') class Voc2007Classification(data.Dataset): def __init__(self, root, set, transform=None, target_transform=None): self.root = root self.path_devkit = os.path.join(root, 'VOCdevkit') self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') self.set = set self.transform = transform self.target_transform = target_transform # download dataset download_voc2007(self.root) # define path of csv file path_csv = os.path.join(self.root, 'files', 'VOC2007') # define filename of csv file file_csv = os.path.join(path_csv, 'classification_' + set + '.csv') # create the csv file if necessary if not os.path.exists(file_csv): if not os.path.exists(path_csv): # create dir if necessary os.makedirs(path_csv) # generate csv file labeled_data = read_object_labels(self.root, 'VOC2007', self.set) # write csv file write_object_labels_csv(file_csv, labeled_data) self.classes = object_categories self.images = read_object_labels_csv(file_csv) print('[dataset] VOC 2007 classification set=%s number of classes=%d number of images=%d' % ( set, len(self.classes), len(self.images))) def __getitem__(self, index): path, target = self.images[index] img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, path, target def __len__(self): return len(self.images) def get_number_classes(self): return len(self.classes)