import logging from ignite.engine import Engine class EarlyStopping(object): """EarlyStopping handler can be used to stop the training if no improvement after a given number of events Args: patience (int): Number of events to wait if no improvement and then stop the training score_function (Callable): It should be a function taking a single argument, an `ignite.engine.Engine` object, and return a score `float`. An improvement is considered if the score is higher. Examples: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import EarlyStopping def score_function(engine): val_loss = engine.state.metrics['nll'] return -val_loss handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, handler) """ def __init__(self, patience, score_function, trainer): assert callable(score_function), "Argument score_function should be a function" assert patience > 0, "Argument patience should be positive" assert isinstance(trainer, Engine), "Argument trainer should be an instance of Engine" self.score_function = score_function self.patience = patience self.trainer = trainer self.counter = 0 self.best_score = None self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._logger.addHandler(logging.NullHandler()) def __call__(self, engine): score = self.score_function(engine) if self.best_score is None: self.best_score = score elif score < self.best_score: self.counter += 1 self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) if self.counter >= self.patience: self._logger.info("EarlyStopping: Stop training") self.trainer.terminate() else: self.best_score = score self.counter = 0