from . import lmdb_datasets from torchvision import datasets,transforms import os import os.path as osp import torch.utils.data import numpy as np import cv2 from . import lmdb_data_pb2 as pb2 import Queue import time import multiprocessing DATASET_SIZE=100 mean=np.array([[[0.485]], [[0.456]], [[0.406]]])*255 std=np.array([[[0.229]], [[0.224]], [[0.225]]])*255 class Imagenet_LMDB(lmdb_datasets.LMDB): def __init__(self,imagenet_dir,train=False): self.train_name='imagenet_train_lmdb' self.val_name='imagenet_val_lmdb' self.train=train super(Imagenet_LMDB, self).__init__(osp.join(imagenet_dir,train and self.train_name or self.val_name)) txn=self.env.begin() self.cur=txn.cursor() self.data = Queue.Queue(DATASET_SIZE*2) self.target = Queue.Queue(DATASET_SIZE*2) self.point=0 # self._read_from_lmdb() def data_transfrom(self,data,other): data=data.astype(np.float32) if self.train: shape=np.fromstring(other[0],np.uint16) data=data.reshape(shape) # Random crop _, w, h = data.shape x1 = np.random.randint(0, w - 224) y1 = np.random.randint(0, h - 224) data=data[:,x1:x1+224 ,y1:y1 + 224] # HorizontalFlip #TODO horizontal flip else: data = data.reshape([3, 224, 224]) data = (data - mean) / std tensor = torch.Tensor(data) del data return tensor def target_transfrom(self,target): return target def _read_from_lmdb(self): self.cur.next() if not self.cur.key(): self.cur.first() dataset = pb2.Dataset().FromString(self.cur.value()) for datum in dataset.datums: data = np.fromstring(datum.data, np.uint8) try: data = self.data_transfrom(data, datum.other) except: print 'cannot trans ', data.shape continue target = int(datum.target) target = self.target_transfrom(target) self.data.put(data) self.target.put(target) # print 'read_from_lmdb', time.time()-r del dataset # def read_from_lmdb(self): # process=multiprocessing.Process(target=self._read_from_lmdb) # process.start() def __getitem__(self,index): if self.data.qsize()<DATASET_SIZE: self._read_from_lmdb() data,target=self.data.get(),self.target.get() return data,target def __len__(self): return self.env.stat()['entries']*DATASET_SIZE def Imagenet_LMDB_generate(imagenet_dir, output_dir, make_val=False, make_train=False): # the imagenet_dir should have direction named 'train' or 'val',with 1000 folders of raw jpeg photos train_name = 'imagenet_train_lmdb' val_name = 'imagenet_val_lmdb' def target_trans(target): return target if make_val: val_lmdb=lmdb_datasets.LMDB_generator(osp.join(output_dir,val_name)) def trans_val_data(dir): tensor = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor() ])(dir) tensor=(tensor.numpy()*255).astype(np.uint8) return tensor val = datasets.ImageFolder(osp.join(imagenet_dir,'val'), trans_val_data,target_trans) val_lmdb.write_classification_lmdb(val, num_per_dataset=DATASET_SIZE) if make_train: train_lmdb = lmdb_datasets.LMDB_generator(osp.join(output_dir, train_name)) def trans_train_data(dir): tensor = transforms.Compose([ transforms.Scale(256), transforms.ToTensor() ])(dir) tensor=(tensor.numpy()*255).astype(np.uint8) return tensor train = datasets.ImageFolder(osp.join(imagenet_dir, 'train'), trans_train_data, target_trans) train.imgs=np.random.permutation(train.imgs) train_lmdb.write_classification_lmdb(train, num_per_dataset=DATASET_SIZE, write_shape=True)