# # imitation_frames.py, doom-net # # Created by Andrey Kolishchak on 01/21/17. # import os import time import h5py import torch import torch.nn as nn import torch.optim as optim from device import device import argparse from doom_instance import * from aac import BaseModel def data_generator(args, screens, variables, labels, episodes, step_size): # remove short episodes episode_min_size = args.episode_size*step_size episodes = episodes[episodes[:, 1]-episodes[:, 0] > episode_min_size] episodes_num = len(episodes) # step_idx = episodes[:, 0].copy() + np.random.randint(step_size, size=episodes_num) step_screens = np.ndarray(shape=(args.batch_size, *screens.shape[1:]), dtype=np.float32) step_variables = np.ndarray(shape=(args.batch_size, *variables.shape[1:]), dtype=np.float32) step_labels = np.ndarray(shape=(args.batch_size,), dtype=np.int) step_terminals = np.ones(shape=(args.batch_size,), dtype=np.float32) # select episodes for the initial batch batch_episodes = np.random.randint(episodes_num, size=args.batch_size) while True: for i in range(args.batch_size): idx = batch_episodes[i] step_screens[i, :] = screens[step_idx[idx]] / 127.5 - 1.0 step_variables[i, :] = variables[step_idx[idx]] / 100 step_labels[i] = labels[step_idx[idx]] step_idx[idx] += step_size if step_idx[idx] > episodes[idx][1]: step_idx[idx] = episodes[idx][0] + np.random.randint(step_size) step_terminals[i] = 0 # reached terminal state, select a new episode batch_episodes[i] = np.random.randint(episodes_num) else: step_terminals[i] = 1 yield torch.from_numpy(step_screens), \ torch.from_numpy(step_variables), \ torch.from_numpy(step_labels), \ torch.from_numpy(step_terminals) def train(args): data_file = h5py.File(args.h5_path, 'r') screens = data_file['screens'] variables = data_file['variables'] labels = data_file['action_labels'] print('Dataset size =', len(screens)) action_sets = data_file['action_sets'][:] episodes = data_file['episodes'][:] input_shape = screens[0].shape train_generator = data_generator(args, screens, variables, labels, episodes, args.skiprate) np.save('action_set', action_sets) model = BaseModel(input_shape[0]*args.frame_num, len(action_sets), variables.shape[1], args.frame_num).to(device) if args.load is not None and os.path.isfile(args.load): print("loading model parameters {}".format(args.load)) source_model = torch.load(args.load) model.load_state_dict(source_model.state_dict()) del source_model criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=5e-4) optimizer.zero_grad() running_loss = 0 running_accuracy = 0 batch_time = time.time() for batch, (screens, variables, labels, terminals) in enumerate(train_generator): labels = labels.to(device) outputs, _ = model(*model.transform_input(screens, variables)) loss = criterion(outputs, labels) model.set_terminal(terminals) running_loss += loss.item() _, pred = outputs.max(1) accuracy = (pred == labels).float().mean() running_accuracy += accuracy loss.backward() optimizer.step() optimizer.zero_grad() if batch % args.episode_length == args.episode_length - 1: running_loss /= args.episode_length running_accuracy /= args.episode_length print( '[{:d}] loss: {:.3f}, accuracy: {:.3f}, time: {:.6f}'.format( batch + 1, running_loss, running_accuracy, time.time()-batch_time ) ) running_loss = 0 running_accuracy = 0 batch_time = time.time() if batch % args.checkpoint_rate == args.checkpoint_rate - 1: torch.save(model, args.checkpoint_file) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Doom Recorder') parser.add_argument('--episode_size', type=int, default=20, help='number of steps in an episode') parser.add_argument('--batch_size', type=int, default=64, help='number of game instances running in parallel') parser.add_argument('--load', default=None, help='path to model file') parser.add_argument('--h5_path', default=os.path.expanduser('~') + '/test/datasets/vizdoom/cig_map01/flat.h5', help='hd5 file path') parser.add_argument('--skiprate', type=int, default=2, help='number of skipped frames') parser.add_argument('--episode_length', type=int, default=30, help='episode length') parser.add_argument('--frame_num', type=int, default=4, help='number of frames per input') parser.add_argument('--checkpoint_file', default=None, help='check point file name') parser.add_argument('--checkpoint_rate', type=int, default=5000, help='number of batches per checkpoit') args = parser.parse_args() train(args)