""" Copyright (c) 2019-present NAVER Corp. MIT License """ import os import time import math import torch from model import LMConfig, LanguageModel class TrainLogger(object): def __init__(self): self.init() def init(self): self.start = time.time() self.cnt_add = 0 self.tot_loss = 0. self.cnt_query = 0 self.cnt_token = 0 def add(self, loss, n_query, n_token): self.cnt_add += 1 self.tot_loss += loss * n_query self.cnt_query += n_query self.cnt_token += n_token def average(self): loss_query = self.tot_loss / self.cnt_query if self.cnt_query != 0 else 0. loss_token = self.tot_loss / self.cnt_token if self.cnt_token != 0 else 0. return loss_query, loss_token def elapsed_time(self): return time.time() - self.start def print_str(self, time_avg_=False): loss_query, loss_token = self.average() time_str = f"{self.elapsed_time() * 1000. / self.cnt_add:6.2f} ms/batch" if time_avg_ else \ f"{self.elapsed_time():6.2f} s" return f"{time_str} | loss_query {loss_query:6.2f} | token_ppl {math.exp(loss_token):6.2f}" def get_params(model): return list(filter(lambda p: p.requires_grad, model.parameters())) def get_model(model): return model.module if hasattr(model, 'module') else model def model_save(path, model, optimizer=None): model_to_save = get_model(model) open(os.path.join(path, 'config.json'), 'w').write(str(model_to_save.config)) torch.save(model_to_save.state_dict(), open(os.path.join(path, 'model.pt'), 'wb')) if optimizer: torch.save(optimizer.state_dict(), open(os.path.join(path, 'optimizer.pt'), 'wb')) def model_load(path, model=None, optimizer=None): config = LMConfig(os.path.join(path, 'config.json')) if model is None: model_to_load = LanguageModel(config) else: model_to_load = get_model(model) model_to_load.__init__(config) model_state_dict = torch.load(open(os.path.join(path, 'model.pt'), 'rb'), map_location=lambda s, l: s) model_to_load.load_state_dict(model_state_dict) if optimizer: optimizer_state_dict = torch.load(open(os.path.join(path, 'optimizer.pt'), 'rb'), map_location=lambda s, l: s) optimizer.load_state_dict(optimizer_state_dict) return model_to_load