# # mcts_base.py, doom-net # # Created by Andrey Kolishchak on 04/29/18. # import os import time import datetime from torch.multiprocessing import Process import numpy as np import h5py from simulator import Simulator from mcts import MCTS from mcts_dataset import MCTSDataset import torch import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from device import device from model import Model import vizdoom class MCTSBase(Model): def __init__(self): super().__init__() def run_train(self, args): print("training...") model = self sim = Simulator(model) games = [] for i in range(1): games.append( args.instance_class(args.vizdoom_config, args.wad_path, args.skiprate, actions=args.action_set, id=i) ) for iter in range(100): print("iteration: ", iter) # # generate data # processes = [] for game in games: process = Process(target=self.generate_data, args=(game, sim, args)) process.start() processes.append(process) for process in processes: process.join() # # train model with new data # self.train_model(model) def run_test(self, args): print("testing...") model = self sim = Simulator(model) model.eval() game = args.instance_class( args.vizdoom_config, args.wad_path, args.skiprate, visible=True, mode=vizdoom.Mode.ASYNC_PLAYER, actions=args.action_set) step_state = game.get_state_normalized() while True: state = sim.get_state(step_state) # compute an action action = sim.get_action(state) # render step_state, _, finished = game.step_normalized(action[0][0]) if finished: print("episode return: {}".format(game.get_episode_return())) def generate_data(self, game, sim, args, episode_num=100): model = sim.get_policy_model() model.eval() target_states, target_actions, target_rewards = [], [], [] mean_reward = 0 for i in range(episode_num): states, actions, rewards = self.get_episode_targets(game, sim, 10) # target_states.extend(states) target_actions.extend(actions) target_rewards.extend(rewards) # mean_reward += sum(rewards)/len(rewards) # # save episodes data to file # filename = os.path.join(args.h5_path, '{:%Y-%m-%d %H-%M-%S}-{}'.format(datetime.datetime.now(), i)) file = h5py.File(filename, 'w') file.create_dataset('states', data=target_states, dtype='float32', compression='gzip') file.create_dataset('actions', data=target_actions, dtype='long', compression='gzip') file.create_dataset('rewards', data=target_rewards, dtype='float32', compression='gzip') mean_reward /= episode_num print("mean reward = ", mean_reward) def get_episode_targets(self, game, sim, max_length): mcts = MCTS(sim, 1000, c_puct=1) target_states, target_actions, target_rewards = [], [], [] step = 0 state = game.get_state_normalized() while step < max_length: prob, state = mcts.get_action_prob(state, 1) action = np.random.choice(len(prob), p=prob) target_states.append(state.policy_state) target_actions.append(action) state, reward, finished = game.step_normalized(action) target_rewards.append(reward) if finished: break step += 1 return target_states, target_actions, target_rewards def train_model(self, model, args, epoch_num=10): dataset = MCTSDataset(args) training_data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=args.batch_size, shuffle=True) model.train() optimizer = optim.AdamW(model.parameters(), lr=5e-3, weight_decay=1e-4, amsgrad=True) mean_value_loss = 0 mean_policy_loss = 0 mean_accuracy = 0 updates = 0 batch_time = time.time() for epoch in range(epoch_num): for batch, (state, target_action, target_value) in enumerate(training_data_loader): state, target_action, target_value = state.to(device), target_action.to(device), target_value.to(device) optimizer.zero_grad() value, log_action = model(state) value_loss = F.mse_loss(value, target_value[:, None]) policy_loss = F.nll_loss(log_action, target_action) loss = value_loss + policy_loss loss.backward() optimizer.step() grads = [] weights = [] for p in model.parameters(): if p.grad is not None: grads.append(p.grad.data.view(-1)) weights.append(p.data.view(-1)) grads = torch.cat(grads, 0) weights = torch.cat(weights, 0) grads_norm = grads.norm() weights_norm = weights.norm() assert grads_norm == grads_norm _, pred_action = log_action.max(1) accuracy = (pred_action == target_action.data).float().mean() if epoch == epoch_num - 1: mean_value_loss += value_loss.item() mean_policy_loss += policy_loss.item() mean_accuracy += accuracy updates += 1 mean_value_loss /= updates mean_policy_loss /= updates mean_accuracy /= updates print( "value_loss = {:f} policy_loss = {:f} accuracy = {:f}, train_time = {:.3f}".format(mean_value_loss, mean_policy_loss, mean_accuracy, time.time() - batch_time)) torch.save(model.state_dict(), args.checkpoint_file) torch.save(optimizer.state_dict(), args.checkpoint_file + '_optimizer.pth')