import copy import os from typing import Any, Dict, Optional, Union, Type import torch from torch import nn, optim class CheckpointManager(object): r""" A :class:`CheckpointManager` periodically serializes models and optimizer as .pth files during training, and keeps track of best performing checkpoint based on an observed metric. Extended Summary ---------------- It saves state dicts of models and optimizer as ``.pth`` files in a specified directory. This class closely follows the API of PyTorch optimizers and learning rate schedulers. Notes ----- For :class:`~torch.nn.DataParallel` objects, ``.module.state_dict()`` is called instead of ``.state_dict()``. Parameters ---------- models: Dict[str, torch.nn.Module] Models which need to be serialized as a checkpoint. optimizer: torch.optim.Optimizer Optimizer which needs to be serialized as a checkpoint. serialization_dir: str Path to an empty or non-existent directory to save checkpoints. mode: str, optional (default="max") One of ``min``, ``max``. In ``min`` mode, best checkpoint will be recorded when metric hits a lower value; in `max` mode it will be recorded when metric hits a higher value. filename_prefix: str, optional (default="checkpoint") Prefix of the to-be-saved checkpoint files. Examples -------- >>> model = torch.nn.Linear(10, 2) >>> optimizer = torch.optim.SGD(model.parameters()) >>> ckpt_manager = CheckpointManager({"model": model}, optimizer, "/tmp/ckpt", mode="min") >>> num_epochs = 20 >>> for epoch in range(num_epochs): ... train(model) ... val_loss = validate(model) ... ckpt_manager.step(val_loss, epoch) """ def __init__( self, models: Union[nn.Module, Dict[str, nn.Module]], optimizer: Type[optim.Optimizer], serialization_dir: str, mode: str = "max", filename_prefix: str = "checkpoint", ): # Convert single model to a dict. if isinstance(models, nn.Module): models = {"model": models} for key in models: if not isinstance(models[key], nn.Module): raise TypeError("{} is not a Module".format(type(models).__name__)) if not isinstance(optimizer, optim.Optimizer): raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self._models = models self._optimizer = optimizer self._serialization_dir = serialization_dir self._mode = mode self._filename_prefix = filename_prefix # Initialize members to hold state dict of best checkpoint and its performance. self._best_metric: Optional[Union[float, torch.Tensor]] = None self._best_ckpt: Dict[str, Any] = {} def step(self, metric: Union[float, torch.Tensor], epoch_or_iteration: int): r"""Serialize checkpoint and update best checkpoint based on metric and mode.""" # Update best checkpoint based on metric and metric mode. if not self._best_metric: self._best_metric = metric models_state_dict: Dict[str, Any] = {} for key in self._models: if isinstance(self._models[key], nn.DataParallel): models_state_dict[key] = self._models[key].module.state_dict() else: models_state_dict[key] = self._models[key].state_dict() if (self._mode == "min" and metric < self._best_metric) or ( self._mode == "max" and metric > self._best_metric ): self._best_metric = metric self._best_ckpt = copy.copy(models_state_dict) # Serialize checkpoint corresponding to current epoch (or iteration). torch.save( {**models_state_dict, "optimizer": self._optimizer.state_dict()}, os.path.join( self._serialization_dir, f"{self._filename_prefix}_{epoch_or_iteration}.pth" ), ) # Serialize best performing checkpoint observed so far. torch.save( self._best_ckpt, os.path.join(self._serialization_dir, f"{self._filename_prefix}_best.pth"), )