Source code for

# Copyright 2021-2023 Cambridge Quantum Computing Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.


Module that contains the base class for a lambeq trainer.

Subclass :py:class:`Lambeq` to define a custom trainer.

from __future__ import annotations

from abc import ABC, abstractmethod
from import Mapping
from datetime import datetime
from math import ceil
import os
import random
import socket
import sys
from typing import Any, Callable, TYPE_CHECKING

from discopy import Tensor
from import tqdm, trange

    from torch.utils.tensorboard.writer import SummaryWriter

from lambeq.core.globals import VerbosityLevel
from import Checkpoint
from import Dataset
from import Model
from lambeq.typing import StrPathT

def _import_tensorboard_writer() -> None:
    global SummaryWriter
        from torch.utils.tensorboard.writer import SummaryWriter
    except ImportError as e:  # pragma: no cover
        raise ImportError(
            'tensorboard not found. Please install it using '
            '`pip install tensorboard`.'
        ) from e

EvalFuncT = Callable[[Any, Any], Any]

[docs]class Trainer(ABC): """Base class for a lambeq trainer."""
[docs] def __init__(self, model: Model, loss_function: Callable[..., Any], epochs: int, evaluate_functions: Mapping[str, EvalFuncT] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: StrPathT | None = None, from_checkpoint: bool = False, verbose: str = VerbosityLevel.TEXT.value, seed: int | None = None) -> None: """Initialise a lambeq trainer. Parameters ---------- model : :py:class:`.Model` A lambeq Model. loss_function : callable A loss function to compare the prediction to the true label. epochs : int Number of training epochs. evaluate_functions : mapping of str to callable, optional Mapping of evaluation metric functions from their names. evaluate_on_train : bool, default: True Evaluate the metrics on the train dataset. use_tensorboard : bool, default: False Use Tensorboard for visualisation of the training logs. log_dir : str or PathLike, optional Location of model checkpoints (and tensorboard log). Default is `runs/**CURRENT_DATETIME_HOSTNAME**`. from_checkpoint : bool, default: False Starts training from the checkpoint, saved in the log_dir. verbose : str, default: 'text', See :py:class:`VerbosityLevel` for options. seed : int, optional Random seed. """ if log_dir is None: current_time ='%b%d_%H-%M-%S') log_dir = os.path.join( 'runs', current_time + '_' + socket.gethostname()) self.log_dir = log_dir os.makedirs(self.log_dir, exist_ok=True) self.backend = 'numpy' self.model = model self.loss_function = loss_function self.epochs = epochs self.evaluate_functions = evaluate_functions self.evaluate_on_train = evaluate_on_train self.use_tensorboard = use_tensorboard self.from_checkpoint = from_checkpoint self.verbose = verbose self.seed = seed self.train_costs: list[float] = [] self.train_epoch_costs: list[float] = [] self.train_results: dict[str, list[Any]] = {} self._train_results_epoch: dict[str, list[Any]] = {} self.val_costs: list[float] = [] self.val_results: dict[str, list[Any]] = {} self._val_results_epoch: dict[str, list[Any]] = {} if self.evaluate_functions is not None: for name in self.evaluate_functions: self.val_results[name] = [] self._val_results_epoch[name] = [] self.train_results[name] = [] self._train_results_epoch[name] = [] if not VerbosityLevel.has_value(self.verbose): raise ValueError(f'`{self.verbose} flag is not supported by ' 'this trainer.') if self.seed is not None: random.seed(self.seed) if self.use_tensorboard: _import_tensorboard_writer() self.writer = SummaryWriter(log_dir=self.log_dir) # load checkpoint self.start_epoch = 0 self.start_step = 0 if self.from_checkpoint: self.checkpoint = self.load_training_checkpoint(self.log_dir) else: self.model.initialise_weights()
def _generate_stat_report(self, train_loss: float | None = None, val_loss: float | None = None) -> str: """Generate the text to display with the progress bar. Parameters ---------- train_loss : float, optional Current training loss. val_loss : float, optional Current validation loss. Returns ------- str Formatted text to be displayed """ report = [] for name, value in [('train/loss', train_loss), ('valid/loss', val_loss)]: str_value = f'{value:.4f}' if value is not None else '-----' report.append(f'{name}: {str_value}') if self.evaluate_on_train and self.evaluate_functions is not None: for name in self.train_results: str_value = (f'{self.train_results[name][-1]:.4f}' if self.train_results[name] else '-----') report.append(f'train/{name}: {str_value}') if self.evaluate_functions is not None: for name in self.val_results: str_value = (f'{self.val_results[name][-1]:.4f}' if self.val_results[name] else '-----') report.append(f'valid/{name}: {str_value}') return ' '.join(report)
[docs] def load_training_checkpoint(self, log_dir: StrPathT) -> Checkpoint: """Load model from a checkpoint. Parameters ---------- log_dir : str or PathLike The path to the `` checkpoint file. Returns ------- py:class:`.Checkpoint` Checkpoint containing the model weights, symbols and the training history. Raises ------ FileNotFoundError If the file does not exist. """ if self.verbose == VerbosityLevel.TEXT.value: print('Restore last checkpoint...', file=sys.stderr) checkpoint_path = os.path.join(log_dir, '') checkpoint = Checkpoint.from_file(checkpoint_path) # load model from checkpoint self.model._load_checkpoint(checkpoint) # load the training history self.train_costs = checkpoint['train_costs'] self.train_epoch_costs = checkpoint['train_epoch_costs'] self.train_results = checkpoint['train_results'] self.val_costs = checkpoint['val_costs'] self.val_results = checkpoint['val_results'] self.start_epoch = checkpoint['epoch'] self.start_step = checkpoint['step'] if self.seed is not None: random.setstate(checkpoint['random_state']) if self.verbose == VerbosityLevel.TEXT.value: print('Checkpoint restored successfully!', # pragma: no cover file=sys.stderr) return checkpoint
[docs] def save_checkpoint(self, save_dict: Mapping[str, Any], log_dir: StrPathT) -> None: """Save checkpoint. Parameters ---------- save_dict : mapping of str to any Mapping containing the checkpoint information. log_dir : str or PathLike The path where to store the `` checkpoint file. """ checkpoint = self.model._make_checkpoint() checkpoint.add_many(save_dict) self._add_extra_checkpoint_info(checkpoint) checkpoint.to_file(os.path.join(log_dir, ''))
@abstractmethod def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """Add any additional information to the training checkpoint. These might include model-specific information like the random state of the backend or the state of the optimizer. Use `checkpoint.add_many()` to add multiple items. Parameters ---------- checkpoint : :py:class:`.Checkpoint` The checkpoint to add information to. """ @abstractmethod def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """Load additional checkpoint information. This includes data previously added by `_add_extra_checkpoint_info()`. Parameters ---------- checkpoint : mapping of str to any Mapping containing the checkpoint information. """
[docs] @abstractmethod def training_step(self, batch: tuple[list[Any], Any]) -> tuple[Any, float]: """Perform a training step. Parameters ---------- batch : tuple of list and any Current batch. Returns ------- Tuple of any and float The model predictions and the calculated loss. """
[docs] @abstractmethod def validation_step( self, batch: tuple[list[Any], Any]) -> tuple[Any, float]: """Perform a validation step. Parameters ---------- batch : tuple of list and any Current batch. Returns ------- Tuple of any and float The model predictions and the calculated loss. """
[docs] def fit(self, train_dataset: Dataset, val_dataset: Dataset | None = None, evaluation_step: int = 1, logging_step: int = 1) -> None: """Fit the model on the training data and, optionally, evaluate it on the validation data. Parameters ---------- train_dataset : :py:class:`Dataset` Dataset used for training. val_dataset : :py:class:`Dataset`, optional Validation dataset. evaluation_step : int, default: 1 Sets the intervals at which the metrics are evaluated on the validation dataset. logging_step : int, default: 1 Sets the intervals at which the training statistics are printed if `verbose = 'text'` (otherwise ignored). """ if self.from_checkpoint: self._load_extra_checkpoint_info(self.checkpoint) def writer_helper(*args: Any) -> None: if self.use_tensorboard: self.writer.add_scalar(*args) # initialise progress bar step = self.start_step batches_per_epoch = ceil(len(train_dataset)/train_dataset.batch_size) status_bar = tqdm(total=float('inf'), bar_format='{desc}', desc=self._generate_stat_report(), disable=( self.verbose != VerbosityLevel.PROGRESS.value), leave=True, position=0) # start training loop for epoch in trange(self.start_epoch, self.epochs, desc='Epoch', disable=( self.verbose != VerbosityLevel.PROGRESS.value), leave=False, position=1): train_loss = 0.0 with Tensor.backend(self.backend): for batch in tqdm(train_dataset, desc='Batch', total=batches_per_epoch, disable=(self.verbose != VerbosityLevel.PROGRESS.value), leave=False, position=2): step += 1 x, y_label = batch y_hat, loss = self.training_step(batch) if (self.evaluate_on_train and self.evaluate_functions is not None): for metr, func in self.evaluate_functions.items(): res = func(y_hat, y_label) metric = self._train_results_epoch[metr] metric.append(len(x) * res) train_loss += len(batch[0]) * loss writer_helper('train/step_loss', loss, step) status_bar.set_description( self._generate_stat_report( train_loss=loss, val_loss=(self.val_costs[-1] if self.val_costs else None))) train_loss /= len(train_dataset) self.train_epoch_costs.append(train_loss) writer_helper('train/epoch_loss', train_loss, epoch + 1) # evaluate on train if (self.evaluate_on_train and self.evaluate_functions is not None): for name in self._train_results_epoch: self.train_results[name].append( sum(self._train_results_epoch[name])/len(train_dataset) ) self._train_results_epoch[name] = [] # reset writer_helper( f'train/{name}', self.train_results[name][-1], epoch+1) if self.verbose == VerbosityLevel.PROGRESS.value: status_bar.set_description( # pragma: no cover self._generate_stat_report( train_loss=train_loss, val_loss=(self.val_costs[-1] if self.val_costs else None))) # evaluate metrics on validation data if val_dataset is not None: if epoch % evaluation_step == 0: val_loss = 0.0 seen_so_far = 0 batches_per_validation = ceil(len(val_dataset) / val_dataset.batch_size) with Tensor.backend(self.backend): disable_tqdm = (self.verbose != VerbosityLevel.PROGRESS.value) for v_batch in tqdm(val_dataset, desc='Validation batch', total=batches_per_validation, disable=disable_tqdm, leave=False, position=2): x_val, y_label_val = v_batch y_hat_val, cur_loss = self.validation_step(v_batch) val_loss += cur_loss * len(x_val) seen_so_far += len(x_val) if self.evaluate_functions is not None: for metr, func in ( self.evaluate_functions.items()): res = func(y_hat_val, y_label_val) self._val_results_epoch[metr].append( len(x_val)*res) status_bar.set_description( self._generate_stat_report( train_loss=train_loss, val_loss=val_loss/seen_so_far)) val_loss /= len(val_dataset) self.val_costs.append(val_loss) status_bar.set_description( self._generate_stat_report( train_loss=train_loss, val_loss=val_loss)) writer_helper('val/loss', val_loss, epoch+1) if self.evaluate_functions is not None: for name in self._val_results_epoch: self.val_results[name].append( sum(self._val_results_epoch[name]) / len(val_dataset)) self._val_results_epoch[name] = [] # reset writer_helper( f'val/{name}', self.val_results[name][-1], epoch + 1) status_bar.set_description( self._generate_stat_report( train_loss=train_loss, val_loss=val_loss)) # save training stats checkpoint trainer_stats = {'epoch': epoch+1, 'train_costs': self.train_costs, 'train_epoch_costs': self.train_epoch_costs, 'train_results': self.train_results, 'val_costs': self.val_costs, 'val_results': self.val_results, 'random_state': random.getstate(), 'step': step} self.save_checkpoint(trainer_stats, self.log_dir) if self.verbose == VerbosityLevel.TEXT.value: # pragma: no cover if epoch == 0 or (epoch+1) % logging_step == 0: space = (len(str(self.epochs))-len(str(epoch+1)) + 2) * ' ' prefix = f'Epoch {epoch+1}:' + space print(prefix + self._generate_stat_report( train_loss=train_loss, val_loss=(self.val_costs[-1] if self.val_costs else None)), file=sys.stderr) status_bar.close() if self.verbose == VerbosityLevel.TEXT.value: print('\nTraining completed!', file=sys.stderr) # pragma: no cover