import numpy as np from torch.utils.data import TensorDataset,DataLoader import torch import os,pickle def load_dataset(batch_size,dir='data',n_workers=0,test_size=16384,total_size=None): print ("Loading data...") data = np.load(os.path.join(dir,'captcha.npz')) image = data['img'].astype(np.float32)/127.5-1 text = data['text'] print ("Loading dictionary...") vocab = pickle.load(open(os.path.join(dir,'captcha.vocab_dict'),'rb'),encoding='utf8') print ("Convert to tensor...") if total_size is None: image = torch.Tensor(image).permute(0,3,1,2) text = torch.LongTensor(text) else: image = torch.Tensor(image[:total_size]).permute(0,3,1,2) text = torch.LongTensor(text[:total_size]) image_train = image[:-test_size] image_test = image[-test_size:] text_train = text[:-test_size] text_test = text[-test_size:] print ("Build dataset...") dataset_train = TensorDataset(image_train,text_train) dataset_test = TensorDataset(image_test,text_test) if torch.cuda.is_available(): pm = True else: pm = False print ("Build dataloader...") dataloader_train = DataLoader(dataset_train,batch_size,True,num_workers=n_workers) dataloader_test = DataLoader(dataset_test,batch_size,shuffle=False) print ("data ready!") return dataloader_train,dataloader_test,vocab if __name__=='__main__': dl_train,dl_test,vocab = load_dataset(32)