import os import time import random from copy import copy import numpy as np import torch from tensorboardX import SummaryWriter from torchsupport.data.io import netwrite, to_device from torchsupport.data.episodic import SupportData from torchsupport.data.collate import DataLoader from torchsupport.training.state import ( TrainingState, NetState, State, SaveStateError ) class Training(object): """Abstract training process class.""" checkpoint_parameters = [] torch_rng_state = torch.random.get_rng_state() np_rng_state = np.random.get_state() random_rng_state = random.getstate() save_interval = 600 last_tick = 0 def __init__(self): pass def each_step(self): self.save_tick() def each_validate(self): pass def each_epoch(self): pass def each_checkpoint(self): pass def train(self): pass def validate(self): pass def save_path(self): raise NotImplementedError("Abstract.") def write(self, path): data = {} data["_torch_rng_state"] = torch.random.get_rng_state() data["_np_rng_state"] = np.random.get_state() data["_random_rng_state"] = random.getstate() for param in self.checkpoint_parameters: param.write_action(self, data) torch.save(data, path) def read(self, path): data = torch.load(path) torch.random.set_rng_state(data["_torch_rng_state"]) np.random.set_state(data["_np_rng_state"]) random.setstate(data["_random_rng_state"]) for param in self.checkpoint_parameters: param.read_action(self, data) def save(self, path=None): path = path or self.save_path() self.write(path) def save_tick(self, step=None): step = step or self.save_interval this_tick = time.monotonic() if this_tick - self.last_tick > step: try: self.save() self.last_tick = this_tick except SaveStateError: torch_rng_state = torch.random.get_rng_state() np_rng_state = np.random.get_state() random_rng_state = random.getstate() self.load() torch.random.set_rng_state(torch_rng_state) np.random.set_state(np_rng_state) random.setstate(random_rng_state) def load(self, path=None): path = path or self.save_path() if os.path.isfile(path): self.read(path) return self class SupervisedTraining(Training): """Standard supervised training process. Args: net (Module): a trainable network module. train_data (DataLoader): a :class:`DataLoader` returning the training data set. validate_data (DataLoader): a :class:`DataLoader` return ing the validation data set. optimizer (Optimizer): an optimizer for the network. Defaults to ADAM. schedule (Schedule): a learning rate schedule. Defaults to decay when stagnated. max_epochs (int): the maximum number of epochs to train. device (str): the device to run on. checkpoint_path (str): the path to save network checkpoints. """ checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetState("net"), NetState("optimizer") ] def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, max_epochs=50, batch_size=128, accumulate=None, device="cpu", network_name="network", path_prefix=".", report_interval=10, checkpoint_interval=1000, valid_callback=lambda x: None): super(SupervisedTraining, self).__init__() self.valid_callback = valid_callback self.network_name = network_name self.writer = SummaryWriter(network_name) self.device = device self.accumulate = accumulate self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader( train_data, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True ) self.validate_data = DataLoader( validate_data, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True ) self.net = net.to(self.device) self.max_epochs = max_epochs self.checkpoint_path = f"{path_prefix}/{network_name}-checkpoint" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.step_id = 0 self.epoch_id = 0 self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None def save_path(self): return self.checkpoint_path + "-save.torch" def checkpoint(self): the_net = self.net if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( self.net, f"{self.checkpoint_path}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def run_networks(self, data): inputs, *labels = data if not isinstance(inputs, (list, tuple)): inputs = [inputs] predictions = self.net(*inputs) if not isinstance(predictions, (list, tuple)): predictions = [predictions] return [combined for combined in zip(predictions, labels)] def loss(self, inputs): loss_val = torch.tensor(0.0).to(self.device) for idx, the_input in enumerate(inputs): this_loss_val = self.losses[idx](*the_input) self.training_losses[idx] = float(this_loss_val) loss_val += this_loss_val return loss_val def valid_loss(self, inputs): training_cache = list(self.training_losses) loss_val = self.loss(inputs) self.validation_losses = self.training_losses self.training_losses = training_cache return loss_val def step(self, data): if self.accumulate is None: self.optimizer.zero_grad() outputs = self.run_networks(data) loss_val = self.loss(outputs) loss_val.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0) if self.accumulate is None: self.optimizer.step() elif self.step_id % self.accumulate == 0: self.optimizer.step() self.optimizer.zero_grad() self.each_step() def validate(self, data): with torch.no_grad(): self.net.eval() outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback( self, to_device(data, "cpu"), to_device(outputs, "cpu") ) self.net.train() def schedule_step(self): self.schedule.step(sum(self.validation_losses)) def each_step(self): Training.each_step(self) for idx, loss in enumerate(self.training_losses): self.writer.add_scalar(f"training loss {idx}", loss, self.step_id) self.writer.add_scalar(f"training loss total", sum(self.training_losses), self.step_id) def each_validate(self): for idx, loss in enumerate(self.validation_losses): self.writer.add_scalar(f"validation loss {idx}", loss, self.step_id) self.writer.add_scalar(f"validation loss total", sum(self.validation_losses), self.step_id) def train(self): for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id valid_iter = iter(self.validate_data) for data in self.train_data: data = to_device(data, self.device) self.step(data) if self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.validate_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 self.schedule_step() self.each_epoch() return self.net class MaskedSupervisedTraining(SupervisedTraining): def run_networks(self, data): inputs, labels_masks = data labels = [label for (label, mask) in labels_masks] masks = [mask for (label, mask) in labels_masks] predictions = self.net(inputs) return list(zip(predictions, labels, masks)) class FewShotTraining(SupervisedTraining): def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, max_epochs=50, batch_size=128, device="cpu", network_name="network", path_prefix=".", report_interval=10, checkpoint_interval=1000, valid_callback=lambda x: None): super(FewShotTraining, self).__init__( net, train_data, validate_data, losses, optimizer=optimizer, schedule=schedule, max_epochs=max_epochs, batch_size=batch_size, device=device, network_name=network_name, path_prefix=path_prefix, report_interval=report_interval, checkpoint_interval=checkpoint_interval, valid_callback=valid_callback ) support_data = copy(train_data) train_data.data_mode = type(train_data.data_mode)(1) support_data = SupportData(train_data, shots=5) validate_support_data = SupportData(validate_data, shots=5) self.support_loader = iter(DataLoader(support_data)) self.valid_support_loader = iter(DataLoader(validate_support_data)) def run_networks(self, data, support, support_label): predictions = self.net(data, support) return list(zip(predictions, support_label)) def step(self, inputs): data, label = inputs self.optimizer.zero_grad() permutation = [0, 1, 2] random.shuffle(permutation) support, support_label = next(self.support_loader) lv = label[0].reshape(-1) for idx, val in enumerate(lv): lv[idx] = permutation[int(val[0])] lv = support_label.reshape(-1) for idx, val in enumerate(lv): lv[idx] = permutation[int(val[0])] support = support[0].to(self.device) support_label = support_label[0].to(self.device) outputs = self.run_networks(data, support, support_label) loss_val = self.loss(outputs) loss_val.backward() self.optimizer.step() self.each_step() def validate(self): with torch.no_grad(): self.net.eval() vit = iter(self.validate_data) inputs, *label = next(vit) inputs, label = inputs.to(self.device), list(map(lambda x: x.to(self.device), label)) support, support_label = next(self.valid_support_loader) support = support[0].to(self.device) support_label = support_label[0].to(self.device) outputs = self.run_networks(inputs, support, support_label) self.valid_loss(outputs) self.each_validate() self.valid_callback(self, to_device(inputs, "cpu"), to_device(outputs, "cpu")) self.net.train()