# # imitation.py, doom-net # # Created by Andrey Kolishchak on 01/21/17. # import os import time import h5py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from device import device import argparse from doom_instance import * from aac import BaseModel class DoomDataset(Dataset): def __init__(self, h5_path): super(DoomDataset, self).__init__() self.h5_path = h5_path with h5py.File(self.h5_path, 'r') as data: inputs = data['screens'] print('Dataset size =', len(inputs)) self.action_sets = data['action_sets'][:] self.input_shape = inputs[0].shape self.length = len(inputs) # # hd5 has issues with fork at DataLoader, so file is opened on first getitem() # https://groups.google.com/forum/#!topic/h5py/bJVtWdFtZQM # self.data = None self.inputs = None self.labels = None self.variables = None def __getitem__(self, index): if self.data is None: self.data = h5py.File(self.h5_path, 'r') self.inputs = self.data['screens'] self.labels = self.data['action_labels'] self.variables = self.data['variables'] return self.inputs[index].astype(np.float32) / 127.5 - 1.0, \ self.variables[index].astype(np.float32) / 100, \ self.labels[index].astype(np.int) def __len__(self): return self.length def train(args): train_set = DoomDataset(args.h5_path) np.save('action_set', train_set.action_sets) training_data_loader = DataLoader(dataset=train_set, num_workers=2, batch_size=100, shuffle=True) model = BaseModel(train_set.input_shape[0], len(train_set.action_sets), 3, args.frame_num).to(device) if args.load is not None and os.path.isfile(args.load): print("loading model parameters {}".format(args.load)) source_model = torch.load(args.load) model.load_state_dict(source_model.state_dict()) del source_model criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=5e-4) for epoch in range(1500000): running_loss = 0 running_accuracy = 0 batch_time = time.time() for batch, (screens, variables, labels) in enumerate(training_data_loader): screens, variables, labels = screens.to(device), variables.to(device), labels.to(device) optimizer.zero_grad() outputs = model(screens, variables)[0] loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, pred = outputs.max(1) accuracy = (pred == labels).float().mean() running_accuracy += accuracy batches_per_print = 10 if batch % batches_per_print == batches_per_print-1: # print every batches_per_print mini-batches print( '[{:d}, {:5d}] loss: {:.3f}, accuracy: {:.3f}, time: {:.6f}'.format( epoch + 1, batch + 1, running_loss/batches_per_print, running_accuracy/batches_per_print, (time.time()-batch_time)/batches_per_print ) ) running_loss = 0 running_accuracy = 0 batch_time = time.time() if epoch % args.checkpoint_rate == args.checkpoint_rate - 1: torch.save(model, args.checkpoint_file) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Doom Recorder') parser.add_argument('--batch_size', type=int, default=100, help='number of game instances running in parallel') parser.add_argument('--load', default=None, help='path to model file') parser.add_argument('--h5_path', default=os.path.expanduser('~') + '/test/datasets/vizdoom/cig_map01/flat.h5', help='hd5 file path') parser.add_argument('--skiprate', type=int, default=1, help='number of skipped frames') parser.add_argument('--frame_num', type=int, default=1, help='number of frames per input') parser.add_argument('--checkpoint_file', default=None, help='check point file name') parser.add_argument('--checkpoint_rate', type=int, default=5000, help='number of batches per checkpoit') args = parser.parse_args() train(args)