Source code for gojo.deepl.callback

# Module containing the callbacks used by the training loops of deep learning models.
#
# Author: Fernando García Gutiérrez
# Email: ga.gu.fernando.concat@gmail.com
#
# STATUS: completed, and testing to be done, and documented.
#
import os
import numpy as np
import pandas as pd
import warnings
import torch
from abc import ABCMeta, abstractmethod
from pathlib import Path

from ..util.validation import (
    checkMultiInputTypes,
    checkInputType
)
from ..util.io import saveTorchModel


[docs]class Callback(object): """ Base class (interface) used to define the callbacks to be executed in each iteration of the training loop of the neural networks defined in :func:`gojo.deepl.loop.fitNeuralNetwork`. These callbacks provide directives to modify the training of the models. A classic example would be the early stopping callback (defined in :class:`gojo.deepl.callback.EarlyStopping`). Subclasses must define the following methods: - evaluate() This method will make available to the callback the following arguments used (and updated) in the current iteration of the :func:'gojo.deepl.loop.fitNeuralNetwork' training loop: model : :class:`gojo.core.base.TorchSKInterface` or :class:`gojo.core.base.ParametrizedTorchSKInterface` Model to be trained. train_metrics : list Train computed metrics until the last epoch. valid_metrics : list Validation computed metrics until the last epoch. train_loss : list Train computed loss until the last epoch. valid_loss : list Validation computed loss until the last epoch. This method has to return a directive (as a string) that will be interpreted by the :func:`gojo.deepl.loop.fitNeuralNetwork` inner loop. - resetState() This method should reset the inner state of the callback. """ __metaclass__ = ABCMeta def __init__(self, name: str): checkMultiInputTypes( ('name', name, [str])) self._name = name def __repr__(self): return self._name def __str__(self): return self.__repr__() def __call__(self, *args, **kwargs) -> str: command = self.evaluate(*args, **kwargs) checkInputType('gojo.deepl.callback.Callback.__call__()', command, [str, type(None)]) return command
[docs] @abstractmethod def evaluate(self, *args, **kwargs) -> str: raise NotImplementedError
[docs] @abstractmethod def resetState(self): raise NotImplementedError
[docs]class EarlyStopping(Callback): """ Callback used to perform an early stopping of the :func:`gojo.deepl.loop.fitNeuralNetwork` training loop. Parameters ---------- it_without_improve : int Number of iterations that must be completed without the model showing a decrease in the loss value over the validation set (average of the last epochs or count of the last epochs, as defined by parameter `track`) to perform an early stopping ending the loop execution. track : str, default='mean' Method used to compare the latest value of the loss on the validation set with respect to the historical value. Methods currently available: - 'mean': compare the current value with respect to the average of the `it_without_improve` epochs. - 'count': compare the current value with respect to `it_without_improve` epochs. ref_metric : str, default=None Reference metric calculated on the validation set to be used as a reference. By default, the loss value will be used. smooth_n : int, default=None Value that indicates if instead of considering the last value of the loss and comparing against the historical ones, to smooth the last value considering the average value of the last `smooth_n` iterations. """ VALID_TRACKING_OPTS = ['mean', 'count'] _LOSS_IDENTIFICATION_KEY = 'loss (mean)' # HACK. Hard-coding, key used to identify the average loss values DIRECTIVE = 'stop' def __init__(self, it_without_improve: int, track: str = 'mean', ref_metric: str = None, smooth_n: int = None): super().__init__(name='EarlyStopping') assert track in EarlyStopping.VALID_TRACKING_OPTS # check smooth parameter if smooth_n is not None and smooth_n <= 1: raise ValueError('If provided. Parameter "smooth_n" must be greater than 1.') self.it_without_improve = it_without_improve self.ref_metric = ref_metric if ref_metric is not None else self._LOSS_IDENTIFICATION_KEY self.track = track self.smooth_n = smooth_n self._saved_valid_loss = [] def _getLastLossValue(self, stats: list) -> float: """ Function used to get and check the current loss values. """ curr_loss = stats[-1].get(self.ref_metric, np.nan) # check for NaNs in the current loss if pd.isna(curr_loss): warnings.warn('Current average loss value is NaN.') return curr_loss
[docs] def evaluate(self, valid_loss: list, **_) -> str: """ Early stopping inner logic. """ # check input type checkInputType('gojo.deepl.callback.EarlyStopping.evaluate(valid_loss)', valid_loss, [list]) command = None if len(self._saved_valid_loss) < self.it_without_improve: # not enough iterations performed self._saved_valid_loss.append(self._getLastLossValue(valid_loss)) else: # there is enough iterations performed to check loss improvements # get the saved losses saved_valid_loss = np.array(self._saved_valid_loss) # get the current loss curr_loss = self._getLastLossValue(valid_loss) # get the loss to compare if self.smooth_n is None: loss_to_comp = curr_loss else: loss_to_comp = np.mean(list(saved_valid_loss[-1*self.smooth_n+1:]) + [curr_loss]) if self.track == 'count': if np.all(loss_to_comp > saved_valid_loss[-1 * self.it_without_improve:]): command = self.DIRECTIVE elif self.track == 'mean': if loss_to_comp > np.mean(saved_valid_loss[-1 * self.it_without_improve:]): command = self.DIRECTIVE else: raise NotImplementedError() # save the current loss self._saved_valid_loss.append(curr_loss) return command
[docs] def resetState(self): """ Reset callback """ self._saved_valid_loss = []
[docs]class SaveCheckPoint(Callback): """ Callback used to save the model parameters during training. Parameters ---------- output_dir : str Output directory used to store model parameters. If it does not exist, it will be created automatically. key : str Key used to identify the model. each_epoch : int Specify the number of epochs to save for each model. verbose : bool, default=True Parameter that indicates whether to display messages on the screen when executing the early stop. """ DIRECTIVE = None def __init__( self, output_dir: str, key: str, each_epoch: int, verbose: bool = True ): super().__init__(name='SaveCheckPoint') self.output_dir = output_dir self.key = key self.each_epoch = each_epoch self.verbose = verbose
[docs] def evaluate(self, n_epoch: int, model: torch.nn.Module, **_): if n_epoch > 0: # create the output directory if it does not exist if not os.path.exists(self.output_dir): Path(self.output_dir).mkdir(parents=True) # save model if n_epoch % self.each_epoch == 0: out_file = saveTorchModel( base_path=self.output_dir, key='%s_checkpoint_%d' % (self.key, int(n_epoch)), model=model ) if self.verbose: self.message(out_file) return self.DIRECTIVE
[docs] @staticmethod def message(out_file: str): print('\nSaved model %s\n' % out_file)
[docs] def resetState(self): """ Reset callback """ pass