# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Trainer
=======
Module that contains the base class for a lambeq trainer.
Subclass :py:class:`Lambeq` to define a custom trainer.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from math import ceil
import os
import random
import socket
import sys
import time
from typing import Any, Callable, TYPE_CHECKING
from tqdm.auto import tqdm, trange
if TYPE_CHECKING:
from torch.utils.tensorboard.writer import SummaryWriter
from lambeq.backend.numerical_backend import backend
from lambeq.core.globals import VerbosityLevel
from lambeq.core.utils import normalise_duration
from lambeq.training.checkpoint import Checkpoint
from lambeq.training.dataset import Dataset
from lambeq.training.model import Model
from lambeq.typing import StrPathT
def _import_tensorboard_writer() -> None:
global SummaryWriter
try:
from torch.utils.tensorboard.writer import SummaryWriter
except ImportError as e: # pragma: no cover
raise ImportError(
'tensorboard not found. Please install it using '
'`pip install tensorboard`.'
) from e
EvalFuncT = Callable[[Any, Any], Any]
class EvalMode(Enum):
"""Evaluation mode."""
EPOCH = 'epoch'
STEP = 'step'
[docs]
class Trainer(ABC):
"""Base class for a lambeq trainer."""
[docs]
def __init__(self,
model: Model,
loss_function: Callable[..., Any],
epochs: int,
evaluate_functions: Mapping[str, EvalFuncT] | None = None,
evaluate_on_train: bool = True,
use_tensorboard: bool = False,
log_dir: StrPathT | None = None,
from_checkpoint: bool = False,
verbose: str = VerbosityLevel.TEXT.value,
seed: int | None = None) -> None:
"""Initialise a lambeq trainer.
Parameters
----------
model : :py:class:`.Model`
A lambeq Model.
loss_function : callable
A loss function to compare the prediction to the true label.
epochs : int
Number of training epochs.
evaluate_functions : mapping of str to callable, optional
Mapping of evaluation metric functions from their names.
evaluate_on_train : bool, default: True
Evaluate the metrics on the train dataset.
use_tensorboard : bool, default: False
Use Tensorboard for visualisation of the training logs.
log_dir : str or PathLike, optional
Location of model checkpoints (and tensorboard log).
Default is `runs/**CURRENT_DATETIME_HOSTNAME**`.
from_checkpoint : bool, default: False
Starts training from the checkpoint, saved in the log_dir.
verbose : str, default: 'text',
See :py:class:`VerbosityLevel` for options.
seed : int, optional
Random seed.
"""
if log_dir is None:
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(
'runs', current_time + '_' + socket.gethostname())
self.log_dir = log_dir
os.makedirs(self.log_dir, exist_ok=True)
self.backend = 'numpy'
self.model = model
self.loss_function = loss_function
self.epochs = epochs
self.evaluate_functions = evaluate_functions
self.evaluate_on_train = evaluate_on_train
self.use_tensorboard = use_tensorboard
self.from_checkpoint = from_checkpoint
self.verbose = verbose
self.seed = seed
self.train_costs: list[float] = []
self.train_durations: list[float] = []
self.train_epoch_costs: list[float] = []
self.train_epoch_durations: list[float] = []
self.train_eval_results: dict[str, list[Any]] = {}
self._train_eval_running: dict[str, list[tuple[int, Any]]] = {}
self.val_costs: list[float] = []
self.val_durations: list[float] = []
self.val_eval_results: dict[str, list[Any]] = {}
self._val_eval_running: dict[str, list[tuple[int, Any]]] = {}
if self.evaluate_functions is not None:
for name in self.evaluate_functions:
self.val_eval_results[name] = []
self._val_eval_running[name] = []
self.train_eval_results[name] = []
self._train_eval_running[name] = []
if not VerbosityLevel.has_value(self.verbose):
raise ValueError(f'`{self.verbose} flag is not supported by '
'this trainer.')
if self.seed is not None:
random.seed(self.seed)
if self.use_tensorboard:
_import_tensorboard_writer()
self.writer = SummaryWriter(log_dir=self.log_dir)
# load checkpoint
self.start_epoch = 1
self.start_step = 0
if self.from_checkpoint:
self.checkpoint = self.load_training_checkpoint(self.log_dir)
else:
self.model.initialise_weights()
def _to_tensorboard(self, *args: Any) -> None:
"""Write to tensorboard if `use_tensorboard` is set to `True`."""
if self.use_tensorboard:
self.writer.add_scalar(*args)
def _generate_stat_report(self,
train_loss: float | None = None,
val_loss: float | None = None,
train_duration: float | None = None,
val_duration: float | None = None,
train_duration_mean: float | None = None,
val_duration_mean: float | None = None,
eval_mode: str = EvalMode.EPOCH.value,
full_timing_report: bool = False) -> str:
"""Generate the text to display with the progress bar.
Parameters
----------
train_loss : float, optional
Current training loss.
val_loss : float, optional
Current validation loss.
train_duration: float, optional
Accumulated training time for the logging interval.
val_duration: float, optional
Accumulated validation time for the logging interval.
train_duration_mean: float, optional
Mean training time per epoch/step for the logging interval.
val_duration_mean: float, optional
Mean validation time per evaluation for the logging interval.
eval_mode: :py:class:`EvalMode`, default: 'epoch'
The evaluation mode passed to the :py:meth:`.fit` method.
Returns
-------
str
Formatted text to be displayed
"""
report = []
for name, value in [('train/loss', train_loss),
('valid/loss', val_loss),]:
str_value = f'{value:.4f}' if value is not None else '-----'
report.append(f'{name}: {str_value}')
for name, value in [('train/time', train_duration),
('valid/time', val_duration)]:
str_value = (normalise_duration(value)
if value is not None else '-----')
report.append(f'{name}: {str_value}')
if full_timing_report:
# Mean durations are optional - they're mostly important
# when verbose='text'
for name, value in [(f'train/time_per_{eval_mode}',
train_duration_mean),
('valid/time_per_eval', val_duration_mean)]:
if value is not None:
str_value = normalise_duration(value)
report.append(f'{name}: {str_value}')
if self.evaluate_on_train and self.evaluate_functions is not None:
for name in self.train_eval_results:
str_value = (f'{self.train_eval_results[name][-1]:.4f}'
if self.train_eval_results[name] else '-----')
report.append(f'train/{name}: {str_value}')
if self.evaluate_functions is not None:
for name in self.val_eval_results:
str_value = (f'{self.val_eval_results[name][-1]:.4f}'
if self.val_eval_results[name] else '-----')
report.append(f'valid/{name}: {str_value}')
return ' '.join(report)
[docs]
def load_training_checkpoint(self, log_dir: StrPathT) -> Checkpoint:
"""Load model from a checkpoint.
Parameters
----------
log_dir : str or PathLike
The path to the `model.lt` checkpoint file.
Returns
-------
py:class:`.Checkpoint`
Checkpoint containing the model weights, symbols and the
training history.
Raises
------
FileNotFoundError
If the file does not exist.
"""
if self.verbose == VerbosityLevel.TEXT.value:
print('Restore last checkpoint...', file=sys.stderr)
checkpoint_path = os.path.join(log_dir, 'model.lt')
checkpoint = Checkpoint.from_file(checkpoint_path)
# load model from checkpoint
self.model._load_checkpoint(checkpoint)
# load the training history
self.train_costs = checkpoint['train_costs']
self.train_durations = checkpoint['train_durations']
self.train_epoch_costs = checkpoint['train_epoch_costs']
self.train_epoch_durations = checkpoint['train_epoch_durations']
self.train_eval_results = checkpoint['train_eval_results']
self.val_costs = checkpoint['val_costs']
self.val_durations = checkpoint['val_durations']
self.val_eval_results = checkpoint['val_eval_results']
self.start_epoch = checkpoint['epoch'] + 1
self.start_step = checkpoint['step']
if self.seed is not None:
random.setstate(checkpoint['random_state'])
if self.verbose == VerbosityLevel.TEXT.value:
print('Checkpoint restored successfully!', # pragma: no cover
file=sys.stderr)
return checkpoint
[docs]
def save_checkpoint(self,
save_dict: Mapping[str, Any],
log_dir: StrPathT,
prefix: str = '') -> None:
"""Save checkpoint.
Parameters
----------
save_dict : mapping of str to any
Mapping containing the checkpoint information.
log_dir : str or PathLike
The path where to store the `model.lt` checkpoint file.
prefix : str, default: ''
Prefix for the checkpoint file name.
"""
checkpoint = self.model._make_checkpoint()
checkpoint.add_many(save_dict)
self._add_extra_checkpoint_info(checkpoint)
checkpoint.to_file(os.path.join(log_dir, prefix + 'model.lt'))
@abstractmethod
def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None:
"""Add any additional information to the training checkpoint.
These might include model-specific information like the random
state of the backend or the state of the optimizer.
Use `checkpoint.add_many()` to add multiple items.
Parameters
----------
checkpoint : :py:class:`.Checkpoint`
The checkpoint to add information to.
"""
@abstractmethod
def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None:
"""Load additional checkpoint information.
This includes data previously added by
`_add_extra_checkpoint_info()`.
Parameters
----------
checkpoint : mapping of str to any
Mapping containing the checkpoint information.
"""
[docs]
@abstractmethod
def training_step(self,
batch: tuple[list[Any], Any]) -> tuple[Any, float]:
"""Perform a training step.
Parameters
----------
batch : tuple of list and any
Current batch.
Returns
-------
Tuple of any and float
The model predictions and the calculated loss.
"""
[docs]
@abstractmethod
def validation_step(
self, batch: tuple[list[Any], Any]) -> tuple[Any, float]:
"""Perform a validation step.
Parameters
----------
batch : tuple of list and any
Current batch.
Returns
-------
Tuple of any and float
The model predictions and the calculated loss.
"""
def _get_weighted_mean(self,
metric_running: list[tuple[int, Any]]):
"""Calculate weighted mean of metric from the running results."""
total, batches = 0.0, 0
for (batch_size, metric) in metric_running:
total += batch_size * metric
batches += batch_size
return total / batches
def _step_and_eval(self,
batch: tuple[list[Any], Any],
step_func: Callable,
losses: list[tuple[int, Any]],
eval_results: dict[str, list[Any]],
step_durations: list[Any],
evaluate: bool = True) -> Any:
"""Perform a forward step and evaluate the metrics."""
step_start = time.time()
batch_size = len(batch[0])
y_hat, loss = step_func(batch)
losses.append((batch_size, loss))
if self.evaluate_functions is not None and evaluate:
for metr, func in self.evaluate_functions.items():
res = func(y_hat, batch[1])
eval_results[metr].append((batch_size, res))
step_end = time.time()
step_duration = step_end - step_start
step_durations.append(step_duration)
return loss
def _summarize_metric(self,
eval_results: dict[str, list[tuple[int, Any]]],
results: dict[str, list[Any]],
interval: int,
status_bar: tqdm,
mode: str,
full_timing_report: bool = False) -> None:
"""Calculate the metric results and write them to tensorboard and
command-line."""
for name in eval_results:
results[name].append(self._get_weighted_mean(eval_results[name]))
eval_results[name] = [] # reset
self._to_tensorboard(f'{mode}/{name}', results[name][-1], interval)
status_bar.set_description(
self._generate_stat_report(
train_loss=(self.train_costs[-1] if self.train_costs
else None),
val_loss=self.val_costs[-1] if self.val_costs else None,
train_duration=(self.train_durations[-1] if
self.train_durations else None),
val_duration=(self.val_durations[-1] if self.val_durations
else None),
full_timing_report=full_timing_report,
)
)
def _check_early_stopping(self,
early_stopping_criterion: str | None = None,
early_stopping_interval: int | None = None,
minimize_criterion: bool = True) -> bool:
"""Determine if training should be stopped based on the specified
early stopping configuration.
Parameters
----------
early_stopping_criterion : str, optional
If specified, the value of this on `val_dataset` (if provided)
will be used as the stopping criterion instead of
the (default) validation loss.
early_stopping_interval : int, optional
If specified, training is stopped if the validation loss does
not improve for `early_stopping_interval` validation cycles.
minimize_criterion: bool, default: True
Flag indicating if we should minimize or maximize the early
stopping criterion.
Returns
-------
Boolean
Flag if early stopping should be performed.
"""
factor = 1 if minimize_criterion else -1
early_stopping = False
criterion_vals = self.val_costs
if early_stopping_criterion is not None:
criterion_vals = self.val_eval_results[
early_stopping_criterion
]
if (early_stopping_interval is not None
and len(criterion_vals) > early_stopping_interval):
reference = factor * criterion_vals[-early_stopping_interval - 1]
latter_vals = [
factor * val for val in
criterion_vals[-early_stopping_interval:]
]
early_stopping = reference < min(latter_vals)
return early_stopping
[docs]
def fit(self,
train_dataset: Dataset,
val_dataset: Dataset | None = None,
log_interval: int = 1,
eval_interval: int = 1,
eval_mode: str = EvalMode.EPOCH.value,
early_stopping_criterion: str | None = None,
early_stopping_interval: int | None = None,
minimize_criterion: bool = True,
full_timing_report: bool = False) -> None:
"""Fit the model on the training data and, optionally,
evaluate it on the validation data.
Parameters
----------
train_dataset : :py:class:`Dataset`
Dataset used for training.
val_dataset : :py:class:`Dataset`, optional
Validation dataset.
log_interval : int, default: 1
Sets the intervals at which the training statistics are
printed if `verbose = 'text'` (otherwise ignored). If `None`,
the statistics are printed at the end of each epoch.
eval_interval : int, default: 1
Sets the number of epochs at which the metrics are
evaluated on the validation dataset. If `None`, the validation
is performed at the end of each epoch.
eval_mode : :py:class:`EvalMode`, default: 'epoch'
Sets the evaluation mode. If `'epoch'`, the metrics are
evaluated after multiples of `eval_interval` epochs. If
`'step'`, the metrics are evaluated after multiples of
`eval_interval` steps. Ignored if `val_dataset` is
`None`.
early_stopping_criterion : str, optional
If specified, the value of this on `val_dataset` (if provided)
will be used as the stopping criterion instead of
the (default) validation loss.
early_stopping_interval : int, optional
If specified, training is stopped if the validation loss does
not improve for `early_stopping_interval` validation cycles.
minimize_criterion: bool, default: True
Flag indicating if we should minimize or maximize the early
stopping criterion.
full_timing_report: bool, default: False
Flag for including mean timing statistics in the logs.
Raises
------
ValueError
If `eval_mode` is not a valid :py:class:`EvalMode`.
"""
if self.from_checkpoint:
self._load_extra_checkpoint_info(self.checkpoint)
# calculate evaluation step
if eval_mode == EvalMode.EPOCH.value:
evaluation_step = eval_interval * train_dataset.batches_per_epoch
elif eval_mode == EvalMode.STEP.value:
evaluation_step = eval_interval
else:
raise ValueError(f'Invalid evaluation mode: {eval_mode}.')
# check that early stopping critera is in available list
if (early_stopping_criterion is not None
and self.evaluate_functions is not None
and early_stopping_criterion not in self.evaluate_functions):
raise ValueError('Invalid early stopping criterion: '
f'{early_stopping_criterion}. '
'Should be one of '
f'{self.evaluate_functions.keys()}')
# Used for early stopping
factor = 1 if minimize_criterion else -1
best_epoch = 0
best_step = 0
logging_step = log_interval * evaluation_step
total_steps = self.epochs * train_dataset.batches_per_epoch
# initialise progress bar
step = self.start_step
if val_dataset is not None:
batches_per_validation = ceil(
len(val_dataset) / val_dataset.batch_size)
disable_tqdm = self.verbose != VerbosityLevel.PROGRESS.value
status_bar = tqdm(total=float('inf'),
bar_format='{desc}',
desc=self._generate_stat_report(),
disable=disable_tqdm,
leave=True,
position=0)
# start training loop
with backend(self.backend):
early_stopping = False
best_val_criterion = float('inf')
for epoch in trange(self.start_epoch,
self.epochs + 1,
desc='Epoch',
disable=disable_tqdm,
leave=False,
position=1):
epoch_start = time.time()
train_losses: list[tuple[int, Any]] = []
for batch in tqdm(train_dataset,
desc='Batch',
total=train_dataset.batches_per_epoch,
disable=disable_tqdm,
leave=False,
position=2):
step += 1
t_loss = self._step_and_eval(
batch,
self.training_step,
train_losses,
self._train_eval_running,
self.train_durations,
self.evaluate_on_train
)
self._to_tensorboard('train/step_loss', t_loss, step)
status_bar.set_description(
self._generate_stat_report(
train_loss=t_loss,
val_loss=(self.val_costs[-1] if self.val_costs
else None),
train_duration=self.train_durations[-1],
val_duration=(self.val_durations[-1] if
self.val_durations else None),
full_timing_report=full_timing_report,
)
)
self._to_tensorboard('train/time',
self.train_durations[-1],
step)
# calculate metrics on train dataset
if self.evaluate_on_train and step % evaluation_step == 0:
self._summarize_metric(
self._train_eval_running,
self.train_eval_results,
epoch,
status_bar,
mode='train',
full_timing_report=full_timing_report,
)
# evaluate metrics on validation data
if val_dataset is not None and step % evaluation_step == 0:
val_loss: list[tuple[int, Any]] = []
for v_batch in tqdm(val_dataset,
desc='Validation batch',
total=batches_per_validation,
disable=disable_tqdm,
leave=False,
position=2):
v_loss = self._step_and_eval(
v_batch,
self.validation_step,
val_loss,
self._val_eval_running,
self.val_durations,
)
status_bar.set_description(
self._generate_stat_report(
train_loss=t_loss,
val_loss=v_loss,
train_duration=self.train_durations[-1],
val_duration=self.val_durations[-1],
full_timing_report=full_timing_report,
)
)
self.val_costs.append(
self._get_weighted_mean(val_loss)
)
status_bar.set_description(
self._generate_stat_report(
train_loss=t_loss,
val_loss=self.val_costs[-1],
train_duration=self.train_durations[-1],
val_duration=self.val_durations[-1],
full_timing_report=full_timing_report,
)
)
self._to_tensorboard('val/loss',
self.val_costs[-1],
epoch)
self._to_tensorboard('val/time',
self.val_durations[-1],
epoch)
self._summarize_metric(
self._val_eval_running,
self.val_eval_results,
epoch,
status_bar,
mode='val',
full_timing_report=full_timing_report,
)
# save best model
criterion_vals = self.val_costs
if early_stopping_criterion is not None:
criterion_vals = self.val_eval_results[
early_stopping_criterion
]
criterion_val = factor * criterion_vals[-1]
if criterion_val < best_val_criterion:
best_val_criterion = criterion_val
best_epoch = epoch
best_step = step
self.save_checkpoint(
{'epoch': epoch,
'train_costs': self.train_costs,
'train_durations': self.train_durations,
'train_epoch_costs': self.train_epoch_costs,
'train_eval_results': self.train_eval_results,
'val_costs': self.val_costs,
'val_durations': self.val_durations,
'train_epoch_durations': self.train_epoch_durations, # noqa: E501
'val_eval_results': self.val_eval_results,
'random_state': random.getstate(),
'step': step},
self.log_dir,
prefix='best_'
)
# print training stats if verbose is set to 'text'
if (self.verbose
== VerbosityLevel.TEXT.value): # pragma: no cover
if step % logging_step == 0:
prefix = ''
if eval_mode == EvalMode.EPOCH.value:
space = (len(str(self.epochs))
- len(str(epoch)) + 2) * ' '
prefix += f'Epoch {epoch}:' + space
if eval_mode == EvalMode.STEP.value:
step_space = (len(str(total_steps))
- len(str(step)) + 2) * ' '
prefix += f'Step {step}:' + step_space
train_duration = (
sum(self.train_durations[-logging_step:]) if
self.train_durations else None
)
train_duration_mean = (
train_duration
/ (log_interval * eval_interval)
) if train_duration else None
val_duration = (
sum(self.val_durations[-log_interval:]) if
self.val_durations else None
)
val_duration_mean = (
val_duration / log_interval
) if val_duration else None
print(
prefix + self._generate_stat_report(
train_loss=(self.train_costs[-1]
if self.train_costs else None),
val_loss=(self.val_costs[-1]
if self.val_costs else None),
train_duration=train_duration,
val_duration=val_duration,
train_duration_mean=train_duration_mean,
val_duration_mean=val_duration_mean,
eval_mode=eval_mode,
full_timing_report=full_timing_report,
),
file=sys.stderr
)
# check for early stopping
early_stopping = self._check_early_stopping(
early_stopping_criterion,
early_stopping_interval,
minimize_criterion
)
if early_stopping:
break # inner epoch loop
epoch_end = time.time()
epoch_duration = epoch_end - epoch_start
self.train_epoch_durations.append(epoch_duration)
# calculate epoch loss
self.train_epoch_costs.append(
self._get_weighted_mean(train_losses))
self._to_tensorboard('train/epoch_loss',
self.train_epoch_costs[-1],
epoch)
self._to_tensorboard('train/time_per_epoch',
self.train_epoch_durations[-1],
epoch)
# save training stats checkpoint
self.save_checkpoint(
{'epoch': epoch,
'train_costs': self.train_costs,
'train_durations': self.train_durations,
'train_epoch_costs': self.train_epoch_costs,
'train_eval_results': self.train_eval_results,
'train_epoch_durations': self.train_epoch_durations,
'val_costs': self.val_costs,
'val_durations': self.val_durations,
'val_eval_results': self.val_eval_results,
'random_state': random.getstate(),
'step': step},
self.log_dir)
if early_stopping:
if self.verbose == VerbosityLevel.TEXT.value:
print('Early stopping!\n'
f'Best model (epoch={best_epoch}, '
f'step={best_step}) saved to\n'
f'{os.path.join(self.log_dir, "best_model.lt")}',
file=sys.stderr)
break # break outer epoch loop
status_bar.close()
# Summarize timing statistics
total_training_time = sum(self.train_durations)
training_time_per_epoch = normalise_duration(
total_training_time / len(self.train_epoch_durations))
training_time_per_step = normalise_duration(
total_training_time / len(self.train_durations))
total_training_time_s = normalise_duration(
total_training_time)
total_validation_time = None
validation_time_per_eval = None
if len(self.val_durations):
total_validation_time = sum(self.val_durations)
validation_time_per_eval = normalise_duration(
total_validation_time / len(self.val_durations))
total_validation_time_s = normalise_duration(
total_validation_time)
timing_summary_desc = (
f'train/time: {total_training_time_s}'
f' train/time_per_epoch: {training_time_per_epoch}'
f' train/time_per_step: {training_time_per_step}'
f' valid/time: {total_validation_time_s}'
f' valid/time_per_eval: {validation_time_per_eval}'
)
if self.verbose == VerbosityLevel.TEXT.value:
print('\nTraining completed!', file=sys.stderr)
# Display timing summary regardless of verbosity
print(timing_summary_desc, file=sys.stderr)